Skip to main content

claude_pool/
chain.rs

1//! Chain execution — sequential pipelines of tasks.
2//!
3//! A chain runs steps in order, feeding each step's output as context
4//! to the next. Steps can reference skills or use inline prompts.
5//!
6//! Chains can be run synchronously via [`execute_chain`] or submitted
7//! for async execution via [`Pool::submit_chain`](crate::Pool::submit_chain).
8
9use std::collections::HashMap;
10use std::sync::Arc;
11use std::time::{SystemTime, UNIX_EPOCH};
12
13use serde::{Deserialize, Serialize};
14
15use crate::pool::Pool;
16use crate::skill::SkillRegistry;
17use crate::store::PoolStore;
18use crate::types::{SlotConfig, TaskId, TaskState};
19
20/// A step in a chain pipeline.
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct ChainStep {
23    /// Step name (for logging and result tracking).
24    pub name: String,
25
26    /// Either an inline prompt or a skill reference.
27    pub action: StepAction,
28
29    /// Per-step config overrides (model, effort, etc.).
30    pub config: Option<SlotConfig>,
31
32    /// Failure policy for this step.
33    #[serde(default)]
34    pub failure_policy: StepFailurePolicy,
35
36    /// Extract named values from this step's JSON output for use in later steps.
37    ///
38    /// Key = variable name, Value = dot-path into the JSON output.
39    /// Use `"."` or `""` for the whole output. Use `"key"` for a top-level field.
40    /// Use `"a.b.c"` for nested access. String values are returned as-is; other
41    /// JSON types are serialized to their JSON representation.
42    ///
43    /// Extracted values are available in subsequent step prompts as
44    /// `{steps.STEP_NAME.VAR_NAME}`.
45    #[serde(default)]
46    pub output_vars: HashMap<String, String>,
47}
48
49/// What a chain step does.
50#[derive(Debug, Clone, Serialize, Deserialize)]
51#[serde(tag = "type", rename_all = "snake_case")]
52pub enum StepAction {
53    /// Run an inline prompt. `{previous_output}` is replaced with
54    /// the output from the prior step.
55    Prompt {
56        /// The prompt template.
57        prompt: String,
58    },
59    /// Run a registered skill with the given arguments.
60    /// The special argument `_previous_output` is automatically set
61    /// to the output from the prior step.
62    Skill {
63        /// Skill name.
64        skill: String,
65        /// Skill arguments.
66        #[serde(default)]
67        arguments: HashMap<String, String>,
68    },
69}
70
71/// Per-step failure handling policy.
72#[derive(Debug, Clone, Default, Serialize, Deserialize)]
73pub struct StepFailurePolicy {
74    /// Number of retries before giving up or recovering (default: 0).
75    #[serde(default)]
76    pub retries: u32,
77    /// If set, run this prompt on failure instead of failing the chain.
78    /// `{error}` is replaced with the error message, `{previous_output}`
79    /// with the last successful step's output.
80    pub recovery_prompt: Option<String>,
81}
82
83/// Isolation mode for a chain execution.
84#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
85#[serde(rename_all = "snake_case")]
86pub enum ChainIsolation {
87    /// Use the slot's working directory (no isolation).
88    None,
89    /// Create a temporary git worktree shared by all steps in the chain (default).
90    #[default]
91    Worktree,
92}
93
94/// Options for chain execution.
95#[derive(Debug, Clone, Default, Serialize, Deserialize)]
96pub struct ChainOptions {
97    /// Tags for the chain task (used when submitted async).
98    #[serde(default)]
99    pub tags: Vec<String>,
100    /// Isolation mode for this chain.
101    #[serde(default)]
102    pub isolation: ChainIsolation,
103}
104
105/// Result of a single chain step.
106#[derive(Debug, Clone, Serialize, Deserialize)]
107pub struct StepResult {
108    /// Step name.
109    pub name: String,
110    /// Output text from this step.
111    pub output: String,
112    /// Whether the step succeeded.
113    pub success: bool,
114    /// Cost in microdollars.
115    pub cost_microdollars: u64,
116    /// Number of retries used.
117    #[serde(default)]
118    pub retries_used: u32,
119    /// Whether this step was skipped due to chain cancellation.
120    #[serde(default)]
121    pub skipped: bool,
122}
123
124/// Result of a full chain execution.
125#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct ChainResult {
127    /// Per-step results in execution order.
128    pub steps: Vec<StepResult>,
129    /// Final output (from the last step).
130    pub final_output: String,
131    /// Total cost across all steps.
132    pub total_cost_microdollars: u64,
133    /// Whether all steps succeeded.
134    pub success: bool,
135}
136
137/// Progress of an in-flight chain.
138#[derive(Debug, Clone, Serialize, Deserialize)]
139pub struct ChainProgress {
140    /// Total number of steps.
141    pub total_steps: usize,
142    /// Index of the currently running step (0-based), or None if done.
143    pub current_step: Option<usize>,
144    /// Name of the currently running step.
145    pub current_step_name: Option<String>,
146    /// Live partial output from the currently running step.
147    ///
148    /// Updated incrementally as streaming output arrives. `None` when
149    /// no step is running (chain completed or not yet started).
150    #[serde(skip_serializing_if = "Option::is_none")]
151    pub current_step_partial_output: Option<String>,
152    /// Unix timestamp (seconds) when the current step started.
153    ///
154    /// Callers can compute elapsed time as `now - started_at`.
155    #[serde(skip_serializing_if = "Option::is_none")]
156    pub current_step_started_at: Option<u64>,
157    /// Completed step results so far.
158    pub completed_steps: Vec<StepResult>,
159    /// Overall status.
160    pub status: ChainStatus,
161}
162
163/// Status of a chain.
164#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
165#[serde(rename_all = "snake_case")]
166pub enum ChainStatus {
167    /// Chain is running.
168    Running,
169    /// All steps completed successfully.
170    Completed,
171    /// A step failed and the chain stopped.
172    Failed,
173    /// Chain was cancelled; remaining steps were skipped.
174    Cancelled,
175}
176
177/// Callback for receiving partial output chunks during streaming execution.
178pub type OnOutputChunk = Arc<dyn Fn(&str) + Send + Sync>;
179
180fn extract_json_path(json_str: &str, path: &str) -> Option<String> {
181    if path == "." || path.is_empty() {
182        return Some(json_str.to_string());
183    }
184    let value: serde_json::Value = serde_json::from_str(json_str).ok()?;
185    let mut current = &value;
186    for key in path.split('.') {
187        current = current.get(key)?;
188    }
189    Some(match current {
190        serde_json::Value::String(s) => s.clone(),
191        other => other.to_string(),
192    })
193}
194
195fn expand_step_refs(mut text: String, step_context: &HashMap<String, String>) -> String {
196    for (key, value) in step_context {
197        text = text.replace(&format!("{{steps.{key}}}"), value);
198    }
199    text
200}
201
202fn unix_secs_now() -> u64 {
203    SystemTime::now()
204        .duration_since(UNIX_EPOCH)
205        .unwrap_or_default()
206        .as_secs()
207}
208
209/// Execute a chain of steps against the pool.
210pub async fn execute_chain<S: PoolStore + 'static>(
211    pool: &Pool<S>,
212    skills: &SkillRegistry,
213    steps: &[ChainStep],
214) -> crate::Result<ChainResult> {
215    execute_chain_with_progress(pool, skills, steps, None, None).await
216}
217
218/// Execute a chain with optional progress tracking.
219///
220/// If `chain_task_id` is provided, intermediate progress is stored so callers
221/// can poll for status. When a chain task ID is present, steps execute with
222/// streaming output so partial results are visible via
223/// [`Pool::chain_progress`](crate::Pool::chain_progress). If `working_dir`
224/// is provided, all steps use that directory instead of the slot's default.
225pub async fn execute_chain_with_progress<S: PoolStore + 'static>(
226    pool: &Pool<S>,
227    skills: &SkillRegistry,
228    steps: &[ChainStep],
229    chain_task_id: Option<&TaskId>,
230    working_dir: Option<&std::path::Path>,
231) -> crate::Result<ChainResult> {
232    let mut step_results = Vec::with_capacity(steps.len());
233    let mut previous_output = String::new();
234    let mut total_cost = 0u64;
235    let mut step_context: HashMap<String, String> = HashMap::new();
236
237    for (step_idx, step) in steps.iter().enumerate() {
238        // Check for cancellation before starting each step.
239        if let Some(task_id) = chain_task_id
240            && let Ok(Some(task)) = pool.store().get_task(task_id).await
241            && task.state == TaskState::Cancelled
242        {
243            for s in &steps[step_idx..] {
244                step_results.push(StepResult {
245                    name: s.name.clone(),
246                    output: String::new(),
247                    success: false,
248                    cost_microdollars: 0,
249                    retries_used: 0,
250                    skipped: true,
251                });
252            }
253            update_chain_progress_final(
254                pool,
255                Some(task_id),
256                steps.len(),
257                &step_results,
258                ChainStatus::Cancelled,
259            )
260            .await;
261            return Ok(ChainResult {
262                final_output: previous_output,
263                steps: step_results,
264                total_cost_microdollars: total_cost,
265                success: false,
266            });
267        }
268
269        // Update progress in the store if we have a task ID.
270        if let Some(task_id) = chain_task_id {
271            let progress = ChainProgress {
272                total_steps: steps.len(),
273                current_step: Some(step_idx),
274                current_step_name: Some(step.name.clone()),
275                current_step_partial_output: Some(String::new()),
276                current_step_started_at: Some(unix_secs_now()),
277                completed_steps: step_results.clone(),
278                status: ChainStatus::Running,
279            };
280            pool.set_chain_progress(task_id, progress).await;
281        }
282
283        let prompt = render_step_prompt(step, &previous_output, skills, &step_context)?;
284
285        // Build an output callback that updates chain progress when we have a task ID.
286        let on_output: Option<OnOutputChunk> = chain_task_id.map(|tid| {
287            let pool = pool.clone();
288            let tid = tid.clone();
289            Arc::new(move |chunk: &str| {
290                pool.append_chain_partial_output(&tid, chunk);
291            }) as OnOutputChunk
292        });
293
294        let (step_result, step_cost) = execute_step_with_retries(
295            pool,
296            step,
297            &prompt,
298            &previous_output,
299            skills,
300            on_output.clone(),
301            working_dir,
302            &step_context,
303        )
304        .await;
305
306        total_cost += step_cost;
307
308        match step_result {
309            Ok(result) => {
310                previous_output = result.output.clone();
311
312                if result.success {
313                    for (var_name, path) in &step.output_vars {
314                        match extract_json_path(&result.output, path) {
315                            Some(extracted) => {
316                                step_context
317                                    .insert(format!("{}.{}", step.name, var_name), extracted);
318                            }
319                            None => {
320                                tracing::warn!(
321                                    step = %step.name,
322                                    var = %var_name,
323                                    path = %path,
324                                    "output_var extraction failed (output not JSON or path not found)"
325                                );
326                            }
327                        }
328                    }
329                }
330
331                step_results.push(result);
332
333                if !step_results.last().unwrap().success {
334                    update_chain_progress_final(
335                        pool,
336                        chain_task_id,
337                        steps.len(),
338                        &step_results,
339                        ChainStatus::Failed,
340                    )
341                    .await;
342                    return Ok(ChainResult {
343                        final_output: previous_output,
344                        steps: step_results,
345                        total_cost_microdollars: total_cost,
346                        success: false,
347                    });
348                }
349            }
350            Err(output) => {
351                step_results.push(StepResult {
352                    name: step.name.clone(),
353                    output: output.clone(),
354                    success: false,
355                    cost_microdollars: 0,
356                    retries_used: step.failure_policy.retries,
357                    skipped: false,
358                });
359                update_chain_progress_final(
360                    pool,
361                    chain_task_id,
362                    steps.len(),
363                    &step_results,
364                    ChainStatus::Failed,
365                )
366                .await;
367                return Ok(ChainResult {
368                    final_output: output,
369                    steps: step_results,
370                    total_cost_microdollars: total_cost,
371                    success: false,
372                });
373            }
374        }
375    }
376
377    update_chain_progress_final(
378        pool,
379        chain_task_id,
380        steps.len(),
381        &step_results,
382        ChainStatus::Completed,
383    )
384    .await;
385
386    Ok(ChainResult {
387        final_output: previous_output,
388        steps: step_results,
389        total_cost_microdollars: total_cost,
390        success: true,
391    })
392}
393
394/// Render the prompt for a step, substituting `{previous_output}` and step refs.
395fn render_step_prompt(
396    step: &ChainStep,
397    previous_output: &str,
398    skills: &SkillRegistry,
399    step_context: &HashMap<String, String>,
400) -> crate::Result<String> {
401    match &step.action {
402        StepAction::Prompt { prompt } => {
403            let rendered = prompt.replace("{previous_output}", previous_output);
404            Ok(expand_step_refs(rendered, step_context))
405        }
406        StepAction::Skill { skill, arguments } => {
407            let skill_def = skills
408                .get(skill)
409                .ok_or_else(|| crate::Error::Store(format!("skill not found: {skill}")))?;
410            let mut args = arguments.clone();
411            if !previous_output.is_empty() {
412                args.entry("_previous_output".into())
413                    .or_insert(previous_output.to_string());
414            }
415            let rendered = skill_def.render(&args)?;
416            Ok(expand_step_refs(rendered, step_context))
417        }
418    }
419}
420
421/// Execute a step with retry and recovery support.
422///
423/// Returns `Ok(StepResult)` on success (or successful recovery), or
424/// `Err(error_message)` if all retries and recovery are exhausted.
425#[allow(clippy::too_many_arguments)]
426async fn execute_step_with_retries<S: PoolStore + 'static>(
427    pool: &Pool<S>,
428    step: &ChainStep,
429    initial_prompt: &str,
430    previous_output: &str,
431    skills: &SkillRegistry,
432    on_output: Option<OnOutputChunk>,
433    working_dir: Option<&std::path::Path>,
434    step_context: &HashMap<String, String>,
435) -> (std::result::Result<StepResult, String>, u64) {
436    let max_attempts = 1 + step.failure_policy.retries;
437    let mut total_cost = 0u64;
438    let mut last_error = String::new();
439
440    for attempt in 0..max_attempts {
441        let prompt = if attempt == 0 {
442            initial_prompt.to_string()
443        } else {
444            // Re-render the prompt for retries (same prompt, fresh attempt).
445            match render_step_prompt(step, previous_output, skills, step_context) {
446                Ok(p) => p,
447                Err(e) => return (Err(e.to_string()), total_cost),
448            }
449        };
450
451        let result = pool
452            .run_with_config_streaming(
453                &prompt,
454                step.config.clone(),
455                on_output.clone(),
456                working_dir.map(|p| p.to_path_buf()),
457            )
458            .await;
459
460        match result {
461            Ok(task_result) => {
462                total_cost += task_result.cost_microdollars;
463                if task_result.success {
464                    return (
465                        Ok(StepResult {
466                            name: step.name.clone(),
467                            output: task_result.output,
468                            success: true,
469                            cost_microdollars: total_cost,
470                            retries_used: attempt,
471                            skipped: false,
472                        }),
473                        total_cost,
474                    );
475                }
476                // Task ran but reported failure.
477                last_error = task_result.output;
478            }
479            Err(e) => {
480                last_error = e.to_string();
481            }
482        }
483
484        tracing::warn!(
485            step = %step.name,
486            attempt = attempt + 1,
487            max_attempts,
488            "chain step failed, will retry"
489        );
490    }
491
492    // All retries exhausted. Try recovery prompt if configured.
493    if let Some(ref recovery_template) = step.failure_policy.recovery_prompt {
494        let recovery_prompt = expand_step_refs(
495            recovery_template
496                .replace("{error}", &last_error)
497                .replace("{previous_output}", previous_output),
498            step_context,
499        );
500
501        tracing::info!(step = %step.name, "attempting recovery prompt");
502
503        let result = pool
504            .run_with_config_streaming(
505                &recovery_prompt,
506                step.config.clone(),
507                on_output,
508                working_dir.map(|p| p.to_path_buf()),
509            )
510            .await;
511
512        match result {
513            Ok(task_result) => {
514                total_cost += task_result.cost_microdollars;
515                return (
516                    Ok(StepResult {
517                        name: step.name.clone(),
518                        output: task_result.output,
519                        success: task_result.success,
520                        cost_microdollars: total_cost,
521                        retries_used: max_attempts,
522                        skipped: false,
523                    }),
524                    total_cost,
525                );
526            }
527            Err(e) => {
528                last_error = e.to_string();
529            }
530        }
531    }
532
533    (Err(last_error), total_cost)
534}
535
536/// Update chain progress to a terminal state.
537async fn update_chain_progress_final<S: PoolStore + 'static>(
538    pool: &Pool<S>,
539    chain_task_id: Option<&TaskId>,
540    total_steps: usize,
541    completed_steps: &[StepResult],
542    status: ChainStatus,
543) {
544    if let Some(task_id) = chain_task_id {
545        let progress = ChainProgress {
546            total_steps,
547            current_step: None,
548            current_step_name: None,
549            current_step_partial_output: None,
550            current_step_started_at: None,
551            completed_steps: completed_steps.to_vec(),
552            status,
553        };
554        pool.set_chain_progress(task_id, progress).await;
555    }
556}
557
558#[cfg(test)]
559mod tests {
560    use super::*;
561
562    #[test]
563    fn prompt_step_replaces_previous_output() {
564        let step = ChainStep {
565            name: "step1".into(),
566            action: StepAction::Prompt {
567                prompt: "Based on: {previous_output}\nDo more.".into(),
568            },
569            config: None,
570            failure_policy: StepFailurePolicy::default(),
571            output_vars: Default::default(),
572        };
573
574        if let StepAction::Prompt { prompt } = &step.action {
575            let rendered = prompt.replace("{previous_output}", "hello world");
576            assert_eq!(rendered, "Based on: hello world\nDo more.");
577        }
578    }
579
580    #[test]
581    fn chain_result_serializes() {
582        let result = ChainResult {
583            steps: vec![StepResult {
584                name: "step1".into(),
585                output: "done".into(),
586                success: true,
587                cost_microdollars: 1000,
588                retries_used: 0,
589                skipped: false,
590            }],
591            final_output: "done".into(),
592            total_cost_microdollars: 1000,
593            success: true,
594        };
595
596        let json = serde_json::to_string(&result).unwrap();
597        assert!(json.contains("step1"));
598    }
599
600    #[test]
601    fn step_failure_policy_defaults() {
602        let policy = StepFailurePolicy::default();
603        assert_eq!(policy.retries, 0);
604        assert!(policy.recovery_prompt.is_none());
605    }
606
607    #[test]
608    fn chain_options_defaults() {
609        let opts = ChainOptions::default();
610        assert!(opts.tags.is_empty());
611        assert_eq!(opts.isolation, ChainIsolation::Worktree);
612    }
613
614    #[test]
615    fn chain_isolation_serde_roundtrip() {
616        let worktree = ChainIsolation::Worktree;
617        let json = serde_json::to_string(&worktree).unwrap();
618        assert_eq!(json, r#""worktree""#);
619
620        let none = ChainIsolation::None;
621        let json = serde_json::to_string(&none).unwrap();
622        assert_eq!(json, r#""none""#);
623
624        let parsed: ChainIsolation = serde_json::from_str(r#""worktree""#).unwrap();
625        assert_eq!(parsed, ChainIsolation::Worktree);
626
627        let parsed: ChainIsolation = serde_json::from_str(r#""none""#).unwrap();
628        assert_eq!(parsed, ChainIsolation::None);
629    }
630
631    #[test]
632    fn chain_options_with_isolation_serializes() {
633        let opts = ChainOptions {
634            tags: vec!["test".into()],
635            isolation: ChainIsolation::Worktree,
636        };
637        let json = serde_json::to_string(&opts).unwrap();
638        let parsed: ChainOptions = serde_json::from_str(&json).unwrap();
639        assert_eq!(parsed.isolation, ChainIsolation::Worktree);
640        assert_eq!(parsed.tags, vec!["test"]);
641    }
642
643    #[test]
644    fn chain_progress_serializes_with_partial_output() {
645        let progress = ChainProgress {
646            total_steps: 3,
647            current_step: Some(1),
648            current_step_name: Some("implement".into()),
649            current_step_partial_output: Some("partial text".into()),
650            current_step_started_at: Some(1700000000),
651            completed_steps: vec![StepResult {
652                name: "plan".into(),
653                output: "planned".into(),
654                success: true,
655                cost_microdollars: 500,
656                retries_used: 0,
657                skipped: false,
658            }],
659            status: ChainStatus::Running,
660        };
661
662        let json = serde_json::to_string(&progress).unwrap();
663        assert!(json.contains("implement"));
664        assert!(json.contains("running"));
665        assert!(json.contains("partial text"));
666        assert!(json.contains("1700000000"));
667    }
668
669    #[test]
670    fn chain_progress_omits_none_fields() {
671        let progress = ChainProgress {
672            total_steps: 2,
673            current_step: None,
674            current_step_name: None,
675            current_step_partial_output: None,
676            current_step_started_at: None,
677            completed_steps: vec![],
678            status: ChainStatus::Completed,
679        };
680
681        let json = serde_json::to_string(&progress).unwrap();
682        assert!(!json.contains("current_step_partial_output"));
683        assert!(!json.contains("current_step_started_at"));
684    }
685
686    #[test]
687    fn chain_progress_empty_partial_output_when_step_starts() {
688        let progress = ChainProgress {
689            total_steps: 3,
690            current_step: Some(0),
691            current_step_name: Some("plan".into()),
692            current_step_partial_output: Some(String::new()),
693            current_step_started_at: Some(1700000000),
694            completed_steps: vec![],
695            status: ChainStatus::Running,
696        };
697
698        let json = serde_json::to_string(&progress).unwrap();
699        // Empty string is still serialized (not None).
700        assert!(json.contains("\"current_step_partial_output\":\"\""));
701    }
702
703    #[test]
704    fn cancelled_status_serializes() {
705        let progress = ChainProgress {
706            total_steps: 3,
707            current_step: None,
708            current_step_name: None,
709            current_step_partial_output: None,
710            current_step_started_at: None,
711            completed_steps: vec![
712                StepResult {
713                    name: "plan".into(),
714                    output: "planned".into(),
715                    success: true,
716                    cost_microdollars: 500,
717                    retries_used: 0,
718                    skipped: false,
719                },
720                StepResult {
721                    name: "implement".into(),
722                    output: String::new(),
723                    success: false,
724                    cost_microdollars: 0,
725                    retries_used: 0,
726                    skipped: true,
727                },
728                StepResult {
729                    name: "review".into(),
730                    output: String::new(),
731                    success: false,
732                    cost_microdollars: 0,
733                    retries_used: 0,
734                    skipped: true,
735                },
736            ],
737            status: ChainStatus::Cancelled,
738        };
739
740        let json = serde_json::to_string(&progress).unwrap();
741        assert!(json.contains("cancelled"));
742        assert!(json.contains("\"skipped\":true"));
743    }
744
745    #[test]
746    fn skipped_defaults_to_false_on_deserialize() {
747        let json =
748            r#"{"name":"s","output":"o","success":true,"cost_microdollars":0,"retries_used":0}"#;
749        let result: StepResult = serde_json::from_str(json).unwrap();
750        assert!(!result.skipped);
751    }
752
753    #[test]
754    fn extract_json_path_whole_output() {
755        let json = r#"{"a": 1}"#;
756        assert_eq!(extract_json_path(json, "."), Some(json.to_string()));
757        assert_eq!(extract_json_path(json, ""), Some(json.to_string()));
758    }
759
760    #[test]
761    fn extract_json_path_top_level_key() {
762        let json = r#"{"summary": "all good"}"#;
763        assert_eq!(
764            extract_json_path(json, "summary"),
765            Some("all good".to_string())
766        );
767    }
768
769    #[test]
770    fn extract_json_path_nested() {
771        let json = r#"{"result": {"count": 42}}"#;
772        assert_eq!(
773            extract_json_path(json, "result.count"),
774            Some("42".to_string())
775        );
776    }
777
778    #[test]
779    fn extract_json_path_not_json() {
780        assert_eq!(extract_json_path("not json", "key"), None);
781    }
782
783    #[test]
784    fn extract_json_path_missing_key() {
785        let json = r#"{"a": 1}"#;
786        assert_eq!(extract_json_path(json, "b"), None);
787    }
788
789    #[test]
790    fn expand_step_refs_substitutes() {
791        let mut ctx = HashMap::new();
792        ctx.insert("plan.summary".into(), "do stuff".into());
793        let text = "Based on {steps.plan.summary}, implement it.".to_string();
794        assert_eq!(
795            expand_step_refs(text, &ctx),
796            "Based on do stuff, implement it."
797        );
798    }
799
800    #[test]
801    fn expand_step_refs_unknown_left_as_is() {
802        let ctx = HashMap::new();
803        let text = "Use {steps.missing.var} here.".to_string();
804        assert_eq!(expand_step_refs(text.clone(), &ctx), text);
805    }
806
807    #[test]
808    fn chain_step_output_vars_defaults_empty() {
809        let json = r#"{"name":"s","action":{"type":"prompt","prompt":"hi"}}"#;
810        let step: ChainStep = serde_json::from_str(json).unwrap();
811        assert!(step.output_vars.is_empty());
812    }
813
814    #[test]
815    fn chain_step_serializes_output_vars() {
816        let mut vars = HashMap::new();
817        vars.insert("summary".into(), "result.summary".into());
818        let step = ChainStep {
819            name: "s".into(),
820            action: StepAction::Prompt {
821                prompt: "hi".into(),
822            },
823            config: None,
824            failure_policy: StepFailurePolicy::default(),
825            output_vars: vars,
826        };
827        let json = serde_json::to_string(&step).unwrap();
828        assert!(json.contains("output_vars"));
829        assert!(json.contains("result.summary"));
830
831        let parsed: ChainStep = serde_json::from_str(&json).unwrap();
832        assert_eq!(parsed.output_vars.get("summary").unwrap(), "result.summary");
833    }
834}