use crate::config::{ResolvedJob, SlurmConfig};
use crate::error::{FlecheError, Result};
use crate::registry::{JobStatus, LiveStatus};
use crate::ssh::{SshClient, shell_escape};
use serde::Serialize;
pub fn generate_sbatch_script(
job_id: &str,
job: &ResolvedJob,
workspace: &str,
job_dir: &str,
) -> String {
let mut script = String::new();
script.push_str("#!/bin/bash\n");
script.push_str(&format!(
"#SBATCH --job-name={}\n",
truncate_job_name(job_id)
));
script.push_str(&format!("#SBATCH --output={job_dir}/job.out\n"));
script.push_str(&format!("#SBATCH --error={job_dir}/job.err\n"));
let slurm = &job.slurm;
if let Some(ref partition) = slurm.partition {
script.push_str(&format!("#SBATCH --partition={partition}\n"));
}
if let Some(ref time) = slurm.time {
script.push_str(&format!("#SBATCH --time={time}\n"));
}
if let Some(gpus) = slurm.gpus {
if gpus > 0 {
script.push_str(&format!("#SBATCH --gpus={gpus}\n"));
}
}
if let Some(cpus) = slurm.cpus {
script.push_str(&format!("#SBATCH --cpus-per-task={cpus}\n"));
}
if let Some(ref memory) = slurm.memory {
script.push_str(&format!("#SBATCH --mem={memory}\n"));
}
if let Some(ref constraint) = slurm.constraint {
if !constraint.is_empty() {
script.push_str(&format!("#SBATCH --constraint={constraint}\n"));
}
}
if let Some(nodes) = slurm.nodes {
script.push_str(&format!("#SBATCH --nodes={nodes}\n"));
}
if let Some(ref exclude) = slurm.exclude {
if !exclude.is_empty() {
script.push_str(&format!("#SBATCH --exclude={exclude}\n"));
}
}
script.push('\n');
let valid_env: Vec<_> = job
.env
.iter()
.filter(|(k, _)| is_valid_env_var_name(k))
.collect();
if !valid_env.is_empty() {
script.push_str("# Environment variables\n");
for (key, value) in valid_env {
script.push_str(&format!("export {key}=\"{}\"\n", escape_bash_value(value)));
}
script.push('\n');
}
script.push_str("# Change to workspace\n");
script.push_str(&format!("cd {}\n\n", shell_escape(workspace)));
script.push_str("# Execute command\n");
script.push_str(&job.command);
if !job.command.ends_with('\n') {
script.push('\n');
}
script
}
fn escape_bash_value(s: &str) -> String {
s.replace('\\', "\\\\")
.replace('"', "\\\"")
.replace('$', "\\$")
.replace('`', "\\`")
}
const MAX_JOB_NAME_LENGTH: usize = 200;
fn is_valid_env_var_name(name: &str) -> bool {
let mut chars = name.chars();
match chars.next() {
Some(c) if c.is_ascii_alphabetic() || c == '_' => {
chars.all(|c| c.is_ascii_alphanumeric() || c == '_')
}
_ => false,
}
}
fn truncate_job_name(name: &str) -> &str {
if name.len() <= MAX_JOB_NAME_LENGTH {
name
} else {
&name[..MAX_JOB_NAME_LENGTH]
}
}
pub async fn submit_job(
ssh: &SshClient,
remote_dir: &str,
dependency: Option<&str>,
) -> Result<String> {
let dep_flag = dependency
.map(|slurm_id| format!(" --dependency=afterok:{slurm_id}"))
.unwrap_or_default();
let output = ssh
.exec(&format!(
"cd {} && sbatch{dep_flag} job.sbatch",
shell_escape(remote_dir)
))
.await?;
let slurm_id = output
.lines()
.find_map(|line| {
line.strip_prefix("Submitted batch job")
.and_then(|rest| rest.split_whitespace().next())
.map(str::to_string)
})
.ok_or_else(|| {
FlecheError::SbatchFailed(format!("Could not parse sbatch output: {output}"))
})?;
Ok(slurm_id)
}
fn parse_squeue_state(state: &str) -> JobStatus {
match state.to_uppercase().as_str() {
"PENDING" | "CONFIGURING" | "RESV_DEL_HOLD" | "REQUEUE_FED" | "REQUEUE_HOLD"
| "REQUEUED" | "SPECIAL_EXIT" => JobStatus::Pending,
"RUNNING" | "COMPLETING" | "SIGNALING" | "STAGE_OUT" | "STOPPED" | "SUSPENDED" => {
JobStatus::Running
}
_ => JobStatus::Running, }
}
#[allow(clippy::match_same_arms)] fn parse_sacct_state(state: &str) -> JobStatus {
let state = state.to_uppercase();
let state = state.split_whitespace().next().unwrap_or(&state);
match state {
"COMPLETED" => JobStatus::Completed,
"FAILED" | "TIMEOUT" | "OUT_OF_MEMORY" | "NODE_FAIL" | "PREEMPTED" | "BOOT_FAIL"
| "DEADLINE" => JobStatus::Failed,
"CANCELLED" => JobStatus::Cancelled,
"PENDING" => JobStatus::Pending,
"RUNNING" => JobStatus::Running,
_ => JobStatus::Failed, }
}
pub fn parse_sacct_exit_code(raw: &str) -> Option<i32> {
let (code_str, signal_str) = raw.split_once(':')?;
let code: i32 = code_str.parse().ok()?;
let signal: i32 = signal_str.parse().ok()?;
if signal != 0 {
Some(128 + signal)
} else {
Some(code)
}
}
pub async fn get_job_status(ssh: &SshClient, slurm_id: &str) -> Result<LiveStatus> {
let escaped_id = shell_escape(slurm_id);
let (success, stdout, _) = ssh
.exec_allow_failure(&format!("squeue -j {escaped_id} -h -o %T"))
.await?;
if success && !stdout.trim().is_empty() {
return Ok(LiveStatus::new(parse_squeue_state(stdout.trim())));
}
let (success, stdout, _) = ssh
.exec_allow_failure(&format!(
"sacct -j {escaped_id} -n -o State,ExitCode --parsable2 | head -1"
))
.await?;
if success && !stdout.trim().is_empty() {
let fields: Vec<&str> = stdout.trim().split('|').collect();
let raw_state = fields.first().unwrap_or(&"").to_string();
let status = parse_sacct_state(&raw_state);
let raw_exit = fields.get(1).map(|s| (*s).to_string());
let exit_code = fields.get(1).and_then(|s| parse_sacct_exit_code(s));
return Ok(LiveStatus {
status,
exit_code,
slurm_state: Some(raw_state),
sacct_exit_code: raw_exit,
});
}
Err(FlecheError::SlurmQueryFailed(slurm_id.to_string()))
}
pub async fn cancel_job(ssh: &SshClient, slurm_id: &str) -> Result<()> {
ssh.exec(&format!("scancel {}", shell_escape(slurm_id)))
.await?;
Ok(())
}
#[derive(Debug, Default, Serialize)]
pub struct JobResourceUsage {
pub elapsed: String,
pub total_cpu: String,
pub max_rss: String,
pub alloc_tres: String,
pub node_list: String,
}
pub async fn get_job_resource_usage(ssh: &SshClient, slurm_id: &str) -> Result<JobResourceUsage> {
let escaped_id = shell_escape(slurm_id);
let output = ssh
.exec(&format!(
"sacct -j {escaped_id}.batch -n -o Elapsed,TotalCPU,MaxRSS,AllocTRES,NodeList --parsable2 2>/dev/null || \
sacct -j {escaped_id} -n -o Elapsed,TotalCPU,MaxRSS,AllocTRES,NodeList --parsable2 | head -1"
))
.await?;
let line = output.lines().next().unwrap_or("");
let fields: Vec<&str> = line.split('|').collect();
Ok(JobResourceUsage {
elapsed: fields.first().unwrap_or(&"").to_string(),
total_cpu: fields.get(1).unwrap_or(&"").to_string(),
max_rss: fields.get(2).unwrap_or(&"").to_string(),
alloc_tres: fields.get(3).unwrap_or(&"").to_string(),
node_list: fields.get(4).unwrap_or(&"").to_string(),
})
}
pub fn slurm_config_from_cli(
partition: Option<String>,
time: Option<String>,
gpus: Option<u32>,
cpus: Option<u32>,
memory: Option<String>,
constraint: Option<String>,
nodes: Option<u32>,
exclude: Option<String>,
) -> SlurmConfig {
SlurmConfig {
partition,
time,
gpus,
cpus,
memory,
constraint,
nodes,
exclude,
}
}
#[cfg(test)]
mod tests {
use super::*;
use indexmap::IndexMap;
#[test]
fn test_escape_bash_value() {
assert_eq!(escape_bash_value("simple"), "simple");
assert_eq!(escape_bash_value("with\"quote"), "with\\\"quote");
assert_eq!(escape_bash_value("with$var"), "with\\$var");
assert_eq!(escape_bash_value("with`cmd`"), "with\\`cmd\\`");
assert_eq!(escape_bash_value("with\\backslash"), "with\\\\backslash");
assert_eq!(
escape_bash_value("all\"$`\\special"),
"all\\\"\\$\\`\\\\special"
);
}
#[test]
fn test_parse_squeue_state_pending() {
assert_eq!(parse_squeue_state("PENDING"), JobStatus::Pending);
assert_eq!(parse_squeue_state("pending"), JobStatus::Pending);
assert_eq!(parse_squeue_state("CONFIGURING"), JobStatus::Pending);
assert_eq!(parse_squeue_state("REQUEUED"), JobStatus::Pending);
}
#[test]
fn test_parse_squeue_state_running() {
assert_eq!(parse_squeue_state("RUNNING"), JobStatus::Running);
assert_eq!(parse_squeue_state("running"), JobStatus::Running);
assert_eq!(parse_squeue_state("COMPLETING"), JobStatus::Running);
assert_eq!(parse_squeue_state("SUSPENDED"), JobStatus::Running);
}
#[test]
fn test_parse_squeue_state_unknown_defaults_to_running() {
assert_eq!(parse_squeue_state("UNKNOWN"), JobStatus::Running);
assert_eq!(parse_squeue_state("WEIRD_STATE"), JobStatus::Running);
}
#[test]
fn test_parse_sacct_state_completed() {
assert_eq!(parse_sacct_state("COMPLETED"), JobStatus::Completed);
assert_eq!(parse_sacct_state("completed"), JobStatus::Completed);
}
#[test]
fn test_parse_sacct_state_failed() {
assert_eq!(parse_sacct_state("FAILED"), JobStatus::Failed);
assert_eq!(parse_sacct_state("TIMEOUT"), JobStatus::Failed);
assert_eq!(parse_sacct_state("OUT_OF_MEMORY"), JobStatus::Failed);
assert_eq!(parse_sacct_state("NODE_FAIL"), JobStatus::Failed);
}
#[test]
fn test_parse_sacct_state_cancelled() {
assert_eq!(parse_sacct_state("CANCELLED"), JobStatus::Cancelled);
assert_eq!(
parse_sacct_state("CANCELLED by 12345"),
JobStatus::Cancelled
);
}
#[test]
fn test_parse_sacct_state_unknown_defaults_to_failed() {
assert_eq!(parse_sacct_state("UNKNOWN"), JobStatus::Failed);
assert_eq!(parse_sacct_state("WEIRD_STATE"), JobStatus::Failed);
}
#[test]
fn test_generate_sbatch_script_basic() {
let job = ResolvedJob {
name: "test".to_string(),
command: "echo hello".to_string(),
inputs: vec![],
outputs: vec![],
slurm: SlurmConfig::default(),
env: IndexMap::new(),
host: "test".to_string(),
exec: false,
};
let script = generate_sbatch_script("test-123", &job, "/workspace", "/jobs/test-123");
assert!(script.starts_with("#!/bin/bash\n"));
assert!(script.contains("#SBATCH --job-name=test-123"));
assert!(script.contains("#SBATCH --output=/jobs/test-123/job.out"));
assert!(script.contains("#SBATCH --error=/jobs/test-123/job.err"));
assert!(script.contains("cd '/workspace'"));
assert!(script.contains("echo hello"));
}
#[test]
fn test_generate_sbatch_script_with_slurm_options() {
let job = ResolvedJob {
name: "test".to_string(),
command: "python train.py".to_string(),
inputs: vec![],
outputs: vec![],
slurm: SlurmConfig {
partition: Some("gpu".to_string()),
time: Some("8:00:00".to_string()),
gpus: Some(2),
cpus: Some(16),
memory: Some("64G".to_string()),
constraint: Some("a100".to_string()),
nodes: Some(1),
exclude: Some("node01".to_string()),
},
env: IndexMap::new(),
host: "test".to_string(),
exec: false,
};
let script = generate_sbatch_script("train-456", &job, "/workspace", "/jobs/train-456");
assert!(script.contains("#SBATCH --partition=gpu"));
assert!(script.contains("#SBATCH --time=8:00:00"));
assert!(script.contains("#SBATCH --gpus=2"));
assert!(script.contains("#SBATCH --cpus-per-task=16"));
assert!(script.contains("#SBATCH --mem=64G"));
assert!(script.contains("#SBATCH --constraint=a100"));
assert!(script.contains("#SBATCH --nodes=1"));
assert!(script.contains("#SBATCH --exclude=node01"));
}
#[test]
fn test_generate_sbatch_script_with_env_vars() {
let mut env = IndexMap::new();
env.insert("FOO".to_string(), "bar".to_string());
env.insert("PATH_VAR".to_string(), "/some/path".to_string());
let job = ResolvedJob {
name: "test".to_string(),
command: "echo $FOO".to_string(),
inputs: vec![],
outputs: vec![],
slurm: SlurmConfig::default(),
env,
host: "test".to_string(),
exec: false,
};
let script = generate_sbatch_script("test-789", &job, "/ws", "/jobs/test-789");
assert!(script.contains("export FOO=\"bar\""));
assert!(script.contains("export PATH_VAR=\"/some/path\""));
}
#[test]
fn test_generate_sbatch_script_escapes_env_values() {
let mut env = IndexMap::new();
env.insert("QUOTED".to_string(), "value\"with\"quotes".to_string());
let job = ResolvedJob {
name: "test".to_string(),
command: "echo test".to_string(),
inputs: vec![],
outputs: vec![],
slurm: SlurmConfig::default(),
env,
host: "test".to_string(),
exec: false,
};
let script = generate_sbatch_script("test-esc", &job, "/ws", "/jobs/test-esc");
assert!(script.contains("export QUOTED=\"value\\\"with\\\"quotes\""));
}
#[test]
fn test_parse_squeue_real_running() {
assert_eq!(parse_squeue_state("RUNNING"), JobStatus::Running);
}
#[test]
fn test_parse_squeue_real_pending() {
assert_eq!(parse_squeue_state("PD"), JobStatus::Running); assert_eq!(parse_squeue_state("PENDING"), JobStatus::Pending);
}
#[test]
fn test_parse_sacct_real_completed() {
assert_eq!(parse_sacct_state("COMPLETED"), JobStatus::Completed);
}
#[test]
fn test_parse_sacct_real_failed() {
assert_eq!(parse_sacct_state("FAILED"), JobStatus::Failed);
}
#[test]
fn test_parse_sacct_real_pending() {
assert_eq!(parse_sacct_state("PENDING"), JobStatus::Pending);
}
#[test]
fn test_truncate_job_name_short() {
assert_eq!(truncate_job_name("my-job"), "my-job");
}
#[test]
fn test_truncate_job_name_exact() {
let name = "a".repeat(MAX_JOB_NAME_LENGTH);
assert_eq!(truncate_job_name(&name), name);
}
#[test]
fn test_truncate_job_name_long() {
let name = "a".repeat(MAX_JOB_NAME_LENGTH + 50);
let truncated = truncate_job_name(&name);
assert_eq!(truncated.len(), MAX_JOB_NAME_LENGTH);
}
#[test]
fn test_parse_sacct_exit_code_success() {
assert_eq!(parse_sacct_exit_code("0:0"), Some(0));
}
#[test]
fn test_parse_sacct_exit_code_failure() {
assert_eq!(parse_sacct_exit_code("1:0"), Some(1));
assert_eq!(parse_sacct_exit_code("2:0"), Some(2));
assert_eq!(parse_sacct_exit_code("127:0"), Some(127));
}
#[test]
fn test_parse_sacct_exit_code_signal() {
assert_eq!(parse_sacct_exit_code("0:9"), Some(137));
assert_eq!(parse_sacct_exit_code("0:15"), Some(143));
}
#[test]
fn test_parse_sacct_exit_code_invalid() {
assert_eq!(parse_sacct_exit_code(""), None);
assert_eq!(parse_sacct_exit_code("abc"), None);
assert_eq!(parse_sacct_exit_code("1"), None);
assert_eq!(parse_sacct_exit_code("a:b"), None);
}
}