1use crate::{Error, Middleware, Result};
2use serde_json::Value;
3use std::time::Duration;
4use tokio::time::timeout;
5
6pub struct TimeoutMiddleware {
9 duration: Duration,
10}
11
12impl TimeoutMiddleware {
13 pub fn new(duration: Duration) -> Self {
14 Self { duration }
15 }
16
17 pub fn from_millis(millis: u64) -> Self {
18 Self::new(Duration::from_millis(millis))
19 }
20
21 pub fn from_secs(secs: u64) -> Self {
22 Self::new(Duration::from_secs(secs))
23 }
24
25 pub fn duration(&self) -> Duration {
26 self.duration
27 }
28}
29
30#[async_trait::async_trait]
31impl Middleware for TimeoutMiddleware {
32 async fn before(&self, request: Value) -> Result<Value> {
33 Ok(request)
34 }
35
36 async fn after(&self, _request: Value, response: Value) -> Result<Value> {
37 Ok(response)
38 }
39}
40
41#[derive(Debug, Clone)]
43pub struct RetryPolicy {
44 pub max_attempts: u32,
46 pub initial_backoff: Duration,
48 pub max_backoff: Duration,
50 pub backoff_multiplier: f64,
52 pub use_jitter: bool,
54}
55
56impl RetryPolicy {
57 pub fn new(max_attempts: u32) -> Self {
58 Self {
59 max_attempts,
60 initial_backoff: Duration::from_millis(100),
61 max_backoff: Duration::from_secs(30),
62 backoff_multiplier: 2.0,
63 use_jitter: true,
64 }
65 }
66
67 pub fn with_backoff(mut self, initial: Duration, max: Duration) -> Self {
68 self.initial_backoff = initial;
69 self.max_backoff = max;
70 self
71 }
72
73 pub fn with_multiplier(mut self, multiplier: f64) -> Self {
74 self.backoff_multiplier = multiplier;
75 self
76 }
77
78 pub fn with_jitter(mut self, use_jitter: bool) -> Self {
79 self.use_jitter = use_jitter;
80 self
81 }
82
83 pub fn backoff_duration(&self, attempt: u32) -> Duration {
85 let base_duration =
86 self.initial_backoff.as_millis() as f64 * self.backoff_multiplier.powi(attempt as i32);
87
88 let capped = base_duration.min(self.max_backoff.as_millis() as f64);
89
90 if self.use_jitter {
91 let jitter = rand::random::<f64>() * capped * 0.1; Duration::from_millis((capped + jitter) as u64)
93 } else {
94 Duration::from_millis(capped as u64)
95 }
96 }
97
98 pub fn is_retryable(&self, error: &Error) -> bool {
100 match error {
102 Error::Handler(msg) => {
103 msg.contains("timeout")
105 || msg.contains("timed out")
106 || msg.contains("connection")
107 || msg.contains("temporary")
108 }
109 _ => false,
110 }
111 }
112}
113
114impl Default for RetryPolicy {
115 fn default() -> Self {
116 Self::new(3)
117 }
118}
119
120pub struct RetryMiddleware {
123 policy: RetryPolicy,
124}
125
126impl RetryMiddleware {
127 pub fn new(policy: RetryPolicy) -> Self {
128 Self { policy }
129 }
130
131 pub fn with_max_attempts(max_attempts: u32) -> Self {
132 Self::new(RetryPolicy::new(max_attempts))
133 }
134
135 pub fn policy(&self) -> &RetryPolicy {
136 &self.policy
137 }
138}
139
140#[async_trait::async_trait]
141impl Middleware for RetryMiddleware {
142 async fn on_error(&self, _request: Value, error: Error) -> Result<Value> {
143 Err(error)
147 }
148}
149
150async fn apply_backoff_delay(policy: &RetryPolicy, attempt: u32, max_attempts: u32) {
152 if attempt < max_attempts {
153 let backoff = policy.backoff_duration(attempt - 1);
154 tokio::time::sleep(backoff).await;
155 }
156}
157
158fn handle_retry_result<T>(
160 error: Error,
161 policy: &RetryPolicy,
162 attempt: &mut u32,
163 last_error: &mut Option<Error>,
164) -> Option<Result<T>> {
165 if !policy.is_retryable(&error) {
166 return Some(Err(error));
167 }
168
169 *last_error = Some(error);
170 *attempt += 1;
171 None
172}
173
174pub async fn retry_with_policy<F, Fut, T>(policy: &RetryPolicy, mut operation: F) -> Result<T>
176where
177 F: FnMut() -> Fut,
178 Fut: std::future::Future<Output = Result<T>>,
179{
180 let mut attempt = 0;
181 let mut last_error = None;
182
183 while attempt < policy.max_attempts {
184 match operation().await {
185 Ok(result) => return Ok(result),
186 Err(error) => {
187 if let Some(result) =
188 handle_retry_result(error, policy, &mut attempt, &mut last_error)
189 {
190 return result;
191 }
192 apply_backoff_delay(policy, attempt, policy.max_attempts).await;
193 }
194 }
195 }
196
197 Err(last_error.unwrap_or_else(|| Error::Handler("All retry attempts failed".to_string())))
198}
199
200pub async fn with_timeout<F>(duration: Duration, future: F) -> Result<F::Output>
202where
203 F: std::future::Future,
204{
205 timeout(duration, future)
206 .await
207 .map_err(|_| Error::Handler(format!("Operation timed out after {:?}", duration)))
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 use std::sync::atomic::{AtomicU32, Ordering};
214 use std::sync::Arc;
215
216 #[test]
217 fn test_retry_policy_backoff() {
218 let policy = RetryPolicy::new(3)
219 .with_backoff(Duration::from_millis(100), Duration::from_secs(5))
220 .with_multiplier(2.0)
221 .with_jitter(false);
222
223 let backoff1 = policy.backoff_duration(0);
224 let backoff2 = policy.backoff_duration(1);
225 let backoff3 = policy.backoff_duration(2);
226
227 assert_eq!(backoff1.as_millis(), 100);
228 assert_eq!(backoff2.as_millis(), 200);
229 assert_eq!(backoff3.as_millis(), 400);
230 }
231
232 #[test]
233 fn test_retry_policy_max_backoff() {
234 let policy = RetryPolicy::new(10)
235 .with_backoff(Duration::from_millis(100), Duration::from_secs(1))
236 .with_multiplier(2.0)
237 .with_jitter(false);
238
239 let backoff = policy.backoff_duration(10); assert!(backoff <= Duration::from_secs(1));
241 }
242
243 #[tokio::test]
244 async fn test_retry_with_policy_success() {
245 let counter = Arc::new(AtomicU32::new(0));
246 let counter_clone = counter.clone();
247
248 let policy = RetryPolicy::new(3)
249 .with_backoff(Duration::from_millis(10), Duration::from_millis(50))
250 .with_jitter(false);
251
252 let result = retry_with_policy(&policy, || {
253 let counter = counter_clone.clone();
254 async move {
255 let count = counter.fetch_add(1, Ordering::SeqCst);
256 if count < 2 {
257 Err(Error::Handler("timeout error".to_string()))
258 } else {
259 Ok(42)
260 }
261 }
262 })
263 .await;
264
265 assert!(result.is_ok());
266 assert_eq!(result.unwrap(), 42);
267 assert_eq!(counter.load(Ordering::SeqCst), 3); }
269
270 #[tokio::test]
271 async fn test_retry_with_policy_max_attempts() {
272 let counter = Arc::new(AtomicU32::new(0));
273 let counter_clone = counter.clone();
274
275 let policy = RetryPolicy::new(3)
276 .with_backoff(Duration::from_millis(10), Duration::from_millis(50))
277 .with_jitter(false);
278
279 let result = retry_with_policy(&policy, || {
280 let counter = counter_clone.clone();
281 async move {
282 counter.fetch_add(1, Ordering::SeqCst);
283 Err::<(), _>(Error::Handler("timeout error".to_string()))
284 }
285 })
286 .await;
287
288 assert!(result.is_err());
289 assert_eq!(counter.load(Ordering::SeqCst), 3); }
291
292 #[tokio::test]
293 async fn test_retry_non_retryable_error() {
294 let counter = Arc::new(AtomicU32::new(0));
295 let counter_clone = counter.clone();
296
297 let policy = RetryPolicy::new(3);
298
299 let result = retry_with_policy(&policy, || {
300 let counter = counter_clone.clone();
301 async move {
302 counter.fetch_add(1, Ordering::SeqCst);
303 Err::<(), _>(Error::Handler("fatal error".to_string()))
304 }
305 })
306 .await;
307
308 assert!(result.is_err());
309 assert_eq!(counter.load(Ordering::SeqCst), 1); }
311
312 #[tokio::test]
313 async fn test_with_timeout_success() {
314 let result = with_timeout(Duration::from_secs(1), async {
315 tokio::time::sleep(Duration::from_millis(10)).await;
316 42
317 })
318 .await;
319
320 assert!(result.is_ok());
321 assert_eq!(result.unwrap(), 42);
322 }
323
324 #[tokio::test]
325 async fn test_with_timeout_exceeded() {
326 let result = with_timeout(Duration::from_millis(50), async {
327 tokio::time::sleep(Duration::from_secs(10)).await;
328 42
329 })
330 .await;
331
332 assert!(result.is_err());
333 assert!(result.unwrap_err().to_string().contains("timed out"));
334 }
335
336 #[tokio::test]
337 async fn test_combined_timeout_and_retry() {
338 let counter = Arc::new(AtomicU32::new(0));
339 let counter_clone = counter.clone();
340
341 let policy = RetryPolicy::new(3)
342 .with_backoff(Duration::from_millis(10), Duration::from_millis(50))
343 .with_jitter(false);
344
345 let result = retry_with_policy(&policy, || {
346 let counter = counter_clone.clone();
347 async move {
348 let count = counter.fetch_add(1, Ordering::SeqCst);
349 if count < 2 {
350 with_timeout(Duration::from_millis(10), async {
351 tokio::time::sleep(Duration::from_secs(10)).await;
352 42
353 })
354 .await
355 } else {
356 Ok(100)
357 }
358 }
359 })
360 .await;
361
362 assert!(result.is_ok());
363 assert_eq!(result.unwrap(), 100);
364 assert_eq!(counter.load(Ordering::SeqCst), 3);
365 }
366
367 #[test]
368 fn test_backoff_multiplier_exact() {
369 let policy = RetryPolicy::new(5)
371 .with_backoff(Duration::from_millis(100), Duration::from_secs(10))
372 .with_multiplier(3.0)
373 .with_jitter(false);
374
375 assert_eq!(policy.backoff_duration(0).as_millis(), 100);
377 assert_eq!(policy.backoff_duration(1).as_millis(), 300);
378 assert_eq!(policy.backoff_duration(2).as_millis(), 900);
379 assert_eq!(policy.backoff_duration(3).as_millis(), 2700);
380 }
381
382 #[test]
383 fn test_is_retryable_logic() {
384 let policy = RetryPolicy::new(3);
386
387 assert!(policy.is_retryable(&Error::Handler("timeout error".to_string())));
389 assert!(policy.is_retryable(&Error::Handler("timed out".to_string())));
390 assert!(policy.is_retryable(&Error::Handler("connection failed".to_string())));
391 assert!(policy.is_retryable(&Error::Handler("temporary issue".to_string())));
392
393 assert!(!policy.is_retryable(&Error::Handler("fatal error".to_string())));
395 assert!(!policy.is_retryable(&Error::Timeout));
396 }
397
398 #[tokio::test]
399 async fn test_retry_attempt_comparison() {
400 let counter = Arc::new(AtomicU32::new(0));
402 let counter_clone = counter.clone();
403
404 let policy = RetryPolicy::new(5) .with_backoff(Duration::from_millis(1), Duration::from_millis(10))
406 .with_jitter(false);
407
408 let _result = retry_with_policy(&policy, || {
409 let counter = counter_clone.clone();
410 async move {
411 counter.fetch_add(1, Ordering::SeqCst);
412 Err::<(), _>(Error::Handler("timeout error".to_string()))
413 }
414 })
415 .await;
416
417 assert_eq!(counter.load(Ordering::SeqCst), 5);
419 }
420
421 #[tokio::test]
422 async fn test_retry_backoff_calculation() {
423 let policy = RetryPolicy::new(3)
425 .with_backoff(Duration::from_millis(100), Duration::from_secs(1))
426 .with_multiplier(2.0)
427 .with_jitter(false);
428
429 let counter = Arc::new(AtomicU32::new(0));
430 let counter_clone = counter.clone();
431 let start = std::time::Instant::now();
432
433 let _result = retry_with_policy(&policy, || {
434 let c = counter_clone.clone();
435 async move {
436 c.fetch_add(1, Ordering::SeqCst);
437 Err::<(), _>(Error::Handler("timeout error".to_string()))
438 }
439 })
440 .await;
441
442 let total_time = start.elapsed();
445 assert!(
446 total_time >= Duration::from_millis(250),
447 "Total time too short: {:?}",
448 total_time
449 );
450 assert!(
451 total_time < Duration::from_millis(500),
452 "Total time too long: {:?}",
453 total_time
454 );
455 }
456}