1use std::time::Duration;
9
10#[derive(Debug, Clone)]
12pub struct RetryConfig {
13 pub max_retries: u32,
15 pub initial_backoff: Duration,
17 pub max_backoff: Duration,
19 pub multiplier: f64,
21 pub max_overload_retries: u32,
23}
24
25impl Default for RetryConfig {
26 fn default() -> Self {
27 Self {
28 max_retries: 3,
29 initial_backoff: Duration::from_millis(1000),
30 max_backoff: Duration::from_secs(60),
31 multiplier: 2.0,
32 max_overload_retries: 3,
33 }
34 }
35}
36
37#[derive(Debug, Default)]
39pub struct RetryState {
40 pub consecutive_failures: u32,
42 pub rate_limit_retries: u32,
44 pub overload_retries: u32,
46 pub using_fallback: bool,
48}
49
50impl RetryState {
51 pub fn next_action(&mut self, error: &RetryableError, config: &RetryConfig) -> RetryAction {
53 self.consecutive_failures += 1;
54
55 match error {
56 RetryableError::RateLimited { retry_after } => {
57 self.rate_limit_retries += 1;
58 if self.rate_limit_retries > config.max_retries {
59 return RetryAction::Abort("Rate limit retries exhausted".into());
60 }
61 RetryAction::Retry {
62 after: Duration::from_millis(*retry_after),
63 }
64 }
65 RetryableError::Overloaded => {
66 self.overload_retries += 1;
67 if self.overload_retries > config.max_overload_retries {
68 if !self.using_fallback {
69 self.using_fallback = true;
70 self.overload_retries = 0;
71 return RetryAction::FallbackModel;
72 }
73 return RetryAction::Abort("Overload retries exhausted on fallback".into());
74 }
75 let backoff = calculate_backoff(
76 self.overload_retries,
77 config.initial_backoff,
78 config.max_backoff,
79 config.multiplier,
80 );
81 RetryAction::Retry { after: backoff }
82 }
83 RetryableError::StreamInterrupted => {
84 if self.consecutive_failures > config.max_retries {
85 return RetryAction::Abort("Stream retry limit reached".into());
86 }
87 let backoff = calculate_backoff(
88 self.consecutive_failures,
89 config.initial_backoff,
90 config.max_backoff,
91 config.multiplier,
92 );
93 RetryAction::Retry { after: backoff }
94 }
95 RetryableError::NonRetryable(msg) => RetryAction::Abort(msg.clone()),
96 }
97 }
98
99 pub fn reset(&mut self) {
101 self.consecutive_failures = 0;
102 self.rate_limit_retries = 0;
103 }
105}
106
107pub enum RetryableError {
109 RateLimited { retry_after: u64 },
110 Overloaded,
111 StreamInterrupted,
112 NonRetryable(String),
113}
114
115#[derive(Debug)]
117pub enum RetryAction {
118 Retry { after: Duration },
120 FallbackModel,
122 Abort(String),
124}
125
126fn calculate_backoff(attempt: u32, initial: Duration, max: Duration, multiplier: f64) -> Duration {
128 let base = initial.as_millis() as f64 * multiplier.powi(attempt as i32 - 1);
129 let capped = base.min(max.as_millis() as f64);
130 let jitter = capped * 0.1 * rand_f64();
132 Duration::from_millis((capped + jitter) as u64)
133}
134
135fn rand_f64() -> f64 {
137 let nanos = std::time::SystemTime::now()
138 .duration_since(std::time::UNIX_EPOCH)
139 .unwrap_or_default()
140 .subsec_nanos();
141 (nanos % 1000) as f64 / 1000.0
142}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147
148 #[test]
149 fn test_default_config() {
150 let c = RetryConfig::default();
151 assert_eq!(c.max_retries, 3);
152 assert!(c.multiplier > 1.0);
153 }
154
155 #[test]
156 fn test_retry_on_rate_limit() {
157 let mut state = RetryState::default();
158 let config = RetryConfig::default();
159 let err = RetryableError::RateLimited { retry_after: 500 };
160 match state.next_action(&err, &config) {
161 RetryAction::Retry { after } => assert!(after.as_millis() >= 500),
162 other => panic!("Expected Retry, got {other:?}"),
163 }
164 }
165
166 #[test]
167 fn test_retry_exhaustion() {
168 let mut state = RetryState::default();
169 let config = RetryConfig {
170 max_retries: 1,
171 ..Default::default()
172 };
173 let err = RetryableError::RateLimited { retry_after: 100 };
174 let _ = state.next_action(&err, &config); match state.next_action(&err, &config) {
176 RetryAction::Abort(_) => {}
177 other => panic!("Expected Abort, got {other:?}"),
178 }
179 }
180
181 #[test]
182 fn test_non_retryable_aborts() {
183 let mut state = RetryState::default();
184 let config = RetryConfig::default();
185 let err = RetryableError::NonRetryable("bad request".into());
186 match state.next_action(&err, &config) {
187 RetryAction::Abort(msg) => assert!(msg.contains("bad request")),
188 other => panic!("Expected Abort, got {other:?}"),
189 }
190 }
191
192 #[test]
193 fn test_overload_escalates_to_fallback() {
194 let mut state = RetryState::default();
195 let config = RetryConfig {
196 max_overload_retries: 2,
197 ..Default::default()
198 };
199 let err = RetryableError::Overloaded;
200 let _ = state.next_action(&err, &config);
201 let _ = state.next_action(&err, &config);
202 match state.next_action(&err, &config) {
203 RetryAction::FallbackModel => {}
204 other => panic!("Expected FallbackModel, got {other:?}"),
205 }
206 }
207
208 #[test]
209 fn test_reset_preserves_fallback() {
210 let mut state = RetryState {
211 using_fallback: true,
212 consecutive_failures: 5,
213 ..Default::default()
214 };
215 state.reset();
216 assert_eq!(state.consecutive_failures, 0);
217 assert!(state.using_fallback); }
219
220 #[test]
221 fn test_backoff_increases_with_attempt() {
222 let initial = Duration::from_millis(1000);
223 let max = Duration::from_secs(60);
224 let multiplier = 2.0;
225
226 let _b1 = calculate_backoff(1, initial, max, multiplier);
227 let b2 = calculate_backoff(2, initial, max, multiplier);
228 let b3 = calculate_backoff(3, initial, max, multiplier);
229
230 assert!(b2.as_millis() >= 1500, "b2 should be >= 1.5s, got {:?}", b2);
233 assert!(b3.as_millis() >= 3000, "b3 should be >= 3s, got {:?}", b3);
234 }
235
236 #[test]
237 fn test_reset_clears_rate_limit_retries() {
238 let mut state = RetryState {
239 consecutive_failures: 3,
240 rate_limit_retries: 5,
241 overload_retries: 2,
242 using_fallback: false,
243 };
244 state.reset();
245 assert_eq!(state.rate_limit_retries, 0);
246 assert_eq!(state.consecutive_failures, 0);
247 assert_eq!(state.overload_retries, 2);
249 }
250
251 #[test]
252 fn test_overloads_then_fallback_then_abort() {
253 let mut state = RetryState::default();
254 let config = RetryConfig {
255 max_overload_retries: 1,
256 ..Default::default()
257 };
258 let err = RetryableError::Overloaded;
259
260 match state.next_action(&err, &config) {
262 RetryAction::Retry { .. } => {}
263 other => panic!("Expected Retry, got {other:?}"),
264 }
265
266 match state.next_action(&err, &config) {
268 RetryAction::FallbackModel => {}
269 other => panic!("Expected FallbackModel, got {other:?}"),
270 }
271 assert!(state.using_fallback);
272
273 match state.next_action(&err, &config) {
275 RetryAction::Retry { .. } => {}
276 other => panic!("Expected Retry on fallback, got {other:?}"),
277 }
278
279 match state.next_action(&err, &config) {
281 RetryAction::Abort(msg) => assert!(msg.contains("fallback")),
282 other => panic!("Expected Abort, got {other:?}"),
283 }
284 }
285
286 #[test]
287 fn test_stream_interrupted_retries_then_aborts() {
288 let mut state = RetryState::default();
289 let config = RetryConfig {
290 max_retries: 2,
291 ..Default::default()
292 };
293 let err = RetryableError::StreamInterrupted;
294
295 match state.next_action(&err, &config) {
297 RetryAction::Retry { .. } => {}
298 other => panic!("Expected Retry, got {other:?}"),
299 }
300 match state.next_action(&err, &config) {
301 RetryAction::Retry { .. } => {}
302 other => panic!("Expected Retry, got {other:?}"),
303 }
304
305 match state.next_action(&err, &config) {
307 RetryAction::Abort(msg) => assert!(msg.contains("Stream")),
308 other => panic!("Expected Abort, got {other:?}"),
309 }
310 }
311
312 #[test]
313 fn test_retry_state_default_values() {
314 let state = RetryState::default();
315 assert_eq!(state.consecutive_failures, 0);
316 assert_eq!(state.rate_limit_retries, 0);
317 assert_eq!(state.overload_retries, 0);
318 assert!(!state.using_fallback);
319 }
320}