use std::time::Duration;
use thiserror::Error;
use crate::agent::AgentResult;
#[derive(Debug, Error)]
pub enum AgentError {
#[error("max turns ({turns}) reached without completion")]
MaxTurnsReached {
turns: usize,
partial: Box<AgentResult>,
},
#[error("provider error: {source}")]
Provider {
#[source]
source: ProviderError,
partial: Box<AgentResult>,
},
#[error("cancelled")]
Cancelled { partial: Box<AgentResult> },
#[error("tool '{tool_name}' failed: {source}")]
Tool {
tool_name: String,
#[source]
source: ToolError,
partial: Box<AgentResult>,
},
}
impl AgentError {
pub fn partial(&self) -> &AgentResult {
match self {
AgentError::MaxTurnsReached { partial, .. }
| AgentError::Provider { partial, .. }
| AgentError::Cancelled { partial }
| AgentError::Tool { partial, .. } => partial,
}
}
pub fn into_partial(self) -> AgentResult {
match self {
AgentError::MaxTurnsReached { partial, .. }
| AgentError::Provider { partial, .. }
| AgentError::Cancelled { partial }
| AgentError::Tool { partial, .. } => *partial,
}
}
}
#[derive(Debug, Error)]
pub enum ProviderError {
#[error("HTTP error: {0}")]
Http(#[from] reqwest::Error),
#[error("API error ({status}): {message}")]
Api {
status: u16,
message: String,
retryable: bool,
},
#[error("overloaded (retry after: {retry_after_ms:?}ms)")]
Overloaded { retry_after_ms: Option<u64> },
#[error("rate limited (retry after: {retry_after_ms:?}ms)")]
RateLimit { retry_after_ms: Option<u64> },
#[error("deserialization error: {0}")]
Deserialization(#[from] serde_json::Error),
#[error("batch not ready (status: {status})")]
BatchNotReady { status: String },
#[error("{0}")]
Other(String),
}
impl ProviderError {
pub fn is_retryable(&self) -> bool {
match self {
ProviderError::Http(e) => is_transient_reqwest(e),
ProviderError::Api { retryable, .. } => *retryable,
ProviderError::Overloaded { .. } | ProviderError::RateLimit { .. } => true,
ProviderError::Deserialization(_)
| ProviderError::Other(_)
| ProviderError::BatchNotReady { .. } => false,
}
}
pub fn retry_after(&self) -> Option<Duration> {
match self {
ProviderError::Overloaded { retry_after_ms }
| ProviderError::RateLimit { retry_after_ms } => {
retry_after_ms.map(Duration::from_millis)
}
_ => None,
}
}
}
fn is_transient_reqwest(e: &reqwest::Error) -> bool {
if e.is_decode() || e.is_builder() || e.is_redirect() {
return false;
}
e.is_timeout() || e.is_connect() || e.is_body() || e.is_request()
}
#[derive(Debug, Error)]
pub enum ToolError {
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("invalid input: {0}")]
InvalidInput(String),
#[error("cancelled")]
Cancelled,
#[error("{0}")]
Execution(String),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn api_retryable_flag_respected() {
let retryable = ProviderError::Api {
status: 500,
message: "internal".into(),
retryable: true,
};
assert!(retryable.is_retryable());
assert!(retryable.retry_after().is_none());
let fatal = ProviderError::Api {
status: 400,
message: "bad request".into(),
retryable: false,
};
assert!(!fatal.is_retryable());
}
#[test]
fn overloaded_always_retryable() {
let with_hint = ProviderError::Overloaded {
retry_after_ms: Some(5_000),
};
assert!(with_hint.is_retryable());
assert_eq!(with_hint.retry_after(), Some(Duration::from_millis(5_000)));
let without = ProviderError::Overloaded {
retry_after_ms: None,
};
assert!(without.is_retryable());
assert_eq!(without.retry_after(), None);
}
#[test]
fn rate_limit_always_retryable() {
let rl = ProviderError::RateLimit {
retry_after_ms: Some(1_500),
};
assert!(rl.is_retryable());
assert_eq!(rl.retry_after(), Some(Duration::from_millis(1_500)));
}
#[test]
fn deserialization_never_retryable() {
let err: serde_json::Error = serde_json::from_str::<serde_json::Value>("{").unwrap_err();
let e = ProviderError::Deserialization(err);
assert!(!e.is_retryable());
assert!(e.retry_after().is_none());
}
#[test]
fn batch_not_ready_never_retryable() {
let e = ProviderError::BatchNotReady {
status: "in_progress".into(),
};
assert!(!e.is_retryable());
assert!(e.retry_after().is_none());
}
#[test]
fn other_never_retryable() {
assert!(!ProviderError::Other("weird".into()).is_retryable());
}
}