1use std::collections::HashMap;
10use std::sync::Arc;
11use std::time::{SystemTime, UNIX_EPOCH};
12
13use serde::{Deserialize, Serialize};
14
15use crate::pool::Pool;
16use crate::skill::SkillRegistry;
17use crate::store::PoolStore;
18use crate::types::{SlotConfig, TaskId, TaskState};
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct ChainStep {
23 pub name: String,
25
26 pub action: StepAction,
28
29 pub config: Option<SlotConfig>,
31
32 #[serde(default)]
34 pub failure_policy: StepFailurePolicy,
35
36 #[serde(default)]
46 pub output_vars: HashMap<String, String>,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51#[serde(tag = "type", rename_all = "snake_case")]
52pub enum StepAction {
53 Prompt {
56 prompt: String,
58 },
59 Skill {
63 skill: String,
65 #[serde(default)]
67 arguments: HashMap<String, String>,
68 },
69}
70
71#[derive(Debug, Clone, Default, Serialize, Deserialize)]
73pub struct StepFailurePolicy {
74 #[serde(default)]
76 pub retries: u32,
77 pub recovery_prompt: Option<String>,
81}
82
83#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
85#[serde(rename_all = "snake_case")]
86pub enum ChainIsolation {
87 None,
89 #[default]
91 Worktree,
92 Clone,
94}
95
96#[derive(Debug, Clone, Default, Serialize, Deserialize)]
98pub struct ChainOptions {
99 #[serde(default)]
101 pub tags: Vec<String>,
102 #[serde(default)]
104 pub isolation: ChainIsolation,
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct StepResult {
110 pub name: String,
112 pub output: String,
114 pub success: bool,
116 pub cost_microdollars: u64,
118 #[serde(default)]
120 pub retries_used: u32,
121 #[serde(default)]
123 pub skipped: bool,
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct ChainResult {
129 pub steps: Vec<StepResult>,
131 pub final_output: String,
133 pub total_cost_microdollars: u64,
135 pub success: bool,
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize)]
141pub struct ChainProgress {
142 pub total_steps: usize,
144 pub current_step: Option<usize>,
146 pub current_step_name: Option<String>,
148 #[serde(skip_serializing_if = "Option::is_none")]
153 pub current_step_partial_output: Option<String>,
154 #[serde(skip_serializing_if = "Option::is_none")]
158 pub current_step_started_at: Option<u64>,
159 pub completed_steps: Vec<StepResult>,
161 pub status: ChainStatus,
163}
164
165#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
167#[serde(rename_all = "snake_case")]
168pub enum ChainStatus {
169 Running,
171 Completed,
173 Failed,
175 Cancelled,
177}
178
179pub type OnOutputChunk = Arc<dyn Fn(&str) + Send + Sync>;
181
182fn extract_json_path(json_str: &str, path: &str) -> Option<String> {
183 if path == "." || path.is_empty() {
184 return Some(json_str.to_string());
185 }
186 let value: serde_json::Value = serde_json::from_str(json_str).ok()?;
187 let mut current = &value;
188 for key in path.split('.') {
189 current = current.get(key)?;
190 }
191 Some(match current {
192 serde_json::Value::String(s) => s.clone(),
193 other => other.to_string(),
194 })
195}
196
197fn expand_step_refs(mut text: String, step_context: &HashMap<String, String>) -> String {
198 for (key, value) in step_context {
199 text = text.replace(&format!("{{steps.{key}}}"), value);
200 }
201 text
202}
203
204fn unix_secs_now() -> u64 {
205 SystemTime::now()
206 .duration_since(UNIX_EPOCH)
207 .unwrap_or_default()
208 .as_secs()
209}
210
211pub async fn execute_chain<S: PoolStore + 'static>(
213 pool: &Pool<S>,
214 skills: &SkillRegistry,
215 steps: &[ChainStep],
216) -> crate::Result<ChainResult> {
217 execute_chain_with_progress(pool, skills, steps, None, None).await
218}
219
220pub async fn execute_chain_with_progress<S: PoolStore + 'static>(
228 pool: &Pool<S>,
229 skills: &SkillRegistry,
230 steps: &[ChainStep],
231 chain_task_id: Option<&TaskId>,
232 working_dir: Option<&std::path::Path>,
233) -> crate::Result<ChainResult> {
234 let mut step_results = Vec::with_capacity(steps.len());
235 let mut previous_output = String::new();
236 let mut total_cost = 0u64;
237 let mut step_context: HashMap<String, String> = HashMap::new();
238
239 for (step_idx, step) in steps.iter().enumerate() {
240 if let Some(task_id) = chain_task_id
242 && let Ok(Some(task)) = pool.store().get_task(task_id).await
243 && task.state == TaskState::Cancelled
244 {
245 for s in &steps[step_idx..] {
246 step_results.push(StepResult {
247 name: s.name.clone(),
248 output: String::new(),
249 success: false,
250 cost_microdollars: 0,
251 retries_used: 0,
252 skipped: true,
253 });
254 }
255 update_chain_progress_final(
256 pool,
257 Some(task_id),
258 steps.len(),
259 &step_results,
260 ChainStatus::Cancelled,
261 )
262 .await;
263 return Ok(ChainResult {
264 final_output: previous_output,
265 steps: step_results,
266 total_cost_microdollars: total_cost,
267 success: false,
268 });
269 }
270
271 if let Some(task_id) = chain_task_id {
273 let progress = ChainProgress {
274 total_steps: steps.len(),
275 current_step: Some(step_idx),
276 current_step_name: Some(step.name.clone()),
277 current_step_partial_output: Some(String::new()),
278 current_step_started_at: Some(unix_secs_now()),
279 completed_steps: step_results.clone(),
280 status: ChainStatus::Running,
281 };
282 pool.set_chain_progress(task_id, progress).await;
283 }
284
285 let prompt = render_step_prompt(step, &previous_output, skills, &step_context)?;
286
287 let on_output: Option<OnOutputChunk> = chain_task_id.map(|tid| {
289 let pool = pool.clone();
290 let tid = tid.clone();
291 Arc::new(move |chunk: &str| {
292 pool.append_chain_partial_output(&tid, chunk);
293 }) as OnOutputChunk
294 });
295
296 let (step_result, step_cost) = execute_step_with_retries(
297 pool,
298 step,
299 &prompt,
300 &previous_output,
301 skills,
302 on_output.clone(),
303 working_dir,
304 &step_context,
305 )
306 .await;
307
308 total_cost += step_cost;
309
310 match step_result {
311 Ok(result) => {
312 previous_output = result.output.clone();
313
314 if result.success {
315 for (var_name, path) in &step.output_vars {
316 match extract_json_path(&result.output, path) {
317 Some(extracted) => {
318 step_context
319 .insert(format!("{}.{}", step.name, var_name), extracted);
320 }
321 None => {
322 tracing::warn!(
323 step = %step.name,
324 var = %var_name,
325 path = %path,
326 "output_var extraction failed (output not JSON or path not found)"
327 );
328 }
329 }
330 }
331 }
332
333 step_results.push(result);
334
335 if !step_results.last().unwrap().success {
336 update_chain_progress_final(
337 pool,
338 chain_task_id,
339 steps.len(),
340 &step_results,
341 ChainStatus::Failed,
342 )
343 .await;
344 return Ok(ChainResult {
345 final_output: previous_output,
346 steps: step_results,
347 total_cost_microdollars: total_cost,
348 success: false,
349 });
350 }
351 }
352 Err(output) => {
353 step_results.push(StepResult {
354 name: step.name.clone(),
355 output: output.clone(),
356 success: false,
357 cost_microdollars: 0,
358 retries_used: step.failure_policy.retries,
359 skipped: false,
360 });
361 update_chain_progress_final(
362 pool,
363 chain_task_id,
364 steps.len(),
365 &step_results,
366 ChainStatus::Failed,
367 )
368 .await;
369 return Ok(ChainResult {
370 final_output: output,
371 steps: step_results,
372 total_cost_microdollars: total_cost,
373 success: false,
374 });
375 }
376 }
377 }
378
379 update_chain_progress_final(
380 pool,
381 chain_task_id,
382 steps.len(),
383 &step_results,
384 ChainStatus::Completed,
385 )
386 .await;
387
388 Ok(ChainResult {
389 final_output: previous_output,
390 steps: step_results,
391 total_cost_microdollars: total_cost,
392 success: true,
393 })
394}
395
396fn render_step_prompt(
398 step: &ChainStep,
399 previous_output: &str,
400 skills: &SkillRegistry,
401 step_context: &HashMap<String, String>,
402) -> crate::Result<String> {
403 match &step.action {
404 StepAction::Prompt { prompt } => {
405 let rendered = prompt.replace("{previous_output}", previous_output);
406 Ok(expand_step_refs(rendered, step_context))
407 }
408 StepAction::Skill { skill, arguments } => {
409 let skill_def = skills
410 .get(skill)
411 .ok_or_else(|| crate::Error::Store(format!("skill not found: {skill}")))?;
412 let mut args = arguments.clone();
413 if !previous_output.is_empty() {
414 args.entry("_previous_output".into())
415 .or_insert(previous_output.to_string());
416 }
417 let rendered = skill_def.render(&args)?;
418 Ok(expand_step_refs(rendered, step_context))
419 }
420 }
421}
422
423#[allow(clippy::too_many_arguments)]
428async fn execute_step_with_retries<S: PoolStore + 'static>(
429 pool: &Pool<S>,
430 step: &ChainStep,
431 initial_prompt: &str,
432 previous_output: &str,
433 skills: &SkillRegistry,
434 on_output: Option<OnOutputChunk>,
435 working_dir: Option<&std::path::Path>,
436 step_context: &HashMap<String, String>,
437) -> (std::result::Result<StepResult, String>, u64) {
438 let max_attempts = 1 + step.failure_policy.retries;
439 let mut total_cost = 0u64;
440 let mut last_error = String::new();
441
442 for attempt in 0..max_attempts {
443 let prompt = if attempt == 0 {
444 initial_prompt.to_string()
445 } else {
446 match render_step_prompt(step, previous_output, skills, step_context) {
448 Ok(p) => p,
449 Err(e) => return (Err(e.to_string()), total_cost),
450 }
451 };
452
453 let result = pool
454 .run_with_config_streaming(
455 &prompt,
456 step.config.clone(),
457 on_output.clone(),
458 working_dir.map(|p| p.to_path_buf()),
459 )
460 .await;
461
462 match result {
463 Ok(task_result) => {
464 total_cost += task_result.cost_microdollars;
465 if task_result.success {
466 return (
467 Ok(StepResult {
468 name: step.name.clone(),
469 output: task_result.output,
470 success: true,
471 cost_microdollars: total_cost,
472 retries_used: attempt,
473 skipped: false,
474 }),
475 total_cost,
476 );
477 }
478 last_error = task_result.output;
480 }
481 Err(e) => {
482 last_error = e.to_string();
483 }
484 }
485
486 tracing::warn!(
487 step = %step.name,
488 attempt = attempt + 1,
489 max_attempts,
490 "chain step failed, will retry"
491 );
492 }
493
494 if let Some(ref recovery_template) = step.failure_policy.recovery_prompt {
496 let recovery_prompt = expand_step_refs(
497 recovery_template
498 .replace("{error}", &last_error)
499 .replace("{previous_output}", previous_output),
500 step_context,
501 );
502
503 tracing::info!(step = %step.name, "attempting recovery prompt");
504
505 let result = pool
506 .run_with_config_streaming(
507 &recovery_prompt,
508 step.config.clone(),
509 on_output,
510 working_dir.map(|p| p.to_path_buf()),
511 )
512 .await;
513
514 match result {
515 Ok(task_result) => {
516 total_cost += task_result.cost_microdollars;
517 return (
518 Ok(StepResult {
519 name: step.name.clone(),
520 output: task_result.output,
521 success: task_result.success,
522 cost_microdollars: total_cost,
523 retries_used: max_attempts,
524 skipped: false,
525 }),
526 total_cost,
527 );
528 }
529 Err(e) => {
530 last_error = e.to_string();
531 }
532 }
533 }
534
535 (Err(last_error), total_cost)
536}
537
538async fn update_chain_progress_final<S: PoolStore + 'static>(
540 pool: &Pool<S>,
541 chain_task_id: Option<&TaskId>,
542 total_steps: usize,
543 completed_steps: &[StepResult],
544 status: ChainStatus,
545) {
546 if let Some(task_id) = chain_task_id {
547 let progress = ChainProgress {
548 total_steps,
549 current_step: None,
550 current_step_name: None,
551 current_step_partial_output: None,
552 current_step_started_at: None,
553 completed_steps: completed_steps.to_vec(),
554 status,
555 };
556 pool.set_chain_progress(task_id, progress).await;
557 }
558}
559
560#[cfg(test)]
561mod tests {
562 use super::*;
563
564 #[test]
565 fn prompt_step_replaces_previous_output() {
566 let step = ChainStep {
567 name: "step1".into(),
568 action: StepAction::Prompt {
569 prompt: "Based on: {previous_output}\nDo more.".into(),
570 },
571 config: None,
572 failure_policy: StepFailurePolicy::default(),
573 output_vars: Default::default(),
574 };
575
576 if let StepAction::Prompt { prompt } = &step.action {
577 let rendered = prompt.replace("{previous_output}", "hello world");
578 assert_eq!(rendered, "Based on: hello world\nDo more.");
579 }
580 }
581
582 #[test]
583 fn chain_result_serializes() {
584 let result = ChainResult {
585 steps: vec![StepResult {
586 name: "step1".into(),
587 output: "done".into(),
588 success: true,
589 cost_microdollars: 1000,
590 retries_used: 0,
591 skipped: false,
592 }],
593 final_output: "done".into(),
594 total_cost_microdollars: 1000,
595 success: true,
596 };
597
598 let json = serde_json::to_string(&result).unwrap();
599 assert!(json.contains("step1"));
600 }
601
602 #[test]
603 fn step_failure_policy_defaults() {
604 let policy = StepFailurePolicy::default();
605 assert_eq!(policy.retries, 0);
606 assert!(policy.recovery_prompt.is_none());
607 }
608
609 #[test]
610 fn chain_options_defaults() {
611 let opts = ChainOptions::default();
612 assert!(opts.tags.is_empty());
613 assert_eq!(opts.isolation, ChainIsolation::Worktree);
614 }
615
616 #[test]
617 fn chain_isolation_serde_roundtrip() {
618 let worktree = ChainIsolation::Worktree;
619 let json = serde_json::to_string(&worktree).unwrap();
620 assert_eq!(json, r#""worktree""#);
621
622 let none = ChainIsolation::None;
623 let json = serde_json::to_string(&none).unwrap();
624 assert_eq!(json, r#""none""#);
625
626 let parsed: ChainIsolation = serde_json::from_str(r#""worktree""#).unwrap();
627 assert_eq!(parsed, ChainIsolation::Worktree);
628
629 let parsed: ChainIsolation = serde_json::from_str(r#""none""#).unwrap();
630 assert_eq!(parsed, ChainIsolation::None);
631 }
632
633 #[test]
634 fn chain_options_with_isolation_serializes() {
635 let opts = ChainOptions {
636 tags: vec!["test".into()],
637 isolation: ChainIsolation::Worktree,
638 };
639 let json = serde_json::to_string(&opts).unwrap();
640 let parsed: ChainOptions = serde_json::from_str(&json).unwrap();
641 assert_eq!(parsed.isolation, ChainIsolation::Worktree);
642 assert_eq!(parsed.tags, vec!["test"]);
643 }
644
645 #[test]
646 fn chain_progress_serializes_with_partial_output() {
647 let progress = ChainProgress {
648 total_steps: 3,
649 current_step: Some(1),
650 current_step_name: Some("implement".into()),
651 current_step_partial_output: Some("partial text".into()),
652 current_step_started_at: Some(1700000000),
653 completed_steps: vec![StepResult {
654 name: "plan".into(),
655 output: "planned".into(),
656 success: true,
657 cost_microdollars: 500,
658 retries_used: 0,
659 skipped: false,
660 }],
661 status: ChainStatus::Running,
662 };
663
664 let json = serde_json::to_string(&progress).unwrap();
665 assert!(json.contains("implement"));
666 assert!(json.contains("running"));
667 assert!(json.contains("partial text"));
668 assert!(json.contains("1700000000"));
669 }
670
671 #[test]
672 fn chain_progress_omits_none_fields() {
673 let progress = ChainProgress {
674 total_steps: 2,
675 current_step: None,
676 current_step_name: None,
677 current_step_partial_output: None,
678 current_step_started_at: None,
679 completed_steps: vec![],
680 status: ChainStatus::Completed,
681 };
682
683 let json = serde_json::to_string(&progress).unwrap();
684 assert!(!json.contains("current_step_partial_output"));
685 assert!(!json.contains("current_step_started_at"));
686 }
687
688 #[test]
689 fn chain_progress_empty_partial_output_when_step_starts() {
690 let progress = ChainProgress {
691 total_steps: 3,
692 current_step: Some(0),
693 current_step_name: Some("plan".into()),
694 current_step_partial_output: Some(String::new()),
695 current_step_started_at: Some(1700000000),
696 completed_steps: vec![],
697 status: ChainStatus::Running,
698 };
699
700 let json = serde_json::to_string(&progress).unwrap();
701 assert!(json.contains("\"current_step_partial_output\":\"\""));
703 }
704
705 #[test]
706 fn cancelled_status_serializes() {
707 let progress = ChainProgress {
708 total_steps: 3,
709 current_step: None,
710 current_step_name: None,
711 current_step_partial_output: None,
712 current_step_started_at: None,
713 completed_steps: vec![
714 StepResult {
715 name: "plan".into(),
716 output: "planned".into(),
717 success: true,
718 cost_microdollars: 500,
719 retries_used: 0,
720 skipped: false,
721 },
722 StepResult {
723 name: "implement".into(),
724 output: String::new(),
725 success: false,
726 cost_microdollars: 0,
727 retries_used: 0,
728 skipped: true,
729 },
730 StepResult {
731 name: "review".into(),
732 output: String::new(),
733 success: false,
734 cost_microdollars: 0,
735 retries_used: 0,
736 skipped: true,
737 },
738 ],
739 status: ChainStatus::Cancelled,
740 };
741
742 let json = serde_json::to_string(&progress).unwrap();
743 assert!(json.contains("cancelled"));
744 assert!(json.contains("\"skipped\":true"));
745 }
746
747 #[test]
748 fn skipped_defaults_to_false_on_deserialize() {
749 let json =
750 r#"{"name":"s","output":"o","success":true,"cost_microdollars":0,"retries_used":0}"#;
751 let result: StepResult = serde_json::from_str(json).unwrap();
752 assert!(!result.skipped);
753 }
754
755 #[test]
756 fn extract_json_path_whole_output() {
757 let json = r#"{"a": 1}"#;
758 assert_eq!(extract_json_path(json, "."), Some(json.to_string()));
759 assert_eq!(extract_json_path(json, ""), Some(json.to_string()));
760 }
761
762 #[test]
763 fn extract_json_path_top_level_key() {
764 let json = r#"{"summary": "all good"}"#;
765 assert_eq!(
766 extract_json_path(json, "summary"),
767 Some("all good".to_string())
768 );
769 }
770
771 #[test]
772 fn extract_json_path_nested() {
773 let json = r#"{"result": {"count": 42}}"#;
774 assert_eq!(
775 extract_json_path(json, "result.count"),
776 Some("42".to_string())
777 );
778 }
779
780 #[test]
781 fn extract_json_path_not_json() {
782 assert_eq!(extract_json_path("not json", "key"), None);
783 }
784
785 #[test]
786 fn extract_json_path_missing_key() {
787 let json = r#"{"a": 1}"#;
788 assert_eq!(extract_json_path(json, "b"), None);
789 }
790
791 #[test]
792 fn expand_step_refs_substitutes() {
793 let mut ctx = HashMap::new();
794 ctx.insert("plan.summary".into(), "do stuff".into());
795 let text = "Based on {steps.plan.summary}, implement it.".to_string();
796 assert_eq!(
797 expand_step_refs(text, &ctx),
798 "Based on do stuff, implement it."
799 );
800 }
801
802 #[test]
803 fn expand_step_refs_unknown_left_as_is() {
804 let ctx = HashMap::new();
805 let text = "Use {steps.missing.var} here.".to_string();
806 assert_eq!(expand_step_refs(text.clone(), &ctx), text);
807 }
808
809 #[test]
810 fn chain_step_output_vars_defaults_empty() {
811 let json = r#"{"name":"s","action":{"type":"prompt","prompt":"hi"}}"#;
812 let step: ChainStep = serde_json::from_str(json).unwrap();
813 assert!(step.output_vars.is_empty());
814 }
815
816 #[test]
817 fn chain_step_serializes_output_vars() {
818 let mut vars = HashMap::new();
819 vars.insert("summary".into(), "result.summary".into());
820 let step = ChainStep {
821 name: "s".into(),
822 action: StepAction::Prompt {
823 prompt: "hi".into(),
824 },
825 config: None,
826 failure_policy: StepFailurePolicy::default(),
827 output_vars: vars,
828 };
829 let json = serde_json::to_string(&step).unwrap();
830 assert!(json.contains("output_vars"));
831 assert!(json.contains("result.summary"));
832
833 let parsed: ChainStep = serde_json::from_str(&json).unwrap();
834 assert_eq!(parsed.output_vars.get("summary").unwrap(), "result.summary");
835 }
836}