1use std::collections::HashMap;
10
11use serde::{Deserialize, Serialize};
12
13use crate::pool::Pool;
14use crate::skill::SkillRegistry;
15use crate::store::PoolStore;
16use crate::types::{SlotConfig, TaskId};
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct ChainStep {
21 pub name: String,
23
24 pub action: StepAction,
26
27 pub config: Option<SlotConfig>,
29
30 #[serde(default)]
32 pub failure_policy: StepFailurePolicy,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37#[serde(tag = "type", rename_all = "snake_case")]
38pub enum StepAction {
39 Prompt {
42 prompt: String,
44 },
45 Skill {
49 skill: String,
51 #[serde(default)]
53 arguments: HashMap<String, String>,
54 },
55}
56
57#[derive(Debug, Clone, Default, Serialize, Deserialize)]
59pub struct StepFailurePolicy {
60 #[serde(default)]
62 pub retries: u32,
63 pub recovery_prompt: Option<String>,
67}
68
69#[derive(Debug, Clone, Default, Serialize, Deserialize)]
71pub struct ChainOptions {
72 #[serde(default)]
74 pub tags: Vec<String>,
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct StepResult {
80 pub name: String,
82 pub output: String,
84 pub success: bool,
86 pub cost_microdollars: u64,
88 #[serde(default)]
90 pub retries_used: u32,
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct ChainResult {
96 pub steps: Vec<StepResult>,
98 pub final_output: String,
100 pub total_cost_microdollars: u64,
102 pub success: bool,
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct ChainProgress {
109 pub total_steps: usize,
111 pub current_step: Option<usize>,
113 pub current_step_name: Option<String>,
115 pub completed_steps: Vec<StepResult>,
117 pub status: ChainStatus,
119}
120
121#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
123#[serde(rename_all = "snake_case")]
124pub enum ChainStatus {
125 Running,
127 Completed,
129 Failed,
131}
132
133pub async fn execute_chain<S: PoolStore + 'static>(
135 pool: &Pool<S>,
136 skills: &SkillRegistry,
137 steps: &[ChainStep],
138) -> crate::Result<ChainResult> {
139 execute_chain_with_progress(pool, skills, steps, None).await
140}
141
142pub async fn execute_chain_with_progress<S: PoolStore + 'static>(
147 pool: &Pool<S>,
148 skills: &SkillRegistry,
149 steps: &[ChainStep],
150 chain_task_id: Option<&TaskId>,
151) -> crate::Result<ChainResult> {
152 let mut step_results = Vec::with_capacity(steps.len());
153 let mut previous_output = String::new();
154 let mut total_cost = 0u64;
155
156 for (step_idx, step) in steps.iter().enumerate() {
157 if let Some(task_id) = chain_task_id {
159 let progress = ChainProgress {
160 total_steps: steps.len(),
161 current_step: Some(step_idx),
162 current_step_name: Some(step.name.clone()),
163 completed_steps: step_results.clone(),
164 status: ChainStatus::Running,
165 };
166 pool.set_chain_progress(task_id, progress).await;
167 }
168
169 let prompt = render_step_prompt(step, &previous_output, skills)?;
170
171 let (step_result, step_cost) =
172 execute_step_with_retries(pool, step, &prompt, &previous_output, skills).await;
173
174 total_cost += step_cost;
175
176 match step_result {
177 Ok(result) => {
178 previous_output = result.output.clone();
179 step_results.push(result);
180
181 if !step_results.last().unwrap().success {
182 update_chain_progress_final(
183 pool,
184 chain_task_id,
185 steps.len(),
186 &step_results,
187 ChainStatus::Failed,
188 )
189 .await;
190 return Ok(ChainResult {
191 final_output: previous_output,
192 steps: step_results,
193 total_cost_microdollars: total_cost,
194 success: false,
195 });
196 }
197 }
198 Err(output) => {
199 step_results.push(StepResult {
200 name: step.name.clone(),
201 output: output.clone(),
202 success: false,
203 cost_microdollars: 0,
204 retries_used: step.failure_policy.retries,
205 });
206 update_chain_progress_final(
207 pool,
208 chain_task_id,
209 steps.len(),
210 &step_results,
211 ChainStatus::Failed,
212 )
213 .await;
214 return Ok(ChainResult {
215 final_output: output,
216 steps: step_results,
217 total_cost_microdollars: total_cost,
218 success: false,
219 });
220 }
221 }
222 }
223
224 update_chain_progress_final(
225 pool,
226 chain_task_id,
227 steps.len(),
228 &step_results,
229 ChainStatus::Completed,
230 )
231 .await;
232
233 Ok(ChainResult {
234 final_output: previous_output,
235 steps: step_results,
236 total_cost_microdollars: total_cost,
237 success: true,
238 })
239}
240
241fn render_step_prompt(
243 step: &ChainStep,
244 previous_output: &str,
245 skills: &SkillRegistry,
246) -> crate::Result<String> {
247 match &step.action {
248 StepAction::Prompt { prompt } => Ok(prompt.replace("{previous_output}", previous_output)),
249 StepAction::Skill { skill, arguments } => {
250 let skill_def = skills
251 .get(skill)
252 .ok_or_else(|| crate::Error::Store(format!("skill not found: {skill}")))?;
253 let mut args = arguments.clone();
254 if !previous_output.is_empty() {
255 args.entry("_previous_output".into())
256 .or_insert(previous_output.to_string());
257 }
258 skill_def.render(&args)
259 }
260 }
261}
262
263async fn execute_step_with_retries<S: PoolStore + 'static>(
268 pool: &Pool<S>,
269 step: &ChainStep,
270 initial_prompt: &str,
271 previous_output: &str,
272 skills: &SkillRegistry,
273) -> (std::result::Result<StepResult, String>, u64) {
274 let max_attempts = 1 + step.failure_policy.retries;
275 let mut total_cost = 0u64;
276 let mut last_error = String::new();
277
278 for attempt in 0..max_attempts {
279 let prompt = if attempt == 0 {
280 initial_prompt.to_string()
281 } else {
282 match render_step_prompt(step, previous_output, skills) {
284 Ok(p) => p,
285 Err(e) => return (Err(e.to_string()), total_cost),
286 }
287 };
288
289 match pool.run_with_config(&prompt, step.config.clone()).await {
290 Ok(task_result) => {
291 total_cost += task_result.cost_microdollars;
292 if task_result.success {
293 return (
294 Ok(StepResult {
295 name: step.name.clone(),
296 output: task_result.output,
297 success: true,
298 cost_microdollars: total_cost,
299 retries_used: attempt,
300 }),
301 total_cost,
302 );
303 }
304 last_error = task_result.output;
306 }
307 Err(e) => {
308 last_error = e.to_string();
309 }
310 }
311
312 tracing::warn!(
313 step = %step.name,
314 attempt = attempt + 1,
315 max_attempts,
316 "chain step failed, will retry"
317 );
318 }
319
320 if let Some(ref recovery_template) = step.failure_policy.recovery_prompt {
322 let recovery_prompt = recovery_template
323 .replace("{error}", &last_error)
324 .replace("{previous_output}", previous_output);
325
326 tracing::info!(step = %step.name, "attempting recovery prompt");
327
328 match pool
329 .run_with_config(&recovery_prompt, step.config.clone())
330 .await
331 {
332 Ok(task_result) => {
333 total_cost += task_result.cost_microdollars;
334 return (
335 Ok(StepResult {
336 name: step.name.clone(),
337 output: task_result.output,
338 success: task_result.success,
339 cost_microdollars: total_cost,
340 retries_used: max_attempts,
341 }),
342 total_cost,
343 );
344 }
345 Err(e) => {
346 last_error = e.to_string();
347 }
348 }
349 }
350
351 (Err(last_error), total_cost)
352}
353
354async fn update_chain_progress_final<S: PoolStore + 'static>(
356 pool: &Pool<S>,
357 chain_task_id: Option<&TaskId>,
358 total_steps: usize,
359 completed_steps: &[StepResult],
360 status: ChainStatus,
361) {
362 if let Some(task_id) = chain_task_id {
363 let progress = ChainProgress {
364 total_steps,
365 current_step: None,
366 current_step_name: None,
367 completed_steps: completed_steps.to_vec(),
368 status,
369 };
370 pool.set_chain_progress(task_id, progress).await;
371 }
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377
378 #[test]
379 fn prompt_step_replaces_previous_output() {
380 let step = ChainStep {
381 name: "step1".into(),
382 action: StepAction::Prompt {
383 prompt: "Based on: {previous_output}\nDo more.".into(),
384 },
385 config: None,
386 failure_policy: StepFailurePolicy::default(),
387 };
388
389 if let StepAction::Prompt { prompt } = &step.action {
390 let rendered = prompt.replace("{previous_output}", "hello world");
391 assert_eq!(rendered, "Based on: hello world\nDo more.");
392 }
393 }
394
395 #[test]
396 fn chain_result_serializes() {
397 let result = ChainResult {
398 steps: vec![StepResult {
399 name: "step1".into(),
400 output: "done".into(),
401 success: true,
402 cost_microdollars: 1000,
403 retries_used: 0,
404 }],
405 final_output: "done".into(),
406 total_cost_microdollars: 1000,
407 success: true,
408 };
409
410 let json = serde_json::to_string(&result).unwrap();
411 assert!(json.contains("step1"));
412 }
413
414 #[test]
415 fn step_failure_policy_defaults() {
416 let policy = StepFailurePolicy::default();
417 assert_eq!(policy.retries, 0);
418 assert!(policy.recovery_prompt.is_none());
419 }
420
421 #[test]
422 fn chain_options_defaults() {
423 let opts = ChainOptions::default();
424 assert!(opts.tags.is_empty());
425 }
426
427 #[test]
428 fn chain_progress_serializes() {
429 let progress = ChainProgress {
430 total_steps: 3,
431 current_step: Some(1),
432 current_step_name: Some("implement".into()),
433 completed_steps: vec![StepResult {
434 name: "plan".into(),
435 output: "planned".into(),
436 success: true,
437 cost_microdollars: 500,
438 retries_used: 0,
439 }],
440 status: ChainStatus::Running,
441 };
442
443 let json = serde_json::to_string(&progress).unwrap();
444 assert!(json.contains("implement"));
445 assert!(json.contains("running"));
446 }
447}