use super::ids::StepId;
use serde::{Deserialize, Serialize};
use std::time::Duration;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "PascalCase")]
pub enum ExecutionErrorCategory {
LlmError,
ToolError,
PolicyViolation,
Timeout,
QuotaExceeded,
KernelInternal,
ValidationError,
NetworkError,
}
impl ExecutionErrorCategory {
pub fn is_fatal(&self) -> bool {
matches!(
self,
Self::PolicyViolation
| Self::QuotaExceeded
| Self::KernelInternal
| Self::ValidationError
)
}
pub fn default_retry_policy(&self) -> RetryPolicy {
match self {
Self::LlmError => RetryPolicy {
retryable: true,
max_retries: 3,
backoff_strategy: BackoffStrategy::Exponential,
base_delay: Duration::from_millis(1000),
max_delay: Duration::from_millis(30000),
requires_idempotency_key: false,
},
Self::ToolError => RetryPolicy {
retryable: true,
max_retries: 2,
backoff_strategy: BackoffStrategy::Constant,
base_delay: Duration::from_millis(500),
max_delay: Duration::from_millis(5000),
requires_idempotency_key: true, },
Self::PolicyViolation => RetryPolicy::fatal(),
Self::Timeout => RetryPolicy {
retryable: true,
max_retries: 1,
backoff_strategy: BackoffStrategy::Constant,
base_delay: Duration::ZERO,
max_delay: Duration::ZERO,
requires_idempotency_key: true,
},
Self::QuotaExceeded => RetryPolicy::fatal(),
Self::KernelInternal => RetryPolicy::fatal(),
Self::ValidationError => RetryPolicy::fatal(),
Self::NetworkError => RetryPolicy {
retryable: true,
max_retries: 3,
backoff_strategy: BackoffStrategy::Exponential,
base_delay: Duration::from_millis(500),
max_delay: Duration::from_millis(15000),
requires_idempotency_key: true,
},
}
}
}
impl std::fmt::Display for ExecutionErrorCategory {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::LlmError => write!(f, "LlmError"),
Self::ToolError => write!(f, "ToolError"),
Self::PolicyViolation => write!(f, "PolicyViolation"),
Self::Timeout => write!(f, "Timeout"),
Self::QuotaExceeded => write!(f, "QuotaExceeded"),
Self::KernelInternal => write!(f, "KernelInternal"),
Self::ValidationError => write!(f, "ValidationError"),
Self::NetworkError => write!(f, "NetworkError"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum BackoffStrategy {
#[default]
None,
Constant,
Linear,
Exponential,
}
impl BackoffStrategy {
pub fn calculate_delay(&self, base: Duration, attempt: u32, max: Duration) -> Duration {
let delay = match self {
Self::None => Duration::ZERO,
Self::Constant => base,
Self::Linear => base * attempt,
Self::Exponential => {
let multiplier = 2u64.saturating_pow(attempt.saturating_sub(1));
base.saturating_mul(multiplier as u32)
}
};
std::cmp::min(delay, max)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct RetryPolicy {
pub retryable: bool,
pub max_retries: u32,
pub backoff_strategy: BackoffStrategy,
#[serde(with = "duration_millis")]
pub base_delay: Duration,
#[serde(with = "duration_millis")]
pub max_delay: Duration,
pub requires_idempotency_key: bool,
}
impl RetryPolicy {
pub fn fatal() -> Self {
Self {
retryable: false,
max_retries: 0,
backoff_strategy: BackoffStrategy::None,
base_delay: Duration::ZERO,
max_delay: Duration::ZERO,
requires_idempotency_key: false,
}
}
pub fn retryable(max_retries: u32) -> Self {
Self {
retryable: true,
max_retries,
backoff_strategy: BackoffStrategy::Exponential,
base_delay: Duration::from_millis(1000),
max_delay: Duration::from_millis(30000),
requires_idempotency_key: false,
}
}
pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
self.backoff_strategy
.calculate_delay(self.base_delay, attempt, self.max_delay)
}
pub fn should_retry(&self, attempt: u32) -> bool {
self.retryable && attempt <= self.max_retries
}
}
impl Default for RetryPolicy {
fn default() -> Self {
Self::fatal()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum LlmErrorCode {
RateLimit,
ContextOverflow,
ContentFiltered,
InvalidRequest,
AuthFailed,
ModelUnavailable,
ProviderError,
}
impl LlmErrorCode {
pub fn is_retryable(&self) -> bool {
matches!(
self,
Self::RateLimit | Self::ModelUnavailable | Self::ProviderError
)
}
}
impl std::fmt::Display for LlmErrorCode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::RateLimit => write!(f, "rate_limit"),
Self::ContextOverflow => write!(f, "context_overflow"),
Self::ContentFiltered => write!(f, "content_filtered"),
Self::InvalidRequest => write!(f, "invalid_request"),
Self::AuthFailed => write!(f, "auth_failed"),
Self::ModelUnavailable => write!(f, "model_unavailable"),
Self::ProviderError => write!(f, "provider_error"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ToolErrorCode {
NotFound,
PermissionDenied,
InvalidInput,
ExecutionFailed,
Timeout,
OutputInvalid,
}
impl ToolErrorCode {
pub fn is_retryable(&self) -> bool {
matches!(self, Self::Timeout | Self::ExecutionFailed)
}
}
impl std::fmt::Display for ToolErrorCode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NotFound => write!(f, "not_found"),
Self::PermissionDenied => write!(f, "permission_denied"),
Self::InvalidInput => write!(f, "invalid_input"),
Self::ExecutionFailed => write!(f, "execution_failed"),
Self::Timeout => write!(f, "timeout"),
Self::OutputInvalid => write!(f, "output_invalid"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecutionError {
pub category: ExecutionErrorCategory,
pub message: String,
pub retry_policy: RetryPolicy,
pub code: Option<String>,
pub attempt: u32,
pub step_id: Option<StepId>,
pub provider: Option<String>,
pub http_status: Option<u16>,
pub details: Option<serde_json::Value>,
pub occurred_at: i64,
}
impl ExecutionError {
pub fn new(category: ExecutionErrorCategory, message: impl Into<String>) -> Self {
Self {
category,
message: message.into(),
retry_policy: category.default_retry_policy(),
code: None,
attempt: 1,
step_id: None,
provider: None,
http_status: None,
details: None,
occurred_at: chrono::Utc::now().timestamp_millis(),
}
}
pub fn with_code(mut self, code: impl Into<String>) -> Self {
self.code = Some(code.into());
self
}
pub fn with_attempt(mut self, attempt: u32) -> Self {
self.attempt = attempt;
self
}
pub fn with_step_id(mut self, step_id: StepId) -> Self {
self.step_id = Some(step_id);
self
}
pub fn with_provider(mut self, provider: impl Into<String>) -> Self {
self.provider = Some(provider.into());
self
}
pub fn with_http_status(mut self, status: u16) -> Self {
self.http_status = Some(status);
self
}
pub fn with_details(mut self, details: serde_json::Value) -> Self {
self.details = Some(details);
self
}
pub fn with_retry_policy(mut self, policy: RetryPolicy) -> Self {
self.retry_policy = policy;
self
}
pub fn llm(code: LlmErrorCode, message: impl Into<String>) -> Self {
Self::new(ExecutionErrorCategory::LlmError, message).with_code(code.to_string())
}
pub fn tool(code: ToolErrorCode, message: impl Into<String>) -> Self {
Self::new(ExecutionErrorCategory::ToolError, message).with_code(code.to_string())
}
pub fn policy_violation(message: impl Into<String>) -> Self {
Self::new(ExecutionErrorCategory::PolicyViolation, message)
}
pub fn timeout(message: impl Into<String>) -> Self {
Self::new(ExecutionErrorCategory::Timeout, message)
}
pub fn quota_exceeded(message: impl Into<String>) -> Self {
Self::new(ExecutionErrorCategory::QuotaExceeded, message)
}
pub fn kernel_internal(message: impl Into<String>) -> Self {
Self::new(ExecutionErrorCategory::KernelInternal, message)
}
pub fn validation(message: impl Into<String>) -> Self {
Self::new(ExecutionErrorCategory::ValidationError, message)
}
pub fn network(message: impl Into<String>) -> Self {
Self::new(ExecutionErrorCategory::NetworkError, message)
}
pub fn is_retryable(&self) -> bool {
self.retry_policy.retryable
}
pub fn is_fatal(&self) -> bool {
self.category.is_fatal()
}
pub fn should_retry(&self) -> bool {
self.retry_policy.should_retry(self.attempt)
}
pub fn retry_delay(&self) -> Duration {
self.retry_policy.delay_for_attempt(self.attempt)
}
pub fn next_attempt(mut self) -> Self {
self.attempt += 1;
self.occurred_at = chrono::Utc::now().timestamp_millis();
self
}
pub fn to_http_status(&self) -> u16 {
if let Some(status) = self.http_status {
return status;
}
match self.category {
ExecutionErrorCategory::LlmError => 502, ExecutionErrorCategory::ToolError => 500, ExecutionErrorCategory::PolicyViolation => 403, ExecutionErrorCategory::Timeout => 504, ExecutionErrorCategory::QuotaExceeded => 429, ExecutionErrorCategory::KernelInternal => 500, ExecutionErrorCategory::ValidationError => 400, ExecutionErrorCategory::NetworkError => 503, }
}
}
impl std::fmt::Display for ExecutionError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "[{}] {}", self.category, self.message)?;
if let Some(code) = &self.code {
write!(f, " ({})", code)?;
}
if self.attempt > 1 {
write!(f, " [attempt {}]", self.attempt)?;
}
Ok(())
}
}
impl std::error::Error for ExecutionError {}
impl From<reqwest::Error> for ExecutionError {
fn from(err: reqwest::Error) -> Self {
if err.is_timeout() {
Self::timeout(format!("HTTP request timed out: {}", err))
} else if err.is_connect() {
Self::network(format!("Connection failed: {}", err))
} else if err.is_status() {
let status = err.status().map(|s| s.as_u16()).unwrap_or(500);
Self::network(format!("HTTP error: {}", err)).with_http_status(status)
} else {
Self::network(format!("HTTP error: {}", err))
}
}
}
impl From<serde_json::Error> for ExecutionError {
fn from(err: serde_json::Error) -> Self {
Self::validation(format!("JSON error: {}", err))
}
}
impl From<std::io::Error> for ExecutionError {
fn from(err: std::io::Error) -> Self {
match err.kind() {
std::io::ErrorKind::TimedOut => Self::timeout(format!("IO timeout: {}", err)),
std::io::ErrorKind::ConnectionRefused
| std::io::ErrorKind::ConnectionReset
| std::io::ErrorKind::ConnectionAborted => {
Self::network(format!("Connection error: {}", err))
}
_ => Self::kernel_internal(format!("IO error: {}", err)),
}
}
}
mod duration_millis {
use serde::{Deserialize, Deserializer, Serializer};
use std::time::Duration;
pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_u64(duration.as_millis() as u64)
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
where
D: Deserializer<'de>,
{
let millis = u64::deserialize(deserializer)?;
Ok(Duration::from_millis(millis))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_categories_fatality() {
assert!(ExecutionErrorCategory::PolicyViolation.is_fatal());
assert!(ExecutionErrorCategory::QuotaExceeded.is_fatal());
assert!(ExecutionErrorCategory::KernelInternal.is_fatal());
assert!(ExecutionErrorCategory::ValidationError.is_fatal());
assert!(!ExecutionErrorCategory::LlmError.is_fatal());
assert!(!ExecutionErrorCategory::ToolError.is_fatal());
assert!(!ExecutionErrorCategory::Timeout.is_fatal());
assert!(!ExecutionErrorCategory::NetworkError.is_fatal());
}
#[test]
fn test_default_retry_policies() {
let llm_policy = ExecutionErrorCategory::LlmError.default_retry_policy();
assert!(llm_policy.retryable);
assert_eq!(llm_policy.max_retries, 3);
assert_eq!(llm_policy.backoff_strategy, BackoffStrategy::Exponential);
let fatal_policy = ExecutionErrorCategory::PolicyViolation.default_retry_policy();
assert!(!fatal_policy.retryable);
assert_eq!(fatal_policy.max_retries, 0);
}
#[test]
fn test_exponential_backoff() {
let strategy = BackoffStrategy::Exponential;
let base = Duration::from_millis(1000);
let max = Duration::from_millis(30000);
assert_eq!(
strategy.calculate_delay(base, 1, max),
Duration::from_millis(1000)
);
assert_eq!(
strategy.calculate_delay(base, 2, max),
Duration::from_millis(2000)
);
assert_eq!(
strategy.calculate_delay(base, 3, max),
Duration::from_millis(4000)
);
assert_eq!(
strategy.calculate_delay(base, 4, max),
Duration::from_millis(8000)
);
assert_eq!(
strategy.calculate_delay(base, 5, max),
Duration::from_millis(16000)
);
assert_eq!(
strategy.calculate_delay(base, 6, max),
Duration::from_millis(30000)
);
}
#[test]
fn test_execution_error_creation() {
let error = ExecutionError::llm(LlmErrorCode::RateLimit, "Too many requests")
.with_provider("azure")
.with_http_status(429);
assert_eq!(error.category, ExecutionErrorCategory::LlmError);
assert_eq!(error.code, Some("rate_limit".to_string()));
assert_eq!(error.provider, Some("azure".to_string()));
assert_eq!(error.http_status, Some(429));
assert!(error.is_retryable());
assert!(!error.is_fatal());
}
#[test]
fn test_should_retry() {
let mut error = ExecutionError::llm(LlmErrorCode::RateLimit, "Rate limited");
assert!(error.should_retry()); error = error.next_attempt();
assert!(error.should_retry()); error = error.next_attempt();
assert!(error.should_retry()); error = error.next_attempt();
assert!(!error.should_retry()); }
#[test]
fn test_fatal_error_never_retries() {
let error = ExecutionError::policy_violation("Content blocked");
assert!(!error.is_retryable());
assert!(error.is_fatal());
assert!(!error.should_retry());
}
#[test]
fn test_http_status_mapping() {
assert_eq!(
ExecutionError::policy_violation("test").to_http_status(),
403
);
assert_eq!(ExecutionError::quota_exceeded("test").to_http_status(), 429);
assert_eq!(ExecutionError::timeout("test").to_http_status(), 504);
assert_eq!(ExecutionError::validation("test").to_http_status(), 400);
}
#[test]
fn test_error_serialization() {
let error = ExecutionError::llm(LlmErrorCode::RateLimit, "Too many requests")
.with_provider("azure");
let json = serde_json::to_string(&error).unwrap();
let parsed: ExecutionError = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.category, error.category);
assert_eq!(parsed.message, error.message);
assert_eq!(parsed.code, error.code);
assert_eq!(parsed.provider, error.provider);
}
#[test]
fn test_error_category_display() {
assert_eq!(format!("{}", ExecutionErrorCategory::LlmError), "LlmError");
assert_eq!(
format!("{}", ExecutionErrorCategory::ToolError),
"ToolError"
);
assert_eq!(
format!("{}", ExecutionErrorCategory::PolicyViolation),
"PolicyViolation"
);
assert_eq!(format!("{}", ExecutionErrorCategory::Timeout), "Timeout");
assert_eq!(
format!("{}", ExecutionErrorCategory::QuotaExceeded),
"QuotaExceeded"
);
assert_eq!(
format!("{}", ExecutionErrorCategory::KernelInternal),
"KernelInternal"
);
assert_eq!(
format!("{}", ExecutionErrorCategory::ValidationError),
"ValidationError"
);
assert_eq!(
format!("{}", ExecutionErrorCategory::NetworkError),
"NetworkError"
);
}
#[test]
fn test_error_category_serde() {
let categories = vec![
ExecutionErrorCategory::LlmError,
ExecutionErrorCategory::ToolError,
ExecutionErrorCategory::PolicyViolation,
ExecutionErrorCategory::Timeout,
ExecutionErrorCategory::QuotaExceeded,
ExecutionErrorCategory::KernelInternal,
ExecutionErrorCategory::ValidationError,
ExecutionErrorCategory::NetworkError,
];
for cat in categories {
let json = serde_json::to_string(&cat).unwrap();
let parsed: ExecutionErrorCategory = serde_json::from_str(&json).unwrap();
assert_eq!(cat, parsed);
}
}
#[test]
fn test_backoff_none() {
let strategy = BackoffStrategy::None;
let base = Duration::from_millis(1000);
let max = Duration::from_millis(30000);
assert_eq!(strategy.calculate_delay(base, 1, max), Duration::ZERO);
assert_eq!(strategy.calculate_delay(base, 5, max), Duration::ZERO);
}
#[test]
fn test_backoff_constant() {
let strategy = BackoffStrategy::Constant;
let base = Duration::from_millis(500);
let max = Duration::from_millis(10000);
assert_eq!(
strategy.calculate_delay(base, 1, max),
Duration::from_millis(500)
);
assert_eq!(
strategy.calculate_delay(base, 5, max),
Duration::from_millis(500)
);
}
#[test]
fn test_backoff_linear() {
let strategy = BackoffStrategy::Linear;
let base = Duration::from_millis(1000);
let max = Duration::from_millis(30000);
assert_eq!(
strategy.calculate_delay(base, 1, max),
Duration::from_millis(1000)
);
assert_eq!(
strategy.calculate_delay(base, 2, max),
Duration::from_millis(2000)
);
assert_eq!(
strategy.calculate_delay(base, 3, max),
Duration::from_millis(3000)
);
assert_eq!(
strategy.calculate_delay(base, 100, max),
Duration::from_millis(30000)
);
}
#[test]
fn test_backoff_default() {
assert_eq!(BackoffStrategy::default(), BackoffStrategy::None);
}
#[test]
fn test_backoff_serde() {
let strategies = vec![
BackoffStrategy::None,
BackoffStrategy::Constant,
BackoffStrategy::Linear,
BackoffStrategy::Exponential,
];
for strat in strategies {
let json = serde_json::to_string(&strat).unwrap();
let parsed: BackoffStrategy = serde_json::from_str(&json).unwrap();
assert_eq!(strat, parsed);
}
}
#[test]
fn test_retry_policy_fatal() {
let policy = RetryPolicy::fatal();
assert!(!policy.retryable);
assert_eq!(policy.max_retries, 0);
assert!(!policy.should_retry(1));
}
#[test]
fn test_retry_policy_retryable() {
let policy = RetryPolicy::retryable(5);
assert!(policy.retryable);
assert_eq!(policy.max_retries, 5);
assert_eq!(policy.backoff_strategy, BackoffStrategy::Exponential);
}
#[test]
fn test_retry_policy_delay_for_attempt() {
let policy = RetryPolicy {
retryable: true,
max_retries: 3,
backoff_strategy: BackoffStrategy::Constant,
base_delay: Duration::from_millis(500),
max_delay: Duration::from_millis(5000),
requires_idempotency_key: false,
};
assert_eq!(policy.delay_for_attempt(1), Duration::from_millis(500));
assert_eq!(policy.delay_for_attempt(2), Duration::from_millis(500));
}
#[test]
fn test_retry_policy_should_retry() {
let policy = RetryPolicy::retryable(3);
assert!(policy.should_retry(1));
assert!(policy.should_retry(2));
assert!(policy.should_retry(3));
assert!(!policy.should_retry(4));
}
#[test]
fn test_retry_policy_default() {
let policy = RetryPolicy::default();
assert!(!policy.retryable);
assert_eq!(policy.max_retries, 0);
}
#[test]
fn test_retry_policy_serde() {
let policy = RetryPolicy::retryable(3);
let json = serde_json::to_string(&policy).unwrap();
let parsed: RetryPolicy = serde_json::from_str(&json).unwrap();
assert_eq!(policy.retryable, parsed.retryable);
assert_eq!(policy.max_retries, parsed.max_retries);
}
#[test]
fn test_llm_error_code_retryable() {
assert!(LlmErrorCode::RateLimit.is_retryable());
assert!(LlmErrorCode::ModelUnavailable.is_retryable());
assert!(LlmErrorCode::ProviderError.is_retryable());
assert!(!LlmErrorCode::ContextOverflow.is_retryable());
assert!(!LlmErrorCode::ContentFiltered.is_retryable());
assert!(!LlmErrorCode::InvalidRequest.is_retryable());
assert!(!LlmErrorCode::AuthFailed.is_retryable());
}
#[test]
fn test_llm_error_code_display() {
assert_eq!(format!("{}", LlmErrorCode::RateLimit), "rate_limit");
assert_eq!(
format!("{}", LlmErrorCode::ContextOverflow),
"context_overflow"
);
assert_eq!(
format!("{}", LlmErrorCode::ContentFiltered),
"content_filtered"
);
assert_eq!(
format!("{}", LlmErrorCode::InvalidRequest),
"invalid_request"
);
assert_eq!(format!("{}", LlmErrorCode::AuthFailed), "auth_failed");
assert_eq!(
format!("{}", LlmErrorCode::ModelUnavailable),
"model_unavailable"
);
assert_eq!(format!("{}", LlmErrorCode::ProviderError), "provider_error");
}
#[test]
fn test_llm_error_code_serde() {
let codes = vec![
LlmErrorCode::RateLimit,
LlmErrorCode::ContextOverflow,
LlmErrorCode::ContentFiltered,
LlmErrorCode::InvalidRequest,
LlmErrorCode::AuthFailed,
LlmErrorCode::ModelUnavailable,
LlmErrorCode::ProviderError,
];
for code in codes {
let json = serde_json::to_string(&code).unwrap();
let parsed: LlmErrorCode = serde_json::from_str(&json).unwrap();
assert_eq!(code, parsed);
}
}
#[test]
fn test_tool_error_code_retryable() {
assert!(ToolErrorCode::Timeout.is_retryable());
assert!(ToolErrorCode::ExecutionFailed.is_retryable());
assert!(!ToolErrorCode::NotFound.is_retryable());
assert!(!ToolErrorCode::PermissionDenied.is_retryable());
assert!(!ToolErrorCode::InvalidInput.is_retryable());
assert!(!ToolErrorCode::OutputInvalid.is_retryable());
}
#[test]
fn test_tool_error_code_display() {
assert_eq!(format!("{}", ToolErrorCode::NotFound), "not_found");
assert_eq!(
format!("{}", ToolErrorCode::PermissionDenied),
"permission_denied"
);
assert_eq!(format!("{}", ToolErrorCode::InvalidInput), "invalid_input");
assert_eq!(
format!("{}", ToolErrorCode::ExecutionFailed),
"execution_failed"
);
assert_eq!(format!("{}", ToolErrorCode::Timeout), "timeout");
assert_eq!(
format!("{}", ToolErrorCode::OutputInvalid),
"output_invalid"
);
}
#[test]
fn test_tool_error_code_serde() {
let codes = vec![
ToolErrorCode::NotFound,
ToolErrorCode::PermissionDenied,
ToolErrorCode::InvalidInput,
ToolErrorCode::ExecutionFailed,
ToolErrorCode::Timeout,
ToolErrorCode::OutputInvalid,
];
for code in codes {
let json = serde_json::to_string(&code).unwrap();
let parsed: ToolErrorCode = serde_json::from_str(&json).unwrap();
assert_eq!(code, parsed);
}
}
#[test]
fn test_execution_error_new() {
let error = ExecutionError::new(ExecutionErrorCategory::LlmError, "Test message");
assert_eq!(error.category, ExecutionErrorCategory::LlmError);
assert_eq!(error.message, "Test message");
assert_eq!(error.attempt, 1);
assert!(error.retry_policy.retryable);
}
#[test]
fn test_execution_error_with_code() {
let error =
ExecutionError::new(ExecutionErrorCategory::ToolError, "Test").with_code("custom_code");
assert_eq!(error.code, Some("custom_code".to_string()));
}
#[test]
fn test_execution_error_with_attempt() {
let error = ExecutionError::new(ExecutionErrorCategory::LlmError, "Test").with_attempt(3);
assert_eq!(error.attempt, 3);
}
#[test]
fn test_execution_error_with_step_id() {
let step_id = StepId::from_string("step_test");
let error = ExecutionError::new(ExecutionErrorCategory::ToolError, "Test")
.with_step_id(step_id.clone());
assert_eq!(error.step_id.unwrap().as_str(), "step_test");
}
#[test]
fn test_execution_error_with_provider() {
let error = ExecutionError::llm(LlmErrorCode::RateLimit, "Test").with_provider("openai");
assert_eq!(error.provider, Some("openai".to_string()));
}
#[test]
fn test_execution_error_with_details() {
let details = serde_json::json!({"key": "value"});
let error = ExecutionError::new(ExecutionErrorCategory::ToolError, "Test")
.with_details(details.clone());
assert_eq!(error.details, Some(details));
}
#[test]
fn test_execution_error_with_retry_policy() {
let policy = RetryPolicy::retryable(10);
let error = ExecutionError::new(ExecutionErrorCategory::ToolError, "Test")
.with_retry_policy(policy.clone());
assert_eq!(error.retry_policy.max_retries, 10);
}
#[test]
fn test_execution_error_convenience_constructors() {
let llm = ExecutionError::llm(LlmErrorCode::RateLimit, "LLM error");
assert_eq!(llm.category, ExecutionErrorCategory::LlmError);
let tool = ExecutionError::tool(ToolErrorCode::NotFound, "Tool error");
assert_eq!(tool.category, ExecutionErrorCategory::ToolError);
let policy = ExecutionError::policy_violation("Policy error");
assert_eq!(policy.category, ExecutionErrorCategory::PolicyViolation);
let timeout = ExecutionError::timeout("Timeout error");
assert_eq!(timeout.category, ExecutionErrorCategory::Timeout);
let quota = ExecutionError::quota_exceeded("Quota error");
assert_eq!(quota.category, ExecutionErrorCategory::QuotaExceeded);
let kernel = ExecutionError::kernel_internal("Kernel error");
assert_eq!(kernel.category, ExecutionErrorCategory::KernelInternal);
let validation = ExecutionError::validation("Validation error");
assert_eq!(validation.category, ExecutionErrorCategory::ValidationError);
let network = ExecutionError::network("Network error");
assert_eq!(network.category, ExecutionErrorCategory::NetworkError);
}
#[test]
fn test_execution_error_is_retryable() {
assert!(ExecutionError::llm(LlmErrorCode::RateLimit, "").is_retryable());
assert!(ExecutionError::tool(ToolErrorCode::Timeout, "").is_retryable());
assert!(ExecutionError::timeout("").is_retryable());
assert!(ExecutionError::network("").is_retryable());
assert!(!ExecutionError::policy_violation("").is_retryable());
assert!(!ExecutionError::quota_exceeded("").is_retryable());
assert!(!ExecutionError::kernel_internal("").is_retryable());
assert!(!ExecutionError::validation("").is_retryable());
}
#[test]
fn test_execution_error_is_fatal() {
assert!(!ExecutionError::llm(LlmErrorCode::RateLimit, "").is_fatal());
assert!(!ExecutionError::timeout("").is_fatal());
assert!(ExecutionError::policy_violation("").is_fatal());
assert!(ExecutionError::quota_exceeded("").is_fatal());
assert!(ExecutionError::kernel_internal("").is_fatal());
assert!(ExecutionError::validation("").is_fatal());
}
#[test]
fn test_execution_error_retry_delay() {
let error = ExecutionError::new(ExecutionErrorCategory::LlmError, "Test");
let delay = error.retry_delay();
assert!(delay > Duration::ZERO);
}
#[test]
fn test_execution_error_next_attempt() {
let error = ExecutionError::llm(LlmErrorCode::RateLimit, "Test");
assert_eq!(error.attempt, 1);
let error2 = error.next_attempt();
assert_eq!(error2.attempt, 2);
let error3 = error2.next_attempt();
assert_eq!(error3.attempt, 3);
}
#[test]
fn test_execution_error_to_http_status_all_categories() {
assert_eq!(
ExecutionError::llm(LlmErrorCode::RateLimit, "").to_http_status(),
502
);
assert_eq!(
ExecutionError::tool(ToolErrorCode::NotFound, "").to_http_status(),
500
);
assert_eq!(ExecutionError::policy_violation("").to_http_status(), 403);
assert_eq!(ExecutionError::timeout("").to_http_status(), 504);
assert_eq!(ExecutionError::quota_exceeded("").to_http_status(), 429);
assert_eq!(ExecutionError::kernel_internal("").to_http_status(), 500);
assert_eq!(ExecutionError::validation("").to_http_status(), 400);
assert_eq!(ExecutionError::network("").to_http_status(), 503);
}
#[test]
fn test_execution_error_to_http_status_override() {
let error = ExecutionError::network("Test").with_http_status(418);
assert_eq!(error.to_http_status(), 418);
}
#[test]
fn test_execution_error_display() {
let error = ExecutionError::llm(LlmErrorCode::RateLimit, "Too many requests");
let display = format!("{}", error);
assert!(display.contains("LlmError"));
assert!(display.contains("Too many requests"));
assert!(display.contains("rate_limit"));
}
#[test]
fn test_execution_error_display_with_attempt() {
let error = ExecutionError::llm(LlmErrorCode::RateLimit, "Test").with_attempt(3);
let display = format!("{}", error);
assert!(display.contains("[attempt 3]"));
}
#[test]
fn test_execution_error_display_no_attempt_shown_for_first() {
let error = ExecutionError::llm(LlmErrorCode::RateLimit, "Test");
let display = format!("{}", error);
assert!(!display.contains("attempt"));
}
#[test]
fn test_from_serde_json_error() {
let json_err = serde_json::from_str::<String>("invalid json").unwrap_err();
let error: ExecutionError = json_err.into();
assert_eq!(error.category, ExecutionErrorCategory::ValidationError);
assert!(error.message.contains("JSON error"));
}
#[test]
fn test_from_io_error_timeout() {
let io_err = std::io::Error::new(std::io::ErrorKind::TimedOut, "timed out");
let error: ExecutionError = io_err.into();
assert_eq!(error.category, ExecutionErrorCategory::Timeout);
}
#[test]
fn test_from_io_error_connection_refused() {
let io_err = std::io::Error::new(std::io::ErrorKind::ConnectionRefused, "refused");
let error: ExecutionError = io_err.into();
assert_eq!(error.category, ExecutionErrorCategory::NetworkError);
}
#[test]
fn test_from_io_error_connection_reset() {
let io_err = std::io::Error::new(std::io::ErrorKind::ConnectionReset, "reset");
let error: ExecutionError = io_err.into();
assert_eq!(error.category, ExecutionErrorCategory::NetworkError);
}
#[test]
fn test_from_io_error_connection_aborted() {
let io_err = std::io::Error::new(std::io::ErrorKind::ConnectionAborted, "aborted");
let error: ExecutionError = io_err.into();
assert_eq!(error.category, ExecutionErrorCategory::NetworkError);
}
#[test]
fn test_from_io_error_other() {
let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "not found");
let error: ExecutionError = io_err.into();
assert_eq!(error.category, ExecutionErrorCategory::KernelInternal);
}
#[test]
fn test_default_retry_policy_tool_error() {
let policy = ExecutionErrorCategory::ToolError.default_retry_policy();
assert!(policy.retryable);
assert_eq!(policy.max_retries, 2);
assert_eq!(policy.backoff_strategy, BackoffStrategy::Constant);
assert!(policy.requires_idempotency_key);
}
#[test]
fn test_default_retry_policy_timeout() {
let policy = ExecutionErrorCategory::Timeout.default_retry_policy();
assert!(policy.retryable);
assert_eq!(policy.max_retries, 1);
assert!(policy.requires_idempotency_key);
}
#[test]
fn test_default_retry_policy_network() {
let policy = ExecutionErrorCategory::NetworkError.default_retry_policy();
assert!(policy.retryable);
assert_eq!(policy.max_retries, 3);
assert_eq!(policy.backoff_strategy, BackoffStrategy::Exponential);
assert!(policy.requires_idempotency_key);
}
}