use std::{collections::HashSet, future::Future};
use crate::{
cancellation::CancellationToken,
error::Result,
session::{SessionManager, SessionState},
};
#[derive(Debug)]
pub struct BatchSessionResult {
pub batch_index: usize,
pub session_id: String,
pub outcome: Result<()>,
}
pub async fn run_all_with_parallelism<F, Fut>(
manager: &SessionManager,
parallelism: usize,
cancel_token: CancellationToken,
run_fn: F,
) -> Result<Vec<BatchSessionResult>>
where
F: Fn(SessionState, CancellationToken) -> Fut + Clone + Send + 'static,
Fut: Future<Output = Result<()>> + Send + 'static,
{
if parallelism == 0 {
return Err(crate::error::CruiseError::Other(
"run_all_with_parallelism: parallelism must be ≥ 1 (got 0)".to_string(),
));
}
if cancel_token.is_cancelled() {
return Ok(Vec::new());
}
let mut seen: HashSet<String> = HashSet::new();
let mut queued: HashSet<String> = HashSet::new();
let mut next_batch_index: usize = 0;
let mut completed: Vec<BatchSessionResult> = Vec::new();
let mut join_set: tokio::task::JoinSet<(usize, String, Result<()>)> =
tokio::task::JoinSet::new();
let initial = manager.run_all_remaining(&seen)?;
for s in &initial {
queued.insert(s.id.clone());
}
let mut candidates: std::collections::VecDeque<SessionState> = initial.into_iter().collect();
loop {
while join_set.len() < parallelism && !cancel_token.is_cancelled() {
let Some(session) = candidates.pop_front() else {
break;
};
let session_id = session.id.clone();
queued.remove(&session_id);
if seen.contains(&session_id) {
continue;
}
let batch_index = next_batch_index;
next_batch_index += 1;
seen.insert(session_id.clone());
let run_fn_clone = run_fn.clone();
let token_clone = cancel_token.clone();
join_set.spawn(async move {
let outcome = run_fn_clone(session, token_clone).await;
(batch_index, session_id, outcome)
});
}
if join_set.is_empty() {
break;
}
let Some(task_result) = join_set.join_next().await else {
break;
};
match task_result {
Ok((batch_index, session_id, outcome)) => {
completed.push(BatchSessionResult {
batch_index,
session_id,
outcome,
});
}
Err(join_err) => {
eprintln!("batch_run: worker task panicked: {join_err}");
}
}
if cancel_token.is_cancelled() {
candidates.clear();
queued.clear();
continue;
}
let fresh = manager.run_all_remaining(&seen)?;
for s in fresh {
if !queued.contains(&s.id) {
queued.insert(s.id.clone());
candidates.push_back(s);
}
}
}
completed.sort_by_key(|r| r.batch_index);
Ok(completed)
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Arc, Mutex};
use tempfile::TempDir;
use crate::{
cancellation::CancellationToken,
error::CruiseError,
session::{SessionManager, SessionPhase, SessionState, WorkspaceMode},
};
fn make_planned_session(manager: &SessionManager, id: &str, base_dir: &std::path::Path) {
let mut state = SessionState::new(
id.to_string(),
base_dir.to_path_buf(),
"test.yaml".to_string(),
format!("task for {id}"),
);
state.phase = SessionPhase::Planned;
state.workspace_mode = WorkspaceMode::Worktree;
manager
.create(&state)
.unwrap_or_else(|e| panic!("create session {id}: {e}"));
}
fn instant_completer(
manager: Arc<SessionManager>,
) -> impl Fn(
SessionState,
CancellationToken,
) -> std::pin::Pin<Box<dyn Future<Output = Result<()>> + Send>>
+ Clone
+ Send
+ 'static {
move |session, _cancel| {
let manager = Arc::clone(&manager);
Box::pin(async move {
let mut state = manager
.load(&session.id)
.unwrap_or_else(|e| panic!("load {}: {e}", session.id));
state.phase = SessionPhase::Completed;
manager
.save(&state)
.unwrap_or_else(|e| panic!("save {}: {e}", session.id));
Ok(())
})
}
}
fn recording_completer(
manager: Arc<SessionManager>,
log: Arc<Mutex<Vec<String>>>,
) -> impl Fn(
SessionState,
CancellationToken,
) -> std::pin::Pin<Box<dyn Future<Output = Result<()>> + Send>>
+ Clone
+ Send
+ 'static {
move |session, cancel| {
let manager = Arc::clone(&manager);
let log = Arc::clone(&log);
Box::pin(async move {
log.lock()
.unwrap_or_else(|e| panic!("{e}"))
.push(session.id.clone());
if cancel.is_cancelled() {
return Err(CruiseError::Interrupted);
}
let mut state = manager
.load(&session.id)
.unwrap_or_else(|e| panic!("load {}: {e}", session.id));
state.phase = SessionPhase::Completed;
manager
.save(&state)
.unwrap_or_else(|e| panic!("save {}: {e}", session.id));
Ok(())
})
}
}
#[tokio::test]
async fn test_empty_candidate_list_returns_empty_results() {
let tmp = TempDir::new().unwrap_or_else(|e| panic!("{e:?}"));
let manager = Arc::new(SessionManager::new(tmp.path().to_path_buf()));
let cancel = CancellationToken::new();
let results =
run_all_with_parallelism(&manager, 1, cancel, instant_completer(Arc::clone(&manager)))
.await
.unwrap_or_else(|e| panic!("expected Ok, got: {e}"));
assert!(
results.is_empty(),
"expected no results for empty candidate list"
);
}
#[tokio::test]
async fn test_single_session_is_executed_with_parallelism_one() {
let tmp = TempDir::new().unwrap_or_else(|e| panic!("{e:?}"));
let manager = Arc::new(SessionManager::new(tmp.path().to_path_buf()));
make_planned_session(&manager, "20260101000001", tmp.path());
let cancel = CancellationToken::new();
let results =
run_all_with_parallelism(&manager, 1, cancel, instant_completer(Arc::clone(&manager)))
.await
.unwrap_or_else(|e| panic!("expected Ok, got: {e}"));
assert_eq!(results.len(), 1);
assert_eq!(results[0].session_id, "20260101000001");
assert!(results[0].outcome.is_ok(), "expected Ok outcome");
}
#[tokio::test]
async fn test_multiple_sessions_all_executed() {
let tmp = TempDir::new().unwrap_or_else(|e| panic!("{e:?}"));
let manager = Arc::new(SessionManager::new(tmp.path().to_path_buf()));
for id in ["20260101000001", "20260101000002", "20260101000003"] {
make_planned_session(&manager, id, tmp.path());
}
let cancel = CancellationToken::new();
let results =
run_all_with_parallelism(&manager, 2, cancel, instant_completer(Arc::clone(&manager)))
.await
.unwrap_or_else(|e| panic!("expected Ok, got: {e}"));
assert_eq!(results.len(), 3);
let mut ids: Vec<_> = results.iter().map(|r| r.session_id.as_str()).collect();
ids.sort_unstable();
assert_eq!(ids, ["20260101000001", "20260101000002", "20260101000003"]);
}
#[tokio::test]
async fn test_results_are_sorted_by_batch_index_ascending() {
let tmp = TempDir::new().unwrap_or_else(|e| panic!("{e:?}"));
let manager = Arc::new(SessionManager::new(tmp.path().to_path_buf()));
make_planned_session(&manager, "20260101000001", tmp.path());
make_planned_session(&manager, "20260101000002", tmp.path());
let cancel = CancellationToken::new();
let results =
run_all_with_parallelism(&manager, 2, cancel, instant_completer(Arc::clone(&manager)))
.await
.unwrap_or_else(|e| panic!("expected Ok, got: {e}"));
assert_eq!(results[0].batch_index, 0);
assert_eq!(results[1].batch_index, 1);
}
#[tokio::test]
async fn test_results_maintain_scheduling_order_when_completions_are_out_of_order() {
let tmp = TempDir::new().unwrap_or_else(|e| panic!("{e:?}"));
let manager = Arc::new(SessionManager::new(tmp.path().to_path_buf()));
make_planned_session(&manager, "20260101000001", tmp.path()); make_planned_session(&manager, "20260101000002", tmp.path());
let barrier = Arc::new(tokio::sync::Barrier::new(2));
let slow_id = "20260101000001".to_string();
let manager_clone = Arc::clone(&manager);
let barrier_clone = Arc::clone(&barrier);
let run_fn = {
let manager = Arc::clone(&manager_clone);
let barrier = Arc::clone(&barrier_clone);
move |session: SessionState, _cancel: CancellationToken| {
let manager = Arc::clone(&manager);
let barrier = Arc::clone(&barrier);
let id = session.id.clone();
let slow = id == slow_id;
Box::pin(async move {
if slow {
barrier.wait().await;
} else {
barrier.wait().await;
}
let mut state = manager
.load(&id)
.unwrap_or_else(|e| panic!("load {id}: {e}"));
state.phase = SessionPhase::Completed;
manager
.save(&state)
.unwrap_or_else(|e| panic!("save {id}: {e}"));
Ok(())
}) as std::pin::Pin<Box<dyn Future<Output = Result<()>> + Send>>
}
};
let cancel = CancellationToken::new();
let results = run_all_with_parallelism(&manager, 2, cancel, run_fn)
.await
.unwrap_or_else(|e| panic!("expected Ok, got: {e}"));
assert_eq!(results.len(), 2);
assert_eq!(
results[0].session_id, "20260101000001",
"first-scheduled session must be at index 0"
);
assert_eq!(
results[1].session_id, "20260101000002",
"second-scheduled session must be at index 1"
);
assert_eq!(results[0].batch_index, 0);
assert_eq!(results[1].batch_index, 1);
}
#[tokio::test]
async fn test_session_added_while_first_is_running_is_picked_up() {
let tmp = TempDir::new().unwrap_or_else(|e| panic!("{e:?}"));
let manager = Arc::new(SessionManager::new(tmp.path().to_path_buf()));
make_planned_session(&manager, "20260101000001", tmp.path());
let gate = Arc::new(tokio::sync::Notify::new());
let gate_clone = Arc::clone(&gate);
let manager_for_adder = Arc::clone(&manager);
let tmp_path = tmp.path().to_path_buf();
let run_fn = {
let manager = Arc::clone(&manager);
let gate = Arc::clone(&gate_clone);
move |session: SessionState, _cancel: CancellationToken| {
let manager = Arc::clone(&manager);
let gate = Arc::clone(&gate);
let id = session.id.clone();
Box::pin(async move {
if id == "20260101000001" {
gate.notify_one();
gate.notified().await;
}
let mut state = manager
.load(&id)
.unwrap_or_else(|e| panic!("load {id}: {e}"));
state.phase = SessionPhase::Completed;
manager
.save(&state)
.unwrap_or_else(|e| panic!("save {id}: {e}"));
Ok(())
}) as std::pin::Pin<Box<dyn Future<Output = Result<()>> + Send>>
}
};
let adder = tokio::spawn(async move {
gate_clone.notified().await;
make_planned_session(&manager_for_adder, "20260101000002", &tmp_path);
gate_clone.notify_one();
});
let cancel = CancellationToken::new();
let results = run_all_with_parallelism(&manager, 1, cancel, run_fn)
.await
.unwrap_or_else(|e| panic!("expected Ok, got: {e}"));
adder.await.unwrap_or_else(|e| panic!("{e}"));
assert_eq!(
results.len(),
2,
"session added mid-run must be picked up; got IDs: {:?}",
results.iter().map(|r| &r.session_id).collect::<Vec<_>>()
);
}
#[tokio::test]
async fn test_session_is_not_executed_twice() {
let tmp = TempDir::new().unwrap_or_else(|e| panic!("{e:?}"));
let manager = Arc::new(SessionManager::new(tmp.path().to_path_buf()));
make_planned_session(&manager, "20260101000001", tmp.path());
let execution_count = Arc::new(Mutex::new(0usize));
let count_clone = Arc::clone(&execution_count);
let mgr_clone = Arc::clone(&manager);
let run_fn = move |session: SessionState, _cancel: CancellationToken| {
let manager = Arc::clone(&mgr_clone);
let count = Arc::clone(&count_clone);
let id = session.id.clone();
Box::pin(async move {
*count.lock().unwrap_or_else(|e| panic!("{e}")) += 1;
let mut state = manager
.load(&id)
.unwrap_or_else(|e| panic!("load {id}: {e}"));
state.phase = SessionPhase::Completed;
manager
.save(&state)
.unwrap_or_else(|e| panic!("save {id}: {e}"));
Ok(())
}) as std::pin::Pin<Box<dyn Future<Output = Result<()>> + Send>>
};
let cancel = CancellationToken::new();
run_all_with_parallelism(&manager, 2, cancel, run_fn)
.await
.unwrap_or_else(|e| panic!("expected Ok, got: {e}"));
let count = *execution_count.lock().unwrap_or_else(|e| panic!("{e}"));
assert_eq!(count, 1, "session must not be executed twice");
}
#[tokio::test]
async fn test_cancellation_before_start_returns_empty_results() {
let tmp = TempDir::new().unwrap_or_else(|e| panic!("{e:?}"));
let manager = Arc::new(SessionManager::new(tmp.path().to_path_buf()));
make_planned_session(&manager, "20260101000001", tmp.path());
let cancel = CancellationToken::new();
cancel.cancel();
let execution_log: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
let results = run_all_with_parallelism(
&manager,
1,
cancel,
recording_completer(Arc::clone(&manager), Arc::clone(&execution_log)),
)
.await
.unwrap_or_else(|e| panic!("expected Ok, got: {e}"));
assert!(
results.is_empty(),
"pre-cancelled run must not execute any sessions"
);
assert!(
execution_log
.lock()
.unwrap_or_else(|e| panic!("{e}"))
.is_empty(),
"run_fn must not be called when already cancelled"
);
}
#[tokio::test]
async fn test_cancellation_stops_scheduling_new_sessions() {
let tmp = TempDir::new().unwrap_or_else(|e| panic!("{e:?}"));
let manager = Arc::new(SessionManager::new(tmp.path().to_path_buf()));
make_planned_session(&manager, "20260101000001", tmp.path());
make_planned_session(&manager, "20260101000002", tmp.path());
let execution_log: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
let log_clone = Arc::clone(&execution_log);
let mgr_clone = Arc::clone(&manager);
let cancel = CancellationToken::new();
let cancel_clone = cancel.clone();
let run_fn = move |session: SessionState, _cancel_arg: CancellationToken| {
let manager = Arc::clone(&mgr_clone);
let log = Arc::clone(&log_clone);
let cancel = cancel_clone.clone();
let id = session.id.clone();
Box::pin(async move {
log.lock()
.unwrap_or_else(|e| panic!("{e}"))
.push(id.clone());
if id == "20260101000001" {
cancel.cancel();
}
let mut state = manager
.load(&id)
.unwrap_or_else(|e| panic!("load {id}: {e}"));
state.phase = SessionPhase::Completed;
manager
.save(&state)
.unwrap_or_else(|e| panic!("save {id}: {e}"));
Ok(())
}) as std::pin::Pin<Box<dyn Future<Output = Result<()>> + Send>>
};
let results = run_all_with_parallelism(&manager, 1, cancel, run_fn)
.await
.unwrap_or_else(|e| panic!("expected Ok, got: {e}"));
let log = execution_log
.lock()
.unwrap_or_else(|e| panic!("{e}"))
.clone();
assert!(
log.contains(&"20260101000001".to_string()),
"session-1 must have run"
);
assert!(
!log.contains(&"20260101000002".to_string()),
"session-2 must NOT be scheduled after cancellation"
);
assert_eq!(results.len(), 1);
}
#[tokio::test]
async fn test_zero_parallelism_returns_error() {
let tmp = TempDir::new().unwrap_or_else(|e| panic!("{e:?}"));
let manager = Arc::new(SessionManager::new(tmp.path().to_path_buf()));
make_planned_session(&manager, "20260101000001", tmp.path());
let cancel = CancellationToken::new();
let result =
run_all_with_parallelism(&manager, 0, cancel, instant_completer(Arc::clone(&manager)))
.await;
assert!(result.is_err(), "expected error for parallelism=0, got Ok");
}
#[tokio::test]
async fn test_failed_session_outcome_is_captured_not_propagated() {
let tmp = TempDir::new().unwrap_or_else(|e| panic!("{e:?}"));
let manager = Arc::new(SessionManager::new(tmp.path().to_path_buf()));
make_planned_session(&manager, "20260101000001", tmp.path());
let cancel = CancellationToken::new();
let run_fn = |_session: SessionState, _cancel: CancellationToken| {
Box::pin(async { Err(CruiseError::Other("step failed".to_string())) })
as std::pin::Pin<Box<dyn Future<Output = Result<()>> + Send>>
};
let results = run_all_with_parallelism(&manager, 1, cancel, run_fn)
.await
.unwrap_or_else(|e| panic!("batch error must not propagate, got: {e}"));
assert_eq!(results.len(), 1);
assert!(
results[0].outcome.is_err(),
"failed session outcome must be captured inside BatchSessionResult"
);
}
}