use crate::multicall::gbatch::cli;
use anyhow::{anyhow, Context, Result};
use clap::Parser;
use gflow::client::Client;
use gflow::core::job::{GpuSharingMode, Job, JobNotifications};
use gflow::utils::parsers::parse_array_spec;
use gflow::utils::{generate_param_combinations, parse_param_spec};
use lettre::message::Mailbox;
use std::{collections::HashMap, env, fs, io::Read, path::PathBuf};
fn validate_project(job: &mut Job, config: &gflow::config::Config) -> Result<()> {
let normalized =
gflow::utils::validate_project_policy(job.project.as_deref(), &config.projects)?;
job.project = normalized.map(|s| s.into());
Ok(())
}
fn resolve_project(args: &cli::AddArgs, script_args: &Option<cli::AddArgs>) -> Option<String> {
args.project
.clone()
.or_else(|| script_args.as_ref().and_then(|s| s.project.clone()))
}
fn default_per_job_notification_events() -> Vec<String> {
vec![
"job_completed".to_string(),
"job_failed".to_string(),
"job_timeout".to_string(),
"job_cancelled".to_string(),
]
}
fn resolve_job_notifications(
args: &cli::AddArgs,
script_args: Option<&cli::AddArgs>,
) -> Result<JobNotifications> {
let mut emails = script_args
.map(|s| s.notify_email.clone())
.unwrap_or_default();
emails.extend(args.notify_email.iter().cloned());
if emails.is_empty() {
return Ok(JobNotifications::default());
}
let events = if !args.notify_on.is_empty() {
args.notify_on.clone()
} else if let Some(script_args) = script_args {
if !script_args.notify_on.is_empty() {
script_args.notify_on.clone()
} else {
default_per_job_notification_events()
}
} else {
default_per_job_notification_events()
};
let notifications = JobNotifications::normalized(emails, events);
validate_job_notifications(¬ifications)?;
Ok(notifications)
}
fn validate_job_notifications(notifications: &JobNotifications) -> Result<()> {
for email in ¬ifications.emails {
email
.as_str()
.parse::<Mailbox>()
.map_err(|e| anyhow!("Invalid --notify-email '{}': {e}", email))?;
}
Ok(())
}
fn validate_shared_requires_gpu_memory(job: &Job) -> Result<()> {
if job.gpu_sharing_mode == GpuSharingMode::Shared && job.gpu_memory_limit_mb.is_none() {
anyhow::bail!(
"Shared jobs must set a GPU memory limit. Use --gpu-memory (alias: --max-gpu-mem) with --shared."
);
}
Ok(())
}
fn preview_substitute(command: &str, parameters: &HashMap<String, String>) -> String {
let mut result = command.to_string();
for (param_name, value) in parameters {
let pattern = format!("{{{}}}", param_name);
result = result.replace(&pattern, value);
}
result
}
fn substitute_template(template: &str, parameters: &HashMap<String, String>) -> String {
let mut result = template.to_string();
for (param_name, value) in parameters {
let pattern = format!("{{{}}}", param_name);
let sanitized_value = gflow::tmux::normalize_session_name(value);
result = result.replace(&pattern, &sanitized_value);
}
result
}
fn parse_param_file(path: &PathBuf) -> Result<Vec<HashMap<String, String>>> {
let mut reader = csv::Reader::from_path(path).context("Failed to read parameter file")?;
let headers: Vec<String> = reader
.headers()?
.iter()
.map(|h| h.trim().to_string())
.collect();
if headers.is_empty() {
return Err(anyhow!("CSV file must have a header row"));
}
let mut param_sets = Vec::new();
for result in reader.records() {
let record = result?;
let mut params = HashMap::new();
for (i, value) in record.iter().enumerate() {
if let Some(header) = headers.get(i) {
params.insert(header.clone(), value.trim().to_string());
}
}
param_sets.push(params);
}
if param_sets.is_empty() {
return Err(anyhow!("CSV file contains no data rows"));
}
Ok(param_sets)
}
pub(crate) async fn handle_add(
config: &gflow::config::Config,
add_args: cli::AddArgs,
use_stdin: bool,
) -> Result<()> {
let client = Client::build(config).context("Failed to build client")?;
let stdin_content = if use_stdin {
let mut buffer = String::new();
std::io::stdin()
.read_to_string(&mut buffer)
.context("Failed to read from stdin")?;
if buffer.trim().is_empty() {
anyhow::bail!("No content provided via stdin");
}
Some(buffer)
} else {
None
};
if !add_args.param.is_empty() && add_args.array.is_some() {
anyhow::bail!("Cannot use both --param and --array together");
}
if add_args.param_file.is_some() && add_args.array.is_some() {
anyhow::bail!("Cannot use both --param-file and --array together");
}
if let Some(ref param_file) = add_args.param_file {
let mut param_combinations = parse_param_file(param_file)?;
if !add_args.param.is_empty() {
let mut param_specs = Vec::new();
for spec in &add_args.param {
param_specs.push(parse_param_spec(spec)?);
}
let cli_combinations = generate_param_combinations(¶m_specs);
let mut merged = Vec::new();
for file_params in ¶m_combinations {
for cli_params in &cli_combinations {
let mut combined = file_params.clone();
combined.extend(cli_params.clone());
merged.push(combined);
}
}
param_combinations = merged;
}
let group_id = if add_args.max_concurrent.is_some() {
Some(uuid::Uuid::new_v4())
} else {
None
};
if add_args.dry_run {
println!("Would submit {} batch job(s):", param_combinations.len());
for (idx, params) in param_combinations.iter().enumerate() {
let job = build_job_with_params(
&add_args,
params.clone(),
&client,
stdin_content.as_ref(),
)
.await?;
let mut cmd = if let Some(c) = &job.command {
c.to_string()
} else if let Some(s) = &job.script {
s.to_string_lossy().to_string()
} else {
String::new()
};
cmd = preview_substitute(&cmd, params);
println!(" [{}] {} (GPUs: {})", idx + 1, cmd, job.gpus);
}
return Ok(());
}
let mut jobs = Vec::with_capacity(param_combinations.len());
for params in param_combinations {
let mut job =
build_job_with_params(&add_args, params, &client, stdin_content.as_ref()).await?;
validate_project(&mut job, config)?;
job.group_id = group_id;
job.max_concurrent = add_args.max_concurrent;
jobs.push(job);
}
let responses = client
.add_jobs(jobs)
.await
.context("Failed to add batch jobs")?;
if let Some(ref gid) = group_id {
println!(
"Submitted {} batch jobs with group_id: {}",
responses.len(),
gid
);
println!(
" (You can update the limit with: gctl set-limit {} <N>)",
responses[0].id
);
}
for response in responses {
println!(
"Submitted batch job {} ({})",
response.id, response.run_name
);
}
return Ok(());
}
if !add_args.param.is_empty() {
let mut param_specs = Vec::new();
for spec in &add_args.param {
param_specs.push(parse_param_spec(spec)?);
}
let param_combinations = generate_param_combinations(¶m_specs);
let group_id = if add_args.max_concurrent.is_some() {
Some(uuid::Uuid::new_v4())
} else {
None
};
if add_args.dry_run {
println!("Would submit {} batch job(s):", param_combinations.len());
for (idx, params) in param_combinations.iter().enumerate() {
let job = build_job_with_params(
&add_args,
params.clone(),
&client,
stdin_content.as_ref(),
)
.await?;
let mut cmd = if let Some(c) = &job.command {
c.to_string()
} else if let Some(s) = &job.script {
s.to_string_lossy().to_string()
} else {
String::new()
};
cmd = preview_substitute(&cmd, params);
println!(" [{}] {} (GPUs: {})", idx + 1, cmd, job.gpus);
}
return Ok(());
}
let mut jobs = Vec::with_capacity(param_combinations.len());
for params in param_combinations {
let mut job =
build_job_with_params(&add_args, params, &client, stdin_content.as_ref()).await?;
validate_project(&mut job, config)?;
job.group_id = group_id;
job.max_concurrent = add_args.max_concurrent;
jobs.push(job);
}
let responses = client
.add_jobs(jobs)
.await
.context("Failed to add batch jobs")?;
if let Some(ref gid) = group_id {
println!(
"Submitted {} batch jobs with group_id: {}",
responses.len(),
gid
);
println!(
" (You can update the limit with: gctl set-limit {} <N>)",
responses[0].id
);
}
for response in responses {
println!(
"Submitted batch job {} ({})",
response.id, response.run_name
);
}
return Ok(());
}
if let Some(array_spec) = &add_args.array {
let task_ids = parse_array_spec(array_spec)?;
let group_id = if add_args.max_concurrent.is_some() {
Some(uuid::Uuid::new_v4())
} else {
None
};
if add_args.dry_run {
println!("Would submit {} array job(s):", task_ids.len());
for (idx, task_id) in task_ids.iter().enumerate() {
let job =
build_job(&add_args, Some(*task_id), &client, stdin_content.as_ref()).await?;
let cmd = if let Some(c) = &job.command {
c.to_string()
} else if let Some(s) = &job.script {
s.to_string_lossy().to_string()
} else {
String::new()
};
println!(
" [{}] {} (GPUs: {}, task_id: {})",
idx + 1,
cmd,
job.gpus,
task_id
);
}
return Ok(());
}
let mut jobs = Vec::with_capacity(task_ids.len());
for task_id in task_ids {
let mut job =
build_job(&add_args, Some(task_id), &client, stdin_content.as_ref()).await?;
validate_project(&mut job, config)?;
job.group_id = group_id;
job.max_concurrent = add_args.max_concurrent;
jobs.push(job);
}
let responses = client
.add_jobs(jobs)
.await
.context("Failed to add batch jobs")?;
if let Some(ref gid) = group_id {
println!(
"Submitted {} batch jobs with group_id: {}",
responses.len(),
gid
);
println!(
" (You can update the limit with: gctl set-limit {} <N>)",
responses[0].id
);
}
for response in responses {
println!(
"Submitted batch job {} ({})",
response.id, response.run_name
);
}
return Ok(());
}
if add_args.dry_run {
let job = build_job(&add_args, None, &client, stdin_content.as_ref()).await?;
println!("Would submit 1 batch job:");
let cmd = if let Some(c) = &job.command {
c.to_string()
} else if let Some(s) = &job.script {
s.to_string_lossy().to_string()
} else {
String::new()
};
println!(" [1] {} (GPUs: {})", cmd, job.gpus);
return Ok(());
}
let mut job = build_job(&add_args, None, &client, stdin_content.as_ref()).await?;
validate_project(&mut job, config)?;
let response = client.add_job(job).await.context("Failed to add job")?;
println!(
"Submitted batch job {} ({})",
response.id, response.run_name
);
Ok(())
}
fn detect_current_conda_env() -> Option<String> {
env::var("CONDA_DEFAULT_ENV")
.ok()
.filter(|env_name| !env_name.is_empty())
}
async fn build_job(
args: &cli::AddArgs,
task_id: Option<u32>,
client: &Client,
stdin_content: Option<&String>,
) -> Result<Job> {
let mut builder = Job::builder();
let run_dir = std::env::current_dir().context("Failed to get current directory")?;
builder = builder.run_dir(run_dir);
builder = builder.task_id(task_id);
let username = gflow::platform::get_current_username();
builder = builder.submitted_by(username);
builder = builder.run_name(args.name.clone());
let time_limit = if let Some(time_str) = &args.time {
Some(gflow::utils::parse_time_limit(time_str)?)
} else {
None
};
let memory_limit_mb = if let Some(memory_str) = &args.memory {
Some(gflow::utils::parse_memory_limit(memory_str)?)
} else {
None
};
let gpu_memory_limit_mb = if let Some(memory_str) = &args.gpu_memory {
Some(gflow::utils::parse_memory_limit(memory_str)?)
} else {
None
};
let (depends_on_ids, dependency_mode) = if let Some(ref deps_all) = args.depends_on_all {
let ids = parse_dependency_list(deps_all, client).await?;
(ids, Some(gflow::core::job::DependencyMode::All))
} else if let Some(ref deps_any) = args.depends_on_any {
let ids = parse_dependency_list(deps_any, client).await?;
(ids, Some(gflow::core::job::DependencyMode::Any))
} else if let Some(ref dep) = args.depends_on {
let dep_id = resolve_dependency(Some(dep.clone()), client).await?;
if let Some(id) = dep_id {
(vec![id], Some(gflow::core::job::DependencyMode::All))
} else {
(vec![], None)
}
} else {
(vec![], None)
};
builder = builder.depends_on_ids(depends_on_ids.clone());
builder = builder.dependency_mode(dependency_mode);
builder = builder.auto_cancel_on_dependency_failure(!args.no_auto_cancel);
builder = builder.max_retries(args.max_retries.unwrap_or(0));
builder = builder.notifications(JobNotifications::default());
if depends_on_ids.len() == 1 {
builder = builder.depends_on(Some(depends_on_ids[0]));
}
if let Some(content) = stdin_content {
let script_args = parse_script_content_for_args(content)?;
let temp_script = save_stdin_to_temp_file(content)?;
builder = builder.script(temp_script);
builder = builder.gpus(args.gpus.or(script_args.gpus).unwrap_or(0));
builder = builder.shared(args.shared || script_args.shared);
builder = builder.priority(args.priority.or(script_args.priority).unwrap_or(10));
builder = builder.project(resolve_project(args, &Some(script_args.clone())));
builder = builder.notifications(resolve_job_notifications(args, Some(&script_args))?);
builder = builder.conda_env(args.conda_env.clone().or(script_args.conda_env));
let final_time_limit = if time_limit.is_some() {
time_limit
} else if let Some(script_time_str) = &script_args.time {
Some(gflow::utils::parse_time_limit(script_time_str)?)
} else {
None
};
builder = builder.time_limit(final_time_limit);
let final_memory_limit = if memory_limit_mb.is_some() {
memory_limit_mb
} else if let Some(script_memory_str) = &script_args.memory {
Some(gflow::utils::parse_memory_limit(script_memory_str)?)
} else {
None
};
builder = builder.memory_limit_mb(final_memory_limit);
let final_gpu_memory_limit = if gpu_memory_limit_mb.is_some() {
gpu_memory_limit_mb
} else if let Some(script_gpu_memory_str) = &script_args.gpu_memory {
Some(gflow::utils::parse_memory_limit(script_gpu_memory_str)?)
} else {
None
};
builder = builder.gpu_memory_limit_mb(final_gpu_memory_limit);
} else {
let is_script =
args.script_or_command.len() == 1 && PathBuf::from(&args.script_or_command[0]).exists();
if is_script {
let script_path = make_absolute_path(PathBuf::from(&args.script_or_command[0]))?;
let script_args = parse_script_for_args(&script_path)?;
builder = builder.script(script_path);
builder = builder.gpus(args.gpus.or(script_args.gpus).unwrap_or(0));
builder = builder.shared(args.shared || script_args.shared);
builder = builder.priority(args.priority.or(script_args.priority).unwrap_or(10));
builder = builder.notifications(resolve_job_notifications(args, Some(&script_args))?);
builder = builder.conda_env(args.conda_env.clone().or(script_args.conda_env));
let final_project = args.project.clone().or(script_args.project);
builder = builder.project(final_project);
let final_time_limit = if time_limit.is_some() {
time_limit
} else if let Some(script_time_str) = &script_args.time {
Some(gflow::utils::parse_time_limit(script_time_str)?)
} else {
None
};
builder = builder.time_limit(final_time_limit);
let final_memory_limit = if memory_limit_mb.is_some() {
memory_limit_mb
} else if let Some(script_memory_str) = &script_args.memory {
Some(gflow::utils::parse_memory_limit(script_memory_str)?)
} else {
None
};
builder = builder.memory_limit_mb(final_memory_limit);
let final_gpu_memory_limit = if gpu_memory_limit_mb.is_some() {
gpu_memory_limit_mb
} else if let Some(script_gpu_memory_str) = &script_args.gpu_memory {
Some(gflow::utils::parse_memory_limit(script_gpu_memory_str)?)
} else {
None
};
builder = builder.gpu_memory_limit_mb(final_gpu_memory_limit);
} else {
let command = args
.script_or_command
.iter()
.map(|arg| shell_escape::escape(arg.into()))
.collect::<Vec<_>>()
.join(" ");
builder = builder.command(command);
builder = builder.gpus(args.gpus.unwrap_or(0));
builder = builder.shared(args.shared);
builder = builder.priority(args.priority.unwrap_or(10));
let conda_env = args.conda_env.clone().or_else(detect_current_conda_env);
builder = builder.conda_env(conda_env);
builder = builder.project(resolve_project(args, &None));
builder = builder.notifications(resolve_job_notifications(args, None)?);
builder = builder.time_limit(time_limit);
builder = builder.memory_limit_mb(memory_limit_mb);
builder = builder.gpu_memory_limit_mb(gpu_memory_limit_mb);
}
}
builder = builder.auto_close_tmux(args.auto_close);
let job = builder.build();
validate_shared_requires_gpu_memory(&job)?;
Ok(job)
}
async fn build_job_with_params(
args: &cli::AddArgs,
parameters: HashMap<String, String>,
client: &Client,
stdin_content: Option<&String>,
) -> Result<Job> {
let mut builder = Job::builder();
let run_dir = std::env::current_dir().context("Failed to get current directory")?;
builder = builder.run_dir(run_dir);
builder = builder.task_id(None);
builder = builder.parameters(parameters.clone());
let username = gflow::platform::get_current_username();
builder = builder.submitted_by(username);
let run_name = if let Some(ref template) = args.name_template {
Some(substitute_template(template, ¶meters))
} else {
args.name.clone()
};
builder = builder.run_name(run_name);
let time_limit = if let Some(time_str) = &args.time {
Some(gflow::utils::parse_time_limit(time_str)?)
} else {
None
};
let memory_limit_mb = if let Some(memory_str) = &args.memory {
Some(gflow::utils::parse_memory_limit(memory_str)?)
} else {
None
};
let gpu_memory_limit_mb = if let Some(memory_str) = &args.gpu_memory {
Some(gflow::utils::parse_memory_limit(memory_str)?)
} else {
None
};
let (depends_on_ids, dependency_mode) = if let Some(ref deps_all) = args.depends_on_all {
let ids = parse_dependency_list(deps_all, client).await?;
(ids, Some(gflow::core::job::DependencyMode::All))
} else if let Some(ref deps_any) = args.depends_on_any {
let ids = parse_dependency_list(deps_any, client).await?;
(ids, Some(gflow::core::job::DependencyMode::Any))
} else if let Some(ref dep) = args.depends_on {
let dep_id = resolve_dependency(Some(dep.clone()), client).await?;
if let Some(id) = dep_id {
(vec![id], Some(gflow::core::job::DependencyMode::All))
} else {
(vec![], None)
}
} else {
(vec![], None)
};
builder = builder.depends_on_ids(depends_on_ids.clone());
builder = builder.dependency_mode(dependency_mode);
builder = builder.auto_cancel_on_dependency_failure(!args.no_auto_cancel);
builder = builder.max_retries(args.max_retries.unwrap_or(0));
builder = builder.notifications(JobNotifications::default());
if depends_on_ids.len() == 1 {
builder = builder.depends_on(Some(depends_on_ids[0]));
}
if let Some(content) = stdin_content {
let script_args = parse_script_content_for_args(content)?;
let temp_script = save_stdin_to_temp_file(content)?;
builder = builder.script(temp_script);
builder = builder.gpus(args.gpus.or(script_args.gpus).unwrap_or(0));
builder = builder.shared(args.shared || script_args.shared);
builder = builder.priority(args.priority.or(script_args.priority).unwrap_or(10));
builder = builder.project(resolve_project(args, &Some(script_args.clone())));
builder = builder.notifications(resolve_job_notifications(args, Some(&script_args))?);
builder = builder.conda_env(args.conda_env.clone().or(script_args.conda_env));
let final_time_limit = if time_limit.is_some() {
time_limit
} else if let Some(script_time_str) = &script_args.time {
Some(gflow::utils::parse_time_limit(script_time_str)?)
} else {
None
};
builder = builder.time_limit(final_time_limit);
let final_memory_limit = if memory_limit_mb.is_some() {
memory_limit_mb
} else if let Some(script_memory_str) = &script_args.memory {
Some(gflow::utils::parse_memory_limit(script_memory_str)?)
} else {
None
};
builder = builder.memory_limit_mb(final_memory_limit);
let final_gpu_memory_limit = if gpu_memory_limit_mb.is_some() {
gpu_memory_limit_mb
} else if let Some(script_gpu_memory_str) = &script_args.gpu_memory {
Some(gflow::utils::parse_memory_limit(script_gpu_memory_str)?)
} else {
None
};
builder = builder.gpu_memory_limit_mb(final_gpu_memory_limit);
} else {
let is_script =
args.script_or_command.len() == 1 && PathBuf::from(&args.script_or_command[0]).exists();
if is_script {
let script_path = make_absolute_path(PathBuf::from(&args.script_or_command[0]))?;
let script_args = parse_script_for_args(&script_path)?;
builder = builder.script(script_path);
builder = builder.gpus(args.gpus.or(script_args.gpus).unwrap_or(0));
builder = builder.shared(args.shared || script_args.shared);
builder = builder.priority(args.priority.or(script_args.priority).unwrap_or(10));
builder = builder.notifications(resolve_job_notifications(args, Some(&script_args))?);
builder = builder.conda_env(args.conda_env.clone().or(script_args.conda_env));
let final_project = args.project.clone().or(script_args.project);
builder = builder.project(final_project);
let final_time_limit = if time_limit.is_some() {
time_limit
} else if let Some(script_time_str) = &script_args.time {
Some(gflow::utils::parse_time_limit(script_time_str)?)
} else {
None
};
builder = builder.time_limit(final_time_limit);
let final_memory_limit = if memory_limit_mb.is_some() {
memory_limit_mb
} else if let Some(script_memory_str) = &script_args.memory {
Some(gflow::utils::parse_memory_limit(script_memory_str)?)
} else {
None
};
builder = builder.memory_limit_mb(final_memory_limit);
let final_gpu_memory_limit = if gpu_memory_limit_mb.is_some() {
gpu_memory_limit_mb
} else if let Some(script_gpu_memory_str) = &script_args.gpu_memory {
Some(gflow::utils::parse_memory_limit(script_gpu_memory_str)?)
} else {
None
};
builder = builder.gpu_memory_limit_mb(final_gpu_memory_limit);
} else {
let command = args
.script_or_command
.iter()
.map(|arg| shell_escape::escape(arg.into()))
.collect::<Vec<_>>()
.join(" ");
builder = builder.command(command);
builder = builder.gpus(args.gpus.unwrap_or(0));
builder = builder.shared(args.shared);
builder = builder.priority(args.priority.unwrap_or(10));
let conda_env = args.conda_env.clone().or_else(detect_current_conda_env);
builder = builder.conda_env(conda_env);
builder = builder.project(resolve_project(args, &None));
builder = builder.notifications(resolve_job_notifications(args, None)?);
builder = builder.time_limit(time_limit);
builder = builder.memory_limit_mb(memory_limit_mb);
builder = builder.gpu_memory_limit_mb(gpu_memory_limit_mb);
}
}
builder = builder.auto_close_tmux(args.auto_close);
let job = builder.build();
validate_shared_requires_gpu_memory(&job)?;
Ok(job)
}
fn parse_script_for_args(script_path: &PathBuf) -> Result<cli::AddArgs> {
let content = fs::read_to_string(script_path).context("Failed to read script file")?;
parse_script_content_for_args(&content)
}
fn parse_script_content_for_args(content: &str) -> Result<cli::AddArgs> {
let gflow_lines: Vec<&str> = content
.lines()
.filter(|line| line.starts_with("# GFLOW"))
.map(|line| line.trim_start_matches("# GFLOW").trim())
.collect();
if gflow_lines.is_empty() {
return Ok(cli::AddArgs {
script_or_command: vec![],
conda_env: None,
gpus: None,
shared: false,
priority: None,
depends_on: None,
depends_on_all: None,
depends_on_any: None,
no_auto_cancel: false,
array: None,
time: None,
memory: None,
gpu_memory: None,
name: None,
auto_close: false,
param: vec![],
dry_run: false,
max_concurrent: None,
max_retries: None,
param_file: None,
name_template: None,
project: None,
notify_email: vec![],
notify_on: vec![],
});
}
let args_str = gflow_lines.join(" ");
let full_args = format!("gbatch {args_str} dummy");
let parsed = cli::GBatch::try_parse_from(full_args.split_whitespace())?;
Ok(parsed.add_args)
}
fn save_stdin_to_temp_file(content: &str) -> Result<PathBuf> {
use std::io::Write;
let temp_dir = std::env::temp_dir();
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_micros();
let temp_path = temp_dir.join(format!("gflow_stdin_{}.sh", timestamp));
let mut file =
fs::File::create(&temp_path).context("Failed to create temporary script file")?;
file.write_all(content.as_bytes())
.context("Failed to write to temporary script file")?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let mut perms = file.metadata()?.permissions();
perms.set_mode(0o755);
fs::set_permissions(&temp_path, perms)?;
}
Ok(temp_path)
}
fn make_absolute_path(path: PathBuf) -> Result<PathBuf> {
if path.is_absolute() {
Ok(path)
} else {
std::env::current_dir()
.map(|pwd| pwd.join(path))
.context("Failed to get current directory")
}
}
async fn resolve_dependency(depends_on: Option<String>, client: &Client) -> Result<Option<u32>> {
match depends_on {
None => Ok(None),
Some(raw) => {
let trimmed = raw.trim();
if trimmed.is_empty() {
return Err(anyhow!("Dependency value cannot be empty"));
}
if trimmed.starts_with('@') {
let username = gflow::platform::get_current_username();
let resolved_id = client
.resolve_dependency(&username, trimmed)
.await
.with_context(|| format!("Failed to resolve dependency '{}'", trimmed))?;
Ok(Some(resolved_id))
} else {
let parsed = trimmed
.parse::<u32>()
.map_err(|_| anyhow!("Invalid dependency value: {trimmed}"))?;
Ok(Some(parsed))
}
}
}
}
async fn parse_dependency_list(deps_str: &str, client: &Client) -> Result<Vec<u32>> {
let mut resolved_deps = Vec::new();
let username = gflow::platform::get_current_username();
for dep in deps_str.split(',') {
let trimmed = dep.trim();
if trimmed.is_empty() {
continue;
}
let dep_id = if trimmed.starts_with('@') {
client
.resolve_dependency(&username, trimmed)
.await
.with_context(|| format!("Failed to resolve dependency '{}'", trimmed))?
} else {
trimmed
.parse::<u32>()
.with_context(|| format!("Invalid dependency ID: {}", trimmed))?
};
resolved_deps.push(dep_id);
}
if resolved_deps.is_empty() {
anyhow::bail!("Dependency list cannot be empty");
}
Ok(resolved_deps)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn resolve_job_notifications_defaults_terminal_events() {
let args = cli::AddArgs {
script_or_command: vec!["python".to_string(), "train.py".to_string()],
conda_env: None,
gpus: None,
shared: false,
priority: None,
depends_on: None,
depends_on_all: None,
depends_on_any: None,
no_auto_cancel: false,
array: None,
time: None,
memory: None,
gpu_memory: None,
name: None,
auto_close: false,
param: vec![],
dry_run: false,
max_concurrent: None,
max_retries: None,
param_file: None,
name_template: None,
project: None,
notify_email: vec!["alice@example.com".to_string()],
notify_on: vec![],
};
let notifications = resolve_job_notifications(&args, None).unwrap();
assert_eq!(notifications.emails.len(), 1);
assert_eq!(notifications.emails[0].as_str(), "alice@example.com");
assert_eq!(
notifications
.events
.iter()
.map(|event| event.as_str())
.collect::<Vec<_>>(),
vec![
"job_completed",
"job_failed",
"job_timeout",
"job_cancelled"
]
);
}
#[test]
fn parse_script_content_supports_notification_directives() {
let args = parse_script_content_for_args(
r#"#!/bin/bash
# GFLOW --notify-email=alice@example.com
# GFLOW --notify-email=ops@example.com
# GFLOW --notify-on=job_failed,job_timeout
python train.py
"#,
)
.unwrap();
assert_eq!(
args.notify_email,
vec![
"alice@example.com".to_string(),
"ops@example.com".to_string()
]
);
assert_eq!(
args.notify_on,
vec!["job_failed".to_string(), "job_timeout".to_string()]
);
}
}