use std::collections::HashMap;
use std::fmt;
use std::path::Path;
use std::process::ExitStatus;
use std::process::Stdio;
use std::str::FromStr;
use std::sync::Arc;
use std::sync::Mutex;
use std::time::Duration;
use anyhow::Context as _;
use anyhow::Result;
use anyhow::anyhow;
use anyhow::bail;
use bytesize::ByteSize;
use crankshaft::engine::service::name::GeneratorIterator;
use crankshaft::engine::service::name::UniqueAlphanumeric;
use crankshaft::events::Event as CrankshaftEvent;
use crankshaft::events::send_event;
use futures::FutureExt;
use futures::future::BoxFuture;
use itertools::Itertools;
use nonempty::NonEmpty;
use tokio::fs;
use tokio::fs::File;
use tokio::process::Command;
use tokio::select;
use tokio::sync::Semaphore;
use tokio::sync::oneshot;
use tokio::time::MissedTickBehavior;
use tokio_util::sync::CancellationToken;
use tracing::debug;
use tracing::error;
use tracing::trace;
use tracing::warn;
use super::ApptainerRuntime;
use super::TaskExecutionBackend;
use crate::CancellationContext;
use crate::EvaluationPath;
use crate::Events;
use crate::ONE_GIBIBYTE;
use crate::PrimitiveValue;
use crate::TaskInputs;
use crate::Value;
use crate::backend::ExecuteTaskRequest;
use crate::backend::INITIAL_EXPECTED_NAMES;
use crate::backend::TaskExecutionConstraints;
use crate::backend::TaskExecutionResult;
use crate::config::Config;
use crate::config::SlurmApptainerBackendConfig;
use crate::config::TaskResourceLimitBehavior;
use crate::http::Transferer;
use crate::v1::requirements;
const APPTAINER_COMMAND_FILE_NAME: &str = "apptainer_command";
const DEFAULT_MONITOR_INTERVAL: u64 = 30;
const DEFAULT_MAX_CONCURRENCY: u32 = 10;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
enum JobState {
BootFail,
Canceled,
Completed,
Deadline,
Failed,
NodeFail,
OutOfMemory,
Pending,
Preempted,
Running,
Requeued,
Resizing,
Revoked,
Suspended,
Timeout,
}
impl JobState {
fn terminated(&self) -> bool {
matches!(
self,
Self::BootFail
| Self::Canceled
| Self::Completed
| Self::Deadline
| Self::Failed
| Self::NodeFail
| Self::OutOfMemory
| Self::Preempted
| Self::Timeout
)
}
}
impl fmt::Display for JobState {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::BootFail => write!(f, "node boot failure"),
Self::Canceled => write!(f, "canceled"),
Self::Completed => write!(f, "completed"),
Self::Deadline => write!(f, "deadline reached"),
Self::Failed => write!(f, "failed"),
Self::NodeFail => write!(f, "node failure"),
Self::OutOfMemory => write!(f, "out of memory"),
Self::Pending => write!(f, "pending"),
Self::Preempted => write!(f, "preempted"),
Self::Running => write!(f, "running"),
Self::Requeued => write!(f, "requeued"),
Self::Resizing => write!(f, "resizing"),
Self::Revoked => write!(f, "revoked"),
Self::Suspended => write!(f, "suspended"),
Self::Timeout => write!(f, "timeout"),
}
}
}
impl FromStr for JobState {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self> {
for (prefix, state) in [
("BOOT_FAIL", Self::BootFail),
("CANCELLED", Self::Canceled),
("COMPLETED", Self::Completed),
("DEADLINE", Self::Deadline),
("FAILED", Self::Failed),
("NODE_FAIL", Self::NodeFail),
("OUT_OF_MEMORY", Self::OutOfMemory),
("PENDING", Self::Pending),
("PREEMPTED", Self::Preempted),
("RUNNING", Self::Running),
("REQUEUED", Self::Requeued),
("RESIZING", Self::Resizing),
("REVOKED", Self::Revoked),
("SUSPENDED", Self::Suspended),
("TIMEOUT", Self::Timeout),
] {
if s.starts_with(prefix) {
return Ok(state);
}
}
bail!("unknown Slurm job state `{s}`");
}
}
#[derive(Debug, Clone, Copy)]
struct JobExitCode {
exit_code: u8,
signal: u8,
}
impl JobExitCode {
fn code(&self) -> u8 {
if self.signal > 0 {
128 + (self.signal & 0x7F)
} else {
self.exit_code
}
}
fn into_exit_status(self) -> ExitStatus {
#[cfg(unix)]
use std::os::unix::process::ExitStatusExt as _;
#[cfg(windows)]
use std::os::windows::process::ExitStatusExt as _;
#[cfg(unix)]
let status = if self.signal > 0 {
ExitStatus::from_raw((self.signal as i32) & 0x7F)
} else {
ExitStatus::from_raw((self.exit_code as i32) << 8)
};
#[cfg(windows)]
let status = ExitStatus::from_raw(self.exit_code as u32);
status
}
}
impl FromStr for JobExitCode {
type Err = anyhow::Error;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
let (exit_code, signal) = s
.split_once(':')
.with_context(|| format!("invalid Slurm exit code `{s}`"))?;
Ok(Self {
exit_code: exit_code
.parse()
.with_context(|| format!("invalid exit code `{exit_code}`"))?,
signal: signal
.parse()
.with_context(|| format!("invalid signal number `{signal}`"))?,
})
}
}
impl fmt::Display for JobExitCode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.signal > 0 {
write!(f, "signal number `{signal}`", signal = self.signal & 0x7F)
} else {
write!(f, "exit code `{code}`", code = self.exit_code)
}
}
}
#[derive(Debug)]
struct JobRecord<'a> {
job_id: u64,
state: JobState,
exit_code: Option<JobExitCode>,
total_cpu: &'a str,
system_cpu: &'a str,
user_cpu: &'a str,
max_vm_size: &'a str,
avg_vm_size: &'a str,
}
impl<'a> JobRecord<'a> {
fn new(job_id: u64, mut fields: impl Iterator<Item = &'a str>) -> Result<Self> {
let state: JobState = fields
.next()
.context("`sacct` output is missing job state")?
.parse()?;
let exit_code = fields
.next()
.context("`sacct` output is missing exit code")?;
let exit_code = if state.terminated() {
Some(exit_code.parse()?)
} else {
None
};
let total_cpu = fields.next().context("`sacct` output missing total CPU")?;
let system_cpu = fields.next().context("`sacct` output missing system CPU")?;
let user_cpu = fields.next().context("`sacct` output missing user CPU")?;
let max_vm_size = fields
.next()
.context("`sacct` output missing maximum virtual memory size")?;
let avg_vm_size = fields
.next()
.context("`sacct` output missing average virtual memory size")?;
Ok(Self {
job_id,
state,
exit_code,
total_cpu,
system_cpu,
user_cpu,
max_vm_size,
avg_vm_size,
})
}
fn fields() -> &'static str {
"JobID,State,ExitCode,TotalCPU,SystemCPU,UserCPU,MaxVMSize,AveVMSize"
}
}
#[derive(Debug)]
struct Job {
crankshaft_id: u64,
state: JobState,
completed: oneshot::Sender<Result<JobExitCode>>,
}
#[derive(Debug)]
struct MonitorState {
names: GeneratorIterator<UniqueAlphanumeric>,
jobs: HashMap<u64, Job>,
}
impl MonitorState {
fn new() -> Self {
Self {
names: GeneratorIterator::new(
UniqueAlphanumeric::default_with_expected_generations(INITIAL_EXPECTED_NAMES),
INITIAL_EXPECTED_NAMES,
),
jobs: HashMap::new(),
}
}
fn add_job(
&mut self,
job_id: u64,
crankshaft_id: u64,
completed: oneshot::Sender<Result<JobExitCode>>,
) {
let prev = self.jobs.insert(
job_id,
Job {
crankshaft_id,
state: JobState::Pending,
completed,
},
);
if prev.is_some() {
warn!(
"encountered duplicate Slurm job id `{job_id}`: tasks may not be monitored \
correctly"
);
}
}
fn update_jobs(&mut self, output: &str, events: &Events) {
for line in output.lines() {
let mut fields = line.split('|');
let Some(job_id) = fields.next() else {
continue;
};
if job_id.contains('.') {
continue;
}
let Ok(job_id) = job_id.parse() else {
continue;
};
let record = match JobRecord::new(job_id, fields) {
Ok(record) => record,
Err(e) => {
let job = self.jobs.remove(&job_id).unwrap();
let _ = job.completed.send(Err(e));
continue;
}
};
let Some(job) = self.jobs.get_mut(&job_id) else {
continue;
};
if record.state != job.state {
if record.state == JobState::Running {
send_event!(
events.crankshaft(),
CrankshaftEvent::TaskStarted {
id: job.crankshaft_id
},
);
}
if record.state.terminated() {
if job.state != JobState::Running {
send_event!(
events.crankshaft(),
CrankshaftEvent::TaskStarted {
id: job.crankshaft_id
},
);
}
let exit_code = record
.exit_code
.expect("terminated job should have exit code");
debug!(
"Slurm job `{job_id}` has exited with {exit_code}: average virtual memory \
size `{avg_mem}`, maximum virtual memory size `{max_mem}`, total CPU \
used `{total_cpu}`, system CPU time `{system_cpu}`, user CPU time \
`{user_cpu}`",
job_id = record.job_id,
avg_mem = record.avg_vm_size,
max_mem = record.max_vm_size,
total_cpu = record.total_cpu,
system_cpu = record.system_cpu,
user_cpu = record.user_cpu,
);
let job = self.jobs.remove(&job_id).unwrap();
let _ = job.completed.send(Ok(exit_code));
continue;
} else {
debug!(
"Slurm job `{id}` is now in the `{state}` state",
id = record.job_id,
state = record.state
);
}
job.state = record.state;
}
}
}
}
#[derive(Debug)]
struct SubmittedJob {
id: u64,
task_name: String,
completed: oneshot::Receiver<Result<JobExitCode>>,
}
#[derive(Debug, Clone)]
struct Monitor {
state: Arc<Mutex<MonitorState>>,
_drop: Arc<oneshot::Sender<()>>,
}
impl Monitor {
fn new(interval: Duration, events: Events) -> Self {
let (tx, rx) = oneshot::channel();
let state = Arc::new(Mutex::new(MonitorState::new()));
tokio::spawn(Self::monitor(state.clone(), interval, events, rx));
Self {
state,
_drop: Arc::new(tx),
}
}
async fn submit_job(
&self,
config: &SlurmApptainerBackendConfig,
request: &ExecuteTaskRequest<'_>,
crankshaft_id: u64,
command_path: &Path,
) -> Result<SubmittedJob> {
let task_name = {
let mut state = self.state.lock().expect("failed to lock state");
let task_name = format!(
"{id}-{generated}",
id = request.id,
generated = state
.names
.next()
.expect("generator should never be exhausted")
);
task_name
};
let mut command = Command::new("sbatch");
if let Some(partition) =
config.slurm_partition_for_task(request.requirements, request.hints)
{
command.arg("--partition").arg(&partition.name);
}
if let Some(gpu_count) =
requirements::gpu(request.inputs, request.requirements, request.hints)
{
command.arg(format!("--gpus-per-task={gpu_count}"));
}
if let Some(args) = &config.extra_sbatch_args {
command.args(args);
}
let job_name = format!(
"{prefix}{sep}{task_name}",
prefix = config.job_name_prefix.as_deref().unwrap_or(""),
sep = if config.job_name_prefix.is_some() {
"-"
} else {
""
}
);
let slurm_stdout_path = request.attempt_dir.join("slurm.stdout");
let slurm_stderr_path = request.attempt_dir.join("slurm.stderr");
command
.arg("--job-name")
.arg(&job_name)
.arg("-o")
.arg(slurm_stdout_path)
.arg("-e")
.arg(slurm_stderr_path)
.arg("--ntasks=1")
.arg(format!(
"--cpus-per-task={}",
request.constraints.cpu.ceil() as u64
))
.arg(format!(
"--mem={}M",
(request.constraints.memory as f64 / bytesize::MIB as f64).ceil() as u64
))
.arg(command_path)
.stdin(Stdio::null())
.stdout(Stdio::piped())
.stderr(Stdio::piped());
trace!(?command, "spawning `sbatch` to queue task");
let child = command.spawn().context("failed to spawn `sbatch`")?;
let output = child
.wait_with_output()
.await
.context("failed to wait for `sbatch` to exit")?;
if !output.status.success() {
bail!(
"failed to submit Slurm job with `sbatch` ({status})\n{stderr}",
status = output.status,
stderr = str::from_utf8(&output.stderr)
.unwrap_or("<output not UTF-8>")
.trim()
);
}
let stdout =
str::from_utf8(&output.stdout).map_err(|_| anyhow!("`sbatch` output was not UTF-8"))?;
let mut job_id = None;
for line in stdout.lines() {
if let Some(id) = line.trim().strip_prefix("Submitted batch job ") {
job_id = Some(
id.parse()
.context("`sbatch` returned an invalid job identifier")?,
);
}
}
let job_id = job_id.context("`sbatch` did not output a job identifier")?;
debug!("task `{task_name}` was queued as Slurm job `{job_id}`");
let (tx, rx) = oneshot::channel();
let mut state = self.state.lock().expect("failed to lock state");
state.add_job(job_id, crankshaft_id, tx);
drop(state);
Ok(SubmittedJob {
id: job_id,
task_name,
completed: rx,
})
}
async fn monitor(
state: Arc<Mutex<MonitorState>>,
interval: Duration,
events: Events,
mut drop: oneshot::Receiver<()>,
) {
debug!(
"Slurm task monitor is starting with polling interval of {interval} seconds",
interval = interval.as_secs()
);
let mut timer = tokio::time::interval(interval);
timer.set_missed_tick_behavior(MissedTickBehavior::Delay);
loop {
select! {
_ = &mut drop => break,
_ = timer.tick() => {
let jobs = {
let state = state.lock().expect("failed to lock state");
if state.jobs.is_empty() {
continue;
}
state.jobs.keys().join(",")
};
match Self::read_jobs(&jobs).await.and_then(|output| String::from_utf8(output).context("`sacct` output was not UTF-8")) {
Ok(output) => {
let mut state = state.lock().expect("failed to lock state");
state.update_jobs(&output, &events);
}
Err(e) => {
error!("failed to read Slurm job state: {e:#}");
}
}
}
}
}
debug!("Slurm task monitor has shut down");
}
async fn read_jobs(jobs: &str) -> Result<Vec<u8>> {
let mut command = Command::new("sacct");
let command = command
.arg("-P") .arg("-n") .arg("--format") .arg(JobRecord::fields())
.arg("-j")
.arg(jobs)
.stdin(Stdio::null())
.stdout(Stdio::piped())
.stderr(Stdio::piped());
trace!(?command, "spawning `sacct` to monitor tasks");
let child = command.spawn().context("failed to spawn `sacct` command")?;
let output = child
.wait_with_output()
.await
.context("failed to wait for `sacct` to exit")?;
if !output.status.success() {
bail!(
"`sacct` failed: {status}: {stderr}",
status = output.status,
stderr = str::from_utf8(&output.stderr)
.unwrap_or("<output not UTF-8>")
.trim()
);
}
Ok(output.stdout)
}
}
pub struct SlurmApptainerBackend {
config: Arc<Config>,
events: Events,
cancellation: CancellationContext,
apptainer: ApptainerRuntime,
monitor: Monitor,
permits: Semaphore,
}
impl SlurmApptainerBackend {
pub fn new(
config: Arc<Config>,
run_root_dir: &Path,
events: Events,
cancellation: CancellationContext,
) -> Result<Self> {
let backend_config = config.backend()?;
let backend_config = backend_config
.as_slurm_apptainer()
.context("configured backend is not Slurm Apptainer")?;
let monitor = Monitor::new(
Duration::from_secs(backend_config.interval.unwrap_or(DEFAULT_MONITOR_INTERVAL)),
events.clone(),
);
let permits = Semaphore::new(
backend_config
.max_concurrency
.unwrap_or(DEFAULT_MAX_CONCURRENCY) as usize,
);
let apptainer = ApptainerRuntime::new(
run_root_dir,
backend_config.apptainer_config.image_cache_dir.as_deref(),
)?;
Ok(Self {
config,
events,
cancellation,
apptainer,
monitor,
permits,
})
}
async fn kill_job(&self, job_id: u64) -> Result<()> {
let mut command = Command::new("scancel");
let command = command
.arg(job_id.to_string())
.stdin(Stdio::null())
.stdout(Stdio::null())
.stderr(Stdio::null());
let _permit = self
.permits
.acquire()
.await
.context("failed to acquire permit for canceling job")?;
trace!(?command, "spawning `scancel` to cancel task");
let mut child = command
.spawn()
.context("failed to spawn `scancel` command")?;
let status = child.wait().await.context("failed to wait for `scancel`")?;
if !status.success() {
bail!("`scancel` failed: {status}");
}
Ok(())
}
}
impl TaskExecutionBackend for SlurmApptainerBackend {
fn constraints(
&self,
inputs: &TaskInputs,
requirements: &HashMap<String, Value>,
hints: &HashMap<String, crate::Value>,
) -> Result<TaskExecutionConstraints> {
let mut required_cpu = requirements::cpu(inputs, requirements);
let mut required_memory = ByteSize::b(requirements::memory(inputs, requirements)? as u64);
let backend_config = self.config.backend()?;
let backend_config = backend_config
.as_slurm_apptainer()
.expect("configured backend is not Slurm Apptainer");
if let Some(partition) = backend_config.slurm_partition_for_task(requirements, hints) {
if let Some(max_cpu) = partition.max_cpu_per_task
&& required_cpu > max_cpu as f64
{
let env_specific = if self.config.suppress_env_specific_output {
String::new()
} else {
format!(", but the execution backend has a maximum of {max_cpu}",)
};
match self.config.task.cpu_limit_behavior {
TaskResourceLimitBehavior::TryWithMax => {
warn!(
"task requires at least {required_cpu} CPU{s}{env_specific}",
s = if required_cpu == 1.0 { "" } else { "s" },
);
required_cpu = max_cpu as f64;
}
TaskResourceLimitBehavior::Deny => {
bail!(
"task requires at least {required_cpu} CPU{s}{env_specific}",
s = if required_cpu == 1.0 { "" } else { "s" },
);
}
}
}
if let Some(max_memory) = partition.max_memory_per_task
&& required_memory > max_memory
{
let env_specific = if self.config.suppress_env_specific_output {
String::new()
} else {
format!(
", but the execution backend has a maximum of {max_memory} GiB",
max_memory = max_memory.as_u64() as f64 / ONE_GIBIBYTE
)
};
match self.config.task.memory_limit_behavior {
TaskResourceLimitBehavior::TryWithMax => {
warn!(
"task requires at least {required_memory} GiB of memory{env_specific}",
required_memory = required_memory.as_u64() as f64 / ONE_GIBIBYTE
);
required_memory = max_memory;
}
TaskResourceLimitBehavior::Deny => {
bail!(
"task requires at least {required_memory} GiB of memory{env_specific}",
required_memory = required_memory.as_u64() as f64 / ONE_GIBIBYTE
);
}
}
}
}
let containers = requirements::container(inputs, requirements, &self.config.task.container);
Ok(super::TaskExecutionConstraints {
container: Some(containers),
cpu: required_cpu,
memory: required_memory.as_u64(),
gpu: Default::default(),
fpga: Default::default(),
disks: Default::default(),
})
}
fn execute<'a>(
&'a self,
_: &'a Arc<dyn Transferer>,
request: ExecuteTaskRequest<'a>,
) -> BoxFuture<'a, Result<Option<TaskExecutionResult>>> {
async move {
let backend_config = self.config.backend()?;
let backend_config = backend_config
.as_slurm_apptainer()
.expect("configured backend is not Slurm Apptainer");
let work_dir = request.work_dir();
fs::create_dir_all(&work_dir).await.with_context(|| {
format!(
"failed to create working directory `{path}`",
path = work_dir.display()
)
})?;
let stdout_path = request.stdout_path();
let _ = File::create(&stdout_path).await.with_context(|| {
format!(
"failed to create stdout file `{path}`",
path = stdout_path.display()
)
})?;
let stderr_path = request.stderr_path();
let _ = File::create(&stderr_path).await.with_context(|| {
format!(
"failed to create stderr file `{path}`",
path = stderr_path.display()
)
})?;
let command_path = request.command_path();
fs::write(&command_path, request.command)
.await
.with_context(|| {
format!(
"failed to write command contents to `{path}`",
path = command_path.display()
)
})?;
let Some((apptainer_script, container)) = self
.apptainer
.generate_script(
&backend_config.apptainer_config,
&self.config.task.shell,
&request,
self.cancellation.first(),
)
.await?
else {
return Ok(None);
};
let apptainer_command_path = request.attempt_dir.join(APPTAINER_COMMAND_FILE_NAME);
fs::write(&apptainer_command_path, apptainer_script)
.await
.with_context(|| {
format!(
"failed to write Apptainer command file `{}`",
apptainer_command_path.display()
)
})?;
#[cfg(unix)]
{
use std::fs::Permissions;
use std::os::unix::fs::PermissionsExt;
fs::set_permissions(&command_path, Permissions::from_mode(0o770)).await?;
fs::set_permissions(&apptainer_command_path, Permissions::from_mode(0o770)).await?;
}
let crankshaft_id = crankshaft::events::next_task_id();
let permit = self
.permits
.acquire()
.await
.context("failed to acquire permit for submitting job")?;
let job = self.monitor.submit_job(backend_config, &request, crankshaft_id, &apptainer_command_path).await?;
drop(permit);
let name = job.task_name;
let job_id = job.id;
let task_token = CancellationToken::new();
send_event!(
self.events.crankshaft(),
CrankshaftEvent::TaskCreated {
id: crankshaft_id,
name: name.clone(),
tes_id: None,
token: task_token.clone(),
},
);
let cancelled = async {
send_event!(
self.events.crankshaft(),
CrankshaftEvent::TaskCanceled { id: crankshaft_id },
);
self.kill_job(job_id).await
};
let token = self.cancellation.second();
let exit_code = tokio::select! {
_ = task_token.cancelled() => {
if let Err(e) = cancelled.await {
error!("failed to cancel task `{name}` (Slurm job `{job_id}`): {e:#}");
}
return Ok(None);
}
_ = token.cancelled() => {
if let Err(e) = cancelled.await {
error!("failed to cancel task `{name}` (Slurm job `{job_id}`): {e:#}");
}
return Ok(None);
}
result = job.completed => match result.context("failed to wait for task to complete")? {
Ok(exit_code) => {
let exit_status = exit_code.into_exit_status();
send_event!(
self.events.crankshaft(),
CrankshaftEvent::TaskCompleted {
id: crankshaft_id,
exit_statuses: NonEmpty::new(exit_status),
}
);
exit_code.code()
},
Err(e) => {
send_event!(
self.events.crankshaft(),
CrankshaftEvent::TaskFailed {
id: crankshaft_id,
message: format!("{e:#}"),
},
);
return Err(e);
}
}
};
Ok(Some(TaskExecutionResult {
container: Some(container),
exit_code: exit_code as i32,
work_dir: EvaluationPath::from_local_path(work_dir),
stdout: PrimitiveValue::new_file(
stdout_path
.into_os_string()
.into_string()
.expect("path should be UTF-8"),
)
.into(),
stderr: PrimitiveValue::new_file(
stderr_path
.into_os_string()
.into_string()
.expect("path should be UTF-8"),
)
.into(),
}))
}
.boxed()
}
}