claude_agent/client/
fallback.rs1use std::collections::HashSet;
4
5#[derive(Debug, Clone)]
6pub struct FallbackConfig {
7 pub fallback_model: String,
8 pub triggers: HashSet<FallbackTrigger>,
9 pub max_retries: u32,
10}
11
12impl FallbackConfig {
13 pub fn new(fallback_model: impl Into<String>) -> Self {
14 Self {
15 fallback_model: fallback_model.into(),
16 triggers: Self::default_triggers(),
17 max_retries: 1,
18 }
19 }
20
21 pub fn with_max_retries(mut self, max_retries: u32) -> Self {
22 self.max_retries = max_retries;
23 self
24 }
25
26 pub fn with_trigger(mut self, trigger: FallbackTrigger) -> Self {
27 self.triggers.insert(trigger);
28 self
29 }
30
31 pub fn with_triggers(mut self, triggers: impl IntoIterator<Item = FallbackTrigger>) -> Self {
32 self.triggers.extend(triggers);
33 self
34 }
35
36 pub fn should_fallback(&self, error: &crate::Error) -> bool {
37 self.triggers.iter().any(|t| t.matches(error))
38 }
39
40 fn default_triggers() -> HashSet<FallbackTrigger> {
41 let mut triggers = HashSet::new();
42 triggers.insert(FallbackTrigger::Overloaded);
43 triggers.insert(FallbackTrigger::RateLimited);
44 triggers
45 }
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
49pub enum FallbackTrigger {
50 Overloaded,
51 RateLimited,
52 HttpStatus(u16),
53 Timeout,
54}
55
56impl FallbackTrigger {
57 pub fn matches(&self, error: &crate::Error) -> bool {
58 match self {
59 Self::Overloaded => error.is_overloaded(),
60 Self::RateLimited => matches!(error, crate::Error::RateLimit { .. }),
61 Self::HttpStatus(code) => error.status_code() == Some(*code),
62 Self::Timeout => matches!(error, crate::Error::Timeout(_)),
63 }
64 }
65}
66
67#[cfg(test)]
68mod tests {
69 use super::*;
70
71 #[test]
72 fn test_fallback_trigger_overloaded() {
73 let config = FallbackConfig::new("claude-haiku-3-5");
74
75 let overloaded_error = crate::Error::Api {
76 message: "Model is overloaded".to_string(),
77 status: Some(529),
78 error_type: None,
79 };
80 assert!(config.should_fallback(&overloaded_error));
81
82 let auth_error = crate::Error::Api {
83 message: "Invalid API key".to_string(),
84 status: Some(401),
85 error_type: None,
86 };
87 assert!(!config.should_fallback(&auth_error));
88 }
89
90 #[test]
91 fn test_fallback_trigger_rate_limit() {
92 let config = FallbackConfig::new("claude-haiku-3-5");
93
94 let rate_limit_error = crate::Error::RateLimit {
95 retry_after: Some(std::time::Duration::from_secs(60)),
96 };
97 assert!(config.should_fallback(&rate_limit_error));
98 }
99
100 #[test]
101 fn test_custom_triggers() {
102 let config = FallbackConfig::new("claude-haiku-3-5")
103 .with_trigger(FallbackTrigger::Timeout)
104 .with_trigger(FallbackTrigger::HttpStatus(500));
105
106 let timeout_error = crate::Error::Timeout(std::time::Duration::from_secs(30));
107 assert!(config.should_fallback(&timeout_error));
108
109 let server_error = crate::Error::Api {
110 message: "Internal server error".to_string(),
111 status: Some(500),
112 error_type: None,
113 };
114 assert!(config.should_fallback(&server_error));
115 }
116}