use anyhow::{anyhow, Context, Result};
use gflow::client::Client;
use gflow::core::job::{GpuSharingMode, Job, JobState, JobStateReason};
use gflow::print_field;
use std::collections::{HashMap, HashSet, VecDeque};
use std::path::PathBuf;
#[allow(clippy::too_many_arguments)]
pub async fn handle_redo(
config_path: &Option<PathBuf>,
job_id_str: &str,
gpus_override: Option<u32>,
priority_override: Option<u8>,
depends_on_override: Option<String>,
time_override: Option<String>,
memory_override: Option<String>,
gpu_memory_override: Option<String>,
conda_env_override: Option<String>,
clear_deps: bool,
cascade: bool,
) -> Result<()> {
let client = gflow::create_client(config_path)?;
let job_id = crate::multicall::gjob::utils::resolve_job_id(&client, job_id_str).await?;
let original_job = match client.get_job(job_id).await? {
Some(job) => job,
None => {
return Err(anyhow!("Job {} not found.", job_id));
}
};
match original_job.state {
JobState::Queued | JobState::Hold => {
return Err(anyhow!(
"Job {} is still in {} state. Use `gjob update` to modify its parameters.",
original_job.id,
original_job.state
));
}
JobState::Running => {
return Err(anyhow!(
"Job {} is still running. Wait for it to finish or cancel it first.",
original_job.id
));
}
_ => {}
}
println!("Resubmitting job {} with parameters:", original_job.id);
let mut builder = Job::builder();
if let Some(ref script) = original_job.script {
builder = builder.script((**script).clone());
print_field!("Script", "{}", script.display());
}
if let Some(ref command) = original_job.command {
builder = builder.command(command.clone());
print_field!("Command", "{}", command);
}
let gpus = gpus_override.unwrap_or(original_job.gpus);
builder = builder.gpus(gpus);
print_field!("GPUs", "{}", gpus);
builder = builder.gpu_sharing_mode(original_job.gpu_sharing_mode);
if original_job.gpu_sharing_mode == GpuSharingMode::Shared {
print_field!("GPUSharing", "shared");
}
let priority = priority_override.unwrap_or(original_job.priority);
builder = builder.priority(priority);
print_field!("Priority", "{}", priority);
let conda_env = if let Some(ref override_env) = conda_env_override {
Some(override_env.clone())
} else {
original_job.conda_env.as_ref().map(|s| s.to_string())
};
builder = builder.conda_env(conda_env.clone());
if let Some(ref env) = conda_env {
print_field!("CondaEnv", "{}", env);
}
let time_limit = if let Some(ref time_str) = time_override {
Some(gflow::utils::parse_time_limit(time_str)?)
} else {
original_job.time_limit
};
builder = builder.time_limit(time_limit);
if let Some(limit) = time_limit {
print_field!("TimeLimit", "{}", gflow::utils::format_duration(limit));
}
let memory_limit_mb = if let Some(ref memory_str) = memory_override {
Some(gflow::utils::parse_memory_limit(memory_str)?)
} else {
original_job.memory_limit_mb
};
builder = builder.memory_limit_mb(memory_limit_mb);
if let Some(memory_mb) = memory_limit_mb {
print_field!("MemoryLimit", "{}", gflow::utils::format_memory(memory_mb));
}
let gpu_memory_limit_mb = if let Some(ref memory_str) = gpu_memory_override {
Some(gflow::utils::parse_memory_limit(memory_str)?)
} else {
original_job.gpu_memory_limit_mb
};
builder = builder.gpu_memory_limit_mb(gpu_memory_limit_mb);
if let Some(memory_mb) = gpu_memory_limit_mb {
print_field!(
"GPUMemoryLimit",
"{}",
gflow::utils::format_memory(memory_mb)
);
}
let depends_on = if clear_deps {
println!(" Dependencies=(cleared)");
None
} else if let Some(ref dep_str) = depends_on_override {
let resolved_dep =
crate::multicall::gjob::utils::resolve_dependency(&client, dep_str).await?;
print_field!("DependsOn", "{}", resolved_dep);
Some(resolved_dep)
} else {
if let Some(dep) = original_job.depends_on {
print_field!("DependsOn", "{}", dep);
}
original_job.depends_on
};
builder = builder.depends_on(depends_on);
builder = builder.run_dir(original_job.run_dir.clone());
builder = builder.task_id(original_job.task_id);
builder = builder.auto_close_tmux(original_job.auto_close_tmux);
builder = builder.parameters_compact(original_job.parameters.clone());
if !original_job.parameters.is_empty() {
println!(" Parameters:");
for (key, value) in &original_job.parameters {
print_field!(key, "{}", value);
}
}
builder = builder.redone_from(Some(original_job.id));
let username = gflow::platform::get_current_username();
builder = builder.submitted_by(username);
let new_job = builder.build();
let response = client
.add_job(new_job)
.await
.context("Failed to submit job")?;
println!(
"\nSubmitted batch job {} ({})",
response.id, response.run_name
);
if cascade {
let cascade_jobs = find_cascade_jobs(&client, original_job.id).await?;
if !cascade_jobs.is_empty() {
println!("\nCascading to {} dependent job(s)...", cascade_jobs.len());
let _mapping =
redo_with_cascade(&client, &original_job, response.id, &cascade_jobs).await?;
println!("\nCascade complete.");
} else {
println!("\nNo dependent jobs to cascade.");
}
}
Ok(())
}
async fn find_cascade_jobs(client: &Client, parent_job_id: u32) -> Result<Vec<Job>> {
let mut queue = VecDeque::new();
let mut visited = HashSet::new();
let mut cascade_jobs = Vec::new();
queue.push_back(parent_job_id);
visited.insert(parent_job_id);
let all_jobs = client.list_jobs().await?;
while let Some(current_id) = queue.pop_front() {
for job in &all_jobs {
if visited.contains(&job.id) {
continue;
}
if job.state == JobState::Cancelled {
if let Some(JobStateReason::DependencyFailed(failed_dep_id)) = job.reason.as_deref()
{
if *failed_dep_id == current_id {
visited.insert(job.id);
queue.push_back(job.id);
cascade_jobs.push(job.clone());
}
}
}
}
}
cascade_jobs.sort_by_key(|job| job.id);
Ok(cascade_jobs)
}
async fn redo_with_cascade(
client: &Client,
original_parent: &Job,
new_parent_id: u32,
cascade_jobs: &[Job],
) -> Result<HashMap<u32, u32>> {
let mut id_mapping = HashMap::new();
id_mapping.insert(original_parent.id, new_parent_id);
for cascade_job in cascade_jobs {
let mut builder = Job::builder();
if let Some(ref script) = cascade_job.script {
builder = builder.script((**script).clone());
}
if let Some(ref command) = cascade_job.command {
builder = builder.command(command.clone());
}
builder = builder.gpus(cascade_job.gpus);
builder = builder.gpu_sharing_mode(cascade_job.gpu_sharing_mode);
builder = builder.gpu_memory_limit_mb(cascade_job.gpu_memory_limit_mb);
builder = builder.priority(cascade_job.priority);
builder = builder.conda_env(cascade_job.conda_env.as_ref().map(|s| s.to_string()));
builder = builder.time_limit(cascade_job.time_limit);
builder = builder.memory_limit_mb(cascade_job.memory_limit_mb);
let updated_depends_on_ids: Vec<u32> = cascade_job
.depends_on_ids
.iter()
.map(|old_id| *id_mapping.get(old_id).unwrap_or(old_id))
.collect();
builder = builder.depends_on_ids(updated_depends_on_ids);
builder = builder.dependency_mode(cascade_job.dependency_mode);
builder = builder
.auto_cancel_on_dependency_failure(cascade_job.auto_cancel_on_dependency_failure);
builder = builder.run_dir(cascade_job.run_dir.clone());
builder = builder.task_id(cascade_job.task_id);
builder = builder.auto_close_tmux(cascade_job.auto_close_tmux);
builder = builder.parameters_compact(cascade_job.parameters.clone());
builder = builder.group_id_uuid(cascade_job.group_id);
builder = builder.max_concurrent(cascade_job.max_concurrent);
builder = builder.redone_from(Some(cascade_job.id));
let username = gflow::platform::get_current_username();
builder = builder.submitted_by(username);
let new_job = builder.build();
let response = client.add_job(new_job).await.context(format!(
"Failed to submit cascade job for {}",
cascade_job.id
))?;
id_mapping.insert(cascade_job.id, response.id);
println!(
" Job {} → Job {} ({})",
cascade_job.id, response.id, response.run_name
);
}
Ok(id_mapping)
}
#[cfg(test)]
mod tests {
use super::*;
use gflow::core::job::{JobBuilder, JobState, JobStateReason};
#[test]
fn test_cascade_job_identification() {
let parent_job = JobBuilder::new()
.submitted_by("test".to_string())
.run_dir("/tmp")
.build();
let cancelled_job = JobBuilder::new()
.submitted_by("test".to_string())
.run_dir("/tmp")
.build();
assert_eq!(parent_job.state, JobState::Queued);
assert_eq!(cancelled_job.state, JobState::Queued);
}
#[test]
fn test_dependency_update_logic() {
let mut id_mapping = HashMap::new();
id_mapping.insert(100, 200);
id_mapping.insert(101, 201);
let old_deps = [100, 101];
let new_deps: Vec<u32> = old_deps
.iter()
.map(|old_id| *id_mapping.get(old_id).unwrap_or(old_id))
.collect();
assert_eq!(new_deps, vec![200, 201]);
}
#[test]
fn test_dependency_update_with_unmapped_ids() {
let mut id_mapping = HashMap::new();
id_mapping.insert(100, 200);
let old_deps = [100, 102]; let new_deps: Vec<u32> = old_deps
.iter()
.map(|old_id| *id_mapping.get(old_id).unwrap_or(old_id))
.collect();
assert_eq!(new_deps, vec![200, 102]);
}
#[test]
fn test_job_state_reason_matching() {
let reason = JobStateReason::DependencyFailed(100);
if let JobStateReason::DependencyFailed(failed_id) = reason {
assert_eq!(failed_id, 100);
} else {
panic!("Expected DependencyFailed reason");
}
}
}