use clap::Parser;
use clap_complete::Shell;
use gflow::build_info::version;
#[derive(Debug, Parser)]
#[command(name = "gbatch", author, version = version(), about = "Submits jobs to the gflow scheduler. Inspired by sbatch.")]
#[command(styles=gflow::utils::STYLES)]
pub struct GBatch {
#[command(subcommand)]
pub commands: Option<Commands>,
#[command(flatten)]
pub add_args: AddArgs,
#[arg(long, global = true, help = "Path to the config file", hide = true)]
pub config: Option<std::path::PathBuf>,
}
#[derive(Debug, Parser)]
pub enum Commands {
New(NewArgs),
Completion {
#[arg(value_enum)]
shell: Shell,
},
}
#[derive(Debug, Parser)]
pub struct NewArgs {
pub name: String,
}
#[derive(Debug, Parser, Clone)]
pub struct AddArgs {
#[arg(trailing_var_arg = true, allow_hyphen_values = true, value_hint = clap::ValueHint::CommandWithArguments)]
pub script_or_command: Vec<String>,
#[arg(short, long, value_hint = clap::ValueHint::Other)]
pub conda_env: Option<String>,
#[arg(short, long, visible_alias = "gres", name = "NUMS")]
pub gpus: Option<u32>,
#[arg(long)]
pub shared: bool,
#[arg(short = 'p', long, visible_alias = "nice")]
pub priority: Option<u8>,
#[arg(short = 'd', long, visible_alias = "dependency", value_hint = clap::ValueHint::Other)]
pub depends_on: Option<String>,
#[arg(long, value_hint = clap::ValueHint::Other, conflicts_with = "depends_on")]
pub depends_on_all: Option<String>,
#[arg(long, value_hint = clap::ValueHint::Other, conflicts_with_all = ["depends_on", "depends_on_all"])]
pub depends_on_any: Option<String>,
#[arg(long)]
pub no_auto_cancel: bool,
#[arg(long, value_hint = clap::ValueHint::Other)]
pub array: Option<String>,
#[arg(
short = 't',
long,
visible_aliases = ["time-limit", "timelimit"],
value_hint = clap::ValueHint::Other
)]
pub time: Option<String>,
#[arg(
short = 'm',
long,
visible_aliases = ["max-mem", "max-memory"],
value_hint = clap::ValueHint::Other
)]
pub memory: Option<String>,
#[arg(
long = "gpu-memory",
visible_aliases = ["max-gpu-mem", "max-gpu-memory"],
value_hint = clap::ValueHint::Other
)]
pub gpu_memory: Option<String>,
#[arg(
short = 'n',
short_alias = 'J',
long,
visible_alias = "job-name",
value_hint = clap::ValueHint::Other
)]
pub name: Option<String>,
#[arg(long)]
pub auto_close: bool,
#[arg(long, value_hint = clap::ValueHint::Other)]
pub param: Vec<String>,
#[arg(long)]
pub dry_run: bool,
#[arg(long, value_hint = clap::ValueHint::Other)]
pub max_concurrent: Option<usize>,
#[arg(long, value_hint = clap::ValueHint::FilePath)]
pub param_file: Option<std::path::PathBuf>,
#[arg(long, value_hint = clap::ValueHint::Other)]
pub name_template: Option<String>,
#[arg(short = 'P', long, value_hint = clap::ValueHint::Other)]
pub project: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_slurm_compatible_aliases() {
let args = GBatch::try_parse_from([
"gbatch",
"--time-limit",
"2:00:00",
"--nice",
"10",
"--job-name",
"train",
"--gres",
"2",
"--dependency",
"@",
"script.sh",
])
.expect("should parse SLURM-compatible aliases");
assert_eq!(args.add_args.time.as_deref(), Some("2:00:00"));
assert_eq!(args.add_args.priority, Some(10));
assert_eq!(args.add_args.name.as_deref(), Some("train"));
assert_eq!(args.add_args.gpus, Some(2));
assert!(!args.add_args.shared);
assert_eq!(args.add_args.depends_on.as_deref(), Some("@"));
assert_eq!(
args.add_args.script_or_command,
vec!["script.sh".to_string()]
);
}
#[test]
fn parses_shared_flag() {
let args = GBatch::try_parse_from(["gbatch", "--shared", "script.sh"])
.expect("should parse --shared flag");
assert!(args.add_args.shared);
}
#[test]
fn parses_max_mem_alias() {
let args = GBatch::try_parse_from(["gbatch", "--max-mem", "8G", "script.sh"])
.expect("should parse --max-mem alias");
assert_eq!(args.add_args.memory.as_deref(), Some("8G"));
}
#[test]
fn parses_max_gpu_mem_alias() {
let args = GBatch::try_parse_from(["gbatch", "--max-gpu-mem", "24G", "script.sh"])
.expect("should parse --max-gpu-mem alias");
assert_eq!(args.add_args.gpu_memory.as_deref(), Some("24G"));
}
}