1use std::collections::HashMap;
10use std::sync::Arc;
11use std::time::{SystemTime, UNIX_EPOCH};
12
13use serde::{Deserialize, Serialize};
14
15use crate::pool::Pool;
16use crate::store::PoolStore;
17use crate::types::{TaskId, TaskOverrides, TaskState};
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct ChainStep {
22 pub name: String,
24
25 pub action: StepAction,
27
28 pub config: Option<TaskOverrides>,
30
31 #[serde(default)]
33 pub failure_policy: StepFailurePolicy,
34
35 #[serde(default)]
45 pub output_vars: HashMap<String, String>,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50#[serde(tag = "type", rename_all = "snake_case")]
51pub enum StepAction {
52 Prompt {
55 prompt: String,
57 },
58}
59
60#[derive(Debug, Clone, Default, Serialize, Deserialize)]
62pub struct StepFailurePolicy {
63 #[serde(default)]
65 pub retries: u32,
66 pub recovery_prompt: Option<String>,
70}
71
72#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
74#[serde(rename_all = "snake_case")]
75pub enum ChainIsolation {
76 None,
78 #[default]
80 Worktree,
81 Clone,
83}
84
85#[derive(Debug, Clone, Default, Serialize, Deserialize)]
87pub struct ChainOptions {
88 #[serde(default)]
90 pub tags: Vec<String>,
91 #[serde(default)]
93 pub isolation: ChainIsolation,
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct StepResult {
99 pub name: String,
101 pub output: String,
103 pub success: bool,
105 pub cost_microdollars: u64,
107 #[serde(default)]
109 pub retries_used: u32,
110 #[serde(default)]
112 pub skipped: bool,
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct ChainResult {
118 pub steps: Vec<StepResult>,
120 pub final_output: String,
122 pub total_cost_microdollars: u64,
124 pub success: bool,
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct ChainProgress {
131 pub total_steps: usize,
133 pub current_step: Option<usize>,
135 pub current_step_name: Option<String>,
137 #[serde(skip_serializing_if = "Option::is_none")]
142 pub current_step_partial_output: Option<String>,
143 #[serde(skip_serializing_if = "Option::is_none")]
147 pub current_step_started_at: Option<u64>,
148 pub completed_steps: Vec<StepResult>,
150 pub status: ChainStatus,
152}
153
154#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
156#[serde(rename_all = "snake_case")]
157pub enum ChainStatus {
158 Running,
160 Completed,
162 Failed,
164 Cancelled,
166}
167
168pub type OnOutputChunk = Arc<dyn Fn(&str) + Send + Sync>;
170
171fn extract_json_path(json_str: &str, path: &str) -> Option<String> {
172 if path == "." || path.is_empty() {
173 return Some(json_str.to_string());
174 }
175 let value: serde_json::Value = serde_json::from_str(json_str).ok()?;
176 let mut current = &value;
177 for key in path.split('.') {
178 current = current.get(key)?;
179 }
180 Some(match current {
181 serde_json::Value::String(s) => s.clone(),
182 other => other.to_string(),
183 })
184}
185
186fn expand_step_refs(mut text: String, step_context: &HashMap<String, String>) -> String {
187 for (key, value) in step_context {
188 text = text.replace(&format!("{{steps.{key}}}"), value);
189 }
190 text
191}
192
193fn unix_secs_now() -> u64 {
194 SystemTime::now()
195 .duration_since(UNIX_EPOCH)
196 .unwrap_or_default()
197 .as_secs()
198}
199
200pub async fn execute_chain<S: PoolStore + 'static>(
202 pool: &Pool<S>,
203 steps: &[ChainStep],
204) -> crate::Result<ChainResult> {
205 execute_chain_with_progress(pool, steps, None, None).await
206}
207
208pub async fn execute_chain_with_progress<S: PoolStore + 'static>(
216 pool: &Pool<S>,
217 steps: &[ChainStep],
218 chain_task_id: Option<&TaskId>,
219 working_dir: Option<&std::path::Path>,
220) -> crate::Result<ChainResult> {
221 let mut step_results = Vec::with_capacity(steps.len());
222 let mut previous_output = String::new();
223 let mut total_cost = 0u64;
224 let mut step_context: HashMap<String, String> = HashMap::new();
225
226 for (step_idx, step) in steps.iter().enumerate() {
227 if let Some(task_id) = chain_task_id
229 && let Ok(Some(task)) = pool.store().get_task(task_id).await
230 && task.state == TaskState::Cancelled
231 {
232 for s in &steps[step_idx..] {
233 step_results.push(StepResult {
234 name: s.name.clone(),
235 output: String::new(),
236 success: false,
237 cost_microdollars: 0,
238 retries_used: 0,
239 skipped: true,
240 });
241 }
242 update_chain_progress_final(
243 pool,
244 Some(task_id),
245 steps.len(),
246 &step_results,
247 ChainStatus::Cancelled,
248 )
249 .await;
250 return Ok(ChainResult {
251 final_output: previous_output,
252 steps: step_results,
253 total_cost_microdollars: total_cost,
254 success: false,
255 });
256 }
257
258 if let Some(task_id) = chain_task_id {
260 let progress = ChainProgress {
261 total_steps: steps.len(),
262 current_step: Some(step_idx),
263 current_step_name: Some(step.name.clone()),
264 current_step_partial_output: Some(String::new()),
265 current_step_started_at: Some(unix_secs_now()),
266 completed_steps: step_results.clone(),
267 status: ChainStatus::Running,
268 };
269 pool.set_chain_progress(task_id, progress).await;
270 }
271
272 let prompt = render_step_prompt(step, &previous_output, &step_context)?;
273
274 let on_output: Option<OnOutputChunk> = chain_task_id.map(|tid| {
276 let pool = pool.clone();
277 let tid = tid.clone();
278 Arc::new(move |chunk: &str| {
279 pool.append_chain_partial_output(&tid, chunk);
280 }) as OnOutputChunk
281 });
282
283 let (step_result, step_cost) = execute_step_with_retries(
284 pool,
285 step,
286 &prompt,
287 &previous_output,
288 on_output.clone(),
289 working_dir,
290 &step_context,
291 )
292 .await;
293
294 total_cost += step_cost;
295
296 match step_result {
297 Ok(result) => {
298 previous_output = result.output.clone();
299
300 if result.success {
301 for (var_name, path) in &step.output_vars {
302 match extract_json_path(&result.output, path) {
303 Some(extracted) => {
304 step_context
305 .insert(format!("{}.{}", step.name, var_name), extracted);
306 }
307 None => {
308 tracing::warn!(
309 step = %step.name,
310 var = %var_name,
311 path = %path,
312 "output_var extraction failed (output not JSON or path not found)"
313 );
314 }
315 }
316 }
317 }
318
319 step_results.push(result);
320
321 if !step_results.last().unwrap().success {
322 update_chain_progress_final(
323 pool,
324 chain_task_id,
325 steps.len(),
326 &step_results,
327 ChainStatus::Failed,
328 )
329 .await;
330 return Ok(ChainResult {
331 final_output: previous_output,
332 steps: step_results,
333 total_cost_microdollars: total_cost,
334 success: false,
335 });
336 }
337 }
338 Err(output) => {
339 step_results.push(StepResult {
340 name: step.name.clone(),
341 output: output.clone(),
342 success: false,
343 cost_microdollars: 0,
344 retries_used: step.failure_policy.retries,
345 skipped: false,
346 });
347 update_chain_progress_final(
348 pool,
349 chain_task_id,
350 steps.len(),
351 &step_results,
352 ChainStatus::Failed,
353 )
354 .await;
355 return Ok(ChainResult {
356 final_output: output,
357 steps: step_results,
358 total_cost_microdollars: total_cost,
359 success: false,
360 });
361 }
362 }
363 }
364
365 update_chain_progress_final(
366 pool,
367 chain_task_id,
368 steps.len(),
369 &step_results,
370 ChainStatus::Completed,
371 )
372 .await;
373
374 Ok(ChainResult {
375 final_output: previous_output,
376 steps: step_results,
377 total_cost_microdollars: total_cost,
378 success: true,
379 })
380}
381
382fn render_step_prompt(
384 step: &ChainStep,
385 previous_output: &str,
386 step_context: &HashMap<String, String>,
387) -> crate::Result<String> {
388 let StepAction::Prompt { prompt } = &step.action;
389 let rendered = prompt.replace("{previous_output}", previous_output);
390 Ok(expand_step_refs(rendered, step_context))
391}
392
393async fn execute_step_with_retries<S: PoolStore + 'static>(
398 pool: &Pool<S>,
399 step: &ChainStep,
400 initial_prompt: &str,
401 previous_output: &str,
402 on_output: Option<OnOutputChunk>,
403 working_dir: Option<&std::path::Path>,
404 step_context: &HashMap<String, String>,
405) -> (std::result::Result<StepResult, String>, u64) {
406 let max_attempts = 1 + step.failure_policy.retries;
407 let mut total_cost = 0u64;
408 let mut last_error = String::new();
409
410 for attempt in 0..max_attempts {
411 let prompt = if attempt == 0 {
412 initial_prompt.to_string()
413 } else {
414 match render_step_prompt(step, previous_output, step_context) {
416 Ok(p) => p,
417 Err(e) => return (Err(e.to_string()), total_cost),
418 }
419 };
420
421 let result = pool
422 .run_with_config_streaming(
423 &prompt,
424 step.config.clone(),
425 on_output.clone(),
426 working_dir.map(|p| p.to_path_buf()),
427 )
428 .await;
429
430 match result {
431 Ok(task_result) => {
432 total_cost += task_result.cost_microdollars;
433 if task_result.success {
434 return (
435 Ok(StepResult {
436 name: step.name.clone(),
437 output: task_result.output,
438 success: true,
439 cost_microdollars: total_cost,
440 retries_used: attempt,
441 skipped: false,
442 }),
443 total_cost,
444 );
445 }
446 last_error = task_result.output;
448 }
449 Err(e) => {
450 last_error = e.to_string();
451 }
452 }
453
454 tracing::warn!(
455 step = %step.name,
456 attempt = attempt + 1,
457 max_attempts,
458 "chain step failed, will retry"
459 );
460 }
461
462 if let Some(ref recovery_template) = step.failure_policy.recovery_prompt {
464 let recovery_prompt = expand_step_refs(
465 recovery_template
466 .replace("{error}", &last_error)
467 .replace("{previous_output}", previous_output),
468 step_context,
469 );
470
471 tracing::info!(step = %step.name, "attempting recovery prompt");
472
473 let result = pool
474 .run_with_config_streaming(
475 &recovery_prompt,
476 step.config.clone(),
477 on_output,
478 working_dir.map(|p| p.to_path_buf()),
479 )
480 .await;
481
482 match result {
483 Ok(task_result) => {
484 total_cost += task_result.cost_microdollars;
485 return (
486 Ok(StepResult {
487 name: step.name.clone(),
488 output: task_result.output,
489 success: task_result.success,
490 cost_microdollars: total_cost,
491 retries_used: max_attempts,
492 skipped: false,
493 }),
494 total_cost,
495 );
496 }
497 Err(e) => {
498 last_error = e.to_string();
499 }
500 }
501 }
502
503 (Err(last_error), total_cost)
504}
505
506async fn update_chain_progress_final<S: PoolStore + 'static>(
508 pool: &Pool<S>,
509 chain_task_id: Option<&TaskId>,
510 total_steps: usize,
511 completed_steps: &[StepResult],
512 status: ChainStatus,
513) {
514 if let Some(task_id) = chain_task_id {
515 let progress = ChainProgress {
516 total_steps,
517 current_step: None,
518 current_step_name: None,
519 current_step_partial_output: None,
520 current_step_started_at: None,
521 completed_steps: completed_steps.to_vec(),
522 status,
523 };
524 pool.set_chain_progress(task_id, progress).await;
525 }
526}
527
528#[cfg(test)]
529mod tests {
530 use super::*;
531
532 #[test]
533 fn prompt_step_replaces_previous_output() {
534 let step = ChainStep {
535 name: "step1".into(),
536 action: StepAction::Prompt {
537 prompt: "Based on: {previous_output}\nDo more.".into(),
538 },
539 config: None,
540 failure_policy: StepFailurePolicy::default(),
541 output_vars: Default::default(),
542 };
543
544 let StepAction::Prompt { prompt } = &step.action;
545 let rendered = prompt.replace("{previous_output}", "hello world");
546 assert_eq!(rendered, "Based on: hello world\nDo more.");
547 }
548
549 #[test]
550 fn chain_result_serializes() {
551 let result = ChainResult {
552 steps: vec![StepResult {
553 name: "step1".into(),
554 output: "done".into(),
555 success: true,
556 cost_microdollars: 1000,
557 retries_used: 0,
558 skipped: false,
559 }],
560 final_output: "done".into(),
561 total_cost_microdollars: 1000,
562 success: true,
563 };
564
565 let json = serde_json::to_string(&result).unwrap();
566 assert!(json.contains("step1"));
567 }
568
569 #[test]
570 fn step_failure_policy_defaults() {
571 let policy = StepFailurePolicy::default();
572 assert_eq!(policy.retries, 0);
573 assert!(policy.recovery_prompt.is_none());
574 }
575
576 #[test]
577 fn chain_options_defaults() {
578 let opts = ChainOptions::default();
579 assert!(opts.tags.is_empty());
580 assert_eq!(opts.isolation, ChainIsolation::Worktree);
581 }
582
583 #[test]
584 fn chain_isolation_serde_roundtrip() {
585 let worktree = ChainIsolation::Worktree;
586 let json = serde_json::to_string(&worktree).unwrap();
587 assert_eq!(json, r#""worktree""#);
588
589 let none = ChainIsolation::None;
590 let json = serde_json::to_string(&none).unwrap();
591 assert_eq!(json, r#""none""#);
592
593 let parsed: ChainIsolation = serde_json::from_str(r#""worktree""#).unwrap();
594 assert_eq!(parsed, ChainIsolation::Worktree);
595
596 let parsed: ChainIsolation = serde_json::from_str(r#""none""#).unwrap();
597 assert_eq!(parsed, ChainIsolation::None);
598 }
599
600 #[test]
601 fn chain_options_with_isolation_serializes() {
602 let opts = ChainOptions {
603 tags: vec!["test".into()],
604 isolation: ChainIsolation::Worktree,
605 };
606 let json = serde_json::to_string(&opts).unwrap();
607 let parsed: ChainOptions = serde_json::from_str(&json).unwrap();
608 assert_eq!(parsed.isolation, ChainIsolation::Worktree);
609 assert_eq!(parsed.tags, vec!["test"]);
610 }
611
612 #[test]
613 fn chain_progress_serializes_with_partial_output() {
614 let progress = ChainProgress {
615 total_steps: 3,
616 current_step: Some(1),
617 current_step_name: Some("implement".into()),
618 current_step_partial_output: Some("partial text".into()),
619 current_step_started_at: Some(1700000000),
620 completed_steps: vec![StepResult {
621 name: "plan".into(),
622 output: "planned".into(),
623 success: true,
624 cost_microdollars: 500,
625 retries_used: 0,
626 skipped: false,
627 }],
628 status: ChainStatus::Running,
629 };
630
631 let json = serde_json::to_string(&progress).unwrap();
632 assert!(json.contains("implement"));
633 assert!(json.contains("running"));
634 assert!(json.contains("partial text"));
635 assert!(json.contains("1700000000"));
636 }
637
638 #[test]
639 fn chain_progress_omits_none_fields() {
640 let progress = ChainProgress {
641 total_steps: 2,
642 current_step: None,
643 current_step_name: None,
644 current_step_partial_output: None,
645 current_step_started_at: None,
646 completed_steps: vec![],
647 status: ChainStatus::Completed,
648 };
649
650 let json = serde_json::to_string(&progress).unwrap();
651 assert!(!json.contains("current_step_partial_output"));
652 assert!(!json.contains("current_step_started_at"));
653 }
654
655 #[test]
656 fn chain_progress_empty_partial_output_when_step_starts() {
657 let progress = ChainProgress {
658 total_steps: 3,
659 current_step: Some(0),
660 current_step_name: Some("plan".into()),
661 current_step_partial_output: Some(String::new()),
662 current_step_started_at: Some(1700000000),
663 completed_steps: vec![],
664 status: ChainStatus::Running,
665 };
666
667 let json = serde_json::to_string(&progress).unwrap();
668 assert!(json.contains("\"current_step_partial_output\":\"\""));
670 }
671
672 #[test]
673 fn cancelled_status_serializes() {
674 let progress = ChainProgress {
675 total_steps: 3,
676 current_step: None,
677 current_step_name: None,
678 current_step_partial_output: None,
679 current_step_started_at: None,
680 completed_steps: vec![
681 StepResult {
682 name: "plan".into(),
683 output: "planned".into(),
684 success: true,
685 cost_microdollars: 500,
686 retries_used: 0,
687 skipped: false,
688 },
689 StepResult {
690 name: "implement".into(),
691 output: String::new(),
692 success: false,
693 cost_microdollars: 0,
694 retries_used: 0,
695 skipped: true,
696 },
697 StepResult {
698 name: "review".into(),
699 output: String::new(),
700 success: false,
701 cost_microdollars: 0,
702 retries_used: 0,
703 skipped: true,
704 },
705 ],
706 status: ChainStatus::Cancelled,
707 };
708
709 let json = serde_json::to_string(&progress).unwrap();
710 assert!(json.contains("cancelled"));
711 assert!(json.contains("\"skipped\":true"));
712 }
713
714 #[test]
715 fn skipped_defaults_to_false_on_deserialize() {
716 let json =
717 r#"{"name":"s","output":"o","success":true,"cost_microdollars":0,"retries_used":0}"#;
718 let result: StepResult = serde_json::from_str(json).unwrap();
719 assert!(!result.skipped);
720 }
721
722 #[test]
723 fn extract_json_path_whole_output() {
724 let json = r#"{"a": 1}"#;
725 assert_eq!(extract_json_path(json, "."), Some(json.to_string()));
726 assert_eq!(extract_json_path(json, ""), Some(json.to_string()));
727 }
728
729 #[test]
730 fn extract_json_path_top_level_key() {
731 let json = r#"{"summary": "all good"}"#;
732 assert_eq!(
733 extract_json_path(json, "summary"),
734 Some("all good".to_string())
735 );
736 }
737
738 #[test]
739 fn extract_json_path_nested() {
740 let json = r#"{"result": {"count": 42}}"#;
741 assert_eq!(
742 extract_json_path(json, "result.count"),
743 Some("42".to_string())
744 );
745 }
746
747 #[test]
748 fn extract_json_path_not_json() {
749 assert_eq!(extract_json_path("not json", "key"), None);
750 }
751
752 #[test]
753 fn extract_json_path_missing_key() {
754 let json = r#"{"a": 1}"#;
755 assert_eq!(extract_json_path(json, "b"), None);
756 }
757
758 #[test]
759 fn expand_step_refs_substitutes() {
760 let mut ctx = HashMap::new();
761 ctx.insert("plan.summary".into(), "do stuff".into());
762 let text = "Based on {steps.plan.summary}, implement it.".to_string();
763 assert_eq!(
764 expand_step_refs(text, &ctx),
765 "Based on do stuff, implement it."
766 );
767 }
768
769 #[test]
770 fn expand_step_refs_unknown_left_as_is() {
771 let ctx = HashMap::new();
772 let text = "Use {steps.missing.var} here.".to_string();
773 assert_eq!(expand_step_refs(text.clone(), &ctx), text);
774 }
775
776 #[test]
777 fn chain_step_output_vars_defaults_empty() {
778 let json = r#"{"name":"s","action":{"type":"prompt","prompt":"hi"}}"#;
779 let step: ChainStep = serde_json::from_str(json).unwrap();
780 assert!(step.output_vars.is_empty());
781 }
782
783 #[test]
784 fn chain_step_serializes_output_vars() {
785 let mut vars = HashMap::new();
786 vars.insert("summary".into(), "result.summary".into());
787 let step = ChainStep {
788 name: "s".into(),
789 action: StepAction::Prompt {
790 prompt: "hi".into(),
791 },
792 config: None,
793 failure_policy: StepFailurePolicy::default(),
794 output_vars: vars,
795 };
796 let json = serde_json::to_string(&step).unwrap();
797 assert!(json.contains("output_vars"));
798 assert!(json.contains("result.summary"));
799
800 let parsed: ChainStep = serde_json::from_str(&json).unwrap();
801 assert_eq!(parsed.output_vars.get("summary").unwrap(), "result.summary");
802 }
803}