1pub mod stream;
2
3use std::{path::Path, time::Duration};
4
5use anyhow::{Context, Result};
6use tokio::{io::AsyncWriteExt, process::Command};
7use tracing::warn;
8
9use self::stream::parse_stream;
10use crate::agents::AgentInvocation;
11
12const MAX_RETRIES: u32 = 2;
13const RETRY_DELAYS: [Duration; 2] = [Duration::from_secs(5), Duration::from_secs(15)];
14const TRANSIENT_PATTERNS: &[&str] = &[
15 "connection reset",
16 "connection refused",
17 "timed out",
18 "timeout",
19 "rate limit",
20 "rate_limit",
21 "http 502",
22 "http 503",
23 "http 429",
24 " 502 ",
25 " 503 ",
26 " 429 ",
27 "(502)",
28 "(503)",
29 "(429)",
30 "overloaded",
31 "econnrefused",
32];
33
34#[derive(Debug, Clone)]
36pub struct AgentResult {
37 pub cost_usd: f64,
38 pub duration: Duration,
39 pub turns: u32,
40 pub output: String,
41 pub session_id: String,
42 pub success: bool,
43}
44
45#[derive(Debug, Clone)]
47pub struct CommandOutput {
48 pub stdout: String,
49 pub stderr: String,
50 pub success: bool,
51}
52
53#[cfg_attr(test, mockall::automock)]
58pub trait CommandRunner: Send + Sync {
59 fn run_claude(
60 &self,
61 prompt: &str,
62 allowed_tools: &[String],
63 working_dir: &Path,
64 max_turns: Option<u32>,
65 ) -> impl std::future::Future<Output = Result<AgentResult>> + Send;
66
67 fn run_gh(
68 &self,
69 args: &[String],
70 working_dir: &Path,
71 ) -> impl std::future::Future<Output = Result<CommandOutput>> + Send;
72}
73
74pub struct RealCommandRunner;
76
77impl CommandRunner for RealCommandRunner {
78 async fn run_claude(
79 &self,
80 prompt: &str,
81 allowed_tools: &[String],
82 working_dir: &Path,
83 max_turns: Option<u32>,
84 ) -> Result<AgentResult> {
85 let tools_arg = allowed_tools.join(",");
86
87 let mut cmd = Command::new("claude");
88 cmd.args(["-p", "--verbose", "--output-format", "stream-json"])
89 .args(["--allowedTools", &tools_arg]);
90
91 if let Some(turns) = max_turns {
92 cmd.args(["--max-turns", &turns.to_string()]);
93 }
94
95 let mut child = cmd
96 .current_dir(working_dir)
97 .stdin(std::process::Stdio::piped())
98 .stdout(std::process::Stdio::piped())
99 .stderr(std::process::Stdio::piped())
100 .kill_on_drop(true)
101 .spawn()
102 .context("spawning claude")?;
103
104 let mut stdin = child.stdin.take().context("capturing claude stdin")?;
106 stdin.write_all(prompt.as_bytes()).await.context("writing prompt to claude stdin")?;
107 stdin.shutdown().await.context("closing claude stdin")?;
108 drop(stdin);
109
110 let stdout = child.stdout.take().context("capturing claude stdout")?;
111 let result = parse_stream(stdout).await?;
112 let status = child.wait().await.context("waiting for claude")?;
113
114 Ok(AgentResult {
115 cost_usd: result.cost_usd,
116 duration: result.duration,
117 turns: result.turns,
118 output: result.output,
119 session_id: result.session_id,
120 success: status.success(),
121 })
122 }
123
124 async fn run_gh(&self, args: &[String], working_dir: &Path) -> Result<CommandOutput> {
125 let output = Command::new("gh")
126 .args(args)
127 .current_dir(working_dir)
128 .kill_on_drop(true)
129 .output()
130 .await
131 .context("spawning gh")?;
132
133 Ok(CommandOutput {
134 stdout: String::from_utf8_lossy(&output.stdout).to_string(),
135 stderr: String::from_utf8_lossy(&output.stderr).to_string(),
136 success: output.status.success(),
137 })
138 }
139}
140
141pub fn is_transient_error(msg: &str) -> bool {
143 let lower = msg.to_lowercase();
144 TRANSIENT_PATTERNS.iter().any(|p| lower.contains(p))
145}
146
147pub async fn run_with_retry<R: CommandRunner>(
152 runner: &R,
153 invocation: &AgentInvocation,
154) -> Result<AgentResult> {
155 let mut last_err = None;
156 for attempt in 0..=MAX_RETRIES {
157 match crate::agents::invoke_agent(runner, invocation).await {
158 Ok(result) => return Ok(result),
159 Err(e) => {
160 let msg = format!("{e:#}");
161 if attempt < MAX_RETRIES && is_transient_error(&msg) {
162 let delay = RETRY_DELAYS[attempt as usize];
163 warn!(
164 attempt = attempt + 1,
165 max = MAX_RETRIES,
166 delay_secs = delay.as_secs(),
167 error = %msg,
168 "transient agent failure, retrying"
169 );
170 tokio::time::sleep(delay).await;
171 last_err = Some(e);
172 } else {
173 return Err(e);
174 }
175 }
176 }
177 }
178 Err(last_err.unwrap_or_else(|| anyhow::anyhow!("agent invocation failed after retries")))
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184
185 #[test]
186 fn agent_result_is_send_sync() {
187 fn assert_send_sync<T: Send + Sync>() {}
188 assert_send_sync::<AgentResult>();
189 assert_send_sync::<CommandOutput>();
190 }
191
192 #[test]
193 fn real_command_runner_is_send_sync() {
194 fn assert_send_sync<T: Send + Sync>() {}
195 assert_send_sync::<RealCommandRunner>();
196 }
197
198 #[test]
199 fn transient_error_detection() {
200 assert!(is_transient_error("connection reset by peer"));
201 assert!(is_transient_error("Connection Refused"));
202 assert!(is_transient_error("request timed out after 30s"));
203 assert!(is_transient_error("rate limit exceeded"));
204 assert!(is_transient_error("rate_limit_error"));
205 assert!(is_transient_error("HTTP 502 Bad Gateway"));
206 assert!(is_transient_error("Service Unavailable (503)"));
207 assert!(is_transient_error("HTTP 429 Too Many Requests"));
208 assert!(is_transient_error("server is overloaded"));
209 assert!(is_transient_error("ECONNREFUSED 127.0.0.1:443"));
210 }
211
212 #[test]
213 fn non_transient_errors_not_matched() {
214 assert!(!is_transient_error("file not found"));
215 assert!(!is_transient_error("permission denied"));
216 assert!(!is_transient_error("invalid JSON in response"));
217 assert!(!is_transient_error("authentication failed"));
218 assert!(!is_transient_error("listening on port 5029"));
220 assert!(!is_transient_error("record id 4291 not found"));
221 assert!(!is_transient_error(""));
222 }
223}