claude_agent/client/
fallback.rs

1//! Automatic model fallback for handling overload and rate limit errors.
2
3use std::collections::HashSet;
4
5#[derive(Debug, Clone)]
6pub struct FallbackConfig {
7    pub fallback_model: String,
8    pub triggers: HashSet<FallbackTrigger>,
9    pub max_retries: u32,
10}
11
12impl FallbackConfig {
13    pub fn new(fallback_model: impl Into<String>) -> Self {
14        Self {
15            fallback_model: fallback_model.into(),
16            triggers: Self::default_triggers(),
17            max_retries: 1,
18        }
19    }
20
21    pub fn with_max_retries(mut self, max_retries: u32) -> Self {
22        self.max_retries = max_retries;
23        self
24    }
25
26    pub fn with_trigger(mut self, trigger: FallbackTrigger) -> Self {
27        self.triggers.insert(trigger);
28        self
29    }
30
31    pub fn with_triggers(mut self, triggers: impl IntoIterator<Item = FallbackTrigger>) -> Self {
32        self.triggers.extend(triggers);
33        self
34    }
35
36    pub fn should_fallback(&self, error: &crate::Error) -> bool {
37        self.triggers.iter().any(|t| t.matches(error))
38    }
39
40    fn default_triggers() -> HashSet<FallbackTrigger> {
41        let mut triggers = HashSet::new();
42        triggers.insert(FallbackTrigger::Overloaded);
43        triggers.insert(FallbackTrigger::RateLimited);
44        triggers
45    }
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
49pub enum FallbackTrigger {
50    Overloaded,
51    RateLimited,
52    HttpStatus(u16),
53    Timeout,
54}
55
56impl FallbackTrigger {
57    pub fn matches(&self, error: &crate::Error) -> bool {
58        match self {
59            Self::Overloaded => error.is_overloaded(),
60            Self::RateLimited => matches!(error, crate::Error::RateLimit { .. }),
61            Self::HttpStatus(code) => error.status_code() == Some(*code),
62            Self::Timeout => matches!(error, crate::Error::Timeout(_)),
63        }
64    }
65}
66
67#[cfg(test)]
68mod tests {
69    use super::*;
70
71    #[test]
72    fn test_fallback_trigger_overloaded() {
73        let config = FallbackConfig::new("claude-haiku-3-5");
74
75        let overloaded_error = crate::Error::Api {
76            message: "Model is overloaded".to_string(),
77            status: Some(529),
78            error_type: None,
79        };
80        assert!(config.should_fallback(&overloaded_error));
81
82        let auth_error = crate::Error::Api {
83            message: "Invalid API key".to_string(),
84            status: Some(401),
85            error_type: None,
86        };
87        assert!(!config.should_fallback(&auth_error));
88    }
89
90    #[test]
91    fn test_fallback_trigger_rate_limit() {
92        let config = FallbackConfig::new("claude-haiku-3-5");
93
94        let rate_limit_error = crate::Error::RateLimit {
95            retry_after: Some(std::time::Duration::from_secs(60)),
96        };
97        assert!(config.should_fallback(&rate_limit_error));
98    }
99
100    #[test]
101    fn test_custom_triggers() {
102        let config = FallbackConfig::new("claude-haiku-3-5")
103            .with_trigger(FallbackTrigger::Timeout)
104            .with_trigger(FallbackTrigger::HttpStatus(500));
105
106        let timeout_error = crate::Error::Timeout(std::time::Duration::from_secs(30));
107        assert!(config.should_fallback(&timeout_error));
108
109        let server_error = crate::Error::Api {
110            message: "Internal server error".to_string(),
111            status: Some(500),
112            error_type: None,
113        };
114        assert!(config.should_fallback(&server_error));
115    }
116}