1use crate::config::RetryConfig;
8use std::time::Duration;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum ErrorKind {
13 RateLimited,
15 NetworkError,
17 ServerError,
19 AuthError,
21 InvalidRequest,
23 ToolNotFound,
25 ContextOverflow,
27 Unknown,
29}
30
31impl ErrorKind {
32 pub fn is_retryable(&self) -> bool {
34 matches!(
35 self,
36 ErrorKind::RateLimited
37 | ErrorKind::NetworkError
38 | ErrorKind::ServerError
39 | ErrorKind::Unknown
40 )
41 }
42}
43
44pub fn classify_error(error_msg: &str) -> ErrorKind {
49 let lower = error_msg.to_lowercase();
50
51 if lower.contains("rate limit") || lower.contains("429") || lower.contains("too many requests")
52 {
53 ErrorKind::RateLimited
54 } else if lower.contains("connection")
55 || lower.contains("timeout")
56 || lower.contains("network")
57 || lower.contains("dns")
58 {
59 ErrorKind::NetworkError
60 } else if lower.contains("500")
61 || lower.contains("502")
62 || lower.contains("503")
63 || lower.contains("internal server error")
64 || lower.contains("service unavailable")
65 {
66 ErrorKind::ServerError
67 } else if lower.contains("auth")
68 || lower.contains("unauthorized")
69 || lower.contains("401")
70 || lower.contains("403")
71 || lower.contains("forbidden")
72 || lower.contains("invalid api key")
73 {
74 ErrorKind::AuthError
75 } else if lower.contains("context")
76 && (lower.contains("length") || lower.contains("window") || lower.contains("exceeded"))
77 {
78 ErrorKind::ContextOverflow
79 } else if lower.contains("not found") && lower.contains("tool") {
80 ErrorKind::ToolNotFound
81 } else if lower.contains("invalid")
82 || lower.contains("malformed")
83 || lower.contains("400")
84 || lower.contains("bad request")
85 {
86 ErrorKind::InvalidRequest
87 } else {
88 ErrorKind::Unknown
89 }
90}
91
92pub struct RetryHandler {
94 config: RetryConfig,
95 attempt: u32,
96}
97
98impl RetryHandler {
99 pub fn new(config: RetryConfig) -> Self {
101 Self { config, attempt: 0 }
102 }
103
104 pub fn should_retry(&mut self, error_msg: &str) -> Option<Duration> {
107 let kind = classify_error(error_msg);
108
109 if !kind.is_retryable() {
110 tracing::warn!(
111 error_kind = ?kind,
112 "Non-retryable error, failing immediately"
113 );
114 return None;
115 }
116
117 if self.attempt >= self.config.max_retries {
118 tracing::warn!(
119 attempt = self.attempt,
120 max = self.config.max_retries,
121 "Max retries exceeded"
122 );
123 return None;
124 }
125
126 let delay = self.next_delay();
127 self.attempt += 1;
128
129 tracing::info!(
130 attempt = self.attempt,
131 delay_ms = delay.as_millis() as u64,
132 error_kind = ?kind,
133 "Retrying after transient error"
134 );
135
136 Some(delay)
137 }
138
139 fn next_delay(&self) -> Duration {
141 let base = self.config.initial_delay.as_millis() as f64;
142 let multiplied = base * self.config.backoff_multiplier.powi(self.attempt as i32);
143 let capped = multiplied.min(self.config.max_delay.as_millis() as f64);
144 Duration::from_millis(capped as u64)
145 }
146
147 pub fn reset(&mut self) {
149 self.attempt = 0;
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156
157 #[test]
158 fn test_classify_rate_limit() {
159 assert_eq!(
160 classify_error("Rate limit exceeded"),
161 ErrorKind::RateLimited
162 );
163 assert_eq!(
164 classify_error("HTTP 429 Too Many Requests"),
165 ErrorKind::RateLimited
166 );
167 }
168
169 #[test]
170 fn test_classify_network() {
171 assert_eq!(
172 classify_error("Connection refused"),
173 ErrorKind::NetworkError
174 );
175 assert_eq!(
176 classify_error("Request timeout after 30s"),
177 ErrorKind::NetworkError
178 );
179 }
180
181 #[test]
182 fn test_classify_auth() {
183 assert_eq!(classify_error("401 Unauthorized"), ErrorKind::AuthError);
184 assert_eq!(
185 classify_error("Invalid API key provided"),
186 ErrorKind::AuthError
187 );
188 }
189
190 #[test]
191 fn test_classify_context_overflow() {
192 assert_eq!(
193 classify_error("Context length exceeded: 128000 tokens"),
194 ErrorKind::ContextOverflow
195 );
196 }
197
198 #[test]
199 fn test_classify_unknown() {
200 assert_eq!(
201 classify_error("Something weird happened"),
202 ErrorKind::Unknown
203 );
204 }
205
206 #[test]
207 fn test_retryable() {
208 assert!(ErrorKind::RateLimited.is_retryable());
209 assert!(ErrorKind::NetworkError.is_retryable());
210 assert!(ErrorKind::ServerError.is_retryable());
211 assert!(ErrorKind::Unknown.is_retryable());
212 assert!(!ErrorKind::AuthError.is_retryable());
213 assert!(!ErrorKind::InvalidRequest.is_retryable());
214 assert!(!ErrorKind::ContextOverflow.is_retryable());
215 }
216
217 #[test]
218 fn test_retry_handler_backoff() {
219 let config = RetryConfig {
220 max_retries: 3,
221 initial_delay: Duration::from_secs(1),
222 max_delay: Duration::from_secs(30),
223 backoff_multiplier: 2.0,
224 };
225 let mut handler = RetryHandler::new(config);
226
227 let delay = handler.should_retry("rate limit exceeded").unwrap();
229 assert_eq!(delay, Duration::from_secs(1));
230
231 let delay = handler.should_retry("rate limit exceeded").unwrap();
233 assert_eq!(delay, Duration::from_secs(2));
234
235 let delay = handler.should_retry("rate limit exceeded").unwrap();
237 assert_eq!(delay, Duration::from_secs(4));
238
239 assert!(handler.should_retry("rate limit exceeded").is_none());
241 }
242
243 #[test]
244 fn test_retry_handler_non_retryable() {
245 let mut handler = RetryHandler::new(RetryConfig::default());
246 assert!(handler.should_retry("401 Unauthorized").is_none());
247 }
248
249 #[test]
250 fn test_retry_handler_reset() {
251 let mut handler = RetryHandler::new(RetryConfig {
252 max_retries: 1,
253 ..Default::default()
254 });
255
256 handler.should_retry("rate limit").unwrap();
257 assert!(handler.should_retry("rate limit").is_none());
258
259 handler.reset();
260 assert!(handler.should_retry("rate limit").is_some());
261 }
262}