Skip to main content

oven_cli/process/
mod.rs

1pub mod stream;
2
3use std::{path::Path, time::Duration};
4
5use anyhow::{Context, Result};
6use tokio::{io::AsyncWriteExt, process::Command};
7use tracing::warn;
8
9use self::stream::parse_stream;
10use crate::agents::AgentInvocation;
11
12const MAX_RETRIES: u32 = 2;
13const RETRY_DELAYS: [Duration; 2] = [Duration::from_secs(5), Duration::from_secs(15)];
14const TRANSIENT_PATTERNS: &[&str] = &[
15    "connection reset",
16    "connection refused",
17    "timed out",
18    "timeout",
19    "rate limit",
20    "rate_limit",
21    "http 502",
22    "http 503",
23    "http 429",
24    " 502 ",
25    " 503 ",
26    " 429 ",
27    "(502)",
28    "(503)",
29    "(429)",
30    "overloaded",
31    "econnrefused",
32];
33
34/// Result from a Claude agent invocation.
35#[derive(Debug, Clone)]
36pub struct AgentResult {
37    pub cost_usd: f64,
38    pub duration: Duration,
39    pub turns: u32,
40    pub output: String,
41    pub session_id: String,
42    pub success: bool,
43}
44
45/// Result from a simple command execution (e.g., gh CLI).
46#[derive(Debug, Clone)]
47pub struct CommandOutput {
48    pub stdout: String,
49    pub stderr: String,
50    pub success: bool,
51}
52
53/// Trait for running external commands.
54///
55/// Enables mocking in tests so we never call real CLIs.
56/// Uses `String` slices rather than `&str` slices for mockall compatibility.
57#[cfg_attr(test, mockall::automock)]
58pub trait CommandRunner: Send + Sync {
59    fn run_claude(
60        &self,
61        prompt: &str,
62        allowed_tools: &[String],
63        working_dir: &Path,
64        max_turns: Option<u32>,
65        model: Option<String>,
66    ) -> impl std::future::Future<Output = Result<AgentResult>> + Send;
67
68    fn run_gh(
69        &self,
70        args: &[String],
71        working_dir: &Path,
72    ) -> impl std::future::Future<Output = Result<CommandOutput>> + Send;
73}
74
75/// Real implementation that spawns actual subprocesses.
76pub struct RealCommandRunner;
77
78impl CommandRunner for RealCommandRunner {
79    async fn run_claude(
80        &self,
81        prompt: &str,
82        allowed_tools: &[String],
83        working_dir: &Path,
84        max_turns: Option<u32>,
85        model: Option<String>,
86    ) -> Result<AgentResult> {
87        let tools_arg = allowed_tools.join(",");
88
89        let mut cmd = Command::new("claude");
90        cmd.args(["-p", "--verbose", "--output-format", "stream-json"])
91            .args(["--allowedTools", &tools_arg]);
92
93        if let Some(ref model) = model {
94            cmd.args(["--model", model]);
95        }
96
97        if let Some(turns) = max_turns {
98            cmd.args(["--max-turns", &turns.to_string()]);
99        }
100
101        let mut child = cmd
102            .current_dir(working_dir)
103            .stdin(std::process::Stdio::piped())
104            .stdout(std::process::Stdio::piped())
105            .stderr(std::process::Stdio::piped())
106            .kill_on_drop(true)
107            .spawn()
108            .context("spawning claude")?;
109
110        // Pass prompt via stdin to avoid leaking it in process listings (ps aux).
111        let mut stdin = child.stdin.take().context("capturing claude stdin")?;
112        stdin.write_all(prompt.as_bytes()).await.context("writing prompt to claude stdin")?;
113        stdin.shutdown().await.context("closing claude stdin")?;
114        drop(stdin);
115
116        let stdout = child.stdout.take().context("capturing claude stdout")?;
117        let result = parse_stream(stdout).await?;
118        let status = child.wait().await.context("waiting for claude")?;
119
120        Ok(AgentResult {
121            cost_usd: result.cost_usd,
122            duration: result.duration,
123            turns: result.turns,
124            output: result.output,
125            session_id: result.session_id,
126            success: status.success(),
127        })
128    }
129
130    async fn run_gh(&self, args: &[String], working_dir: &Path) -> Result<CommandOutput> {
131        let output = Command::new("gh")
132            .args(args)
133            .current_dir(working_dir)
134            .kill_on_drop(true)
135            .output()
136            .await
137            .context("spawning gh")?;
138
139        Ok(CommandOutput {
140            stdout: String::from_utf8_lossy(&output.stdout).to_string(),
141            stderr: String::from_utf8_lossy(&output.stderr).to_string(),
142            success: output.status.success(),
143        })
144    }
145}
146
147/// Check whether an error message indicates a transient failure worth retrying.
148pub fn is_transient_error(msg: &str) -> bool {
149    let lower = msg.to_lowercase();
150    TRANSIENT_PATTERNS.iter().any(|p| lower.contains(p))
151}
152
153/// Invoke an agent with retry logic for transient failures.
154///
155/// Retries up to `MAX_RETRIES` times (with backoff) when the error message
156/// matches known transient patterns (connection resets, rate limits, 5xx, etc.).
157pub async fn run_with_retry<R: CommandRunner>(
158    runner: &R,
159    invocation: &AgentInvocation,
160) -> Result<AgentResult> {
161    let mut last_err = None;
162    for attempt in 0..=MAX_RETRIES {
163        match crate::agents::invoke_agent(runner, invocation).await {
164            Ok(result) => return Ok(result),
165            Err(e) => {
166                let msg = format!("{e:#}");
167                if attempt < MAX_RETRIES && is_transient_error(&msg) {
168                    let delay = RETRY_DELAYS[attempt as usize];
169                    warn!(
170                        attempt = attempt + 1,
171                        max = MAX_RETRIES,
172                        delay_secs = delay.as_secs(),
173                        error = %msg,
174                        "transient agent failure, retrying"
175                    );
176                    tokio::time::sleep(delay).await;
177                    last_err = Some(e);
178                } else {
179                    return Err(e);
180                }
181            }
182        }
183    }
184    Err(last_err.unwrap_or_else(|| anyhow::anyhow!("agent invocation failed after retries")))
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190
191    #[test]
192    fn agent_result_is_send_sync() {
193        fn assert_send_sync<T: Send + Sync>() {}
194        assert_send_sync::<AgentResult>();
195        assert_send_sync::<CommandOutput>();
196    }
197
198    #[test]
199    fn real_command_runner_is_send_sync() {
200        fn assert_send_sync<T: Send + Sync>() {}
201        assert_send_sync::<RealCommandRunner>();
202    }
203
204    #[test]
205    fn transient_error_detection() {
206        assert!(is_transient_error("connection reset by peer"));
207        assert!(is_transient_error("Connection Refused"));
208        assert!(is_transient_error("request timed out after 30s"));
209        assert!(is_transient_error("rate limit exceeded"));
210        assert!(is_transient_error("rate_limit_error"));
211        assert!(is_transient_error("HTTP 502 Bad Gateway"));
212        assert!(is_transient_error("Service Unavailable (503)"));
213        assert!(is_transient_error("HTTP 429 Too Many Requests"));
214        assert!(is_transient_error("server is overloaded"));
215        assert!(is_transient_error("ECONNREFUSED 127.0.0.1:443"));
216    }
217
218    #[test]
219    fn non_transient_errors_not_matched() {
220        assert!(!is_transient_error("file not found"));
221        assert!(!is_transient_error("permission denied"));
222        assert!(!is_transient_error("invalid JSON in response"));
223        assert!(!is_transient_error("authentication failed"));
224        // Bare numbers should not match (e.g. port numbers, IDs)
225        assert!(!is_transient_error("listening on port 5029"));
226        assert!(!is_transient_error("record id 4291 not found"));
227        assert!(!is_transient_error(""));
228    }
229}