1use super::checkpoint::WorkflowCheckpoint;
9use super::executor::{AgentExecutor, AgentStepSpec, StepOutcome};
10use crate::agent::AgentEvent;
11use crate::ordered_parallel::run_ordered_parallel_with_limit;
12use crate::store::SessionStore;
13use std::collections::HashMap;
14use std::sync::Arc;
15use tokio::sync::broadcast;
16
17fn now_epoch_ms() -> u64 {
18 std::time::SystemTime::now()
19 .duration_since(std::time::UNIX_EPOCH)
20 .map(|d| d.as_millis() as u64)
21 .unwrap_or(0)
22}
23
24pub type PipelineStage<I> =
31 Arc<dyn Fn(Option<&StepOutcome>, &I) -> Option<AgentStepSpec> + Send + Sync>;
32
33pub async fn execute_pipeline<I>(
47 executor: Arc<dyn AgentExecutor>,
48 items: Vec<I>,
49 stages: Vec<PipelineStage<I>>,
50 event_tx: Option<broadcast::Sender<AgentEvent>>,
51) -> Vec<Option<StepOutcome>>
52where
53 I: Send + 'static,
54{
55 let limit = executor.concurrency_hint();
56 let stages = Arc::new(stages);
57
58 let results = run_ordered_parallel_with_limit(items, limit, move |_idx, item| {
59 let executor = Arc::clone(&executor);
60 let stages = Arc::clone(&stages);
61 let event_tx = event_tx.clone();
62 async move {
63 let mut prev: Option<StepOutcome> = None;
64 for stage in stages.iter() {
65 let Some(spec) = stage(prev.as_ref(), &item) else {
66 break;
67 };
68 let outcome = executor.execute_step(spec, event_tx.clone()).await;
69 let succeeded = outcome.success;
70 prev = Some(outcome);
71 if !succeeded {
72 break;
73 }
74 }
75 prev
76 }
77 })
78 .await;
79
80 results
83 .into_iter()
84 .map(|result| result.output.unwrap_or(None))
85 .collect()
86}
87
88pub async fn execute_steps_parallel_resumable(
103 executor: Arc<dyn AgentExecutor>,
104 specs: Vec<AgentStepSpec>,
105 workflow_id: &str,
106 store: Arc<dyn SessionStore>,
107 event_tx: Option<broadcast::Sender<AgentEvent>>,
108) -> Vec<StepOutcome> {
109 let done: HashMap<String, StepOutcome> = match store.load_workflow_checkpoint(workflow_id).await
115 {
116 Ok(Some(cp)) => cp.completed(),
117 Ok(None) => HashMap::new(),
118 Err(e) => {
119 tracing::warn!(
120 workflow_id = %workflow_id,
121 error = %e,
122 "workflow checkpoint unreadable; re-running the workflow from scratch"
123 );
124 HashMap::new()
125 }
126 };
127
128 let pending: Vec<AgentStepSpec> = specs
129 .iter()
130 .filter(|s| !done.contains_key(&s.task_id))
131 .cloned()
132 .collect();
133 let labels: Vec<(String, String)> = pending
134 .iter()
135 .map(|s| (s.task_id.clone(), s.agent.clone()))
136 .collect();
137
138 let acc = Arc::new(tokio::sync::Mutex::new(done.clone()));
140 let limit = executor.concurrency_hint();
141 let workflow_id_owned = workflow_id.to_string();
142 let store_steps = Arc::clone(&store);
143
144 let results = run_ordered_parallel_with_limit(pending, limit, move |_idx, spec| {
145 let executor = Arc::clone(&executor);
146 let event_tx = event_tx.clone();
147 let acc = Arc::clone(&acc);
148 let store = Arc::clone(&store_steps);
149 let workflow_id = workflow_id_owned.clone();
150 async move {
151 let outcome = executor.execute_step(spec, event_tx).await;
152 if outcome.success {
156 let mut guard = acc.lock().await;
157 guard.insert(outcome.task_id.clone(), outcome.clone());
158 let checkpoint =
159 WorkflowCheckpoint::from_completed(&workflow_id, &guard, now_epoch_ms());
160 if let Err(e) = store
161 .save_workflow_checkpoint(&workflow_id, &checkpoint)
162 .await
163 {
164 tracing::warn!(
166 workflow_id = %workflow_id,
167 error = %e,
168 "workflow checkpoint save failed; run continues"
169 );
170 }
171 }
172 outcome
173 }
174 })
175 .await;
176
177 let mut fresh: HashMap<String, StepOutcome> = HashMap::new();
178 for result in results {
179 match result.output {
180 Ok(outcome) => {
181 fresh.insert(outcome.task_id.clone(), outcome);
182 }
183 Err(error) => {
184 if let Some((task_id, agent)) = labels.get(result.index).cloned() {
185 fresh.insert(
186 task_id.clone(),
187 StepOutcome::failed(task_id, agent, error.to_string()),
188 );
189 }
190 }
191 }
192 }
193
194 let merged: Vec<StepOutcome> = specs
196 .iter()
197 .map(|s| {
198 done.get(&s.task_id)
199 .cloned()
200 .or_else(|| fresh.remove(&s.task_id))
201 .unwrap_or_else(|| {
202 StepOutcome::failed(
203 s.task_id.clone(),
204 s.agent.clone(),
205 "step produced no outcome",
206 )
207 })
208 })
209 .collect();
210
211 if merged.iter().all(|o| o.success) {
212 let _ = store.delete_workflow_checkpoint(workflow_id).await;
213 }
214 merged
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220 use async_trait::async_trait;
221 use std::sync::atomic::{AtomicUsize, Ordering};
222 use std::time::Duration;
223
224 struct EchoExecutor {
227 active: Arc<AtomicUsize>,
228 max_active: Arc<AtomicUsize>,
229 }
230
231 impl EchoExecutor {
232 fn new() -> Self {
233 Self {
234 active: Arc::new(AtomicUsize::new(0)),
235 max_active: Arc::new(AtomicUsize::new(0)),
236 }
237 }
238 }
239
240 #[async_trait]
241 impl AgentExecutor for EchoExecutor {
242 async fn execute_step(
243 &self,
244 spec: AgentStepSpec,
245 _event_tx: Option<broadcast::Sender<AgentEvent>>,
246 ) -> StepOutcome {
247 let now = self.active.fetch_add(1, Ordering::SeqCst) + 1;
248 self.max_active.fetch_max(now, Ordering::SeqCst);
249 tokio::time::sleep(Duration::from_millis(15)).await;
250 self.active.fetch_sub(1, Ordering::SeqCst);
251 assert!(spec.agent != "boom", "boom");
252 StepOutcome {
253 task_id: spec.task_id.clone(),
254 session_id: format!("task-run-{}", spec.task_id),
255 agent: spec.agent.clone(),
256 output: spec.prompt.clone(),
257 success: spec.agent != "fail",
258 structured: None,
259 }
260 }
261 fn concurrency_hint(&self) -> usize {
262 4
263 }
264 }
265
266 fn stage<I, F>(f: F) -> PipelineStage<I>
267 where
268 F: Fn(Option<&StepOutcome>, &I) -> Option<AgentStepSpec> + Send + Sync + 'static,
269 {
270 Arc::new(f)
271 }
272
273 #[tokio::test]
274 async fn each_item_chains_through_stages_and_later_stages_see_prior_output() {
275 let exec: Arc<dyn AgentExecutor> = Arc::new(EchoExecutor::new());
276 let stages = vec![
279 stage(|_prev: Option<&StepOutcome>, item: &&str| {
280 Some(AgentStepSpec::new("s1", "explore", "d", *item))
281 }),
282 stage(|prev: Option<&StepOutcome>, _item: &&str| {
283 let prior = prev.map(|o| o.output.clone()).unwrap_or_default();
284 Some(AgentStepSpec::new(
285 "s2",
286 "review",
287 "d",
288 format!("review of: {prior}"),
289 ))
290 }),
291 ];
292 let out = execute_pipeline(exec, vec!["alpha", "beta"], stages, None).await;
293
294 assert_eq!(out.len(), 2, "one result per item, order preserved");
295 assert_eq!(out[0].as_ref().unwrap().output, "review of: alpha");
298 assert_eq!(out[1].as_ref().unwrap().output, "review of: beta");
299 assert!(out.iter().all(|o| o.as_ref().unwrap().success));
300 }
301
302 #[tokio::test]
303 async fn chain_stops_on_failure_and_on_none_stage() {
304 let exec: Arc<dyn AgentExecutor> = Arc::new(EchoExecutor::new());
305 let stages = vec![
308 stage(|_p: Option<&StepOutcome>, item: &&str| {
309 let agent = if *item == "x" { "fail" } else { "explore" };
310 Some(AgentStepSpec::new("s1", agent, "d", *item))
311 }),
312 stage(|_p: Option<&StepOutcome>, item: &&str| {
313 if *item == "y" {
314 None } else {
316 Some(AgentStepSpec::new("s2", "review", "d", "second"))
317 }
318 }),
319 ];
320 let out = execute_pipeline(exec, vec!["x", "y"], stages, None).await;
321
322 let first = out[0].as_ref().unwrap();
323 assert!(!first.success, "failed stage 1 surfaces");
324 assert_eq!(
325 first.output, "x",
326 "stage 2 did not run after stage 1 failed"
327 );
328
329 let second = out[1].as_ref().unwrap();
330 assert!(second.success);
331 assert_eq!(
332 second.output, "y",
333 "stage 2 returned None → chain stopped at stage 1"
334 );
335 }
336
337 #[tokio::test]
338 async fn no_barrier_between_stages_bounded_by_hint() {
339 let echo = EchoExecutor::new();
340 let max_active = Arc::clone(&echo.max_active);
341 let exec: Arc<dyn AgentExecutor> = Arc::new(echo);
342 let stages = vec![
343 stage(|_p: Option<&StepOutcome>, item: &usize| {
344 Some(AgentStepSpec::new(
345 format!("s1-{item}"),
346 "explore",
347 "d",
348 "p",
349 ))
350 }),
351 stage(|_p: Option<&StepOutcome>, item: &usize| {
352 Some(AgentStepSpec::new(format!("s2-{item}"), "review", "d", "p"))
353 }),
354 ];
355 let items: Vec<usize> = (0..8).collect();
356 let out = execute_pipeline(exec, items, stages, None).await;
357 assert_eq!(out.len(), 8);
358 assert!(out.iter().all(|o| o.is_some()));
359 assert!(
361 max_active.load(Ordering::SeqCst) <= 4,
362 "concurrency never exceeds the executor's hint"
363 );
364 }
365
366 #[tokio::test]
367 async fn panicking_stage_isolates_to_its_chain() {
368 let exec: Arc<dyn AgentExecutor> = Arc::new(EchoExecutor::new());
369 let stages = vec![stage(|_p: Option<&StepOutcome>, item: &&str| {
370 Some(AgentStepSpec::new("s1", *item, "d", "p"))
372 })];
373 let out = execute_pipeline(exec, vec!["explore", "boom", "review"], stages, None).await;
374 assert_eq!(out.len(), 3);
375 assert!(out[0].as_ref().unwrap().success);
376 assert!(out[1].is_none(), "panicked chain becomes None, not a drop");
377 assert!(out[2].as_ref().unwrap().success, "later chains unaffected");
378 }
379
380 struct RecordingExecutor {
382 ran: Arc<tokio::sync::Mutex<Vec<String>>>,
383 }
384
385 #[async_trait]
386 impl AgentExecutor for RecordingExecutor {
387 async fn execute_step(
388 &self,
389 spec: AgentStepSpec,
390 _event_tx: Option<broadcast::Sender<AgentEvent>>,
391 ) -> StepOutcome {
392 self.ran.lock().await.push(spec.task_id.clone());
393 StepOutcome {
394 task_id: spec.task_id.clone(),
395 session_id: format!("task-run-{}", spec.task_id),
396 agent: spec.agent.clone(),
397 output: format!("ran:{}", spec.task_id),
398 success: true,
399 structured: None,
400 }
401 }
402 fn concurrency_hint(&self) -> usize {
403 4
404 }
405 }
406
407 #[tokio::test]
408 async fn resumable_skips_completed_then_clears_on_success() {
409 use crate::store::MemorySessionStore;
410 let store: Arc<dyn SessionStore> = Arc::new(MemorySessionStore::new());
411
412 let mut done = std::collections::HashMap::new();
415 done.insert(
416 "a".to_string(),
417 StepOutcome {
418 task_id: "a".into(),
419 session_id: "task-run-a".into(),
420 agent: "explore".into(),
421 output: "cached-a".into(),
422 success: true,
423 structured: None,
424 },
425 );
426 store
427 .save_workflow_checkpoint(
428 "wf-1",
429 &WorkflowCheckpoint::from_completed("wf-1", &done, 1),
430 )
431 .await
432 .unwrap();
433
434 let ran = Arc::new(tokio::sync::Mutex::new(Vec::new()));
437 let exec: Arc<dyn AgentExecutor> = Arc::new(RecordingExecutor {
438 ran: Arc::clone(&ran),
439 });
440 let specs = vec![
441 AgentStepSpec::new("a", "explore", "d", "pa"),
442 AgentStepSpec::new("b", "review", "d", "pb"),
443 ];
444
445 let out =
446 execute_steps_parallel_resumable(exec, specs, "wf-1", Arc::clone(&store), None).await;
447
448 assert_eq!(
449 *ran.lock().await,
450 vec!["b".to_string()],
451 "only the not-yet-completed step runs"
452 );
453 assert_eq!(out.len(), 2);
454 assert_eq!(out[0].task_id, "a");
455 assert_eq!(
456 out[0].output, "cached-a",
457 "completed step returns its cached outcome, unchanged"
458 );
459 assert_eq!(out[1].task_id, "b");
460 assert!(out.iter().all(|o| o.success));
461 assert!(
462 store
463 .load_workflow_checkpoint("wf-1")
464 .await
465 .unwrap()
466 .is_none(),
467 "a fully-succeeded workflow clears its checkpoint"
468 );
469 }
470
471 #[tokio::test]
472 async fn resumable_retains_checkpoint_recording_only_successes_on_partial_failure() {
473 use crate::store::MemorySessionStore;
474 let store: Arc<dyn SessionStore> = Arc::new(MemorySessionStore::new());
475 let exec: Arc<dyn AgentExecutor> = Arc::new(EchoExecutor::new());
477 let specs = vec![
478 AgentStepSpec::new("ok", "explore", "d", "p"),
479 AgentStepSpec::new("bad", "fail", "d", "p"),
480 ];
481
482 let out =
483 execute_steps_parallel_resumable(exec, specs, "wf-2", Arc::clone(&store), None).await;
484 assert!(out[0].success);
485 assert!(!out[1].success);
486
487 let cp = store
490 .load_workflow_checkpoint("wf-2")
491 .await
492 .unwrap()
493 .expect("checkpoint retained on partial failure");
494 let completed = cp.completed();
495 assert!(completed.contains_key("ok"), "succeeded step is recorded");
496 assert!(
497 !completed.contains_key("bad"),
498 "failed step is NOT recorded → it retries on resume"
499 );
500 }
501
502 struct ZeroHintExecutor;
503 #[async_trait]
504 impl AgentExecutor for ZeroHintExecutor {
505 async fn execute_step(
506 &self,
507 spec: AgentStepSpec,
508 _event_tx: Option<broadcast::Sender<AgentEvent>>,
509 ) -> StepOutcome {
510 StepOutcome {
511 task_id: spec.task_id.clone(),
512 session_id: format!("task-run-{}", spec.task_id),
513 agent: spec.agent.clone(),
514 output: "ok".to_string(),
515 success: true,
516 structured: None,
517 }
518 }
519 fn concurrency_hint(&self) -> usize {
520 0
521 }
522 }
523
524 #[tokio::test]
525 async fn empty_inputs_return_empty() {
526 let exec: Arc<dyn AgentExecutor> = Arc::new(EchoExecutor::new());
527 assert!(
528 crate::orchestration::execute_steps_parallel(Arc::clone(&exec), vec![], None)
529 .await
530 .is_empty()
531 );
532 let stages: Vec<PipelineStage<&str>> =
533 vec![stage(|_p: Option<&StepOutcome>, item: &&str| {
534 Some(AgentStepSpec::new("s", "explore", "d", *item))
535 })];
536 assert!(execute_pipeline(exec, Vec::<&str>::new(), stages, None)
537 .await
538 .is_empty());
539 }
540
541 #[tokio::test]
542 async fn zero_concurrency_hint_still_makes_progress() {
543 let exec: Arc<dyn AgentExecutor> = Arc::new(ZeroHintExecutor);
546 let specs = vec![
547 AgentStepSpec::new("a", "explore", "d", "p"),
548 AgentStepSpec::new("b", "explore", "d", "p"),
549 AgentStepSpec::new("c", "explore", "d", "p"),
550 ];
551 let out = crate::orchestration::execute_steps_parallel(exec, specs, None).await;
552 assert_eq!(
553 out.iter().map(|o| o.task_id.as_str()).collect::<Vec<_>>(),
554 vec!["a", "b", "c"]
555 );
556 assert!(out.iter().all(|o| o.success));
557 }
558
559 #[tokio::test]
560 async fn pipeline_first_stage_none_yields_none_outcome() {
561 let exec: Arc<dyn AgentExecutor> = Arc::new(EchoExecutor::new());
562 let stages: Vec<PipelineStage<&str>> =
563 vec![stage(|_p: Option<&StepOutcome>, item: &&str| {
564 if *item == "skip" {
565 None
566 } else {
567 Some(AgentStepSpec::new("s", "explore", "d", *item))
568 }
569 })];
570 let out = execute_pipeline(exec, vec!["skip", "run"], stages, None).await;
571 assert!(
572 out[0].is_none(),
573 "a first-stage None yields a None outcome (chain never started)"
574 );
575 assert!(out[1].as_ref().unwrap().success);
576 }
577
578 fn cached(task_id: &str, agent: &str, output: &str) -> StepOutcome {
579 StepOutcome {
580 task_id: task_id.to_string(),
581 session_id: format!("task-run-{task_id}"),
582 agent: agent.to_string(),
583 output: output.to_string(),
584 success: true,
585 structured: None,
586 }
587 }
588
589 #[tokio::test]
590 async fn resumable_reruns_all_when_checkpoint_load_errors() {
591 use crate::store::MemorySessionStore;
592 let store: Arc<dyn SessionStore> = Arc::new(MemorySessionStore::new());
593
594 let mut done = std::collections::HashMap::new();
599 done.insert("a".to_string(), cached("a", "explore", "old"));
600 let mut cp = WorkflowCheckpoint::from_completed("wf-err", &done, 1);
601 cp.schema_version = crate::orchestration::WORKFLOW_CHECKPOINT_SCHEMA_VERSION + 1;
602 store.save_workflow_checkpoint("wf-err", &cp).await.unwrap();
603
604 let ran = Arc::new(tokio::sync::Mutex::new(Vec::new()));
605 let exec: Arc<dyn AgentExecutor> = Arc::new(RecordingExecutor {
606 ran: Arc::clone(&ran),
607 });
608 let specs = vec![
609 AgentStepSpec::new("a", "explore", "d", "pa"),
610 AgentStepSpec::new("b", "review", "d", "pb"),
611 ];
612 let out =
613 execute_steps_parallel_resumable(exec, specs, "wf-err", Arc::clone(&store), None).await;
614
615 let mut ran_ids = ran.lock().await.clone();
616 ran_ids.sort();
617 assert_eq!(
618 ran_ids,
619 vec!["a".to_string(), "b".to_string()],
620 "an unreadable (future-version) checkpoint is ignored → all steps re-run"
621 );
622 assert_eq!(out.len(), 2);
623 assert!(out.iter().all(|o| o.success));
624 }
625
626 #[tokio::test]
627 async fn resumable_ignores_checkpointed_steps_absent_from_new_specs() {
628 use crate::store::MemorySessionStore;
629 let store: Arc<dyn SessionStore> = Arc::new(MemorySessionStore::new());
630
631 let mut done = std::collections::HashMap::new();
635 done.insert("a".to_string(), cached("a", "explore", "cached-a"));
636 done.insert("b".to_string(), cached("b", "review", "cached-b"));
637 store
638 .save_workflow_checkpoint(
639 "wf-x",
640 &WorkflowCheckpoint::from_completed("wf-x", &done, 1),
641 )
642 .await
643 .unwrap();
644
645 let ran = Arc::new(tokio::sync::Mutex::new(Vec::new()));
646 let exec: Arc<dyn AgentExecutor> = Arc::new(RecordingExecutor {
647 ran: Arc::clone(&ran),
648 });
649 let specs = vec![
650 AgentStepSpec::new("b", "review", "d", "pb"),
651 AgentStepSpec::new("c", "plan", "d", "pc"),
652 ];
653 let out =
654 execute_steps_parallel_resumable(exec, specs, "wf-x", Arc::clone(&store), None).await;
655
656 assert_eq!(
657 *ran.lock().await,
658 vec!["c".to_string()],
659 "cached b reused, stale a dropped, only new c runs"
660 );
661 assert_eq!(out.len(), 2);
662 assert_eq!(out[0].task_id, "b");
663 assert_eq!(out[0].output, "cached-b");
664 assert_eq!(out[1].task_id, "c");
665 assert!(out.iter().all(|o| o.success));
666 }
667}