use serde::{Deserialize, Serialize};
use crate::error::ProviderError;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum FailoverReason {
Auth,
AuthPermanent,
Billing,
RateLimit,
Overloaded,
ServerError,
Timeout,
ContextOverflow,
PayloadTooLarge,
ModelNotFound,
FormatError,
ThinkingSignature,
LongContextTier,
Unknown,
}
impl FailoverReason {
pub fn is_retryable(&self) -> bool {
!matches!(
self,
Self::AuthPermanent | Self::Billing | Self::ModelNotFound | Self::FormatError
)
}
pub fn should_compress(&self) -> bool {
matches!(
self,
Self::ContextOverflow | Self::PayloadTooLarge | Self::LongContextTier
)
}
pub fn should_fallback(&self) -> bool {
matches!(
self,
Self::Billing
| Self::RateLimit
| Self::Overloaded
| Self::ModelNotFound
| Self::AuthPermanent
)
}
pub fn should_rotate_credential(&self) -> bool {
matches!(self, Self::Auth)
}
pub fn recommended_backoff_ms(&self) -> Option<u64> {
match self {
Self::RateLimit => Some(5000), Self::Overloaded => Some(3000), Self::Timeout => Some(2000), Self::ServerError => Some(1000), _ => None,
}
}
}
#[derive(Debug, Clone)]
pub struct ClassifiedError {
pub reason: FailoverReason,
pub status_code: Option<u16>,
pub message: String,
pub retryable: bool,
pub should_compress: bool,
pub should_fallback: bool,
pub should_rotate_credential: bool,
}
impl ClassifiedError {
pub fn summary(&self) -> String {
let action = if self.should_compress {
"需要压缩上下文"
} else if self.should_fallback {
"建议切换 Provider"
} else if self.should_rotate_credential {
"需要轮换凭证"
} else if self.retryable {
"可重试"
} else {
"不可恢复"
};
format!(
"错误原因:{:?} | 状态码:{} | 策略:{}",
self.reason,
self.status_code
.map_or_else(|| "N/A".to_string(), |s| s.to_string()),
action
)
}
}
#[derive(Debug, Clone, Default)]
pub struct ErrorContext {
pub status_code: Option<u16>,
pub provider: Option<String>,
pub model: Option<String>,
pub request_size: Option<usize>,
pub response_size: Option<usize>,
pub context_tokens: Option<usize>,
pub context_window: Option<usize>,
}
pub trait ErrorClassifier: Send + Sync {
fn classify(&self, error: &ProviderError, context: &ErrorContext) -> ClassifiedError;
}
pub trait ProviderErrorExt {
fn classify(&self, context: &ErrorContext) -> ClassifiedError;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_failover_reason_is_retryable() {
assert!(FailoverReason::RateLimit.is_retryable());
assert!(FailoverReason::Timeout.is_retryable());
assert!(!FailoverReason::Billing.is_retryable());
assert!(!FailoverReason::AuthPermanent.is_retryable());
}
#[test]
fn test_failover_reason_should_compress() {
assert!(FailoverReason::ContextOverflow.should_compress());
assert!(FailoverReason::PayloadTooLarge.should_compress());
assert!(!FailoverReason::RateLimit.should_compress());
}
#[test]
fn test_failover_reason_should_fallback() {
assert!(FailoverReason::Billing.should_fallback());
assert!(FailoverReason::RateLimit.should_fallback());
assert!(!FailoverReason::ContextOverflow.should_fallback());
}
#[test]
fn test_failover_reason_backoff_ms() {
assert_eq!(
FailoverReason::RateLimit.recommended_backoff_ms(),
Some(5000)
);
assert_eq!(FailoverReason::Timeout.recommended_backoff_ms(), Some(2000));
assert_eq!(FailoverReason::Billing.recommended_backoff_ms(), None);
}
}