1use std::collections::HashMap;
10use std::sync::Arc;
11use std::time::{SystemTime, UNIX_EPOCH};
12
13use serde::{Deserialize, Serialize};
14
15use crate::pool::Pool;
16use crate::skill::SkillRegistry;
17use crate::store::PoolStore;
18use crate::types::{SlotConfig, TaskId, TaskState};
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct ChainStep {
23 pub name: String,
25
26 pub action: StepAction,
28
29 pub config: Option<SlotConfig>,
31
32 #[serde(default)]
34 pub failure_policy: StepFailurePolicy,
35
36 #[serde(default)]
46 pub output_vars: HashMap<String, String>,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51#[serde(tag = "type", rename_all = "snake_case")]
52pub enum StepAction {
53 Prompt {
56 prompt: String,
58 },
59 Skill {
63 skill: String,
65 #[serde(default)]
67 arguments: HashMap<String, String>,
68 },
69}
70
71#[derive(Debug, Clone, Default, Serialize, Deserialize)]
73pub struct StepFailurePolicy {
74 #[serde(default)]
76 pub retries: u32,
77 pub recovery_prompt: Option<String>,
81}
82
83#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
85#[serde(rename_all = "snake_case")]
86pub enum ChainIsolation {
87 None,
89 #[default]
91 Worktree,
92}
93
94#[derive(Debug, Clone, Default, Serialize, Deserialize)]
96pub struct ChainOptions {
97 #[serde(default)]
99 pub tags: Vec<String>,
100 #[serde(default)]
102 pub isolation: ChainIsolation,
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize)]
107pub struct StepResult {
108 pub name: String,
110 pub output: String,
112 pub success: bool,
114 pub cost_microdollars: u64,
116 #[serde(default)]
118 pub retries_used: u32,
119 #[serde(default)]
121 pub skipped: bool,
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct ChainResult {
127 pub steps: Vec<StepResult>,
129 pub final_output: String,
131 pub total_cost_microdollars: u64,
133 pub success: bool,
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct ChainProgress {
140 pub total_steps: usize,
142 pub current_step: Option<usize>,
144 pub current_step_name: Option<String>,
146 #[serde(skip_serializing_if = "Option::is_none")]
151 pub current_step_partial_output: Option<String>,
152 #[serde(skip_serializing_if = "Option::is_none")]
156 pub current_step_started_at: Option<u64>,
157 pub completed_steps: Vec<StepResult>,
159 pub status: ChainStatus,
161}
162
163#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
165#[serde(rename_all = "snake_case")]
166pub enum ChainStatus {
167 Running,
169 Completed,
171 Failed,
173 Cancelled,
175}
176
177pub type OnOutputChunk = Arc<dyn Fn(&str) + Send + Sync>;
179
180fn extract_json_path(json_str: &str, path: &str) -> Option<String> {
181 if path == "." || path.is_empty() {
182 return Some(json_str.to_string());
183 }
184 let value: serde_json::Value = serde_json::from_str(json_str).ok()?;
185 let mut current = &value;
186 for key in path.split('.') {
187 current = current.get(key)?;
188 }
189 Some(match current {
190 serde_json::Value::String(s) => s.clone(),
191 other => other.to_string(),
192 })
193}
194
195fn expand_step_refs(mut text: String, step_context: &HashMap<String, String>) -> String {
196 for (key, value) in step_context {
197 text = text.replace(&format!("{{steps.{key}}}"), value);
198 }
199 text
200}
201
202fn unix_secs_now() -> u64 {
203 SystemTime::now()
204 .duration_since(UNIX_EPOCH)
205 .unwrap_or_default()
206 .as_secs()
207}
208
209pub async fn execute_chain<S: PoolStore + 'static>(
211 pool: &Pool<S>,
212 skills: &SkillRegistry,
213 steps: &[ChainStep],
214) -> crate::Result<ChainResult> {
215 execute_chain_with_progress(pool, skills, steps, None, None).await
216}
217
218pub async fn execute_chain_with_progress<S: PoolStore + 'static>(
226 pool: &Pool<S>,
227 skills: &SkillRegistry,
228 steps: &[ChainStep],
229 chain_task_id: Option<&TaskId>,
230 working_dir: Option<&std::path::Path>,
231) -> crate::Result<ChainResult> {
232 let mut step_results = Vec::with_capacity(steps.len());
233 let mut previous_output = String::new();
234 let mut total_cost = 0u64;
235 let mut step_context: HashMap<String, String> = HashMap::new();
236
237 for (step_idx, step) in steps.iter().enumerate() {
238 if let Some(task_id) = chain_task_id
240 && let Ok(Some(task)) = pool.store().get_task(task_id).await
241 && task.state == TaskState::Cancelled
242 {
243 for s in &steps[step_idx..] {
244 step_results.push(StepResult {
245 name: s.name.clone(),
246 output: String::new(),
247 success: false,
248 cost_microdollars: 0,
249 retries_used: 0,
250 skipped: true,
251 });
252 }
253 update_chain_progress_final(
254 pool,
255 Some(task_id),
256 steps.len(),
257 &step_results,
258 ChainStatus::Cancelled,
259 )
260 .await;
261 return Ok(ChainResult {
262 final_output: previous_output,
263 steps: step_results,
264 total_cost_microdollars: total_cost,
265 success: false,
266 });
267 }
268
269 if let Some(task_id) = chain_task_id {
271 let progress = ChainProgress {
272 total_steps: steps.len(),
273 current_step: Some(step_idx),
274 current_step_name: Some(step.name.clone()),
275 current_step_partial_output: Some(String::new()),
276 current_step_started_at: Some(unix_secs_now()),
277 completed_steps: step_results.clone(),
278 status: ChainStatus::Running,
279 };
280 pool.set_chain_progress(task_id, progress).await;
281 }
282
283 let prompt = render_step_prompt(step, &previous_output, skills, &step_context)?;
284
285 let on_output: Option<OnOutputChunk> = chain_task_id.map(|tid| {
287 let pool = pool.clone();
288 let tid = tid.clone();
289 Arc::new(move |chunk: &str| {
290 pool.append_chain_partial_output(&tid, chunk);
291 }) as OnOutputChunk
292 });
293
294 let (step_result, step_cost) = execute_step_with_retries(
295 pool,
296 step,
297 &prompt,
298 &previous_output,
299 skills,
300 on_output.clone(),
301 working_dir,
302 &step_context,
303 )
304 .await;
305
306 total_cost += step_cost;
307
308 match step_result {
309 Ok(result) => {
310 previous_output = result.output.clone();
311
312 if result.success {
313 for (var_name, path) in &step.output_vars {
314 match extract_json_path(&result.output, path) {
315 Some(extracted) => {
316 step_context
317 .insert(format!("{}.{}", step.name, var_name), extracted);
318 }
319 None => {
320 tracing::warn!(
321 step = %step.name,
322 var = %var_name,
323 path = %path,
324 "output_var extraction failed (output not JSON or path not found)"
325 );
326 }
327 }
328 }
329 }
330
331 step_results.push(result);
332
333 if !step_results.last().unwrap().success {
334 update_chain_progress_final(
335 pool,
336 chain_task_id,
337 steps.len(),
338 &step_results,
339 ChainStatus::Failed,
340 )
341 .await;
342 return Ok(ChainResult {
343 final_output: previous_output,
344 steps: step_results,
345 total_cost_microdollars: total_cost,
346 success: false,
347 });
348 }
349 }
350 Err(output) => {
351 step_results.push(StepResult {
352 name: step.name.clone(),
353 output: output.clone(),
354 success: false,
355 cost_microdollars: 0,
356 retries_used: step.failure_policy.retries,
357 skipped: false,
358 });
359 update_chain_progress_final(
360 pool,
361 chain_task_id,
362 steps.len(),
363 &step_results,
364 ChainStatus::Failed,
365 )
366 .await;
367 return Ok(ChainResult {
368 final_output: output,
369 steps: step_results,
370 total_cost_microdollars: total_cost,
371 success: false,
372 });
373 }
374 }
375 }
376
377 update_chain_progress_final(
378 pool,
379 chain_task_id,
380 steps.len(),
381 &step_results,
382 ChainStatus::Completed,
383 )
384 .await;
385
386 Ok(ChainResult {
387 final_output: previous_output,
388 steps: step_results,
389 total_cost_microdollars: total_cost,
390 success: true,
391 })
392}
393
394fn render_step_prompt(
396 step: &ChainStep,
397 previous_output: &str,
398 skills: &SkillRegistry,
399 step_context: &HashMap<String, String>,
400) -> crate::Result<String> {
401 match &step.action {
402 StepAction::Prompt { prompt } => {
403 let rendered = prompt.replace("{previous_output}", previous_output);
404 Ok(expand_step_refs(rendered, step_context))
405 }
406 StepAction::Skill { skill, arguments } => {
407 let skill_def = skills
408 .get(skill)
409 .ok_or_else(|| crate::Error::Store(format!("skill not found: {skill}")))?;
410 let mut args = arguments.clone();
411 if !previous_output.is_empty() {
412 args.entry("_previous_output".into())
413 .or_insert(previous_output.to_string());
414 }
415 let rendered = skill_def.render(&args)?;
416 Ok(expand_step_refs(rendered, step_context))
417 }
418 }
419}
420
421#[allow(clippy::too_many_arguments)]
426async fn execute_step_with_retries<S: PoolStore + 'static>(
427 pool: &Pool<S>,
428 step: &ChainStep,
429 initial_prompt: &str,
430 previous_output: &str,
431 skills: &SkillRegistry,
432 on_output: Option<OnOutputChunk>,
433 working_dir: Option<&std::path::Path>,
434 step_context: &HashMap<String, String>,
435) -> (std::result::Result<StepResult, String>, u64) {
436 let max_attempts = 1 + step.failure_policy.retries;
437 let mut total_cost = 0u64;
438 let mut last_error = String::new();
439
440 for attempt in 0..max_attempts {
441 let prompt = if attempt == 0 {
442 initial_prompt.to_string()
443 } else {
444 match render_step_prompt(step, previous_output, skills, step_context) {
446 Ok(p) => p,
447 Err(e) => return (Err(e.to_string()), total_cost),
448 }
449 };
450
451 let result = pool
452 .run_with_config_streaming(
453 &prompt,
454 step.config.clone(),
455 on_output.clone(),
456 working_dir.map(|p| p.to_path_buf()),
457 )
458 .await;
459
460 match result {
461 Ok(task_result) => {
462 total_cost += task_result.cost_microdollars;
463 if task_result.success {
464 return (
465 Ok(StepResult {
466 name: step.name.clone(),
467 output: task_result.output,
468 success: true,
469 cost_microdollars: total_cost,
470 retries_used: attempt,
471 skipped: false,
472 }),
473 total_cost,
474 );
475 }
476 last_error = task_result.output;
478 }
479 Err(e) => {
480 last_error = e.to_string();
481 }
482 }
483
484 tracing::warn!(
485 step = %step.name,
486 attempt = attempt + 1,
487 max_attempts,
488 "chain step failed, will retry"
489 );
490 }
491
492 if let Some(ref recovery_template) = step.failure_policy.recovery_prompt {
494 let recovery_prompt = expand_step_refs(
495 recovery_template
496 .replace("{error}", &last_error)
497 .replace("{previous_output}", previous_output),
498 step_context,
499 );
500
501 tracing::info!(step = %step.name, "attempting recovery prompt");
502
503 let result = pool
504 .run_with_config_streaming(
505 &recovery_prompt,
506 step.config.clone(),
507 on_output,
508 working_dir.map(|p| p.to_path_buf()),
509 )
510 .await;
511
512 match result {
513 Ok(task_result) => {
514 total_cost += task_result.cost_microdollars;
515 return (
516 Ok(StepResult {
517 name: step.name.clone(),
518 output: task_result.output,
519 success: task_result.success,
520 cost_microdollars: total_cost,
521 retries_used: max_attempts,
522 skipped: false,
523 }),
524 total_cost,
525 );
526 }
527 Err(e) => {
528 last_error = e.to_string();
529 }
530 }
531 }
532
533 (Err(last_error), total_cost)
534}
535
536async fn update_chain_progress_final<S: PoolStore + 'static>(
538 pool: &Pool<S>,
539 chain_task_id: Option<&TaskId>,
540 total_steps: usize,
541 completed_steps: &[StepResult],
542 status: ChainStatus,
543) {
544 if let Some(task_id) = chain_task_id {
545 let progress = ChainProgress {
546 total_steps,
547 current_step: None,
548 current_step_name: None,
549 current_step_partial_output: None,
550 current_step_started_at: None,
551 completed_steps: completed_steps.to_vec(),
552 status,
553 };
554 pool.set_chain_progress(task_id, progress).await;
555 }
556}
557
558#[cfg(test)]
559mod tests {
560 use super::*;
561
562 #[test]
563 fn prompt_step_replaces_previous_output() {
564 let step = ChainStep {
565 name: "step1".into(),
566 action: StepAction::Prompt {
567 prompt: "Based on: {previous_output}\nDo more.".into(),
568 },
569 config: None,
570 failure_policy: StepFailurePolicy::default(),
571 output_vars: Default::default(),
572 };
573
574 if let StepAction::Prompt { prompt } = &step.action {
575 let rendered = prompt.replace("{previous_output}", "hello world");
576 assert_eq!(rendered, "Based on: hello world\nDo more.");
577 }
578 }
579
580 #[test]
581 fn chain_result_serializes() {
582 let result = ChainResult {
583 steps: vec![StepResult {
584 name: "step1".into(),
585 output: "done".into(),
586 success: true,
587 cost_microdollars: 1000,
588 retries_used: 0,
589 skipped: false,
590 }],
591 final_output: "done".into(),
592 total_cost_microdollars: 1000,
593 success: true,
594 };
595
596 let json = serde_json::to_string(&result).unwrap();
597 assert!(json.contains("step1"));
598 }
599
600 #[test]
601 fn step_failure_policy_defaults() {
602 let policy = StepFailurePolicy::default();
603 assert_eq!(policy.retries, 0);
604 assert!(policy.recovery_prompt.is_none());
605 }
606
607 #[test]
608 fn chain_options_defaults() {
609 let opts = ChainOptions::default();
610 assert!(opts.tags.is_empty());
611 assert_eq!(opts.isolation, ChainIsolation::Worktree);
612 }
613
614 #[test]
615 fn chain_isolation_serde_roundtrip() {
616 let worktree = ChainIsolation::Worktree;
617 let json = serde_json::to_string(&worktree).unwrap();
618 assert_eq!(json, r#""worktree""#);
619
620 let none = ChainIsolation::None;
621 let json = serde_json::to_string(&none).unwrap();
622 assert_eq!(json, r#""none""#);
623
624 let parsed: ChainIsolation = serde_json::from_str(r#""worktree""#).unwrap();
625 assert_eq!(parsed, ChainIsolation::Worktree);
626
627 let parsed: ChainIsolation = serde_json::from_str(r#""none""#).unwrap();
628 assert_eq!(parsed, ChainIsolation::None);
629 }
630
631 #[test]
632 fn chain_options_with_isolation_serializes() {
633 let opts = ChainOptions {
634 tags: vec!["test".into()],
635 isolation: ChainIsolation::Worktree,
636 };
637 let json = serde_json::to_string(&opts).unwrap();
638 let parsed: ChainOptions = serde_json::from_str(&json).unwrap();
639 assert_eq!(parsed.isolation, ChainIsolation::Worktree);
640 assert_eq!(parsed.tags, vec!["test"]);
641 }
642
643 #[test]
644 fn chain_progress_serializes_with_partial_output() {
645 let progress = ChainProgress {
646 total_steps: 3,
647 current_step: Some(1),
648 current_step_name: Some("implement".into()),
649 current_step_partial_output: Some("partial text".into()),
650 current_step_started_at: Some(1700000000),
651 completed_steps: vec![StepResult {
652 name: "plan".into(),
653 output: "planned".into(),
654 success: true,
655 cost_microdollars: 500,
656 retries_used: 0,
657 skipped: false,
658 }],
659 status: ChainStatus::Running,
660 };
661
662 let json = serde_json::to_string(&progress).unwrap();
663 assert!(json.contains("implement"));
664 assert!(json.contains("running"));
665 assert!(json.contains("partial text"));
666 assert!(json.contains("1700000000"));
667 }
668
669 #[test]
670 fn chain_progress_omits_none_fields() {
671 let progress = ChainProgress {
672 total_steps: 2,
673 current_step: None,
674 current_step_name: None,
675 current_step_partial_output: None,
676 current_step_started_at: None,
677 completed_steps: vec![],
678 status: ChainStatus::Completed,
679 };
680
681 let json = serde_json::to_string(&progress).unwrap();
682 assert!(!json.contains("current_step_partial_output"));
683 assert!(!json.contains("current_step_started_at"));
684 }
685
686 #[test]
687 fn chain_progress_empty_partial_output_when_step_starts() {
688 let progress = ChainProgress {
689 total_steps: 3,
690 current_step: Some(0),
691 current_step_name: Some("plan".into()),
692 current_step_partial_output: Some(String::new()),
693 current_step_started_at: Some(1700000000),
694 completed_steps: vec![],
695 status: ChainStatus::Running,
696 };
697
698 let json = serde_json::to_string(&progress).unwrap();
699 assert!(json.contains("\"current_step_partial_output\":\"\""));
701 }
702
703 #[test]
704 fn cancelled_status_serializes() {
705 let progress = ChainProgress {
706 total_steps: 3,
707 current_step: None,
708 current_step_name: None,
709 current_step_partial_output: None,
710 current_step_started_at: None,
711 completed_steps: vec![
712 StepResult {
713 name: "plan".into(),
714 output: "planned".into(),
715 success: true,
716 cost_microdollars: 500,
717 retries_used: 0,
718 skipped: false,
719 },
720 StepResult {
721 name: "implement".into(),
722 output: String::new(),
723 success: false,
724 cost_microdollars: 0,
725 retries_used: 0,
726 skipped: true,
727 },
728 StepResult {
729 name: "review".into(),
730 output: String::new(),
731 success: false,
732 cost_microdollars: 0,
733 retries_used: 0,
734 skipped: true,
735 },
736 ],
737 status: ChainStatus::Cancelled,
738 };
739
740 let json = serde_json::to_string(&progress).unwrap();
741 assert!(json.contains("cancelled"));
742 assert!(json.contains("\"skipped\":true"));
743 }
744
745 #[test]
746 fn skipped_defaults_to_false_on_deserialize() {
747 let json =
748 r#"{"name":"s","output":"o","success":true,"cost_microdollars":0,"retries_used":0}"#;
749 let result: StepResult = serde_json::from_str(json).unwrap();
750 assert!(!result.skipped);
751 }
752
753 #[test]
754 fn extract_json_path_whole_output() {
755 let json = r#"{"a": 1}"#;
756 assert_eq!(extract_json_path(json, "."), Some(json.to_string()));
757 assert_eq!(extract_json_path(json, ""), Some(json.to_string()));
758 }
759
760 #[test]
761 fn extract_json_path_top_level_key() {
762 let json = r#"{"summary": "all good"}"#;
763 assert_eq!(
764 extract_json_path(json, "summary"),
765 Some("all good".to_string())
766 );
767 }
768
769 #[test]
770 fn extract_json_path_nested() {
771 let json = r#"{"result": {"count": 42}}"#;
772 assert_eq!(
773 extract_json_path(json, "result.count"),
774 Some("42".to_string())
775 );
776 }
777
778 #[test]
779 fn extract_json_path_not_json() {
780 assert_eq!(extract_json_path("not json", "key"), None);
781 }
782
783 #[test]
784 fn extract_json_path_missing_key() {
785 let json = r#"{"a": 1}"#;
786 assert_eq!(extract_json_path(json, "b"), None);
787 }
788
789 #[test]
790 fn expand_step_refs_substitutes() {
791 let mut ctx = HashMap::new();
792 ctx.insert("plan.summary".into(), "do stuff".into());
793 let text = "Based on {steps.plan.summary}, implement it.".to_string();
794 assert_eq!(
795 expand_step_refs(text, &ctx),
796 "Based on do stuff, implement it."
797 );
798 }
799
800 #[test]
801 fn expand_step_refs_unknown_left_as_is() {
802 let ctx = HashMap::new();
803 let text = "Use {steps.missing.var} here.".to_string();
804 assert_eq!(expand_step_refs(text.clone(), &ctx), text);
805 }
806
807 #[test]
808 fn chain_step_output_vars_defaults_empty() {
809 let json = r#"{"name":"s","action":{"type":"prompt","prompt":"hi"}}"#;
810 let step: ChainStep = serde_json::from_str(json).unwrap();
811 assert!(step.output_vars.is_empty());
812 }
813
814 #[test]
815 fn chain_step_serializes_output_vars() {
816 let mut vars = HashMap::new();
817 vars.insert("summary".into(), "result.summary".into());
818 let step = ChainStep {
819 name: "s".into(),
820 action: StepAction::Prompt {
821 prompt: "hi".into(),
822 },
823 config: None,
824 failure_policy: StepFailurePolicy::default(),
825 output_vars: vars,
826 };
827 let json = serde_json::to_string(&step).unwrap();
828 assert!(json.contains("output_vars"));
829 assert!(json.contains("result.summary"));
830
831 let parsed: ChainStep = serde_json::from_str(&json).unwrap();
832 assert_eq!(parsed.output_vars.get("summary").unwrap(), "result.summary");
833 }
834}