1use crate::error::Error;
11use std::time::Duration;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum RetryStrategyType {
16 Fixed,
18 Exponential,
20 Linear,
22}
23
24#[derive(Debug, Clone)]
26pub struct RetryConfig {
27 pub max_retries: u32,
29 pub strategy_type: RetryStrategyType,
31 pub base_delay_ms: u64,
33 pub max_delay_ms: u64,
35 pub retry_on_network_error: bool,
37 pub retry_on_rate_limit: bool,
39 pub retry_on_server_error: bool,
41 pub retry_on_timeout: bool,
43 pub jitter_factor: f64,
45}
46
47impl Default for RetryConfig {
48 fn default() -> Self {
49 Self {
50 max_retries: 3,
51 strategy_type: RetryStrategyType::Exponential,
52 base_delay_ms: 100,
53 max_delay_ms: 30000,
54 retry_on_network_error: true,
55 retry_on_rate_limit: true,
56 retry_on_server_error: true,
57 retry_on_timeout: true,
58 jitter_factor: 0.1,
59 }
60 }
61}
62
63impl RetryConfig {
64 pub fn conservative() -> Self {
66 Self {
67 max_retries: 2,
68 strategy_type: RetryStrategyType::Fixed,
69 base_delay_ms: 500,
70 max_delay_ms: 5000,
71 retry_on_network_error: true,
72 retry_on_rate_limit: true,
73 retry_on_server_error: false,
74 retry_on_timeout: false,
75 jitter_factor: 0.0,
76 }
77 }
78
79 pub fn aggressive() -> Self {
81 Self {
82 max_retries: 5,
83 strategy_type: RetryStrategyType::Exponential,
84 base_delay_ms: 200,
85 max_delay_ms: 60000,
86 retry_on_network_error: true,
87 retry_on_rate_limit: true,
88 retry_on_server_error: true,
89 retry_on_timeout: true,
90 jitter_factor: 0.2,
91 }
92 }
93
94 pub fn rate_limit_only() -> Self {
96 Self {
97 max_retries: 3,
98 strategy_type: RetryStrategyType::Linear,
99 base_delay_ms: 2000,
100 max_delay_ms: 10000,
101 retry_on_network_error: false,
102 retry_on_rate_limit: true,
103 retry_on_server_error: false,
104 retry_on_timeout: false,
105 jitter_factor: 0.0,
106 }
107 }
108}
109
110#[derive(Debug)]
112pub struct RetryStrategy {
113 config: RetryConfig,
114}
115
116impl RetryStrategy {
117 pub fn new(config: RetryConfig) -> Self {
119 Self { config }
120 }
121
122 pub fn default_strategy() -> Self {
124 Self::new(RetryConfig::default())
125 }
126
127 pub fn should_retry(&self, error: &Error, attempt: u32) -> bool {
138 if attempt > self.config.max_retries {
139 return false;
140 }
141 match error {
142 Error::Network(_) => self.config.retry_on_network_error,
143 Error::RateLimit { .. } => self.config.retry_on_rate_limit,
144 Error::Exchange(details) => {
145 if self.config.retry_on_server_error && self.is_server_error(&details.message) {
146 return true;
147 }
148 if self.config.retry_on_timeout && self.is_timeout_error(&details.message) {
149 return true;
150 }
151 false
152 }
153 _ => false,
154 }
155 }
156
157 pub fn calculate_delay(&self, attempt: u32, error: &Error) -> Duration {
168 let base_delay = match self.config.strategy_type {
169 RetryStrategyType::Fixed => self.config.base_delay_ms,
170 RetryStrategyType::Exponential => {
171 self.config.base_delay_ms * 2_u64.pow(attempt.saturating_sub(1))
172 }
173 RetryStrategyType::Linear => self.config.base_delay_ms * attempt as u64,
174 };
175
176 let mut delay = base_delay.min(self.config.max_delay_ms);
177
178 if matches!(error, Error::RateLimit { .. }) {
179 delay = delay.max(2000);
180 }
181 if self.config.jitter_factor > 0.0 {
182 delay = self.apply_jitter(delay);
183 }
184
185 Duration::from_millis(delay)
186 }
187
188 fn apply_jitter(&self, delay_ms: u64) -> u64 {
190 use rand::Rng;
191 let mut rng = rand::rngs::ThreadRng::default();
192 let jitter_range = (delay_ms as f64 * self.config.jitter_factor) as u64;
193 let jitter = rng.random_range(0..=jitter_range);
194 delay_ms + jitter
195 }
196
197 fn is_server_error(&self, msg: &str) -> bool {
199 let msg_lower = msg.to_lowercase();
200 msg_lower.contains("500")
201 || msg_lower.contains("502")
202 || msg_lower.contains("503")
203 || msg_lower.contains("504")
204 || msg_lower.contains("internal server error")
205 || msg_lower.contains("bad gateway")
206 || msg_lower.contains("service unavailable")
207 || msg_lower.contains("gateway timeout")
208 }
209
210 fn is_timeout_error(&self, msg: &str) -> bool {
212 let msg_lower = msg.to_lowercase();
213 msg_lower.contains("timeout")
214 || msg_lower.contains("timed out")
215 || msg_lower.contains("408")
216 }
217
218 pub fn config(&self) -> &RetryConfig {
220 &self.config
221 }
222
223 pub fn max_retries(&self) -> u32 {
225 self.config.max_retries
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232
233 #[test]
234 fn test_retry_config_default() {
235 let config = RetryConfig::default();
236 assert_eq!(config.max_retries, 3);
237 assert_eq!(config.strategy_type, RetryStrategyType::Exponential);
238 assert_eq!(config.base_delay_ms, 100);
239 assert!(config.retry_on_network_error);
240 assert!(config.retry_on_rate_limit);
241 }
242
243 #[test]
244 fn test_retry_config_conservative() {
245 let config = RetryConfig::conservative();
246 assert_eq!(config.max_retries, 2);
247 assert_eq!(config.strategy_type, RetryStrategyType::Fixed);
248 assert!(!config.retry_on_server_error);
249 }
250
251 #[test]
252 fn test_retry_config_aggressive() {
253 let config = RetryConfig::aggressive();
254 assert_eq!(config.max_retries, 5);
255 assert!(config.retry_on_server_error);
256 assert!(config.retry_on_timeout);
257 }
258
259 #[test]
260 fn test_should_retry_network_error() {
261 let strategy = RetryStrategy::default_strategy();
262 let error = Error::network("Connection failed");
263
264 assert!(strategy.should_retry(&error, 1));
265 assert!(strategy.should_retry(&error, 2));
266 assert!(strategy.should_retry(&error, 3));
267 assert!(!strategy.should_retry(&error, 4));
268 }
269
270 #[test]
271 fn test_should_retry_rate_limit() {
272 let strategy = RetryStrategy::default_strategy();
273 let error = Error::rate_limit("Rate limit exceeded", None);
274
275 assert!(strategy.should_retry(&error, 1));
276 assert!(strategy.should_retry(&error, 3));
277 }
278
279 #[test]
280 fn test_should_not_retry_invalid_request() {
281 let strategy = RetryStrategy::default_strategy();
282 let error = Error::invalid_request("Bad request");
283
284 assert!(!strategy.should_retry(&error, 1));
285 }
286
287 #[test]
288 fn test_calculate_delay_fixed() {
289 let config = RetryConfig {
290 strategy_type: RetryStrategyType::Fixed,
291 base_delay_ms: 1000,
292 jitter_factor: 0.0,
293 ..Default::default()
294 };
295 let strategy = RetryStrategy::new(config);
296 let error = Error::network("test");
297
298 assert_eq!(strategy.calculate_delay(1, &error).as_millis(), 1000);
299 assert_eq!(strategy.calculate_delay(2, &error).as_millis(), 1000);
300 assert_eq!(strategy.calculate_delay(3, &error).as_millis(), 1000);
301 }
302
303 #[test]
304 fn test_calculate_delay_exponential() {
305 let config = RetryConfig {
306 strategy_type: RetryStrategyType::Exponential,
307 base_delay_ms: 100,
308 max_delay_ms: 10000,
309 jitter_factor: 0.0,
310 ..Default::default()
311 };
312 let strategy = RetryStrategy::new(config);
313 let error = Error::network("test");
314
315 assert_eq!(strategy.calculate_delay(1, &error).as_millis(), 100);
316 assert_eq!(strategy.calculate_delay(2, &error).as_millis(), 200);
317 assert_eq!(strategy.calculate_delay(3, &error).as_millis(), 400);
318 assert_eq!(strategy.calculate_delay(4, &error).as_millis(), 800);
319 }
320
321 #[test]
322 fn test_calculate_delay_linear() {
323 let config = RetryConfig {
324 strategy_type: RetryStrategyType::Linear,
325 base_delay_ms: 500,
326 max_delay_ms: 10000,
327 jitter_factor: 0.0,
328 ..Default::default()
329 };
330 let strategy = RetryStrategy::new(config);
331 let error = Error::network("test");
332
333 assert_eq!(strategy.calculate_delay(1, &error).as_millis(), 500);
334 assert_eq!(strategy.calculate_delay(2, &error).as_millis(), 1000);
335 assert_eq!(strategy.calculate_delay(3, &error).as_millis(), 1500);
336 }
337
338 #[test]
339 fn test_calculate_delay_with_max_limit() {
340 let config = RetryConfig {
341 strategy_type: RetryStrategyType::Exponential,
342 base_delay_ms: 1000,
343 max_delay_ms: 5000,
344 jitter_factor: 0.0,
345 ..Default::default()
346 };
347 let strategy = RetryStrategy::new(config);
348 let error = Error::network("test");
349
350 assert_eq!(strategy.calculate_delay(1, &error).as_millis(), 1000);
351 assert_eq!(strategy.calculate_delay(2, &error).as_millis(), 2000);
352 assert_eq!(strategy.calculate_delay(3, &error).as_millis(), 4000);
353 assert_eq!(strategy.calculate_delay(4, &error).as_millis(), 5000);
354 assert_eq!(strategy.calculate_delay(5, &error).as_millis(), 5000);
355 }
356
357 #[test]
358 fn test_is_server_error() {
359 let strategy = RetryStrategy::default_strategy();
360
361 assert!(strategy.is_server_error("500 Internal Server Error"));
362 assert!(strategy.is_server_error("502 Bad Gateway"));
363 assert!(strategy.is_server_error("503 Service Unavailable"));
364 assert!(strategy.is_server_error("504 Gateway Timeout"));
365 assert!(!strategy.is_server_error("400 Bad Request"));
366 assert!(!strategy.is_server_error("404 Not Found"));
367 }
368
369 #[test]
370 fn test_is_timeout_error() {
371 let strategy = RetryStrategy::default_strategy();
372
373 assert!(strategy.is_timeout_error("Request timeout"));
374 assert!(strategy.is_timeout_error("Connection timed out"));
375 assert!(strategy.is_timeout_error("408 Request Timeout"));
376 assert!(!strategy.is_timeout_error("Connection refused"));
377 }
378
379 #[test]
380 fn test_rate_limit_error_minimum_delay() {
381 let config = RetryConfig {
382 strategy_type: RetryStrategyType::Fixed,
383 base_delay_ms: 100, jitter_factor: 0.0,
385 ..Default::default()
386 };
387 let strategy = RetryStrategy::new(config);
388 let error = Error::rate_limit("Rate limit exceeded", None);
389
390 assert!(strategy.calculate_delay(1, &error).as_millis() >= 2000);
391 }
392}