1use crate::error::Error;
11use std::time::Duration;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum RetryStrategyType {
16 Fixed,
18 Exponential,
20 Linear,
22}
23
24#[derive(Debug, Clone)]
26pub struct RetryConfig {
27 pub max_retries: u32,
29 pub strategy_type: RetryStrategyType,
31 pub base_delay_ms: u64,
33 pub max_delay_ms: u64,
35 pub retry_on_network_error: bool,
37 pub retry_on_rate_limit: bool,
39 pub retry_on_server_error: bool,
41 pub retry_on_timeout: bool,
43 pub jitter_factor: f64,
45}
46
47impl Default for RetryConfig {
48 fn default() -> Self {
49 Self {
50 max_retries: 3,
51 strategy_type: RetryStrategyType::Exponential,
52 base_delay_ms: 100,
53 max_delay_ms: 30000,
54 retry_on_network_error: true,
55 retry_on_rate_limit: true,
56 retry_on_server_error: true,
57 retry_on_timeout: true,
58 jitter_factor: 0.1,
59 }
60 }
61}
62
63impl RetryConfig {
64 pub fn conservative() -> Self {
66 Self {
67 max_retries: 2,
68 strategy_type: RetryStrategyType::Fixed,
69 base_delay_ms: 500,
70 max_delay_ms: 5000,
71 retry_on_network_error: true,
72 retry_on_rate_limit: true,
73 retry_on_server_error: false,
74 retry_on_timeout: false,
75 jitter_factor: 0.0,
76 }
77 }
78
79 pub fn aggressive() -> Self {
81 Self {
82 max_retries: 5,
83 strategy_type: RetryStrategyType::Exponential,
84 base_delay_ms: 200,
85 max_delay_ms: 60000,
86 retry_on_network_error: true,
87 retry_on_rate_limit: true,
88 retry_on_server_error: true,
89 retry_on_timeout: true,
90 jitter_factor: 0.2,
91 }
92 }
93
94 pub fn rate_limit_only() -> Self {
96 Self {
97 max_retries: 3,
98 strategy_type: RetryStrategyType::Linear,
99 base_delay_ms: 2000,
100 max_delay_ms: 10000,
101 retry_on_network_error: false,
102 retry_on_rate_limit: true,
103 retry_on_server_error: false,
104 retry_on_timeout: false,
105 jitter_factor: 0.0,
106 }
107 }
108}
109
110#[derive(Debug)]
112pub struct RetryStrategy {
113 config: RetryConfig,
114}
115
116impl RetryStrategy {
117 pub fn new(config: RetryConfig) -> Self {
119 Self { config }
120 }
121
122 pub fn default_strategy() -> Self {
124 Self::new(RetryConfig::default())
125 }
126
127 pub fn should_retry(&self, error: &Error, attempt: u32) -> bool {
138 if attempt > self.config.max_retries {
139 return false;
140 }
141 match error {
142 Error::Network(_) => self.config.retry_on_network_error,
143 Error::RateLimit { .. } => self.config.retry_on_rate_limit,
144 Error::Exchange(details) => {
145 if self.config.retry_on_server_error && Self::is_server_error(&details.message) {
146 return true;
147 }
148 if self.config.retry_on_timeout && Self::is_timeout_error(&details.message) {
149 return true;
150 }
151 false
152 }
153 _ => false,
154 }
155 }
156
157 pub fn calculate_delay(&self, attempt: u32, error: &Error) -> Duration {
168 let base_delay = match self.config.strategy_type {
169 RetryStrategyType::Fixed => self.config.base_delay_ms,
170 RetryStrategyType::Exponential => {
171 self.config.base_delay_ms * 2_u64.pow(attempt.saturating_sub(1))
172 }
173 RetryStrategyType::Linear => self.config.base_delay_ms * u64::from(attempt),
174 };
175
176 let mut delay = base_delay.min(self.config.max_delay_ms);
177
178 if matches!(error, Error::RateLimit { .. }) {
179 delay = delay.max(2000);
180 }
181 if self.config.jitter_factor > 0.0 {
182 delay = self.apply_jitter(delay);
183 }
184
185 Duration::from_millis(delay)
186 }
187
188 fn apply_jitter(&self, delay_ms: u64) -> u64 {
190 use rand::Rng;
191 let mut rng = rand::rngs::ThreadRng::default();
192 #[allow(clippy::cast_precision_loss)]
193 #[allow(clippy::cast_possible_truncation)]
194 let jitter_range = (delay_ms as f64 * self.config.jitter_factor) as u64;
195 let jitter = rng.random_range(0..=jitter_range);
196 delay_ms + jitter
197 }
198
199 fn is_server_error(msg: &str) -> bool {
201 let msg_lower = msg.to_lowercase();
202 msg_lower.contains("500")
203 || msg_lower.contains("502")
204 || msg_lower.contains("503")
205 || msg_lower.contains("504")
206 || msg_lower.contains("internal server error")
207 || msg_lower.contains("bad gateway")
208 || msg_lower.contains("service unavailable")
209 || msg_lower.contains("gateway timeout")
210 }
211
212 fn is_timeout_error(msg: &str) -> bool {
214 let msg_lower = msg.to_lowercase();
215 msg_lower.contains("timeout")
216 || msg_lower.contains("timed out")
217 || msg_lower.contains("408")
218 }
219
220 pub fn config(&self) -> &RetryConfig {
222 &self.config
223 }
224
225 pub fn max_retries(&self) -> u32 {
227 self.config.max_retries
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234
235 #[test]
236 fn test_retry_config_default() {
237 let config = RetryConfig::default();
238 assert_eq!(config.max_retries, 3);
239 assert_eq!(config.strategy_type, RetryStrategyType::Exponential);
240 assert_eq!(config.base_delay_ms, 100);
241 assert!(config.retry_on_network_error);
242 assert!(config.retry_on_rate_limit);
243 }
244
245 #[test]
246 fn test_retry_config_conservative() {
247 let config = RetryConfig::conservative();
248 assert_eq!(config.max_retries, 2);
249 assert_eq!(config.strategy_type, RetryStrategyType::Fixed);
250 assert!(!config.retry_on_server_error);
251 }
252
253 #[test]
254 fn test_retry_config_aggressive() {
255 let config = RetryConfig::aggressive();
256 assert_eq!(config.max_retries, 5);
257 assert!(config.retry_on_server_error);
258 assert!(config.retry_on_timeout);
259 }
260
261 #[test]
262 fn test_should_retry_network_error() {
263 let strategy = RetryStrategy::default_strategy();
264 let error = Error::network("Connection failed");
265
266 assert!(strategy.should_retry(&error, 1));
267 assert!(strategy.should_retry(&error, 2));
268 assert!(strategy.should_retry(&error, 3));
269 assert!(!strategy.should_retry(&error, 4));
270 }
271
272 #[test]
273 fn test_should_retry_rate_limit() {
274 let strategy = RetryStrategy::default_strategy();
275 let error = Error::rate_limit("Rate limit exceeded", None);
276
277 assert!(strategy.should_retry(&error, 1));
278 assert!(strategy.should_retry(&error, 3));
279 }
280
281 #[test]
282 fn test_should_not_retry_invalid_request() {
283 let strategy = RetryStrategy::default_strategy();
284 let error = Error::invalid_request("Bad request");
285
286 assert!(!strategy.should_retry(&error, 1));
287 }
288
289 #[test]
290 fn test_calculate_delay_fixed() {
291 let config = RetryConfig {
292 strategy_type: RetryStrategyType::Fixed,
293 base_delay_ms: 1000,
294 jitter_factor: 0.0,
295 ..Default::default()
296 };
297 let strategy = RetryStrategy::new(config);
298 let error = Error::network("test");
299
300 assert_eq!(strategy.calculate_delay(1, &error).as_millis(), 1000);
301 assert_eq!(strategy.calculate_delay(2, &error).as_millis(), 1000);
302 assert_eq!(strategy.calculate_delay(3, &error).as_millis(), 1000);
303 }
304
305 #[test]
306 fn test_calculate_delay_exponential() {
307 let config = RetryConfig {
308 strategy_type: RetryStrategyType::Exponential,
309 base_delay_ms: 100,
310 max_delay_ms: 10000,
311 jitter_factor: 0.0,
312 ..Default::default()
313 };
314 let strategy = RetryStrategy::new(config);
315 let error = Error::network("test");
316
317 assert_eq!(strategy.calculate_delay(1, &error).as_millis(), 100);
318 assert_eq!(strategy.calculate_delay(2, &error).as_millis(), 200);
319 assert_eq!(strategy.calculate_delay(3, &error).as_millis(), 400);
320 assert_eq!(strategy.calculate_delay(4, &error).as_millis(), 800);
321 }
322
323 #[test]
324 fn test_calculate_delay_linear() {
325 let config = RetryConfig {
326 strategy_type: RetryStrategyType::Linear,
327 base_delay_ms: 500,
328 max_delay_ms: 10000,
329 jitter_factor: 0.0,
330 ..Default::default()
331 };
332 let strategy = RetryStrategy::new(config);
333 let error = Error::network("test");
334
335 assert_eq!(strategy.calculate_delay(1, &error).as_millis(), 500);
336 assert_eq!(strategy.calculate_delay(2, &error).as_millis(), 1000);
337 assert_eq!(strategy.calculate_delay(3, &error).as_millis(), 1500);
338 }
339
340 #[test]
341 fn test_calculate_delay_with_max_limit() {
342 let config = RetryConfig {
343 strategy_type: RetryStrategyType::Exponential,
344 base_delay_ms: 1000,
345 max_delay_ms: 5000,
346 jitter_factor: 0.0,
347 ..Default::default()
348 };
349 let strategy = RetryStrategy::new(config);
350 let error = Error::network("test");
351
352 assert_eq!(strategy.calculate_delay(1, &error).as_millis(), 1000);
353 assert_eq!(strategy.calculate_delay(2, &error).as_millis(), 2000);
354 assert_eq!(strategy.calculate_delay(3, &error).as_millis(), 4000);
355 assert_eq!(strategy.calculate_delay(4, &error).as_millis(), 5000);
356 assert_eq!(strategy.calculate_delay(5, &error).as_millis(), 5000);
357 }
358
359 #[test]
360 fn test_is_server_error() {
361 assert!(RetryStrategy::is_server_error("500 Internal Server Error"));
362 assert!(RetryStrategy::is_server_error("502 Bad Gateway"));
363 assert!(RetryStrategy::is_server_error("503 Service Unavailable"));
364 assert!(RetryStrategy::is_server_error("504 Gateway Timeout"));
365 assert!(!RetryStrategy::is_server_error("400 Bad Request"));
366 assert!(!RetryStrategy::is_server_error("404 Not Found"));
367 }
368
369 #[test]
370 fn test_is_timeout_error() {
371 assert!(RetryStrategy::is_timeout_error("Request timeout"));
372 assert!(RetryStrategy::is_timeout_error("Connection timed out"));
373 assert!(RetryStrategy::is_timeout_error("408 Request Timeout"));
374 assert!(!RetryStrategy::is_timeout_error("Connection refused"));
375 }
376
377 #[test]
378 fn test_rate_limit_error_minimum_delay() {
379 let config = RetryConfig {
380 strategy_type: RetryStrategyType::Fixed,
381 base_delay_ms: 100, jitter_factor: 0.0,
383 ..Default::default()
384 };
385 let strategy = RetryStrategy::new(config);
386 let error = Error::rate_limit("Rate limit exceeded", None);
387
388 assert!(strategy.calculate_delay(1, &error).as_millis() >= 2000);
389 }
390}