Skip to main content

claude_pool/
chain.rs

1//! Chain execution — sequential pipelines of tasks.
2//!
3//! A chain runs steps in order, feeding each step's output as context
4//! to the next. Steps can reference skills or use inline prompts.
5//!
6//! Chains can be run synchronously via [`execute_chain`] or submitted
7//! for async execution via [`Pool::submit_chain`](crate::Pool::submit_chain).
8
9use std::collections::HashMap;
10use std::sync::Arc;
11use std::time::{SystemTime, UNIX_EPOCH};
12
13use serde::{Deserialize, Serialize};
14
15use crate::pool::Pool;
16use crate::skill::SkillRegistry;
17use crate::store::PoolStore;
18use crate::types::{SlotConfig, TaskId, TaskState};
19
20/// A step in a chain pipeline.
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct ChainStep {
23    /// Step name (for logging and result tracking).
24    pub name: String,
25
26    /// Either an inline prompt or a skill reference.
27    pub action: StepAction,
28
29    /// Per-step config overrides (model, effort, etc.).
30    pub config: Option<SlotConfig>,
31
32    /// Failure policy for this step.
33    #[serde(default)]
34    pub failure_policy: StepFailurePolicy,
35
36    /// Extract named values from this step's JSON output for use in later steps.
37    ///
38    /// Key = variable name, Value = dot-path into the JSON output.
39    /// Use `"."` or `""` for the whole output. Use `"key"` for a top-level field.
40    /// Use `"a.b.c"` for nested access. String values are returned as-is; other
41    /// JSON types are serialized to their JSON representation.
42    ///
43    /// Extracted values are available in subsequent step prompts as
44    /// `{steps.STEP_NAME.VAR_NAME}`.
45    #[serde(default)]
46    pub output_vars: HashMap<String, String>,
47}
48
49/// What a chain step does.
50#[derive(Debug, Clone, Serialize, Deserialize)]
51#[serde(tag = "type", rename_all = "snake_case")]
52pub enum StepAction {
53    /// Run an inline prompt. `{previous_output}` is replaced with
54    /// the output from the prior step.
55    Prompt {
56        /// The prompt template.
57        prompt: String,
58    },
59    /// Run a registered skill with the given arguments.
60    /// The special argument `_previous_output` is automatically set
61    /// to the output from the prior step.
62    Skill {
63        /// Skill name.
64        skill: String,
65        /// Skill arguments.
66        #[serde(default)]
67        arguments: HashMap<String, String>,
68    },
69}
70
71/// Per-step failure handling policy.
72#[derive(Debug, Clone, Default, Serialize, Deserialize)]
73pub struct StepFailurePolicy {
74    /// Number of retries before giving up or recovering (default: 0).
75    #[serde(default)]
76    pub retries: u32,
77    /// If set, run this prompt on failure instead of failing the chain.
78    /// `{error}` is replaced with the error message, `{previous_output}`
79    /// with the last successful step's output.
80    pub recovery_prompt: Option<String>,
81}
82
83/// Isolation mode for a chain execution.
84#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
85#[serde(rename_all = "snake_case")]
86pub enum ChainIsolation {
87    /// Use the slot's working directory (no isolation).
88    None,
89    /// Create a temporary git worktree shared by all steps in the chain (default).
90    #[default]
91    Worktree,
92    /// Create a full clone with `git clone --local --shared` for complete isolation.
93    Clone,
94}
95
96/// Options for chain execution.
97#[derive(Debug, Clone, Default, Serialize, Deserialize)]
98pub struct ChainOptions {
99    /// Tags for the chain task (used when submitted async).
100    #[serde(default)]
101    pub tags: Vec<String>,
102    /// Isolation mode for this chain.
103    #[serde(default)]
104    pub isolation: ChainIsolation,
105}
106
107/// Result of a single chain step.
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct StepResult {
110    /// Step name.
111    pub name: String,
112    /// Output text from this step.
113    pub output: String,
114    /// Whether the step succeeded.
115    pub success: bool,
116    /// Cost in microdollars.
117    pub cost_microdollars: u64,
118    /// Number of retries used.
119    #[serde(default)]
120    pub retries_used: u32,
121    /// Whether this step was skipped due to chain cancellation.
122    #[serde(default)]
123    pub skipped: bool,
124}
125
126/// Result of a full chain execution.
127#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct ChainResult {
129    /// Per-step results in execution order.
130    pub steps: Vec<StepResult>,
131    /// Final output (from the last step).
132    pub final_output: String,
133    /// Total cost across all steps.
134    pub total_cost_microdollars: u64,
135    /// Whether all steps succeeded.
136    pub success: bool,
137}
138
139/// Progress of an in-flight chain.
140#[derive(Debug, Clone, Serialize, Deserialize)]
141pub struct ChainProgress {
142    /// Total number of steps.
143    pub total_steps: usize,
144    /// Index of the currently running step (0-based), or None if done.
145    pub current_step: Option<usize>,
146    /// Name of the currently running step.
147    pub current_step_name: Option<String>,
148    /// Live partial output from the currently running step.
149    ///
150    /// Updated incrementally as streaming output arrives. `None` when
151    /// no step is running (chain completed or not yet started).
152    #[serde(skip_serializing_if = "Option::is_none")]
153    pub current_step_partial_output: Option<String>,
154    /// Unix timestamp (seconds) when the current step started.
155    ///
156    /// Callers can compute elapsed time as `now - started_at`.
157    #[serde(skip_serializing_if = "Option::is_none")]
158    pub current_step_started_at: Option<u64>,
159    /// Completed step results so far.
160    pub completed_steps: Vec<StepResult>,
161    /// Overall status.
162    pub status: ChainStatus,
163}
164
165/// Status of a chain.
166#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
167#[serde(rename_all = "snake_case")]
168pub enum ChainStatus {
169    /// Chain is running.
170    Running,
171    /// All steps completed successfully.
172    Completed,
173    /// A step failed and the chain stopped.
174    Failed,
175    /// Chain was cancelled; remaining steps were skipped.
176    Cancelled,
177}
178
179/// Callback for receiving partial output chunks during streaming execution.
180pub type OnOutputChunk = Arc<dyn Fn(&str) + Send + Sync>;
181
182fn extract_json_path(json_str: &str, path: &str) -> Option<String> {
183    if path == "." || path.is_empty() {
184        return Some(json_str.to_string());
185    }
186    let value: serde_json::Value = serde_json::from_str(json_str).ok()?;
187    let mut current = &value;
188    for key in path.split('.') {
189        current = current.get(key)?;
190    }
191    Some(match current {
192        serde_json::Value::String(s) => s.clone(),
193        other => other.to_string(),
194    })
195}
196
197fn expand_step_refs(mut text: String, step_context: &HashMap<String, String>) -> String {
198    for (key, value) in step_context {
199        text = text.replace(&format!("{{steps.{key}}}"), value);
200    }
201    text
202}
203
204fn unix_secs_now() -> u64 {
205    SystemTime::now()
206        .duration_since(UNIX_EPOCH)
207        .unwrap_or_default()
208        .as_secs()
209}
210
211/// Execute a chain of steps against the pool.
212pub async fn execute_chain<S: PoolStore + 'static>(
213    pool: &Pool<S>,
214    skills: &SkillRegistry,
215    steps: &[ChainStep],
216) -> crate::Result<ChainResult> {
217    execute_chain_with_progress(pool, skills, steps, None, None).await
218}
219
220/// Execute a chain with optional progress tracking.
221///
222/// If `chain_task_id` is provided, intermediate progress is stored so callers
223/// can poll for status. When a chain task ID is present, steps execute with
224/// streaming output so partial results are visible via
225/// [`Pool::chain_progress`](crate::Pool::chain_progress). If `working_dir`
226/// is provided, all steps use that directory instead of the slot's default.
227pub async fn execute_chain_with_progress<S: PoolStore + 'static>(
228    pool: &Pool<S>,
229    skills: &SkillRegistry,
230    steps: &[ChainStep],
231    chain_task_id: Option<&TaskId>,
232    working_dir: Option<&std::path::Path>,
233) -> crate::Result<ChainResult> {
234    let mut step_results = Vec::with_capacity(steps.len());
235    let mut previous_output = String::new();
236    let mut total_cost = 0u64;
237    let mut step_context: HashMap<String, String> = HashMap::new();
238
239    for (step_idx, step) in steps.iter().enumerate() {
240        // Check for cancellation before starting each step.
241        if let Some(task_id) = chain_task_id
242            && let Ok(Some(task)) = pool.store().get_task(task_id).await
243            && task.state == TaskState::Cancelled
244        {
245            for s in &steps[step_idx..] {
246                step_results.push(StepResult {
247                    name: s.name.clone(),
248                    output: String::new(),
249                    success: false,
250                    cost_microdollars: 0,
251                    retries_used: 0,
252                    skipped: true,
253                });
254            }
255            update_chain_progress_final(
256                pool,
257                Some(task_id),
258                steps.len(),
259                &step_results,
260                ChainStatus::Cancelled,
261            )
262            .await;
263            return Ok(ChainResult {
264                final_output: previous_output,
265                steps: step_results,
266                total_cost_microdollars: total_cost,
267                success: false,
268            });
269        }
270
271        // Update progress in the store if we have a task ID.
272        if let Some(task_id) = chain_task_id {
273            let progress = ChainProgress {
274                total_steps: steps.len(),
275                current_step: Some(step_idx),
276                current_step_name: Some(step.name.clone()),
277                current_step_partial_output: Some(String::new()),
278                current_step_started_at: Some(unix_secs_now()),
279                completed_steps: step_results.clone(),
280                status: ChainStatus::Running,
281            };
282            pool.set_chain_progress(task_id, progress).await;
283        }
284
285        let prompt = render_step_prompt(step, &previous_output, skills, &step_context)?;
286
287        // Build an output callback that updates chain progress when we have a task ID.
288        let on_output: Option<OnOutputChunk> = chain_task_id.map(|tid| {
289            let pool = pool.clone();
290            let tid = tid.clone();
291            Arc::new(move |chunk: &str| {
292                pool.append_chain_partial_output(&tid, chunk);
293            }) as OnOutputChunk
294        });
295
296        let (step_result, step_cost) = execute_step_with_retries(
297            pool,
298            step,
299            &prompt,
300            &previous_output,
301            skills,
302            on_output.clone(),
303            working_dir,
304            &step_context,
305        )
306        .await;
307
308        total_cost += step_cost;
309
310        match step_result {
311            Ok(result) => {
312                previous_output = result.output.clone();
313
314                if result.success {
315                    for (var_name, path) in &step.output_vars {
316                        match extract_json_path(&result.output, path) {
317                            Some(extracted) => {
318                                step_context
319                                    .insert(format!("{}.{}", step.name, var_name), extracted);
320                            }
321                            None => {
322                                tracing::warn!(
323                                    step = %step.name,
324                                    var = %var_name,
325                                    path = %path,
326                                    "output_var extraction failed (output not JSON or path not found)"
327                                );
328                            }
329                        }
330                    }
331                }
332
333                step_results.push(result);
334
335                if !step_results.last().unwrap().success {
336                    update_chain_progress_final(
337                        pool,
338                        chain_task_id,
339                        steps.len(),
340                        &step_results,
341                        ChainStatus::Failed,
342                    )
343                    .await;
344                    return Ok(ChainResult {
345                        final_output: previous_output,
346                        steps: step_results,
347                        total_cost_microdollars: total_cost,
348                        success: false,
349                    });
350                }
351            }
352            Err(output) => {
353                step_results.push(StepResult {
354                    name: step.name.clone(),
355                    output: output.clone(),
356                    success: false,
357                    cost_microdollars: 0,
358                    retries_used: step.failure_policy.retries,
359                    skipped: false,
360                });
361                update_chain_progress_final(
362                    pool,
363                    chain_task_id,
364                    steps.len(),
365                    &step_results,
366                    ChainStatus::Failed,
367                )
368                .await;
369                return Ok(ChainResult {
370                    final_output: output,
371                    steps: step_results,
372                    total_cost_microdollars: total_cost,
373                    success: false,
374                });
375            }
376        }
377    }
378
379    update_chain_progress_final(
380        pool,
381        chain_task_id,
382        steps.len(),
383        &step_results,
384        ChainStatus::Completed,
385    )
386    .await;
387
388    Ok(ChainResult {
389        final_output: previous_output,
390        steps: step_results,
391        total_cost_microdollars: total_cost,
392        success: true,
393    })
394}
395
396/// Render the prompt for a step, substituting `{previous_output}` and step refs.
397fn render_step_prompt(
398    step: &ChainStep,
399    previous_output: &str,
400    skills: &SkillRegistry,
401    step_context: &HashMap<String, String>,
402) -> crate::Result<String> {
403    match &step.action {
404        StepAction::Prompt { prompt } => {
405            let rendered = prompt.replace("{previous_output}", previous_output);
406            Ok(expand_step_refs(rendered, step_context))
407        }
408        StepAction::Skill { skill, arguments } => {
409            let skill_def = skills
410                .get(skill)
411                .ok_or_else(|| crate::Error::Store(format!("skill not found: {skill}")))?;
412            let mut args = arguments.clone();
413            if !previous_output.is_empty() {
414                args.entry("_previous_output".into())
415                    .or_insert(previous_output.to_string());
416            }
417            let rendered = skill_def.render(&args)?;
418            Ok(expand_step_refs(rendered, step_context))
419        }
420    }
421}
422
423/// Execute a step with retry and recovery support.
424///
425/// Returns `Ok(StepResult)` on success (or successful recovery), or
426/// `Err(error_message)` if all retries and recovery are exhausted.
427#[allow(clippy::too_many_arguments)]
428async fn execute_step_with_retries<S: PoolStore + 'static>(
429    pool: &Pool<S>,
430    step: &ChainStep,
431    initial_prompt: &str,
432    previous_output: &str,
433    skills: &SkillRegistry,
434    on_output: Option<OnOutputChunk>,
435    working_dir: Option<&std::path::Path>,
436    step_context: &HashMap<String, String>,
437) -> (std::result::Result<StepResult, String>, u64) {
438    let max_attempts = 1 + step.failure_policy.retries;
439    let mut total_cost = 0u64;
440    let mut last_error = String::new();
441
442    for attempt in 0..max_attempts {
443        let prompt = if attempt == 0 {
444            initial_prompt.to_string()
445        } else {
446            // Re-render the prompt for retries (same prompt, fresh attempt).
447            match render_step_prompt(step, previous_output, skills, step_context) {
448                Ok(p) => p,
449                Err(e) => return (Err(e.to_string()), total_cost),
450            }
451        };
452
453        let result = pool
454            .run_with_config_streaming(
455                &prompt,
456                step.config.clone(),
457                on_output.clone(),
458                working_dir.map(|p| p.to_path_buf()),
459            )
460            .await;
461
462        match result {
463            Ok(task_result) => {
464                total_cost += task_result.cost_microdollars;
465                if task_result.success {
466                    return (
467                        Ok(StepResult {
468                            name: step.name.clone(),
469                            output: task_result.output,
470                            success: true,
471                            cost_microdollars: total_cost,
472                            retries_used: attempt,
473                            skipped: false,
474                        }),
475                        total_cost,
476                    );
477                }
478                // Task ran but reported failure.
479                last_error = task_result.output;
480            }
481            Err(e) => {
482                last_error = e.to_string();
483            }
484        }
485
486        tracing::warn!(
487            step = %step.name,
488            attempt = attempt + 1,
489            max_attempts,
490            "chain step failed, will retry"
491        );
492    }
493
494    // All retries exhausted. Try recovery prompt if configured.
495    if let Some(ref recovery_template) = step.failure_policy.recovery_prompt {
496        let recovery_prompt = expand_step_refs(
497            recovery_template
498                .replace("{error}", &last_error)
499                .replace("{previous_output}", previous_output),
500            step_context,
501        );
502
503        tracing::info!(step = %step.name, "attempting recovery prompt");
504
505        let result = pool
506            .run_with_config_streaming(
507                &recovery_prompt,
508                step.config.clone(),
509                on_output,
510                working_dir.map(|p| p.to_path_buf()),
511            )
512            .await;
513
514        match result {
515            Ok(task_result) => {
516                total_cost += task_result.cost_microdollars;
517                return (
518                    Ok(StepResult {
519                        name: step.name.clone(),
520                        output: task_result.output,
521                        success: task_result.success,
522                        cost_microdollars: total_cost,
523                        retries_used: max_attempts,
524                        skipped: false,
525                    }),
526                    total_cost,
527                );
528            }
529            Err(e) => {
530                last_error = e.to_string();
531            }
532        }
533    }
534
535    (Err(last_error), total_cost)
536}
537
538/// Update chain progress to a terminal state.
539async fn update_chain_progress_final<S: PoolStore + 'static>(
540    pool: &Pool<S>,
541    chain_task_id: Option<&TaskId>,
542    total_steps: usize,
543    completed_steps: &[StepResult],
544    status: ChainStatus,
545) {
546    if let Some(task_id) = chain_task_id {
547        let progress = ChainProgress {
548            total_steps,
549            current_step: None,
550            current_step_name: None,
551            current_step_partial_output: None,
552            current_step_started_at: None,
553            completed_steps: completed_steps.to_vec(),
554            status,
555        };
556        pool.set_chain_progress(task_id, progress).await;
557    }
558}
559
560#[cfg(test)]
561mod tests {
562    use super::*;
563
564    #[test]
565    fn prompt_step_replaces_previous_output() {
566        let step = ChainStep {
567            name: "step1".into(),
568            action: StepAction::Prompt {
569                prompt: "Based on: {previous_output}\nDo more.".into(),
570            },
571            config: None,
572            failure_policy: StepFailurePolicy::default(),
573            output_vars: Default::default(),
574        };
575
576        if let StepAction::Prompt { prompt } = &step.action {
577            let rendered = prompt.replace("{previous_output}", "hello world");
578            assert_eq!(rendered, "Based on: hello world\nDo more.");
579        }
580    }
581
582    #[test]
583    fn chain_result_serializes() {
584        let result = ChainResult {
585            steps: vec![StepResult {
586                name: "step1".into(),
587                output: "done".into(),
588                success: true,
589                cost_microdollars: 1000,
590                retries_used: 0,
591                skipped: false,
592            }],
593            final_output: "done".into(),
594            total_cost_microdollars: 1000,
595            success: true,
596        };
597
598        let json = serde_json::to_string(&result).unwrap();
599        assert!(json.contains("step1"));
600    }
601
602    #[test]
603    fn step_failure_policy_defaults() {
604        let policy = StepFailurePolicy::default();
605        assert_eq!(policy.retries, 0);
606        assert!(policy.recovery_prompt.is_none());
607    }
608
609    #[test]
610    fn chain_options_defaults() {
611        let opts = ChainOptions::default();
612        assert!(opts.tags.is_empty());
613        assert_eq!(opts.isolation, ChainIsolation::Worktree);
614    }
615
616    #[test]
617    fn chain_isolation_serde_roundtrip() {
618        let worktree = ChainIsolation::Worktree;
619        let json = serde_json::to_string(&worktree).unwrap();
620        assert_eq!(json, r#""worktree""#);
621
622        let none = ChainIsolation::None;
623        let json = serde_json::to_string(&none).unwrap();
624        assert_eq!(json, r#""none""#);
625
626        let parsed: ChainIsolation = serde_json::from_str(r#""worktree""#).unwrap();
627        assert_eq!(parsed, ChainIsolation::Worktree);
628
629        let parsed: ChainIsolation = serde_json::from_str(r#""none""#).unwrap();
630        assert_eq!(parsed, ChainIsolation::None);
631    }
632
633    #[test]
634    fn chain_options_with_isolation_serializes() {
635        let opts = ChainOptions {
636            tags: vec!["test".into()],
637            isolation: ChainIsolation::Worktree,
638        };
639        let json = serde_json::to_string(&opts).unwrap();
640        let parsed: ChainOptions = serde_json::from_str(&json).unwrap();
641        assert_eq!(parsed.isolation, ChainIsolation::Worktree);
642        assert_eq!(parsed.tags, vec!["test"]);
643    }
644
645    #[test]
646    fn chain_progress_serializes_with_partial_output() {
647        let progress = ChainProgress {
648            total_steps: 3,
649            current_step: Some(1),
650            current_step_name: Some("implement".into()),
651            current_step_partial_output: Some("partial text".into()),
652            current_step_started_at: Some(1700000000),
653            completed_steps: vec![StepResult {
654                name: "plan".into(),
655                output: "planned".into(),
656                success: true,
657                cost_microdollars: 500,
658                retries_used: 0,
659                skipped: false,
660            }],
661            status: ChainStatus::Running,
662        };
663
664        let json = serde_json::to_string(&progress).unwrap();
665        assert!(json.contains("implement"));
666        assert!(json.contains("running"));
667        assert!(json.contains("partial text"));
668        assert!(json.contains("1700000000"));
669    }
670
671    #[test]
672    fn chain_progress_omits_none_fields() {
673        let progress = ChainProgress {
674            total_steps: 2,
675            current_step: None,
676            current_step_name: None,
677            current_step_partial_output: None,
678            current_step_started_at: None,
679            completed_steps: vec![],
680            status: ChainStatus::Completed,
681        };
682
683        let json = serde_json::to_string(&progress).unwrap();
684        assert!(!json.contains("current_step_partial_output"));
685        assert!(!json.contains("current_step_started_at"));
686    }
687
688    #[test]
689    fn chain_progress_empty_partial_output_when_step_starts() {
690        let progress = ChainProgress {
691            total_steps: 3,
692            current_step: Some(0),
693            current_step_name: Some("plan".into()),
694            current_step_partial_output: Some(String::new()),
695            current_step_started_at: Some(1700000000),
696            completed_steps: vec![],
697            status: ChainStatus::Running,
698        };
699
700        let json = serde_json::to_string(&progress).unwrap();
701        // Empty string is still serialized (not None).
702        assert!(json.contains("\"current_step_partial_output\":\"\""));
703    }
704
705    #[test]
706    fn cancelled_status_serializes() {
707        let progress = ChainProgress {
708            total_steps: 3,
709            current_step: None,
710            current_step_name: None,
711            current_step_partial_output: None,
712            current_step_started_at: None,
713            completed_steps: vec![
714                StepResult {
715                    name: "plan".into(),
716                    output: "planned".into(),
717                    success: true,
718                    cost_microdollars: 500,
719                    retries_used: 0,
720                    skipped: false,
721                },
722                StepResult {
723                    name: "implement".into(),
724                    output: String::new(),
725                    success: false,
726                    cost_microdollars: 0,
727                    retries_used: 0,
728                    skipped: true,
729                },
730                StepResult {
731                    name: "review".into(),
732                    output: String::new(),
733                    success: false,
734                    cost_microdollars: 0,
735                    retries_used: 0,
736                    skipped: true,
737                },
738            ],
739            status: ChainStatus::Cancelled,
740        };
741
742        let json = serde_json::to_string(&progress).unwrap();
743        assert!(json.contains("cancelled"));
744        assert!(json.contains("\"skipped\":true"));
745    }
746
747    #[test]
748    fn skipped_defaults_to_false_on_deserialize() {
749        let json =
750            r#"{"name":"s","output":"o","success":true,"cost_microdollars":0,"retries_used":0}"#;
751        let result: StepResult = serde_json::from_str(json).unwrap();
752        assert!(!result.skipped);
753    }
754
755    #[test]
756    fn extract_json_path_whole_output() {
757        let json = r#"{"a": 1}"#;
758        assert_eq!(extract_json_path(json, "."), Some(json.to_string()));
759        assert_eq!(extract_json_path(json, ""), Some(json.to_string()));
760    }
761
762    #[test]
763    fn extract_json_path_top_level_key() {
764        let json = r#"{"summary": "all good"}"#;
765        assert_eq!(
766            extract_json_path(json, "summary"),
767            Some("all good".to_string())
768        );
769    }
770
771    #[test]
772    fn extract_json_path_nested() {
773        let json = r#"{"result": {"count": 42}}"#;
774        assert_eq!(
775            extract_json_path(json, "result.count"),
776            Some("42".to_string())
777        );
778    }
779
780    #[test]
781    fn extract_json_path_not_json() {
782        assert_eq!(extract_json_path("not json", "key"), None);
783    }
784
785    #[test]
786    fn extract_json_path_missing_key() {
787        let json = r#"{"a": 1}"#;
788        assert_eq!(extract_json_path(json, "b"), None);
789    }
790
791    #[test]
792    fn expand_step_refs_substitutes() {
793        let mut ctx = HashMap::new();
794        ctx.insert("plan.summary".into(), "do stuff".into());
795        let text = "Based on {steps.plan.summary}, implement it.".to_string();
796        assert_eq!(
797            expand_step_refs(text, &ctx),
798            "Based on do stuff, implement it."
799        );
800    }
801
802    #[test]
803    fn expand_step_refs_unknown_left_as_is() {
804        let ctx = HashMap::new();
805        let text = "Use {steps.missing.var} here.".to_string();
806        assert_eq!(expand_step_refs(text.clone(), &ctx), text);
807    }
808
809    #[test]
810    fn chain_step_output_vars_defaults_empty() {
811        let json = r#"{"name":"s","action":{"type":"prompt","prompt":"hi"}}"#;
812        let step: ChainStep = serde_json::from_str(json).unwrap();
813        assert!(step.output_vars.is_empty());
814    }
815
816    #[test]
817    fn chain_step_serializes_output_vars() {
818        let mut vars = HashMap::new();
819        vars.insert("summary".into(), "result.summary".into());
820        let step = ChainStep {
821            name: "s".into(),
822            action: StepAction::Prompt {
823                prompt: "hi".into(),
824            },
825            config: None,
826            failure_policy: StepFailurePolicy::default(),
827            output_vars: vars,
828        };
829        let json = serde_json::to_string(&step).unwrap();
830        assert!(json.contains("output_vars"));
831        assert!(json.contains("result.summary"));
832
833        let parsed: ChainStep = serde_json::from_str(&json).unwrap();
834        assert_eq!(parsed.output_vars.get("summary").unwrap(), "result.summary");
835    }
836}