reqwest_drive/
throttle_middleware.rs

1use crate::DriveCache;
2use async_trait::async_trait;
3use http::Extensions;
4use rand::Rng;
5use reqwest::{Request, Response};
6use reqwest_middleware::{Error, Middleware, Next};
7use std::sync::Arc;
8use tokio::sync::Semaphore;
9use tokio::time::{sleep, Duration};
10
11/// Defines the throttling and backoff behavior for handling HTTP requests.
12///
13/// This policy determines the **rate limiting strategy** used for outgoing requests,
14/// including fixed delays, adaptive backoff, and retry settings.
15#[derive(Clone, Debug)]
16pub struct ThrottlePolicy {
17    /// The base delay (in milliseconds) applied before making a request.
18    ///
19    /// This ensures a **minimum delay** between consecutive requests.
20    pub base_delay_ms: u64,
21
22    /// The maximum random jitter (in milliseconds) added to the backoff delay.
23    ///
24    /// This prevents synchronization issues when multiple clients are making requests,
25    /// reducing the likelihood of rate-limit collisions.
26    pub adaptive_jitter_ms: u64,
27
28    /// The maximum number of concurrent requests allowed at any given time.
29    ///
30    /// This controls **parallel request execution**, ensuring that no more than
31    /// `max_concurrent` requests are in-flight simultaneously.
32    pub max_concurrent: usize,
33
34    /// The maximum number of retries allowed in case of failed requests.
35    ///
36    /// If a request fails (e.g., due to a **server error or rate limiting**),
37    /// it will be retried up to `max_retries` times with exponential backoff.
38    pub max_retries: usize,
39}
40
41/// Implements a throttling and exponential backoff middleware for HTTP requests.
42///
43/// This middleware **limits request concurrency** and applies **adaptive delays**
44/// between retries, helping to prevent rate-limiting issues when interacting
45/// with APIs that enforce request quotas.
46///
47/// Requests are throttled using a **semaphore-based** approach, ensuring that
48/// the maximum number of concurrent requests does not exceed `max_concurrent`.
49///
50/// If a request fails, it enters a **retry loop** where each retry is delayed
51/// according to an **exponential backoff strategy**.
52pub struct DriveThrottleBackoff {
53    /// Semaphore controlling the maximum number of concurrent requests.
54    semaphore: Arc<Semaphore>,
55
56    /// Defines the backoff and throttling behavior.
57    policy: ThrottlePolicy,
58
59    /// Cache layer for detecting previously cached responses.
60    cache: Arc<DriveCache>,
61}
62
63impl DriveThrottleBackoff {
64    /// Creates a new `DriveThrottleBackoff` middleware with the specified throttling policy.
65    ///
66    /// # Arguments
67    ///
68    /// * `policy` - The throttling configuration defining concurrency limits, delays, and retry behavior.
69    /// * `cache` - The shared cache instance used for **detecting previously cached requests**.
70    ///
71    /// # Returns
72    ///
73    /// A new instance of `DriveThrottleBackoff`.
74    pub fn new(policy: ThrottlePolicy, cache: Arc<DriveCache>) -> Self {
75        Self {
76            semaphore: Arc::new(Semaphore::new(policy.max_concurrent)),
77            policy,
78            cache,
79        }
80    }
81
82    #[cfg(any(test, debug_assertions))]
83    pub fn available_permits(&self) -> usize {
84        self.semaphore.available_permits()
85    }
86}
87
88#[async_trait]
89impl Middleware for DriveThrottleBackoff {
90    /// Handles throttling and retry logic for HTTP requests.
91    ///
92    /// This method:
93    /// 1. **Checks the cache**: If the request is already cached, it bypasses throttling.
94    /// 2. **Enforces concurrency limits**: Ensures no more than `max_concurrent` requests are in flight.
95    /// 3. **Applies an initial delay** before sending the request.
96    /// 4. **Retries failed requests**: Uses **exponential backoff** with jitter for failed requests.
97    ///
98    /// # Arguments
99    ///
100    /// * `req` - The incoming HTTP request.
101    /// * `extensions` - A mutable reference to request extensions, used for tracking metadata.
102    /// * `next` - The next middleware in the request chain.
103    ///
104    /// # Returns
105    ///
106    /// A `Result<Response, Error>` containing either:
107    /// - A successfully processed response.
108    /// - An error if the request failed after exhausting all retries.
109    ///
110    /// # Behavior
111    ///
112    /// - If the request is **already cached**, the middleware immediately forwards it.
113    /// - If **throttling is required**, it waits according to the configured delay.
114    /// - If a request fails, **exponential backoff** is applied before retrying.
115    async fn handle(
116        &self,
117        req: Request,
118        extensions: &mut Extensions,
119        next: Next<'_>,
120    ) -> Result<Response, Error> {
121        let url = req.url().to_string();
122
123        let cache_key = format!("{} {}", req.method(), &url);
124
125        if self.cache.is_cached(&req).await {
126            eprintln!("Using cache for: {}", &cache_key);
127
128            return next.run(req, extensions).await;
129        } else {
130            eprintln!("No cache found for: {}", &cache_key);
131        }
132
133        // Log if the permit is not immediately available
134        if self.semaphore.available_permits() == 0 {
135            eprintln!(
136                "Waiting for permit... ({} in use)",
137                self.policy.max_concurrent
138            );
139        }
140
141        // Acquire the permit and log when granted
142        let permit = self
143            .semaphore
144            .acquire()
145            .await
146            .map_err(|e| Error::Middleware(e.into()))?;
147
148        eprintln!(
149            "Permit granted: {} ({} permits left)",
150            cache_key,
151            self.semaphore.available_permits()
152        );
153
154        // Hold the permit until this function completes
155        let _permit_guard = permit;
156
157        sleep(Duration::from_millis(self.policy.base_delay_ms)).await;
158
159        let mut attempt = 0;
160
161        loop {
162            let req_clone = req.try_clone().expect("Request cloning failed");
163            let result = next.clone().run(req_clone, extensions).await;
164
165            match result {
166                Ok(resp) if resp.status().is_success() => return Ok(resp),
167                result if attempt >= self.policy.max_retries => return result,
168                _ => {
169                    attempt += 1;
170
171                    let backoff_duration = {
172                        let mut rng = rand::rng();
173                        Duration::from_millis(
174                            self.policy.base_delay_ms * 2u64.pow(attempt as u32)
175                                + rng.random_range(0..=self.policy.adaptive_jitter_ms),
176                        )
177                    };
178
179                    eprintln!(
180                        "Retry {}/{} for URL {} after {} ms",
181                        attempt,
182                        self.policy.max_retries,
183                        url,
184                        backoff_duration.as_millis()
185                    );
186
187                    sleep(backoff_duration).await;
188
189                    if attempt >= self.policy.max_retries {
190                        break;
191                    }
192                }
193            }
194        }
195
196        next.run(req, extensions).await
197    }
198}