1pub mod stream;
2
3use std::{path::Path, time::Duration};
4
5use anyhow::{Context, Result};
6use tokio::{io::AsyncWriteExt, process::Command};
7use tracing::warn;
8
9use self::stream::parse_stream;
10use crate::agents::AgentInvocation;
11
12const MAX_RETRIES: u32 = 2;
13const RETRY_DELAYS: [Duration; 2] = [Duration::from_secs(5), Duration::from_secs(15)];
14const TRANSIENT_PATTERNS: &[&str] = &[
15 "connection reset",
16 "connection refused",
17 "timed out",
18 "timeout",
19 "rate limit",
20 "rate_limit",
21 "http 502",
22 "http 503",
23 "http 429",
24 " 502 ",
25 " 503 ",
26 " 429 ",
27 "(502)",
28 "(503)",
29 "(429)",
30 "overloaded",
31 "econnrefused",
32];
33
34#[derive(Debug, Clone)]
36pub struct AgentResult {
37 pub cost_usd: f64,
38 pub duration: Duration,
39 pub turns: u32,
40 pub output: String,
41 pub session_id: String,
42 pub success: bool,
43}
44
45#[derive(Debug, Clone)]
47pub struct CommandOutput {
48 pub stdout: String,
49 pub stderr: String,
50 pub success: bool,
51}
52
53#[cfg_attr(test, mockall::automock)]
58pub trait CommandRunner: Send + Sync {
59 fn run_claude(
60 &self,
61 prompt: &str,
62 allowed_tools: &[String],
63 working_dir: &Path,
64 max_turns: Option<u32>,
65 model: Option<String>,
66 ) -> impl std::future::Future<Output = Result<AgentResult>> + Send;
67
68 fn run_gh(
69 &self,
70 args: &[String],
71 working_dir: &Path,
72 ) -> impl std::future::Future<Output = Result<CommandOutput>> + Send;
73}
74
75pub struct RealCommandRunner;
77
78impl CommandRunner for RealCommandRunner {
79 async fn run_claude(
80 &self,
81 prompt: &str,
82 allowed_tools: &[String],
83 working_dir: &Path,
84 max_turns: Option<u32>,
85 model: Option<String>,
86 ) -> Result<AgentResult> {
87 let tools_arg = allowed_tools.join(",");
88
89 let mut cmd = Command::new("claude");
90 cmd.args(["-p", "--verbose", "--output-format", "stream-json"])
91 .args(["--allowedTools", &tools_arg]);
92
93 if let Some(ref model) = model {
94 cmd.args(["--model", model]);
95 }
96
97 if let Some(turns) = max_turns {
98 cmd.args(["--max-turns", &turns.to_string()]);
99 }
100
101 let mut child = cmd
102 .current_dir(working_dir)
103 .stdin(std::process::Stdio::piped())
104 .stdout(std::process::Stdio::piped())
105 .stderr(std::process::Stdio::piped())
106 .kill_on_drop(true)
107 .spawn()
108 .context("spawning claude")?;
109
110 let mut stdin = child.stdin.take().context("capturing claude stdin")?;
112 stdin.write_all(prompt.as_bytes()).await.context("writing prompt to claude stdin")?;
113 stdin.shutdown().await.context("closing claude stdin")?;
114 drop(stdin);
115
116 let stdout = child.stdout.take().context("capturing claude stdout")?;
117 let result = parse_stream(stdout).await?;
118 let status = child.wait().await.context("waiting for claude")?;
119
120 Ok(AgentResult {
121 cost_usd: result.cost_usd,
122 duration: result.duration,
123 turns: result.turns,
124 output: result.output,
125 session_id: result.session_id,
126 success: status.success(),
127 })
128 }
129
130 async fn run_gh(&self, args: &[String], working_dir: &Path) -> Result<CommandOutput> {
131 let output = Command::new("gh")
132 .args(args)
133 .current_dir(working_dir)
134 .kill_on_drop(true)
135 .output()
136 .await
137 .context("spawning gh")?;
138
139 Ok(CommandOutput {
140 stdout: String::from_utf8_lossy(&output.stdout).to_string(),
141 stderr: String::from_utf8_lossy(&output.stderr).to_string(),
142 success: output.status.success(),
143 })
144 }
145}
146
147pub fn is_transient_error(msg: &str) -> bool {
149 let lower = msg.to_lowercase();
150 TRANSIENT_PATTERNS.iter().any(|p| lower.contains(p))
151}
152
153pub async fn run_with_retry<R: CommandRunner>(
158 runner: &R,
159 invocation: &AgentInvocation,
160) -> Result<AgentResult> {
161 let mut last_err = None;
162 for attempt in 0..=MAX_RETRIES {
163 match crate::agents::invoke_agent(runner, invocation).await {
164 Ok(result) => return Ok(result),
165 Err(e) => {
166 let msg = format!("{e:#}");
167 if attempt < MAX_RETRIES && is_transient_error(&msg) {
168 let delay = RETRY_DELAYS[attempt as usize];
169 warn!(
170 attempt = attempt + 1,
171 max = MAX_RETRIES,
172 delay_secs = delay.as_secs(),
173 error = %msg,
174 "transient agent failure, retrying"
175 );
176 tokio::time::sleep(delay).await;
177 last_err = Some(e);
178 } else {
179 return Err(e);
180 }
181 }
182 }
183 }
184 Err(last_err.unwrap_or_else(|| anyhow::anyhow!("agent invocation failed after retries")))
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190
191 #[test]
192 fn agent_result_is_send_sync() {
193 fn assert_send_sync<T: Send + Sync>() {}
194 assert_send_sync::<AgentResult>();
195 assert_send_sync::<CommandOutput>();
196 }
197
198 #[test]
199 fn real_command_runner_is_send_sync() {
200 fn assert_send_sync<T: Send + Sync>() {}
201 assert_send_sync::<RealCommandRunner>();
202 }
203
204 #[test]
205 fn transient_error_detection() {
206 assert!(is_transient_error("connection reset by peer"));
207 assert!(is_transient_error("Connection Refused"));
208 assert!(is_transient_error("request timed out after 30s"));
209 assert!(is_transient_error("rate limit exceeded"));
210 assert!(is_transient_error("rate_limit_error"));
211 assert!(is_transient_error("HTTP 502 Bad Gateway"));
212 assert!(is_transient_error("Service Unavailable (503)"));
213 assert!(is_transient_error("HTTP 429 Too Many Requests"));
214 assert!(is_transient_error("server is overloaded"));
215 assert!(is_transient_error("ECONNREFUSED 127.0.0.1:443"));
216 }
217
218 #[test]
219 fn non_transient_errors_not_matched() {
220 assert!(!is_transient_error("file not found"));
221 assert!(!is_transient_error("permission denied"));
222 assert!(!is_transient_error("invalid JSON in response"));
223 assert!(!is_transient_error("authentication failed"));
224 assert!(!is_transient_error("listening on port 5029"));
226 assert!(!is_transient_error("record id 4291 not found"));
227 assert!(!is_transient_error(""));
228 }
229}