use std::collections::HashMap;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use serde::{Deserialize, Serialize};
use crate::pool::Pool;
use crate::store::PoolStore;
use crate::types::{TaskId, TaskOverrides, TaskState};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChainStep {
pub name: String,
pub action: StepAction,
pub config: Option<TaskOverrides>,
#[serde(default)]
pub failure_policy: StepFailurePolicy,
#[serde(default)]
pub output_vars: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum StepAction {
Prompt {
prompt: String,
},
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct StepFailurePolicy {
#[serde(default)]
pub retries: u32,
pub recovery_prompt: Option<String>,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ChainIsolation {
None,
#[default]
Worktree,
Clone,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ChainOptions {
#[serde(default)]
pub tags: Vec<String>,
#[serde(default)]
pub isolation: ChainIsolation,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StepResult {
pub name: String,
pub output: String,
pub success: bool,
pub cost_microdollars: u64,
#[serde(default)]
pub retries_used: u32,
#[serde(default)]
pub skipped: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChainResult {
pub steps: Vec<StepResult>,
pub final_output: String,
pub total_cost_microdollars: u64,
pub success: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChainProgress {
pub total_steps: usize,
pub current_step: Option<usize>,
pub current_step_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub current_step_partial_output: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub current_step_started_at: Option<u64>,
pub completed_steps: Vec<StepResult>,
pub status: ChainStatus,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ChainStatus {
Running,
Completed,
Failed,
Cancelled,
}
pub type OnOutputChunk = Arc<dyn Fn(&str) + Send + Sync>;
fn extract_json_path(json_str: &str, path: &str) -> Option<String> {
if path == "." || path.is_empty() {
return Some(json_str.to_string());
}
let value: serde_json::Value = serde_json::from_str(json_str).ok()?;
let mut current = &value;
for key in path.split('.') {
current = current.get(key)?;
}
Some(match current {
serde_json::Value::String(s) => s.clone(),
other => other.to_string(),
})
}
fn expand_step_refs(mut text: String, step_context: &HashMap<String, String>) -> String {
for (key, value) in step_context {
text = text.replace(&format!("{{steps.{key}}}"), value);
}
text
}
fn unix_secs_now() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
pub async fn execute_chain<S: PoolStore + 'static>(
pool: &Pool<S>,
steps: &[ChainStep],
) -> crate::Result<ChainResult> {
execute_chain_with_progress(pool, steps, None, None).await
}
pub async fn execute_chain_with_progress<S: PoolStore + 'static>(
pool: &Pool<S>,
steps: &[ChainStep],
chain_task_id: Option<&TaskId>,
working_dir: Option<&std::path::Path>,
) -> crate::Result<ChainResult> {
let mut step_results = Vec::with_capacity(steps.len());
let mut previous_output = String::new();
let mut total_cost = 0u64;
let mut step_context: HashMap<String, String> = HashMap::new();
for (step_idx, step) in steps.iter().enumerate() {
if let Some(task_id) = chain_task_id
&& let Ok(Some(task)) = pool.store().get_task(task_id).await
&& task.state == TaskState::Cancelled
{
for s in &steps[step_idx..] {
step_results.push(StepResult {
name: s.name.clone(),
output: String::new(),
success: false,
cost_microdollars: 0,
retries_used: 0,
skipped: true,
});
}
update_chain_progress_final(
pool,
Some(task_id),
steps.len(),
&step_results,
ChainStatus::Cancelled,
)
.await;
return Ok(ChainResult {
final_output: previous_output,
steps: step_results,
total_cost_microdollars: total_cost,
success: false,
});
}
if let Some(task_id) = chain_task_id {
let progress = ChainProgress {
total_steps: steps.len(),
current_step: Some(step_idx),
current_step_name: Some(step.name.clone()),
current_step_partial_output: Some(String::new()),
current_step_started_at: Some(unix_secs_now()),
completed_steps: step_results.clone(),
status: ChainStatus::Running,
};
pool.set_chain_progress(task_id, progress).await;
}
let prompt = render_step_prompt(step, &previous_output, &step_context)?;
let on_output: Option<OnOutputChunk> = chain_task_id.map(|tid| {
let pool = pool.clone();
let tid = tid.clone();
Arc::new(move |chunk: &str| {
pool.append_chain_partial_output(&tid, chunk);
}) as OnOutputChunk
});
let (step_result, step_cost) = execute_step_with_retries(
pool,
step,
&prompt,
&previous_output,
on_output.clone(),
working_dir,
&step_context,
)
.await;
total_cost += step_cost;
match step_result {
Ok(result) => {
previous_output = result.output.clone();
if result.success {
for (var_name, path) in &step.output_vars {
match extract_json_path(&result.output, path) {
Some(extracted) => {
step_context
.insert(format!("{}.{}", step.name, var_name), extracted);
}
None => {
tracing::warn!(
step = %step.name,
var = %var_name,
path = %path,
"output_var extraction failed (output not JSON or path not found)"
);
}
}
}
}
step_results.push(result);
if !step_results.last().unwrap().success {
update_chain_progress_final(
pool,
chain_task_id,
steps.len(),
&step_results,
ChainStatus::Failed,
)
.await;
return Ok(ChainResult {
final_output: previous_output,
steps: step_results,
total_cost_microdollars: total_cost,
success: false,
});
}
}
Err(output) => {
step_results.push(StepResult {
name: step.name.clone(),
output: output.clone(),
success: false,
cost_microdollars: 0,
retries_used: step.failure_policy.retries,
skipped: false,
});
update_chain_progress_final(
pool,
chain_task_id,
steps.len(),
&step_results,
ChainStatus::Failed,
)
.await;
return Ok(ChainResult {
final_output: output,
steps: step_results,
total_cost_microdollars: total_cost,
success: false,
});
}
}
}
update_chain_progress_final(
pool,
chain_task_id,
steps.len(),
&step_results,
ChainStatus::Completed,
)
.await;
Ok(ChainResult {
final_output: previous_output,
steps: step_results,
total_cost_microdollars: total_cost,
success: true,
})
}
fn render_step_prompt(
step: &ChainStep,
previous_output: &str,
step_context: &HashMap<String, String>,
) -> crate::Result<String> {
let StepAction::Prompt { prompt } = &step.action;
let rendered = prompt.replace("{previous_output}", previous_output);
Ok(expand_step_refs(rendered, step_context))
}
async fn execute_step_with_retries<S: PoolStore + 'static>(
pool: &Pool<S>,
step: &ChainStep,
initial_prompt: &str,
previous_output: &str,
on_output: Option<OnOutputChunk>,
working_dir: Option<&std::path::Path>,
step_context: &HashMap<String, String>,
) -> (std::result::Result<StepResult, String>, u64) {
let max_attempts = 1 + step.failure_policy.retries;
let mut total_cost = 0u64;
let mut last_error = String::new();
for attempt in 0..max_attempts {
let prompt = if attempt == 0 {
initial_prompt.to_string()
} else {
match render_step_prompt(step, previous_output, step_context) {
Ok(p) => p,
Err(e) => return (Err(e.to_string()), total_cost),
}
};
let result = pool
.run_with_config_streaming(
&prompt,
step.config.clone(),
on_output.clone(),
working_dir.map(|p| p.to_path_buf()),
)
.await;
match result {
Ok(task_result) => {
total_cost += task_result.cost_microdollars;
if task_result.success {
return (
Ok(StepResult {
name: step.name.clone(),
output: task_result.output,
success: true,
cost_microdollars: total_cost,
retries_used: attempt,
skipped: false,
}),
total_cost,
);
}
last_error = task_result.output;
}
Err(e) => {
last_error = e.to_string();
}
}
tracing::warn!(
step = %step.name,
attempt = attempt + 1,
max_attempts,
"chain step failed, will retry"
);
}
if let Some(ref recovery_template) = step.failure_policy.recovery_prompt {
let recovery_prompt = expand_step_refs(
recovery_template
.replace("{error}", &last_error)
.replace("{previous_output}", previous_output),
step_context,
);
tracing::info!(step = %step.name, "attempting recovery prompt");
let result = pool
.run_with_config_streaming(
&recovery_prompt,
step.config.clone(),
on_output,
working_dir.map(|p| p.to_path_buf()),
)
.await;
match result {
Ok(task_result) => {
total_cost += task_result.cost_microdollars;
return (
Ok(StepResult {
name: step.name.clone(),
output: task_result.output,
success: task_result.success,
cost_microdollars: total_cost,
retries_used: max_attempts,
skipped: false,
}),
total_cost,
);
}
Err(e) => {
last_error = e.to_string();
}
}
}
(Err(last_error), total_cost)
}
async fn update_chain_progress_final<S: PoolStore + 'static>(
pool: &Pool<S>,
chain_task_id: Option<&TaskId>,
total_steps: usize,
completed_steps: &[StepResult],
status: ChainStatus,
) {
if let Some(task_id) = chain_task_id {
let progress = ChainProgress {
total_steps,
current_step: None,
current_step_name: None,
current_step_partial_output: None,
current_step_started_at: None,
completed_steps: completed_steps.to_vec(),
status,
};
pool.set_chain_progress(task_id, progress).await;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn prompt_step_replaces_previous_output() {
let step = ChainStep {
name: "step1".into(),
action: StepAction::Prompt {
prompt: "Based on: {previous_output}\nDo more.".into(),
},
config: None,
failure_policy: StepFailurePolicy::default(),
output_vars: Default::default(),
};
let StepAction::Prompt { prompt } = &step.action;
let rendered = prompt.replace("{previous_output}", "hello world");
assert_eq!(rendered, "Based on: hello world\nDo more.");
}
#[test]
fn chain_result_serializes() {
let result = ChainResult {
steps: vec![StepResult {
name: "step1".into(),
output: "done".into(),
success: true,
cost_microdollars: 1000,
retries_used: 0,
skipped: false,
}],
final_output: "done".into(),
total_cost_microdollars: 1000,
success: true,
};
let json = serde_json::to_string(&result).unwrap();
assert!(json.contains("step1"));
}
#[test]
fn step_failure_policy_defaults() {
let policy = StepFailurePolicy::default();
assert_eq!(policy.retries, 0);
assert!(policy.recovery_prompt.is_none());
}
#[test]
fn chain_options_defaults() {
let opts = ChainOptions::default();
assert!(opts.tags.is_empty());
assert_eq!(opts.isolation, ChainIsolation::Worktree);
}
#[test]
fn chain_isolation_serde_roundtrip() {
let worktree = ChainIsolation::Worktree;
let json = serde_json::to_string(&worktree).unwrap();
assert_eq!(json, r#""worktree""#);
let none = ChainIsolation::None;
let json = serde_json::to_string(&none).unwrap();
assert_eq!(json, r#""none""#);
let parsed: ChainIsolation = serde_json::from_str(r#""worktree""#).unwrap();
assert_eq!(parsed, ChainIsolation::Worktree);
let parsed: ChainIsolation = serde_json::from_str(r#""none""#).unwrap();
assert_eq!(parsed, ChainIsolation::None);
}
#[test]
fn chain_options_with_isolation_serializes() {
let opts = ChainOptions {
tags: vec!["test".into()],
isolation: ChainIsolation::Worktree,
};
let json = serde_json::to_string(&opts).unwrap();
let parsed: ChainOptions = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.isolation, ChainIsolation::Worktree);
assert_eq!(parsed.tags, vec!["test"]);
}
#[test]
fn chain_progress_serializes_with_partial_output() {
let progress = ChainProgress {
total_steps: 3,
current_step: Some(1),
current_step_name: Some("implement".into()),
current_step_partial_output: Some("partial text".into()),
current_step_started_at: Some(1700000000),
completed_steps: vec![StepResult {
name: "plan".into(),
output: "planned".into(),
success: true,
cost_microdollars: 500,
retries_used: 0,
skipped: false,
}],
status: ChainStatus::Running,
};
let json = serde_json::to_string(&progress).unwrap();
assert!(json.contains("implement"));
assert!(json.contains("running"));
assert!(json.contains("partial text"));
assert!(json.contains("1700000000"));
}
#[test]
fn chain_progress_omits_none_fields() {
let progress = ChainProgress {
total_steps: 2,
current_step: None,
current_step_name: None,
current_step_partial_output: None,
current_step_started_at: None,
completed_steps: vec![],
status: ChainStatus::Completed,
};
let json = serde_json::to_string(&progress).unwrap();
assert!(!json.contains("current_step_partial_output"));
assert!(!json.contains("current_step_started_at"));
}
#[test]
fn chain_progress_empty_partial_output_when_step_starts() {
let progress = ChainProgress {
total_steps: 3,
current_step: Some(0),
current_step_name: Some("plan".into()),
current_step_partial_output: Some(String::new()),
current_step_started_at: Some(1700000000),
completed_steps: vec![],
status: ChainStatus::Running,
};
let json = serde_json::to_string(&progress).unwrap();
assert!(json.contains("\"current_step_partial_output\":\"\""));
}
#[test]
fn cancelled_status_serializes() {
let progress = ChainProgress {
total_steps: 3,
current_step: None,
current_step_name: None,
current_step_partial_output: None,
current_step_started_at: None,
completed_steps: vec![
StepResult {
name: "plan".into(),
output: "planned".into(),
success: true,
cost_microdollars: 500,
retries_used: 0,
skipped: false,
},
StepResult {
name: "implement".into(),
output: String::new(),
success: false,
cost_microdollars: 0,
retries_used: 0,
skipped: true,
},
StepResult {
name: "review".into(),
output: String::new(),
success: false,
cost_microdollars: 0,
retries_used: 0,
skipped: true,
},
],
status: ChainStatus::Cancelled,
};
let json = serde_json::to_string(&progress).unwrap();
assert!(json.contains("cancelled"));
assert!(json.contains("\"skipped\":true"));
}
#[test]
fn skipped_defaults_to_false_on_deserialize() {
let json =
r#"{"name":"s","output":"o","success":true,"cost_microdollars":0,"retries_used":0}"#;
let result: StepResult = serde_json::from_str(json).unwrap();
assert!(!result.skipped);
}
#[test]
fn extract_json_path_whole_output() {
let json = r#"{"a": 1}"#;
assert_eq!(extract_json_path(json, "."), Some(json.to_string()));
assert_eq!(extract_json_path(json, ""), Some(json.to_string()));
}
#[test]
fn extract_json_path_top_level_key() {
let json = r#"{"summary": "all good"}"#;
assert_eq!(
extract_json_path(json, "summary"),
Some("all good".to_string())
);
}
#[test]
fn extract_json_path_nested() {
let json = r#"{"result": {"count": 42}}"#;
assert_eq!(
extract_json_path(json, "result.count"),
Some("42".to_string())
);
}
#[test]
fn extract_json_path_not_json() {
assert_eq!(extract_json_path("not json", "key"), None);
}
#[test]
fn extract_json_path_missing_key() {
let json = r#"{"a": 1}"#;
assert_eq!(extract_json_path(json, "b"), None);
}
#[test]
fn expand_step_refs_substitutes() {
let mut ctx = HashMap::new();
ctx.insert("plan.summary".into(), "do stuff".into());
let text = "Based on {steps.plan.summary}, implement it.".to_string();
assert_eq!(
expand_step_refs(text, &ctx),
"Based on do stuff, implement it."
);
}
#[test]
fn expand_step_refs_unknown_left_as_is() {
let ctx = HashMap::new();
let text = "Use {steps.missing.var} here.".to_string();
assert_eq!(expand_step_refs(text.clone(), &ctx), text);
}
#[test]
fn chain_step_output_vars_defaults_empty() {
let json = r#"{"name":"s","action":{"type":"prompt","prompt":"hi"}}"#;
let step: ChainStep = serde_json::from_str(json).unwrap();
assert!(step.output_vars.is_empty());
}
#[test]
fn chain_step_serializes_output_vars() {
let mut vars = HashMap::new();
vars.insert("summary".into(), "result.summary".into());
let step = ChainStep {
name: "s".into(),
action: StepAction::Prompt {
prompt: "hi".into(),
},
config: None,
failure_policy: StepFailurePolicy::default(),
output_vars: vars,
};
let json = serde_json::to_string(&step).unwrap();
assert!(json.contains("output_vars"));
assert!(json.contains("result.summary"));
let parsed: ChainStep = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.output_vars.get("summary").unwrap(), "result.summary");
}
}