Skip to main content

claude_pool/
chain.rs

1//! Chain execution — sequential pipelines of tasks.
2//!
3//! A chain runs steps in order, feeding each step's output as context
4//! to the next. Steps can reference skills or use inline prompts.
5//!
6//! Chains can be run synchronously via [`execute_chain`] or submitted
7//! for async execution via [`Pool::submit_chain`](crate::Pool::submit_chain).
8
9use std::collections::HashMap;
10
11use serde::{Deserialize, Serialize};
12
13use crate::pool::Pool;
14use crate::skill::SkillRegistry;
15use crate::store::PoolStore;
16use crate::types::{SlotConfig, TaskId};
17
18/// A step in a chain pipeline.
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct ChainStep {
21    /// Step name (for logging and result tracking).
22    pub name: String,
23
24    /// Either an inline prompt or a skill reference.
25    pub action: StepAction,
26
27    /// Per-step config overrides (model, effort, etc.).
28    pub config: Option<SlotConfig>,
29
30    /// Failure policy for this step.
31    #[serde(default)]
32    pub failure_policy: StepFailurePolicy,
33}
34
35/// What a chain step does.
36#[derive(Debug, Clone, Serialize, Deserialize)]
37#[serde(tag = "type", rename_all = "snake_case")]
38pub enum StepAction {
39    /// Run an inline prompt. `{previous_output}` is replaced with
40    /// the output from the prior step.
41    Prompt {
42        /// The prompt template.
43        prompt: String,
44    },
45    /// Run a registered skill with the given arguments.
46    /// The special argument `_previous_output` is automatically set
47    /// to the output from the prior step.
48    Skill {
49        /// Skill name.
50        skill: String,
51        /// Skill arguments.
52        #[serde(default)]
53        arguments: HashMap<String, String>,
54    },
55}
56
57/// Per-step failure handling policy.
58#[derive(Debug, Clone, Default, Serialize, Deserialize)]
59pub struct StepFailurePolicy {
60    /// Number of retries before giving up or recovering (default: 0).
61    #[serde(default)]
62    pub retries: u32,
63    /// If set, run this prompt on failure instead of failing the chain.
64    /// `{error}` is replaced with the error message, `{previous_output}`
65    /// with the last successful step's output.
66    pub recovery_prompt: Option<String>,
67}
68
69/// Options for chain execution.
70#[derive(Debug, Clone, Default, Serialize, Deserialize)]
71pub struct ChainOptions {
72    /// Tags for the chain task (used when submitted async).
73    #[serde(default)]
74    pub tags: Vec<String>,
75}
76
77/// Result of a single chain step.
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct StepResult {
80    /// Step name.
81    pub name: String,
82    /// Output text from this step.
83    pub output: String,
84    /// Whether the step succeeded.
85    pub success: bool,
86    /// Cost in microdollars.
87    pub cost_microdollars: u64,
88    /// Number of retries used.
89    #[serde(default)]
90    pub retries_used: u32,
91}
92
93/// Result of a full chain execution.
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct ChainResult {
96    /// Per-step results in execution order.
97    pub steps: Vec<StepResult>,
98    /// Final output (from the last step).
99    pub final_output: String,
100    /// Total cost across all steps.
101    pub total_cost_microdollars: u64,
102    /// Whether all steps succeeded.
103    pub success: bool,
104}
105
106/// Progress of an in-flight chain.
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct ChainProgress {
109    /// Total number of steps.
110    pub total_steps: usize,
111    /// Index of the currently running step (0-based), or None if done.
112    pub current_step: Option<usize>,
113    /// Name of the currently running step.
114    pub current_step_name: Option<String>,
115    /// Completed step results so far.
116    pub completed_steps: Vec<StepResult>,
117    /// Overall status.
118    pub status: ChainStatus,
119}
120
121/// Status of a chain.
122#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
123#[serde(rename_all = "snake_case")]
124pub enum ChainStatus {
125    /// Chain is running.
126    Running,
127    /// All steps completed successfully.
128    Completed,
129    /// A step failed and the chain stopped.
130    Failed,
131}
132
133/// Execute a chain of steps against the pool.
134pub async fn execute_chain<S: PoolStore + 'static>(
135    pool: &Pool<S>,
136    skills: &SkillRegistry,
137    steps: &[ChainStep],
138) -> crate::Result<ChainResult> {
139    execute_chain_with_progress(pool, skills, steps, None).await
140}
141
142/// Execute a chain with optional progress tracking.
143///
144/// If `chain_task_id` is provided, intermediate progress is stored so callers
145/// can poll for status.
146pub async fn execute_chain_with_progress<S: PoolStore + 'static>(
147    pool: &Pool<S>,
148    skills: &SkillRegistry,
149    steps: &[ChainStep],
150    chain_task_id: Option<&TaskId>,
151) -> crate::Result<ChainResult> {
152    let mut step_results = Vec::with_capacity(steps.len());
153    let mut previous_output = String::new();
154    let mut total_cost = 0u64;
155
156    for (step_idx, step) in steps.iter().enumerate() {
157        // Update progress in the store if we have a task ID.
158        if let Some(task_id) = chain_task_id {
159            let progress = ChainProgress {
160                total_steps: steps.len(),
161                current_step: Some(step_idx),
162                current_step_name: Some(step.name.clone()),
163                completed_steps: step_results.clone(),
164                status: ChainStatus::Running,
165            };
166            pool.set_chain_progress(task_id, progress).await;
167        }
168
169        let prompt = render_step_prompt(step, &previous_output, skills)?;
170
171        let (step_result, step_cost) =
172            execute_step_with_retries(pool, step, &prompt, &previous_output, skills).await;
173
174        total_cost += step_cost;
175
176        match step_result {
177            Ok(result) => {
178                previous_output = result.output.clone();
179                step_results.push(result);
180
181                if !step_results.last().unwrap().success {
182                    update_chain_progress_final(
183                        pool,
184                        chain_task_id,
185                        steps.len(),
186                        &step_results,
187                        ChainStatus::Failed,
188                    )
189                    .await;
190                    return Ok(ChainResult {
191                        final_output: previous_output,
192                        steps: step_results,
193                        total_cost_microdollars: total_cost,
194                        success: false,
195                    });
196                }
197            }
198            Err(output) => {
199                step_results.push(StepResult {
200                    name: step.name.clone(),
201                    output: output.clone(),
202                    success: false,
203                    cost_microdollars: 0,
204                    retries_used: step.failure_policy.retries,
205                });
206                update_chain_progress_final(
207                    pool,
208                    chain_task_id,
209                    steps.len(),
210                    &step_results,
211                    ChainStatus::Failed,
212                )
213                .await;
214                return Ok(ChainResult {
215                    final_output: output,
216                    steps: step_results,
217                    total_cost_microdollars: total_cost,
218                    success: false,
219                });
220            }
221        }
222    }
223
224    update_chain_progress_final(
225        pool,
226        chain_task_id,
227        steps.len(),
228        &step_results,
229        ChainStatus::Completed,
230    )
231    .await;
232
233    Ok(ChainResult {
234        final_output: previous_output,
235        steps: step_results,
236        total_cost_microdollars: total_cost,
237        success: true,
238    })
239}
240
241/// Render the prompt for a step, substituting `{previous_output}`.
242fn render_step_prompt(
243    step: &ChainStep,
244    previous_output: &str,
245    skills: &SkillRegistry,
246) -> crate::Result<String> {
247    match &step.action {
248        StepAction::Prompt { prompt } => Ok(prompt.replace("{previous_output}", previous_output)),
249        StepAction::Skill { skill, arguments } => {
250            let skill_def = skills
251                .get(skill)
252                .ok_or_else(|| crate::Error::Store(format!("skill not found: {skill}")))?;
253            let mut args = arguments.clone();
254            if !previous_output.is_empty() {
255                args.entry("_previous_output".into())
256                    .or_insert(previous_output.to_string());
257            }
258            skill_def.render(&args)
259        }
260    }
261}
262
263/// Execute a step with retry and recovery support.
264///
265/// Returns `Ok(StepResult)` on success (or successful recovery), or
266/// `Err(error_message)` if all retries and recovery are exhausted.
267async fn execute_step_with_retries<S: PoolStore + 'static>(
268    pool: &Pool<S>,
269    step: &ChainStep,
270    initial_prompt: &str,
271    previous_output: &str,
272    skills: &SkillRegistry,
273) -> (std::result::Result<StepResult, String>, u64) {
274    let max_attempts = 1 + step.failure_policy.retries;
275    let mut total_cost = 0u64;
276    let mut last_error = String::new();
277
278    for attempt in 0..max_attempts {
279        let prompt = if attempt == 0 {
280            initial_prompt.to_string()
281        } else {
282            // Re-render the prompt for retries (same prompt, fresh attempt).
283            match render_step_prompt(step, previous_output, skills) {
284                Ok(p) => p,
285                Err(e) => return (Err(e.to_string()), total_cost),
286            }
287        };
288
289        match pool.run_with_config(&prompt, step.config.clone()).await {
290            Ok(task_result) => {
291                total_cost += task_result.cost_microdollars;
292                if task_result.success {
293                    return (
294                        Ok(StepResult {
295                            name: step.name.clone(),
296                            output: task_result.output,
297                            success: true,
298                            cost_microdollars: total_cost,
299                            retries_used: attempt,
300                        }),
301                        total_cost,
302                    );
303                }
304                // Task ran but reported failure.
305                last_error = task_result.output;
306            }
307            Err(e) => {
308                last_error = e.to_string();
309            }
310        }
311
312        tracing::warn!(
313            step = %step.name,
314            attempt = attempt + 1,
315            max_attempts,
316            "chain step failed, will retry"
317        );
318    }
319
320    // All retries exhausted. Try recovery prompt if configured.
321    if let Some(ref recovery_template) = step.failure_policy.recovery_prompt {
322        let recovery_prompt = recovery_template
323            .replace("{error}", &last_error)
324            .replace("{previous_output}", previous_output);
325
326        tracing::info!(step = %step.name, "attempting recovery prompt");
327
328        match pool
329            .run_with_config(&recovery_prompt, step.config.clone())
330            .await
331        {
332            Ok(task_result) => {
333                total_cost += task_result.cost_microdollars;
334                return (
335                    Ok(StepResult {
336                        name: step.name.clone(),
337                        output: task_result.output,
338                        success: task_result.success,
339                        cost_microdollars: total_cost,
340                        retries_used: max_attempts,
341                    }),
342                    total_cost,
343                );
344            }
345            Err(e) => {
346                last_error = e.to_string();
347            }
348        }
349    }
350
351    (Err(last_error), total_cost)
352}
353
354/// Update chain progress to a terminal state.
355async fn update_chain_progress_final<S: PoolStore + 'static>(
356    pool: &Pool<S>,
357    chain_task_id: Option<&TaskId>,
358    total_steps: usize,
359    completed_steps: &[StepResult],
360    status: ChainStatus,
361) {
362    if let Some(task_id) = chain_task_id {
363        let progress = ChainProgress {
364            total_steps,
365            current_step: None,
366            current_step_name: None,
367            completed_steps: completed_steps.to_vec(),
368            status,
369        };
370        pool.set_chain_progress(task_id, progress).await;
371    }
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377
378    #[test]
379    fn prompt_step_replaces_previous_output() {
380        let step = ChainStep {
381            name: "step1".into(),
382            action: StepAction::Prompt {
383                prompt: "Based on: {previous_output}\nDo more.".into(),
384            },
385            config: None,
386            failure_policy: StepFailurePolicy::default(),
387        };
388
389        if let StepAction::Prompt { prompt } = &step.action {
390            let rendered = prompt.replace("{previous_output}", "hello world");
391            assert_eq!(rendered, "Based on: hello world\nDo more.");
392        }
393    }
394
395    #[test]
396    fn chain_result_serializes() {
397        let result = ChainResult {
398            steps: vec![StepResult {
399                name: "step1".into(),
400                output: "done".into(),
401                success: true,
402                cost_microdollars: 1000,
403                retries_used: 0,
404            }],
405            final_output: "done".into(),
406            total_cost_microdollars: 1000,
407            success: true,
408        };
409
410        let json = serde_json::to_string(&result).unwrap();
411        assert!(json.contains("step1"));
412    }
413
414    #[test]
415    fn step_failure_policy_defaults() {
416        let policy = StepFailurePolicy::default();
417        assert_eq!(policy.retries, 0);
418        assert!(policy.recovery_prompt.is_none());
419    }
420
421    #[test]
422    fn chain_options_defaults() {
423        let opts = ChainOptions::default();
424        assert!(opts.tags.is_empty());
425    }
426
427    #[test]
428    fn chain_progress_serializes() {
429        let progress = ChainProgress {
430            total_steps: 3,
431            current_step: Some(1),
432            current_step_name: Some("implement".into()),
433            completed_steps: vec![StepResult {
434                name: "plan".into(),
435                output: "planned".into(),
436                success: true,
437                cost_microdollars: 500,
438                retries_used: 0,
439            }],
440            status: ChainStatus::Running,
441        };
442
443        let json = serde_json::to_string(&progress).unwrap();
444        assert!(json.contains("implement"));
445        assert!(json.contains("running"));
446    }
447}