Skip to main content

modkit_http/layers/
retry.rs

1use crate::config::{ExponentialBackoff, RetryConfig, RetryTrigger};
2use crate::error::HttpError;
3use crate::response::{ResponseBody, parse_retry_after};
4use bytes::Bytes;
5use http::{HeaderValue, Request, Response};
6use http_body_util::{BodyExt, Full};
7use rand::Rng;
8use std::future::Future;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use std::time::Duration;
12use tower::{Layer, Service, ServiceExt};
13
14/// Header name for retry attempt number (1-indexed).
15/// Added to retried requests to indicate which retry attempt this is.
16pub const RETRY_ATTEMPT_HEADER: &str = "X-Retry-Attempt";
17
18/// Tower layer that implements retry with exponential backoff and jitter
19///
20/// This layer operates on services that return `HttpError` and makes retry
21/// decisions based on error type and HTTP status codes.
22#[derive(Clone)]
23pub struct RetryLayer {
24    config: RetryConfig,
25    total_timeout: Option<Duration>,
26}
27
28impl RetryLayer {
29    /// Create a new `RetryLayer` with the specified configuration
30    #[must_use]
31    pub fn new(config: RetryConfig) -> Self {
32        Self {
33            config,
34            total_timeout: None,
35        }
36    }
37
38    /// Create a new `RetryLayer` with total timeout (deadline across all retries)
39    #[must_use]
40    pub fn with_total_timeout(config: RetryConfig, total_timeout: Option<Duration>) -> Self {
41        Self {
42            config,
43            total_timeout,
44        }
45    }
46}
47
48impl<S> Layer<S> for RetryLayer {
49    type Service = RetryService<S>;
50
51    fn layer(&self, inner: S) -> Self::Service {
52        RetryService {
53            inner,
54            config: self.config.clone(),
55            total_timeout: self.total_timeout,
56        }
57    }
58}
59
60/// Service that implements retry logic with exponential backoff
61///
62/// Retries on both `Err(HttpError)` and `Ok(Response)` based on status codes.
63/// When retrying on status codes, drains response body up to configured limit
64/// to allow connection reuse.
65///
66/// `send()` returns `Ok(Response)` for ALL HTTP statuses after retries exhaust.
67/// `send()` returns `Err` only for transport/timeout errors.
68///
69/// # Total Timeout (Deadline)
70///
71/// When `total_timeout` is set, the entire operation (including all retries and
72/// backoff delays) must complete within that duration. This provides a hard
73/// deadline for the caller, regardless of how many retries are configured.
74#[derive(Clone)]
75pub struct RetryService<S> {
76    inner: S,
77    config: RetryConfig,
78    total_timeout: Option<Duration>,
79}
80
81impl<S> Service<Request<Full<Bytes>>> for RetryService<S>
82where
83    S: Service<Request<Full<Bytes>>, Response = Response<ResponseBody>, Error = HttpError>
84        + Clone
85        + Send
86        + 'static,
87    S::Future: Send,
88{
89    type Response = S::Response;
90    type Error = HttpError;
91    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
92
93    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
94        self.inner.poll_ready(cx)
95    }
96
97    fn call(&mut self, req: Request<Full<Bytes>>) -> Self::Future {
98        // Swap so we consume the instance that was poll_ready'd,
99        // leaving a fresh clone for the next poll_ready cycle.
100        let clone = self.inner.clone();
101        let inner = std::mem::replace(&mut self.inner, clone);
102        let config = self.config.clone();
103        let total_timeout = self.total_timeout;
104
105        let (parts, body_bytes) = req.into_parts();
106
107        // Preserve HTTP version for retry requests (required per HTTP spec)
108        let http_version = parts.version;
109
110        // Preserve extensions for retry requests (tracing context, matched routes, etc.)
111        // Note: Only extensions implementing Clone + Send + Sync are preserved.
112        // Non-cloneable extensions (like some tracing spans) will be lost on retry.
113        let extensions = parts.extensions.clone();
114
115        // Check for idempotency key header before wrapping in Arc
116        // Header name is pre-parsed at config construction, so just check directly
117        let has_idempotency_key = config
118            .idempotency_key_header
119            .as_ref()
120            .is_some_and(|name| parts.headers.contains_key(name));
121
122        let parts = std::sync::Arc::new(parts);
123
124        Box::pin(async move {
125            let method = parts.method.clone();
126
127            // Extract request identity for logging (host + optional request-id)
128            // Use authority() for full host:port, falling back to host() or "unknown"
129            let url_host = parts
130                .uri
131                .authority()
132                .map(ToString::to_string)
133                .or_else(|| parts.uri.host().map(ToOwned::to_owned))
134                .unwrap_or_else(|| "unknown".to_owned());
135            let request_id = parts
136                .headers
137                .get("x-request-id")
138                .or_else(|| parts.headers.get("x-correlation-id"))
139                .and_then(|v| v.to_str().ok())
140                .map(String::from);
141
142            // Calculate deadline if total_timeout is set.
143            // Store (deadline_instant, timeout_duration) together to avoid unsafe unwrap/expect later.
144            let deadline_info = total_timeout.map(|t| (tokio::time::Instant::now() + t, t));
145
146            let mut attempt = 0usize;
147            loop {
148                // Check deadline before each attempt
149                if let Some((deadline, timeout_duration)) = deadline_info
150                    && tokio::time::Instant::now() >= deadline
151                {
152                    return Err(HttpError::DeadlineExceeded(timeout_duration));
153                }
154
155                // Reconstruct request from preserved parts
156                let mut req = Request::from_parts((*parts).clone(), body_bytes.clone());
157
158                // Restore HTTP version (may have been lost during Parts clone)
159                *req.version_mut() = http_version;
160
161                // Restore extensions (tracing context, matched routes, etc.)
162                // This ensures retry requests maintain the same context as the original
163                *req.extensions_mut() = extensions.clone();
164
165                // Add retry attempt header for retried requests (attempt > 0)
166                if attempt > 0 {
167                    // Safe: attempt is a small usize, always valid as a header value
168                    if let Ok(value) = HeaderValue::try_from(attempt.to_string()) {
169                        req.headers_mut().insert(RETRY_ATTEMPT_HEADER, value);
170                    }
171                }
172
173                let mut svc = inner.clone();
174                svc.ready().await?;
175
176                match svc.call(req).await {
177                    Ok(resp) => {
178                        // Check if we should retry based on HTTP status code
179                        let status_code = resp.status().as_u16();
180                        let trigger = RetryTrigger::Status(status_code);
181
182                        if config.max_retries > 0
183                            && attempt < config.max_retries
184                            && config.should_retry(trigger, &method, has_idempotency_key)
185                        {
186                            // Parse Retry-After from response headers
187                            let retry_after = parse_retry_after(resp.headers());
188                            let backoff_duration = if config.ignore_retry_after {
189                                calculate_backoff(&config.backoff, attempt)
190                            } else {
191                                retry_after
192                                    .unwrap_or_else(|| calculate_backoff(&config.backoff, attempt))
193                            };
194
195                            // Drain response body to allow connection reuse
196                            let drain_limit = config.retry_response_drain_limit;
197                            let should_drain = if config.skip_drain_on_retry {
198                                // Configured to skip drain entirely
199                                tracing::trace!("Skipping drain: skip_drain_on_retry enabled");
200                                false
201                            } else if let Some(content_length) = resp
202                                .headers()
203                                .get(http::header::CONTENT_LENGTH)
204                                .and_then(|v| v.to_str().ok())
205                                .and_then(|s| s.parse::<u64>().ok())
206                            {
207                                if content_length > drain_limit as u64 {
208                                    // Content-Length exceeds drain limit, skip to avoid
209                                    // expensive decompression of large error bodies
210                                    tracing::debug!(
211                                        content_length,
212                                        drain_limit,
213                                        "Skipping drain: Content-Length exceeds limit"
214                                    );
215                                    false
216                                } else {
217                                    true
218                                }
219                            } else {
220                                // No Content-Length, attempt drain up to limit
221                                true
222                            };
223
224                            if should_drain
225                                && let Err(e) = drain_response_body(resp, drain_limit).await
226                            {
227                                // If drain fails, log but continue with retry
228                                tracing::debug!(
229                                    error = %e,
230                                    "Failed to drain response body before retry; connection may not be reused"
231                                );
232                            }
233
234                            // Check if backoff would exceed deadline
235                            let effective_backoff =
236                                if let Some((deadline, timeout_duration)) = deadline_info {
237                                    let remaining = deadline
238                                        .saturating_duration_since(tokio::time::Instant::now());
239                                    if remaining.is_zero() {
240                                        return Err(HttpError::DeadlineExceeded(timeout_duration));
241                                    }
242                                    backoff_duration.min(remaining)
243                                } else {
244                                    backoff_duration
245                                };
246
247                            tracing::debug!(
248                                retry = attempt + 1,
249                                max_retries = config.max_retries,
250                                status = status_code,
251                                trigger = ?trigger,
252                                method = %method,
253                                host = %url_host,
254                                request_id = ?request_id,
255                                backoff_ms = effective_backoff.as_millis(),
256                                retry_after_used = retry_after.is_some() && !config.ignore_retry_after,
257                                "Retrying request after status code"
258                            );
259                            tokio::time::sleep(effective_backoff).await;
260                            attempt += 1;
261                            continue;
262                        }
263
264                        // No retry needed or retries exhausted - return Ok(Response)
265                        return Ok(resp);
266                    }
267                    Err(err) => {
268                        if config.max_retries == 0 || attempt >= config.max_retries {
269                            return Err(err);
270                        }
271
272                        let trigger = get_retry_trigger(&err);
273                        if !config.should_retry(trigger, &method, has_idempotency_key) {
274                            return Err(err);
275                        }
276
277                        // For errors, there's no response body to drain
278                        let backoff_duration = calculate_backoff(&config.backoff, attempt);
279
280                        // Check if backoff would exceed deadline
281                        let effective_backoff =
282                            if let Some((deadline, timeout_duration)) = deadline_info {
283                                let remaining =
284                                    deadline.saturating_duration_since(tokio::time::Instant::now());
285                                if remaining.is_zero() {
286                                    return Err(HttpError::DeadlineExceeded(timeout_duration));
287                                }
288                                backoff_duration.min(remaining)
289                            } else {
290                                backoff_duration
291                            };
292
293                        tracing::debug!(
294                            retry = attempt + 1,
295                            max_retries = config.max_retries,
296                            error = %err,
297                            trigger = ?trigger,
298                            method = %method,
299                            host = %url_host,
300                            request_id = ?request_id,
301                            backoff_ms = effective_backoff.as_millis(),
302                            "Retrying request after error"
303                        );
304                        tokio::time::sleep(effective_backoff).await;
305                        attempt += 1;
306                    }
307                }
308            }
309        })
310    }
311}
312
313/// Drain response body up to limit bytes to allow connection reuse.
314///
315/// # Connection Reuse
316///
317/// For HTTP/1.1, the response body must be fully consumed before the connection
318/// can be reused for subsequent requests. This function drains up to `limit`
319/// bytes to enable connection pooling.
320///
321/// # Decompression Note
322///
323/// This operates on the **decompressed** body (after `DecompressionLayer`).
324/// The limit applies to decompressed bytes. For compressed responses, the
325/// actual network traffic may be smaller than the configured limit.
326///
327/// This means draining can cost CPU for highly compressible responses, but
328/// provides protection against unexpected memory consumption.
329///
330/// # Behavior
331///
332/// - Stops draining once `limit` bytes have been read
333/// - If the body exceeds the limit, draining stops early and the connection
334///   may not be reused (a new connection will be established for the retry)
335/// - Returns `Ok(())` on success, or `HttpError` if body read fails
336async fn drain_response_body(
337    response: Response<ResponseBody>,
338    limit: usize,
339) -> Result<(), HttpError> {
340    let (_parts, body) = response.into_parts();
341    let mut body = std::pin::pin!(body);
342    let mut drained = 0usize;
343
344    while let Some(frame) = body.frame().await {
345        let frame = frame.map_err(HttpError::Transport)?;
346        if let Some(chunk) = frame.data_ref() {
347            drained += chunk.len();
348            if drained >= limit {
349                // Hit limit, stop draining (connection may not be reused)
350                break;
351            }
352        }
353    }
354
355    Ok(())
356}
357
358/// Extract retry trigger from an error
359fn get_retry_trigger(err: &HttpError) -> RetryTrigger {
360    match err {
361        HttpError::Transport(_) => RetryTrigger::TransportError,
362        HttpError::Timeout(_) => RetryTrigger::Timeout,
363        // DeadlineExceeded, ServiceClosed, and other errors are not retryable
364        _ => RetryTrigger::NonRetryable,
365    }
366}
367
368/// Calculate backoff duration for a given attempt
369///
370/// Safely handles edge cases (NaN, infinity, negative values) to avoid panics.
371pub fn calculate_backoff(backoff: &ExponentialBackoff, attempt: usize) -> Duration {
372    // Maximum safe backoff in seconds (1 day - beyond this is unreasonable for retry logic)
373    const MAX_BACKOFF_SECS: f64 = 86400.0;
374
375    // Safely convert attempt to i32, clamping to i32::MAX (which is already way beyond
376    // any reasonable retry count - at that point backoff will be at max anyway)
377    let attempt_i32 = i32::try_from(attempt).unwrap_or(i32::MAX);
378
379    // Sanitize multiplier: must be finite and >= 0, default to 1.0
380    let multiplier = if backoff.multiplier.is_finite() && backoff.multiplier >= 0.0 {
381        backoff.multiplier
382    } else {
383        1.0
384    };
385
386    // Sanitize initial backoff
387    let initial_secs = backoff.initial.as_secs_f64();
388    let initial_secs = if initial_secs.is_finite() && initial_secs >= 0.0 {
389        initial_secs
390    } else {
391        0.0
392    };
393
394    // Sanitize max backoff
395    let max_secs = backoff.max.as_secs_f64();
396    let max_secs = if max_secs.is_finite() && max_secs >= 0.0 {
397        max_secs.min(MAX_BACKOFF_SECS)
398    } else {
399        MAX_BACKOFF_SECS
400    };
401
402    // Calculate with sanitized values
403    let base_duration = initial_secs * multiplier.powi(attempt_i32);
404
405    // Clamp to valid range for Duration::from_secs_f64 (must be finite, non-negative)
406    let clamped = if base_duration.is_finite() {
407        base_duration.min(max_secs).max(0.0)
408    } else {
409        max_secs
410    };
411    let duration = Duration::from_secs_f64(clamped);
412
413    // Apply jitter
414    let duration = if backoff.jitter {
415        let mut rng = rand::rng();
416        let jitter_factor = rng.random_range(0.0..=0.25);
417        let jitter = duration.mul_f64(jitter_factor);
418        duration + jitter
419    } else {
420        duration
421    };
422
423    // Keep jittered value within max_backoff
424    let max_duration = Duration::from_secs_f64(max_secs);
425    duration.min(max_duration)
426}
427
428#[cfg(test)]
429#[cfg_attr(coverage_nightly, coverage(off))]
430mod tests {
431    use super::*;
432    use crate::config::IDEMPOTENCY_KEY_HEADER;
433    use bytes::Bytes;
434    use http::{Method, Request, Response, StatusCode};
435    use http_body_util::Full;
436
437    /// Helper to create a boxed `ResponseBody` from bytes
438    fn make_response_body(data: &[u8]) -> ResponseBody {
439        let body = Full::new(Bytes::from(data.to_vec()));
440        body.map_err(|e| -> Box<dyn std::error::Error + Send + Sync> { Box::new(e) })
441            .boxed()
442    }
443
444    #[tokio::test]
445    async fn test_retry_layer_successful_request() {
446        use std::sync::{Arc, Mutex};
447
448        #[derive(Clone)]
449        struct CountingService {
450            call_count: Arc<Mutex<usize>>,
451        }
452
453        impl Service<Request<Full<Bytes>>> for CountingService {
454            type Response = Response<ResponseBody>;
455            type Error = HttpError;
456            type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
457
458            fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
459                Poll::Ready(Ok(()))
460            }
461
462            fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
463                let count = self.call_count.clone();
464                Box::pin(async move {
465                    *count.lock().unwrap() += 1;
466                    let response = Response::builder()
467                        .status(StatusCode::OK)
468                        .body(make_response_body(b""))
469                        .unwrap();
470                    Ok(response)
471                })
472            }
473        }
474
475        let call_count = Arc::new(Mutex::new(0));
476        let service = CountingService {
477            call_count: call_count.clone(),
478        };
479
480        let retry_config = RetryConfig::default();
481        let layer = RetryLayer::new(retry_config);
482        let mut retry_service = layer.layer(service);
483
484        let req = Request::builder()
485            .method(Method::GET)
486            .uri("http://example.com")
487            .body(Full::new(Bytes::new()))
488            .unwrap();
489
490        let result = retry_service.call(req).await;
491        assert!(result.is_ok());
492        assert_eq!(*call_count.lock().unwrap(), 1); // Should only call once on success
493    }
494
495    /// Test: POST request with 500 is NOT retried and returns Ok(Response).
496    /// With new semantics: 500 for non-idempotent method passes through as Ok(Response).
497    #[tokio::test]
498    async fn test_retry_layer_post_not_retried_on_5xx() {
499        use std::sync::{Arc, Mutex};
500
501        #[derive(Clone)]
502        struct ServerErrorService {
503            call_count: Arc<Mutex<usize>>,
504        }
505
506        impl Service<Request<Full<Bytes>>> for ServerErrorService {
507            type Response = Response<ResponseBody>;
508            type Error = HttpError;
509            type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
510
511            fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
512                Poll::Ready(Ok(()))
513            }
514
515            fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
516                let count = self.call_count.clone();
517                Box::pin(async move {
518                    *count.lock().unwrap() += 1;
519                    // Return Ok(Response) with 500 status - POST won't retry
520                    Ok(Response::builder()
521                        .status(StatusCode::INTERNAL_SERVER_ERROR)
522                        .body(make_response_body(b"Internal Server Error"))
523                        .unwrap())
524                })
525            }
526        }
527
528        let call_count = Arc::new(Mutex::new(0));
529        let service = ServerErrorService {
530            call_count: call_count.clone(),
531        };
532
533        let retry_config = RetryConfig {
534            backoff: ExponentialBackoff::fast(),
535            ..RetryConfig::default()
536        };
537        let layer = RetryLayer::new(retry_config);
538        let mut retry_service = layer.layer(service);
539
540        let req = Request::builder()
541            .method(Method::POST)
542            .uri("http://example.com")
543            .body(Full::new(Bytes::new()))
544            .unwrap();
545
546        let result = retry_service.call(req).await;
547        // New semantics: returns Ok(Response) with 500 status, NOT Err
548        assert!(result.is_ok());
549        let resp = result.unwrap();
550        assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
551        assert_eq!(*call_count.lock().unwrap(), 1); // POST should NOT be retried on 500
552    }
553
554    /// Test: GET request with 500 is retried (idempotent method).
555    /// Returns Ok(Response) with final status after retries exhaust or success.
556    #[tokio::test]
557    async fn test_retry_layer_get_retried_on_5xx() {
558        use std::sync::{Arc, Mutex};
559
560        #[derive(Clone)]
561        struct FailThenSucceedService {
562            call_count: Arc<Mutex<usize>>,
563        }
564
565        impl Service<Request<Full<Bytes>>> for FailThenSucceedService {
566            type Response = Response<ResponseBody>;
567            type Error = HttpError;
568            type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
569
570            fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
571                Poll::Ready(Ok(()))
572            }
573
574            fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
575                let count = self.call_count.clone();
576                Box::pin(async move {
577                    let mut c = count.lock().unwrap();
578                    *c += 1;
579                    if *c < 3 {
580                        // Return 500 - will trigger retry for GET
581                        Ok(Response::builder()
582                            .status(StatusCode::INTERNAL_SERVER_ERROR)
583                            .body(make_response_body(b"Internal Server Error"))
584                            .unwrap())
585                    } else {
586                        Ok(Response::builder()
587                            .status(StatusCode::OK)
588                            .body(make_response_body(b""))
589                            .unwrap())
590                    }
591                })
592            }
593        }
594
595        let call_count = Arc::new(Mutex::new(0));
596        let service = FailThenSucceedService {
597            call_count: call_count.clone(),
598        };
599
600        let retry_config = RetryConfig {
601            backoff: ExponentialBackoff::fast(),
602            ..RetryConfig::default()
603        };
604        let layer = RetryLayer::new(retry_config);
605        let mut retry_service = layer.layer(service);
606
607        let req = Request::builder()
608            .method(Method::GET)
609            .uri("http://example.com")
610            .body(Full::new(Bytes::new()))
611            .unwrap();
612
613        let result = retry_service.call(req).await;
614        assert!(result.is_ok());
615        assert_eq!(result.unwrap().status(), StatusCode::OK);
616        assert_eq!(*call_count.lock().unwrap(), 3); // GET should retry on 500
617    }
618
619    /// Test: 429 is always retried (POST included), returns Ok(Response).
620    #[tokio::test]
621    async fn test_retry_layer_always_retries_429() {
622        use std::sync::{Arc, Mutex};
623
624        #[derive(Clone)]
625        struct RateLimitThenSucceedService {
626            call_count: Arc<Mutex<usize>>,
627        }
628
629        impl Service<Request<Full<Bytes>>> for RateLimitThenSucceedService {
630            type Response = Response<ResponseBody>;
631            type Error = HttpError;
632            type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
633
634            fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
635                Poll::Ready(Ok(()))
636            }
637
638            fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
639                let count = self.call_count.clone();
640                Box::pin(async move {
641                    let mut c = count.lock().unwrap();
642                    *c += 1;
643                    if *c < 2 {
644                        // Return 429 - triggers retry for all methods
645                        Ok(Response::builder()
646                            .status(StatusCode::TOO_MANY_REQUESTS)
647                            .body(make_response_body(b"Rate limited"))
648                            .unwrap())
649                    } else {
650                        Ok(Response::builder()
651                            .status(StatusCode::OK)
652                            .body(make_response_body(b""))
653                            .unwrap())
654                    }
655                })
656            }
657        }
658
659        let call_count = Arc::new(Mutex::new(0));
660        let service = RateLimitThenSucceedService {
661            call_count: call_count.clone(),
662        };
663
664        let retry_config = RetryConfig {
665            backoff: ExponentialBackoff::fast(),
666            ..RetryConfig::default()
667        };
668        let layer = RetryLayer::new(retry_config);
669        let mut retry_service = layer.layer(service);
670
671        // 429 should be retried even for POST
672        let req = Request::builder()
673            .method(Method::POST)
674            .uri("http://example.com")
675            .body(Full::new(Bytes::new()))
676            .unwrap();
677
678        let result = retry_service.call(req).await;
679        assert!(result.is_ok());
680        assert_eq!(result.unwrap().status(), StatusCode::OK);
681        assert_eq!(*call_count.lock().unwrap(), 2); // POST retries on 429
682    }
683
684    #[tokio::test]
685    async fn test_retry_layer_retries_transport_errors() {
686        use std::sync::{Arc, Mutex};
687
688        #[derive(Clone)]
689        struct FailThenSucceedService {
690            call_count: Arc<Mutex<usize>>,
691        }
692
693        impl Service<Request<Full<Bytes>>> for FailThenSucceedService {
694            type Response = Response<ResponseBody>;
695            type Error = HttpError;
696            type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
697
698            fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
699                Poll::Ready(Ok(()))
700            }
701
702            fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
703                let count = self.call_count.clone();
704                Box::pin(async move {
705                    let mut c = count.lock().unwrap();
706                    *c += 1;
707                    if *c < 3 {
708                        Err(HttpError::Transport(Box::new(std::io::Error::new(
709                            std::io::ErrorKind::ConnectionReset,
710                            "connection reset",
711                        ))))
712                    } else {
713                        Ok(Response::builder()
714                            .status(StatusCode::OK)
715                            .body(make_response_body(b""))
716                            .unwrap())
717                    }
718                })
719            }
720        }
721
722        let call_count = Arc::new(Mutex::new(0));
723        let service = FailThenSucceedService {
724            call_count: call_count.clone(),
725        };
726
727        let retry_config = RetryConfig {
728            backoff: ExponentialBackoff::fast(),
729            ..RetryConfig::default()
730        };
731        let layer = RetryLayer::new(retry_config);
732        let mut retry_service = layer.layer(service);
733
734        let req = Request::builder()
735            .method(Method::GET)
736            .uri("http://example.com")
737            .body(Full::new(Bytes::new()))
738            .unwrap();
739
740        let result = retry_service.call(req).await;
741        assert!(result.is_ok());
742        assert_eq!(*call_count.lock().unwrap(), 3); // Should retry until success
743    }
744
745    /// Test: POST request is NOT retried on transport errors (by default, for safety)
746    #[tokio::test]
747    async fn test_retry_layer_post_not_retried_on_transport_error() {
748        use std::sync::{Arc, Mutex};
749
750        #[derive(Clone)]
751        struct TransportErrorService {
752            call_count: Arc<Mutex<usize>>,
753        }
754
755        impl Service<Request<Full<Bytes>>> for TransportErrorService {
756            type Response = Response<ResponseBody>;
757            type Error = HttpError;
758            type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
759
760            fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
761                Poll::Ready(Ok(()))
762            }
763
764            fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
765                let count = self.call_count.clone();
766                Box::pin(async move {
767                    *count.lock().unwrap() += 1;
768                    Err(HttpError::Transport(Box::new(std::io::Error::new(
769                        std::io::ErrorKind::ConnectionReset,
770                        "connection reset",
771                    ))))
772                })
773            }
774        }
775
776        let call_count = Arc::new(Mutex::new(0));
777        let service = TransportErrorService {
778            call_count: call_count.clone(),
779        };
780
781        let retry_config = RetryConfig {
782            backoff: ExponentialBackoff::fast(),
783            ..RetryConfig::default()
784        };
785        let layer = RetryLayer::new(retry_config);
786        let mut retry_service = layer.layer(service);
787
788        // POST without idempotency key should NOT be retried on transport error
789        let req = Request::builder()
790            .method(Method::POST)
791            .uri("http://example.com")
792            .body(Full::new(Bytes::new()))
793            .unwrap();
794
795        let result = retry_service.call(req).await;
796        assert!(result.is_err()); // Should return error, not retry
797        assert_eq!(*call_count.lock().unwrap(), 1); // Only one attempt
798    }
799
800    /// Test: POST request WITH idempotency key IS retried on transport errors
801    #[tokio::test]
802    async fn test_retry_layer_post_with_idempotency_key_retried() {
803        use std::sync::{Arc, Mutex};
804
805        #[derive(Clone)]
806        struct FailThenSucceedService {
807            call_count: Arc<Mutex<usize>>,
808        }
809
810        impl Service<Request<Full<Bytes>>> for FailThenSucceedService {
811            type Response = Response<ResponseBody>;
812            type Error = HttpError;
813            type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
814
815            fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
816                Poll::Ready(Ok(()))
817            }
818
819            fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
820                let count = self.call_count.clone();
821                Box::pin(async move {
822                    let mut c = count.lock().unwrap();
823                    *c += 1;
824                    if *c < 3 {
825                        Err(HttpError::Transport(Box::new(std::io::Error::new(
826                            std::io::ErrorKind::ConnectionReset,
827                            "connection reset",
828                        ))))
829                    } else {
830                        Ok(Response::builder()
831                            .status(StatusCode::OK)
832                            .body(make_response_body(b""))
833                            .unwrap())
834                    }
835                })
836            }
837        }
838
839        let call_count = Arc::new(Mutex::new(0));
840        let service = FailThenSucceedService {
841            call_count: call_count.clone(),
842        };
843
844        let retry_config = RetryConfig {
845            backoff: ExponentialBackoff::fast(),
846            ..RetryConfig::default()
847        };
848        let layer = RetryLayer::new(retry_config);
849        let mut retry_service = layer.layer(service);
850
851        // POST WITH idempotency key should be retried on transport error
852        let req = Request::builder()
853            .method(Method::POST)
854            .uri("http://example.com")
855            .header(IDEMPOTENCY_KEY_HEADER, "unique-key-123")
856            .body(Full::new(Bytes::new()))
857            .unwrap();
858
859        let result = retry_service.call(req).await;
860        assert!(result.is_ok()); // Should succeed after retries
861        assert_eq!(*call_count.lock().unwrap(), 3); // Should retry until success
862    }
863
864    #[tokio::test]
865    async fn test_retry_layer_does_not_retry_json_errors() {
866        use std::sync::{Arc, Mutex};
867
868        #[derive(Clone)]
869        struct JsonErrorService {
870            call_count: Arc<Mutex<usize>>,
871        }
872
873        impl Service<Request<Full<Bytes>>> for JsonErrorService {
874            type Response = Response<ResponseBody>;
875            type Error = HttpError;
876            type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
877
878            fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
879                Poll::Ready(Ok(()))
880            }
881
882            fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
883                let count = self.call_count.clone();
884                Box::pin(async move {
885                    *count.lock().unwrap() += 1;
886                    // Simulate a JSON parse error (non-retryable)
887                    let err: serde_json::Error =
888                        serde_json::from_str::<serde_json::Value>("invalid").unwrap_err();
889                    Err(HttpError::Json(err))
890                })
891            }
892        }
893
894        let call_count = Arc::new(Mutex::new(0));
895        let service = JsonErrorService {
896            call_count: call_count.clone(),
897        };
898
899        let retry_config = RetryConfig::default();
900        let layer = RetryLayer::new(retry_config);
901        let mut retry_service = layer.layer(service);
902
903        let req = Request::builder()
904            .method(Method::GET)
905            .uri("http://example.com")
906            .body(Full::new(Bytes::new()))
907            .unwrap();
908
909        let result = retry_service.call(req).await;
910        assert!(result.is_err());
911        assert_eq!(*call_count.lock().unwrap(), 1); // Should NOT retry JSON errors
912    }
913
914    #[test]
915    fn test_calculate_backoff_no_jitter() {
916        let backoff = ExponentialBackoff {
917            initial: Duration::from_millis(100),
918            max: Duration::from_secs(10),
919            multiplier: 2.0,
920            jitter: false,
921        };
922
923        let backoff0 = calculate_backoff(&backoff, 0);
924        assert_eq!(backoff0, Duration::from_millis(100));
925
926        let backoff1 = calculate_backoff(&backoff, 1);
927        assert_eq!(backoff1, Duration::from_millis(200));
928
929        let backoff2 = calculate_backoff(&backoff, 2);
930        assert_eq!(backoff2, Duration::from_millis(400));
931
932        // Should cap at max
933        let backoff_capped = calculate_backoff(&backoff, 10);
934        assert_eq!(backoff_capped, Duration::from_secs(10));
935    }
936
937    #[test]
938    fn test_calculate_backoff_with_jitter() {
939        let backoff = ExponentialBackoff {
940            initial: Duration::from_millis(100),
941            max: Duration::from_secs(10),
942            multiplier: 2.0,
943            jitter: true,
944        };
945
946        let backoff0 = calculate_backoff(&backoff, 0);
947        // With jitter, should be between 100ms and 125ms
948        assert!(backoff0 >= Duration::from_millis(100));
949        assert!(backoff0 <= Duration::from_millis(125));
950    }
951
952    #[test]
953    fn test_calculate_backoff_with_nan_multiplier() {
954        // NaN multiplier should default to 1.0, not panic
955        let backoff = ExponentialBackoff {
956            initial: Duration::from_millis(100),
957            max: Duration::from_secs(10),
958            multiplier: f64::NAN,
959            jitter: false,
960        };
961
962        // Should not panic, NaN multiplier falls back to 1.0
963        let result = calculate_backoff(&backoff, 0);
964        assert_eq!(result, Duration::from_millis(100));
965
966        let result1 = calculate_backoff(&backoff, 1);
967        // With multiplier = 1.0, backoff stays at initial value
968        assert_eq!(result1, Duration::from_millis(100));
969    }
970
971    #[test]
972    fn test_calculate_backoff_with_infinity_multiplier() {
973        // Infinity multiplier should default to 1.0, not panic
974        let backoff = ExponentialBackoff {
975            initial: Duration::from_millis(100),
976            max: Duration::from_secs(10),
977            multiplier: f64::INFINITY,
978            jitter: false,
979        };
980
981        // Should not panic
982        let result = calculate_backoff(&backoff, 0);
983        assert_eq!(result, Duration::from_millis(100));
984    }
985
986    #[test]
987    fn test_calculate_backoff_with_negative_multiplier() {
988        // Negative multiplier should default to 1.0, not panic
989        let backoff = ExponentialBackoff {
990            initial: Duration::from_millis(100),
991            max: Duration::from_secs(10),
992            multiplier: -2.0,
993            jitter: false,
994        };
995
996        // Should not panic, negative multiplier falls back to 1.0
997        let result = calculate_backoff(&backoff, 0);
998        assert_eq!(result, Duration::from_millis(100));
999    }
1000
1001    #[test]
1002    fn test_calculate_backoff_with_huge_attempt() {
1003        // Large attempt number should not overflow or panic
1004        let backoff = ExponentialBackoff {
1005            initial: Duration::from_millis(100),
1006            max: Duration::from_secs(10),
1007            multiplier: 2.0,
1008            jitter: false,
1009        };
1010
1011        // usize::MAX should be clamped to i32::MAX internally
1012        let result = calculate_backoff(&backoff, usize::MAX);
1013        // Should return max since 2^(i32::MAX) is way beyond max
1014        assert_eq!(result, Duration::from_secs(10));
1015    }
1016
1017    /// Test: Retry-After header in response is used for backoff timing.
1018    #[tokio::test]
1019    async fn test_retry_layer_uses_retry_after_header() {
1020        use std::sync::{Arc, Mutex};
1021
1022        #[derive(Clone)]
1023        struct RetryAfterService {
1024            call_count: Arc<Mutex<usize>>,
1025        }
1026
1027        impl Service<Request<Full<Bytes>>> for RetryAfterService {
1028            type Response = Response<ResponseBody>;
1029            type Error = HttpError;
1030            type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
1031
1032            fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1033                Poll::Ready(Ok(()))
1034            }
1035
1036            fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
1037                let count = self.call_count.clone();
1038                Box::pin(async move {
1039                    let mut c = count.lock().unwrap();
1040                    *c += 1;
1041                    if *c < 2 {
1042                        // Return 429 with Retry-After header (50ms)
1043                        Ok(Response::builder()
1044                            .status(StatusCode::TOO_MANY_REQUESTS)
1045                            .header(http::header::RETRY_AFTER, "0")
1046                            .body(make_response_body(b"Rate limited"))
1047                            .unwrap())
1048                    } else {
1049                        Ok(Response::builder()
1050                            .status(StatusCode::OK)
1051                            .body(make_response_body(b""))
1052                            .unwrap())
1053                    }
1054                })
1055            }
1056        }
1057
1058        let call_count = Arc::new(Mutex::new(0));
1059        let service = RetryAfterService {
1060            call_count: call_count.clone(),
1061        };
1062
1063        let retry_config = RetryConfig {
1064            backoff: ExponentialBackoff {
1065                initial: Duration::from_secs(10), // Long backoff that would fail test
1066                jitter: false,
1067                ..ExponentialBackoff::default()
1068            },
1069            ignore_retry_after: false, // Use Retry-After header
1070            ..RetryConfig::default()
1071        };
1072        let layer = RetryLayer::new(retry_config);
1073        let mut retry_service = layer.layer(service);
1074
1075        let req = Request::builder()
1076            .method(Method::POST)
1077            .uri("http://example.com")
1078            .body(Full::new(Bytes::new()))
1079            .unwrap();
1080
1081        let start = std::time::Instant::now();
1082        let result = retry_service.call(req).await;
1083        let elapsed = start.elapsed();
1084
1085        assert!(result.is_ok());
1086        assert_eq!(*call_count.lock().unwrap(), 2);
1087
1088        // Should have used Retry-After: 0 (immediate), not 10s backoff
1089        assert!(
1090            elapsed < Duration::from_secs(1),
1091            "Expected quick retry using Retry-After, but took {elapsed:?}",
1092        );
1093    }
1094
1095    /// Test: Retry-After header is ignored when config says to ignore it.
1096    #[tokio::test]
1097    async fn test_retry_layer_ignores_retry_after_when_configured() {
1098        use std::sync::{Arc, Mutex};
1099
1100        #[derive(Clone)]
1101        struct RetryAfterService {
1102            call_count: Arc<Mutex<usize>>,
1103        }
1104
1105        impl Service<Request<Full<Bytes>>> for RetryAfterService {
1106            type Response = Response<ResponseBody>;
1107            type Error = HttpError;
1108            type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
1109
1110            fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1111                Poll::Ready(Ok(()))
1112            }
1113
1114            fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
1115                let count = self.call_count.clone();
1116                Box::pin(async move {
1117                    let mut c = count.lock().unwrap();
1118                    *c += 1;
1119                    if *c < 2 {
1120                        // Return 429 with Retry-After: 10s (should be ignored)
1121                        Ok(Response::builder()
1122                            .status(StatusCode::TOO_MANY_REQUESTS)
1123                            .header(http::header::RETRY_AFTER, "10")
1124                            .body(make_response_body(b"Rate limited"))
1125                            .unwrap())
1126                    } else {
1127                        Ok(Response::builder()
1128                            .status(StatusCode::OK)
1129                            .body(make_response_body(b""))
1130                            .unwrap())
1131                    }
1132                })
1133            }
1134        }
1135
1136        let call_count = Arc::new(Mutex::new(0));
1137        let service = RetryAfterService {
1138            call_count: call_count.clone(),
1139        };
1140
1141        let retry_config = RetryConfig {
1142            backoff: ExponentialBackoff::fast(), // Fast backoff (1ms initial, no jitter)
1143            ignore_retry_after: true,            // Ignore Retry-After header
1144            ..RetryConfig::default()
1145        };
1146        let layer = RetryLayer::new(retry_config);
1147        let mut retry_service = layer.layer(service);
1148
1149        let req = Request::builder()
1150            .method(Method::POST)
1151            .uri("http://example.com")
1152            .body(Full::new(Bytes::new()))
1153            .unwrap();
1154
1155        let start = std::time::Instant::now();
1156        let result = retry_service.call(req).await;
1157        let elapsed = start.elapsed();
1158
1159        assert!(result.is_ok());
1160        assert_eq!(*call_count.lock().unwrap(), 2);
1161
1162        // Should have used 1ms backoff, not 10s Retry-After
1163        assert!(
1164            elapsed < Duration::from_secs(1),
1165            "Expected quick retry using backoff policy (1ms), but took {elapsed:?}",
1166        );
1167    }
1168
1169    #[tokio::test]
1170    async fn test_retry_attempt_header_added_on_retry() {
1171        use std::sync::{Arc, Mutex};
1172
1173        #[derive(Clone)]
1174        struct HeaderCapturingService {
1175            call_count: Arc<Mutex<usize>>,
1176            captured_headers: Arc<Mutex<Vec<Option<String>>>>,
1177        }
1178
1179        impl Service<Request<Full<Bytes>>> for HeaderCapturingService {
1180            type Response = Response<ResponseBody>;
1181            type Error = HttpError;
1182            type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
1183
1184            fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1185                Poll::Ready(Ok(()))
1186            }
1187
1188            fn call(&mut self, req: Request<Full<Bytes>>) -> Self::Future {
1189                let count = self.call_count.clone();
1190                let captured_headers = self.captured_headers.clone();
1191
1192                // Capture the X-Retry-Attempt header value
1193                let retry_header = req
1194                    .headers()
1195                    .get(RETRY_ATTEMPT_HEADER)
1196                    .map(|v| v.to_str().unwrap_or("invalid").to_owned());
1197
1198                Box::pin(async move {
1199                    let mut c = count.lock().unwrap();
1200                    *c += 1;
1201                    captured_headers.lock().unwrap().push(retry_header);
1202
1203                    if *c < 3 {
1204                        // Fail with transport error (always retried)
1205                        Err(HttpError::Transport(Box::new(std::io::Error::new(
1206                            std::io::ErrorKind::ConnectionReset,
1207                            "connection reset",
1208                        ))))
1209                    } else {
1210                        Ok(Response::builder()
1211                            .status(StatusCode::OK)
1212                            .body(make_response_body(b""))
1213                            .unwrap())
1214                    }
1215                })
1216            }
1217        }
1218
1219        let call_count = Arc::new(Mutex::new(0));
1220        let captured_headers = Arc::new(Mutex::new(Vec::new()));
1221        let service = HeaderCapturingService {
1222            call_count: call_count.clone(),
1223            captured_headers: captured_headers.clone(),
1224        };
1225
1226        let retry_config = RetryConfig {
1227            backoff: ExponentialBackoff::fast(),
1228            ..RetryConfig::default()
1229        };
1230        let layer = RetryLayer::new(retry_config);
1231        let mut retry_service = layer.layer(service);
1232
1233        let req = Request::builder()
1234            .method(Method::GET)
1235            .uri("http://example.com")
1236            .body(Full::new(Bytes::new()))
1237            .unwrap();
1238
1239        let result = retry_service.call(req).await;
1240        assert!(result.is_ok());
1241        assert_eq!(*call_count.lock().unwrap(), 3);
1242
1243        // Verify captured headers
1244        let headers = captured_headers.lock().unwrap();
1245        assert_eq!(headers.len(), 3);
1246        // First call (attempt 0): no header
1247        assert_eq!(headers[0], None);
1248        // Second call (attempt 1): header = "1"
1249        assert_eq!(headers[1], Some("1".to_owned()));
1250        // Third call (attempt 2): header = "2"
1251        assert_eq!(headers[2], Some("2".to_owned()));
1252    }
1253
1254    /// Test: Retries exhausted returns Ok(Response) with final status, not Err.
1255    #[tokio::test]
1256    async fn test_retry_layer_exhausted_returns_ok_with_status() {
1257        use std::sync::{Arc, Mutex};
1258
1259        #[derive(Clone)]
1260        struct AlwaysFailService {
1261            call_count: Arc<Mutex<usize>>,
1262        }
1263
1264        impl Service<Request<Full<Bytes>>> for AlwaysFailService {
1265            type Response = Response<ResponseBody>;
1266            type Error = HttpError;
1267            type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
1268
1269            fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1270                Poll::Ready(Ok(()))
1271            }
1272
1273            fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
1274                let count = self.call_count.clone();
1275                Box::pin(async move {
1276                    *count.lock().unwrap() += 1;
1277                    // Always return 500
1278                    Ok(Response::builder()
1279                        .status(StatusCode::INTERNAL_SERVER_ERROR)
1280                        .body(make_response_body(b"error"))
1281                        .unwrap())
1282                })
1283            }
1284        }
1285
1286        let call_count = Arc::new(Mutex::new(0));
1287        let service = AlwaysFailService {
1288            call_count: call_count.clone(),
1289        };
1290
1291        let retry_config = RetryConfig {
1292            max_retries: 2,
1293            backoff: ExponentialBackoff::fast(),
1294            ..RetryConfig::default()
1295        };
1296        let layer = RetryLayer::new(retry_config);
1297        let mut retry_service = layer.layer(service);
1298
1299        let req = Request::builder()
1300            .method(Method::GET)
1301            .uri("http://example.com")
1302            .body(Full::new(Bytes::new()))
1303            .unwrap();
1304
1305        let result = retry_service.call(req).await;
1306
1307        // Retries exhausted: returns Ok(Response) with 500 status
1308        assert!(result.is_ok());
1309        let resp = result.unwrap();
1310        assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
1311
1312        // 1 initial + 2 retries = 3 calls
1313        assert_eq!(*call_count.lock().unwrap(), 3);
1314    }
1315
1316    /// Test: Non-retryable status (404) passes through immediately.
1317    #[tokio::test]
1318    async fn test_retry_layer_non_retryable_status_passes_through() {
1319        use std::sync::{Arc, Mutex};
1320
1321        #[derive(Clone)]
1322        struct NotFoundService {
1323            call_count: Arc<Mutex<usize>>,
1324        }
1325
1326        impl Service<Request<Full<Bytes>>> for NotFoundService {
1327            type Response = Response<ResponseBody>;
1328            type Error = HttpError;
1329            type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
1330
1331            fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1332                Poll::Ready(Ok(()))
1333            }
1334
1335            fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
1336                let count = self.call_count.clone();
1337                Box::pin(async move {
1338                    *count.lock().unwrap() += 1;
1339                    Ok(Response::builder()
1340                        .status(StatusCode::NOT_FOUND)
1341                        .body(make_response_body(b"not found"))
1342                        .unwrap())
1343                })
1344            }
1345        }
1346
1347        let call_count = Arc::new(Mutex::new(0));
1348        let service = NotFoundService {
1349            call_count: call_count.clone(),
1350        };
1351
1352        let retry_config = RetryConfig {
1353            max_retries: 3,
1354            backoff: ExponentialBackoff::fast(),
1355            ..RetryConfig::default()
1356        };
1357        let layer = RetryLayer::new(retry_config);
1358        let mut retry_service = layer.layer(service);
1359
1360        let req = Request::builder()
1361            .method(Method::GET)
1362            .uri("http://example.com")
1363            .body(Full::new(Bytes::new()))
1364            .unwrap();
1365
1366        let result = retry_service.call(req).await;
1367
1368        // 404 is not retryable - passes through immediately
1369        assert!(result.is_ok());
1370        let resp = result.unwrap();
1371        assert_eq!(resp.status(), StatusCode::NOT_FOUND);
1372
1373        // Only 1 call (no retries)
1374        assert_eq!(*call_count.lock().unwrap(), 1);
1375    }
1376}