use crate::abstractions::{ClaudeClient, RealClaudeClient};
use anyhow::{Context, Result};
use once_cell::sync::Lazy;
use std::sync::Arc;
use std::time::Duration;
use tokio::process::Command;
use tokio::time::sleep;
static CLAUDE_CLIENT: Lazy<Arc<RealClaudeClient>> = Lazy::new(|| Arc::new(RealClaudeClient::new()));
pub async fn execute_with_retry(
mut command: Command,
description: &str,
max_retries: u32,
verbose: bool,
) -> Result<std::process::Output> {
let mut attempt = 0;
let mut last_error = None;
while attempt <= max_retries {
if attempt > 0 {
await_retry_delay(attempt, description, max_retries, verbose).await;
}
match command.output().await {
Ok(output) => {
if let Some(retry_reason) = should_retry_output(&output, attempt, max_retries) {
if verbose {
eprintln!("⚠️ Transient error detected: {retry_reason}");
}
last_error = Some(retry_reason);
attempt += 1;
continue;
}
return Ok(output);
}
Err(e) => {
let retry_result =
handle_command_error(e, description, attempt, max_retries, verbose)?;
if let Some(error_msg) = retry_result {
last_error = Some(error_msg);
attempt += 1;
continue;
}
}
}
}
Err(anyhow::anyhow!(
"Failed {} after {} retries. Last error: {}",
description,
max_retries,
last_error.unwrap_or_else(|| "Unknown error".to_string())
))
}
async fn await_retry_delay(attempt: u32, description: &str, max_retries: u32, verbose: bool) {
let delay = Duration::from_secs(2u64.pow(attempt.min(3))); if verbose {
println!("⏳ Retrying {description} after {delay:?} (attempt {attempt}/{max_retries})");
}
sleep(delay).await;
}
fn should_retry_output(
output: &std::process::Output,
attempt: u32,
max_retries: u32,
) -> Option<String> {
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
if is_transient_error(&stderr) && attempt < max_retries {
return Some(stderr.lines().next().unwrap_or("Unknown error").to_string());
}
}
None
}
fn handle_command_error(
error: std::io::Error,
description: &str,
attempt: u32,
max_retries: u32,
verbose: bool,
) -> Result<Option<String>> {
if error.kind() == std::io::ErrorKind::NotFound {
return Err(error).context(format!("Command not found for {description}"));
}
if attempt < max_retries {
if verbose {
eprintln!("⚠️ IO error: {error}");
}
return Ok(Some(error.to_string()));
}
Err(error).context(format!("Failed to execute {description}"))
}
pub fn is_transient_error(stderr: &str) -> bool {
let transient_patterns = [
"rate limit",
"timeout",
"connection refused",
"temporary failure",
"network",
"503",
"429", "could not connect",
"broken pipe",
];
let stderr_lower = stderr.to_lowercase();
transient_patterns
.iter()
.any(|pattern| stderr_lower.contains(pattern))
}
pub async fn check_claude_cli() -> Result<()> {
CLAUDE_CLIENT.check_availability().await
}
pub fn format_subprocess_error(
command: &str,
exit_code: Option<i32>,
stderr: &str,
stdout: &str,
) -> String {
let mut error_msg = format!("Command '{command}' failed");
if let Some(code) = exit_code {
error_msg.push_str(&format!(" with exit code {code}"));
}
if !stderr.trim().is_empty() {
error_msg.push_str(&format!("\n\nError output:\n{}", stderr.trim()));
}
if stderr.trim().is_empty() && !stdout.trim().is_empty() {
error_msg.push_str(&format!("\n\nOutput:\n{}", stdout.trim()));
}
if stderr.contains("permission denied") || stderr.contains("unauthorized") {
error_msg.push_str("\n\nHint: Check that you have authenticated with 'claude auth'");
} else if stderr.contains("not found") && stderr.contains("command") {
error_msg.push_str(&format!(
"\n\nHint: The '{command}' command may not be installed or not in PATH"
));
} else if stderr.contains("rate limit") {
error_msg.push_str(
"\n\nHint: You may have hit the API rate limit. Please wait a moment and try again.",
);
}
error_msg
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_transient_error_detection() {
assert!(is_transient_error("Error: rate limit exceeded"));
assert!(is_transient_error("Connection timeout"));
assert!(is_transient_error("HTTP 503 Service Unavailable"));
assert!(is_transient_error("Error 429: Too Many Requests"));
assert!(is_transient_error("Connection refused by server"));
assert!(is_transient_error("Temporary failure in name resolution"));
assert!(is_transient_error("Network is unreachable"));
assert!(is_transient_error("Could not connect to API"));
assert!(is_transient_error("Broken pipe error"));
assert!(!is_transient_error("Syntax error in file"));
assert!(!is_transient_error("Command not found"));
assert!(!is_transient_error("Invalid argument"));
}
#[test]
fn test_transient_error_case_insensitive() {
assert!(is_transient_error("RATE LIMIT EXCEEDED"));
assert!(is_transient_error("Rate Limit Exceeded"));
assert!(is_transient_error("RaTe LiMiT"));
}
#[test]
fn test_error_formatting() {
let error = format_subprocess_error("claude", Some(1), "Error: permission denied", "");
assert!(error.contains("exit code 1"));
assert!(error.contains("permission denied"));
assert!(error.contains("claude auth"));
}
#[test]
fn test_error_formatting_no_exit_code() {
let error = format_subprocess_error("claude", None, "Something went wrong", "");
assert!(error.contains("Command 'claude' failed"));
assert!(error.contains("Something went wrong"));
assert!(!error.contains("exit code"));
}
#[test]
fn test_error_formatting_with_stdout_only() {
let error = format_subprocess_error("claude", Some(1), "", "Error in stdout");
assert!(error.contains("Error in stdout"));
assert!(error.contains("Output:"));
}
#[test]
fn test_error_formatting_command_not_found() {
let error = format_subprocess_error("unknown-cmd", Some(127), "command not found", "");
assert!(error.contains("may not be installed or not in PATH"));
}
#[test]
fn test_error_formatting_rate_limit() {
let error = format_subprocess_error("claude", Some(1), "Error: rate limit exceeded", "");
assert!(error.contains("You may have hit the API rate limit"));
}
#[test]
fn test_error_formatting_empty_outputs() {
let error = format_subprocess_error("claude", Some(1), " ", " ");
assert!(error.contains("Command 'claude' failed"));
assert!(!error.contains("Error output:"));
assert!(!error.contains("Output:"));
}
#[tokio::test]
async fn test_execute_with_retry_success() {
let mut cmd = Command::new("echo");
cmd.arg("hello");
let output = execute_with_retry(cmd, "echo test", 3, false)
.await
.unwrap();
assert!(output.status.success());
assert_eq!(String::from_utf8_lossy(&output.stdout).trim(), "hello");
}
#[tokio::test]
async fn test_execute_with_retry_command_not_found() {
let cmd = Command::new("this-command-does-not-exist");
let result = execute_with_retry(cmd, "nonexistent command", 3, false).await;
assert!(result.is_err());
let error = result.unwrap_err().to_string();
assert!(error.contains("Command not found"));
}
#[tokio::test]
async fn test_execute_with_retry_non_transient_failure() {
let mut cmd = Command::new("sh");
cmd.arg("-c").arg("echo 'Fatal error' >&2; exit 1");
let output = execute_with_retry(cmd, "failing command", 3, false)
.await
.unwrap();
assert!(!output.status.success());
assert!(String::from_utf8_lossy(&output.stderr).contains("Fatal error"));
}
#[tokio::test]
async fn test_execute_with_retry_exhausted_retries() {
let mut cmd = Command::new("sh");
cmd.arg("-c").arg("echo 'connection refused' >&2; exit 1");
let output = execute_with_retry(cmd, "transient error test", 2, true)
.await
.unwrap();
assert!(!output.status.success());
let stderr = String::from_utf8_lossy(&output.stderr);
assert!(stderr.contains("connection refused"));
}
#[tokio::test]
async fn test_execute_with_retry_exponential_backoff() {
use std::time::Instant;
let mut cmd = Command::new("sh");
cmd.arg("-c").arg("echo 'rate limit' >&2; exit 1");
let start = Instant::now();
let output = execute_with_retry(cmd, "backoff test", 2, false)
.await
.unwrap();
let elapsed = start.elapsed();
assert!(!output.status.success());
let stderr = String::from_utf8_lossy(&output.stderr);
assert!(stderr.contains("rate limit"));
assert!(elapsed.as_secs() >= 6);
}
#[tokio::test]
async fn test_check_claude_cli_when_missing() {
let mut cmd = Command::new("which");
cmd.arg("nonexistent-command-xyz");
let output = cmd.output().await.unwrap();
assert!(!output.status.success());
}
#[test]
fn test_format_subprocess_error_all_fields() {
let error = format_subprocess_error(
"test-cmd",
Some(42),
"Error details\nMultiline error",
"Some stdout content",
);
assert!(error.contains("'test-cmd'"));
assert!(error.contains("exit code 42"));
assert!(error.contains("Error details"));
assert!(error.contains("Multiline error"));
assert!(!error.contains("Some stdout content"));
}
#[test]
fn test_format_subprocess_error_unauthorized() {
let error = format_subprocess_error("claude", Some(1), "Error: unauthorized access", "");
assert!(error.contains("Check that you have authenticated"));
assert!(error.contains("unauthorized access"));
assert!(error.contains("claude"));
}
#[test]
fn test_format_subprocess_error_unauthorized_variations() {
let test_cases = vec![
("Error: 401 Unauthorized", "401 Unauthorized"),
("API key invalid", "API key"),
("Authentication failed", "Authentication"),
];
for (stderr, expected) in test_cases {
let error = format_subprocess_error("claude", Some(1), stderr, "");
assert!(error.contains(expected));
}
}
#[test]
fn test_is_transient_error_partial_matches() {
assert!(is_transient_error("Error occurred: rate limit hit"));
assert!(is_transient_error("Request timeout after 30s"));
assert!(is_transient_error("Could not connect to host"));
assert!(is_transient_error("Error: broken pipe while sending"));
}
#[tokio::test]
async fn test_execute_with_retry_max_attempts_reached() {
let mut cmd = Command::new("sh");
cmd.arg("-c").arg("echo 'temporary failure' >&2; exit 1");
let output = execute_with_retry(cmd, "max retries test", 1, false)
.await
.unwrap();
assert!(!output.status.success());
let stderr = String::from_utf8_lossy(&output.stderr);
assert!(stderr.contains("temporary failure"));
}
#[tokio::test]
async fn test_check_claude_cli_fallback() {
let result = check_claude_cli().await;
let _ = result;
}
}
#[cfg(test)]
mod additional_tests {
use super::*;
#[tokio::test]
async fn test_execute_with_retry_io_error_recovery() {
let mut cmd = Command::new("sh");
cmd.arg("-c").arg("exit 0");
let output = execute_with_retry(cmd, "io test", 1, false).await.unwrap();
assert!(output.status.success());
}
#[tokio::test]
async fn test_format_subprocess_error_all_hints() {
let patterns = vec![
("permission denied in /usr/local", "claude auth"),
("command not found: xyz", "may not be installed"),
("rate limit exceeded for API", "hit the API rate limit"),
];
for (stderr, expected_hint) in patterns {
let error = format_subprocess_error("test", Some(1), stderr, "");
assert!(
error.contains(expected_hint),
"Expected hint '{expected_hint}' for error '{stderr}'"
);
}
}
#[test]
fn test_is_transient_error_edge_cases() {
assert!(!is_transient_error(""));
assert!(!is_transient_error(" \n "));
let long_error = format!("{}rate limit", "x".repeat(1000));
assert!(is_transient_error(&long_error));
assert!(is_transient_error("timeout and connection refused"));
}
#[tokio::test]
async fn test_execute_with_retry_verbose_output() {
let mut cmd = Command::new("echo");
cmd.arg("verbose test");
let output = execute_with_retry(cmd, "verbose test", 1, true)
.await
.unwrap();
assert!(output.status.success());
}
#[test]
fn test_format_subprocess_error_with_newlines() {
let error = format_subprocess_error(
"multi-line",
Some(2),
"Error:\n Line 1\n Line 2\n",
"Output:\n Data 1\n Data 2\n",
);
assert!(error.contains("Line 1"));
assert!(error.contains("Line 2"));
assert!(!error.contains("Data 1")); }
#[tokio::test]
async fn test_execute_with_retry_network_timeout() {
let mut cmd = Command::new("sleep");
cmd.arg("10");
let start = std::time::Instant::now();
let result = execute_with_retry(cmd, "timeout test", 2, false).await;
assert!(result.is_ok() || start.elapsed().as_secs() > 5);
}
#[tokio::test]
async fn test_execute_with_retry_signal_interruption() {
let mut cmd = Command::new("sh");
cmd.arg("-c").arg("trap 'exit 1' TERM; sleep 10");
tokio::spawn(async {
tokio::time::sleep(Duration::from_millis(100)).await;
});
let result = execute_with_retry(cmd, "signal test", 3, true).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_check_claude_cli_error_cases() {
let result = check_claude_cli().await;
if result.is_err() {
let error_msg = result.unwrap_err().to_string();
assert!(
error_msg.contains("claude") || error_msg.contains("install"),
"Error message should mention 'claude' or 'install': {error_msg}"
);
}
}
}