Skip to main content

reqwest_drive/
throttle_middleware.rs

1use crate::{
2    DriveCache,
3    cache_middleware::{CacheBust, CacheBypass},
4};
5use async_trait::async_trait;
6use http::Extensions;
7use rand::RngExt;
8use reqwest::{Request, Response};
9use reqwest_middleware::{Error, Middleware, Next};
10use std::sync::Arc;
11use tokio::sync::Semaphore;
12use tokio::time::{Duration, sleep};
13
14/// Defines the throttling and backoff behavior for handling HTTP requests.
15///
16/// This policy determines the **rate limiting strategy** used for outgoing requests,
17/// including fixed delays, adaptive backoff, and retry settings.
18#[derive(Clone, Debug)]
19pub struct ThrottlePolicy {
20    /// The base delay (in milliseconds) applied before making a request.
21    ///
22    /// This ensures a **minimum delay** between consecutive requests.
23    pub base_delay_ms: u64,
24
25    /// The maximum random jitter (in milliseconds) added to the backoff delay.
26    ///
27    /// This prevents synchronization issues when multiple clients are making requests,
28    /// reducing the likelihood of rate-limit collisions.
29    pub adaptive_jitter_ms: u64,
30
31    /// The maximum number of concurrent requests allowed at any given time.
32    ///
33    /// This controls **parallel request execution**, ensuring that no more than
34    /// `max_concurrent` requests are in-flight simultaneously.
35    pub max_concurrent: usize,
36
37    /// The maximum number of retries allowed in case of failed requests.
38    ///
39    /// If a request fails (e.g., due to a **server error or rate limiting**),
40    /// it will be retried up to `max_retries` times with exponential backoff.
41    pub max_retries: usize,
42}
43
44impl Default for ThrottlePolicy {
45    /// Provides a sensible default throttling policy.
46    ///
47    /// This default configuration is suitable for most API use cases and includes:
48    /// - A **base delay** of 500ms between requests.
49    /// - A **random jitter** of up to 250ms to avoid synchronization issues.
50    /// - A **maximum of 5 concurrent requests** to prevent excessive load.
51    /// - A **maximum of 3 retries** for failed requests.
52    ///
53    /// # Returns
54    ///
55    /// A `ThrottlePolicy` instance with preconfigured defaults.
56    fn default() -> Self {
57        Self {
58            base_delay_ms: 500,      // 500ms base delay between requests
59            adaptive_jitter_ms: 250, // Add up to 250ms random jitter
60            max_concurrent: 5,       // Allow up to 5 concurrent requests
61            max_retries: 3,          // Retry failed requests up to 3 times
62        }
63    }
64}
65
66/// Implements a throttling and exponential backoff middleware for HTTP requests.
67///
68/// This middleware **limits request concurrency** and applies **adaptive delays**
69/// between retries, helping to prevent rate-limiting issues when interacting
70/// with APIs that enforce request quotas.
71///
72/// It can run in two modes:
73/// - **Cache-aware mode** via [`DriveThrottleBackoff::new`], where cached requests can bypass throttling.
74/// - **Throttle-only mode** via [`DriveThrottleBackoff::without_cache`], where all requests are throttled.
75///
76/// Requests are throttled using a **semaphore-based** approach, ensuring that
77/// the maximum number of concurrent requests does not exceed `max_concurrent`.
78///
79/// If a request fails, it enters a **retry loop** where each retry is delayed
80/// according to an **exponential backoff strategy**.
81pub struct DriveThrottleBackoff {
82    /// Semaphore controlling the maximum number of concurrent requests.
83    semaphore: Arc<Semaphore>,
84
85    /// Defines the backoff and throttling behavior.
86    policy: ThrottlePolicy,
87
88    /// Optional cache layer for detecting previously cached responses.
89    cache: Option<Arc<DriveCache>>,
90}
91
92impl DriveThrottleBackoff {
93    /// Creates a new `DriveThrottleBackoff` middleware with the specified throttling policy.
94    ///
95    /// # Arguments
96    ///
97    /// * `policy` - The throttling configuration defining concurrency limits, delays, and retry behavior.
98    /// * `cache` - The shared cache instance used for **detecting previously cached requests**.
99    ///
100    /// # Returns
101    ///
102    /// A new instance of `DriveThrottleBackoff`.
103    pub fn new(policy: ThrottlePolicy, cache: Arc<DriveCache>) -> Self {
104        Self {
105            semaphore: Arc::new(Semaphore::new(policy.max_concurrent)),
106            policy,
107            cache: Some(cache),
108        }
109    }
110
111    /// Creates a new `DriveThrottleBackoff` middleware without any cache integration.
112    ///
113    /// In this mode, every request is throttled based on the configured policy,
114    /// and no cache checks are performed.
115    pub fn without_cache(policy: ThrottlePolicy) -> Self {
116        Self {
117            semaphore: Arc::new(Semaphore::new(policy.max_concurrent)),
118            policy,
119            cache: None,
120        }
121    }
122
123    pub fn available_permits(&self) -> usize {
124        self.semaphore.available_permits()
125    }
126}
127
128#[async_trait]
129impl Middleware for DriveThrottleBackoff {
130    /// Handles throttling and retry logic for HTTP requests.
131    ///
132    /// This method:
133    /// 1. **Optionally checks the cache**: In cache-aware mode, cached requests bypass throttling.
134    /// 2. **Enforces concurrency limits**: Ensures no more than `max_concurrent` requests are in flight.
135    /// 3. **Applies an initial delay** before sending the request.
136    /// 4. **Retries failed requests**: Uses **exponential backoff** with jitter for failed requests.
137    ///
138    /// # Arguments
139    ///
140    /// * `req` - The incoming HTTP request.
141    /// * `extensions` - A mutable reference to request extensions, used for tracking metadata.
142    /// * `next` - The next middleware in the request chain.
143    ///
144    /// # Returns
145    ///
146    /// A `Result<Response, Error>` containing either:
147    /// - A successfully processed response.
148    /// - An error if the request failed after exhausting all retries.
149    ///
150    /// # Behavior
151    ///
152    /// - In cache-aware mode, if the request is **already cached**, the middleware immediately forwards it.
153    /// - In throttle-only mode, cache checks are skipped.
154    /// - If **throttling is required**, it waits according to the configured delay.
155    /// - If a request fails, **exponential backoff** is applied before retrying.
156    async fn handle(
157        &self,
158        req: Request,
159        extensions: &mut Extensions,
160        next: Next<'_>,
161    ) -> Result<Response, Error> {
162        let url = req.url().to_string();
163        let bypass_cache = extensions
164            .get::<CacheBypass>()
165            .map(|flag| flag.0)
166            .unwrap_or(false);
167        let bust_cache = extensions
168            .get::<CacheBust>()
169            .map(|flag| flag.0)
170            .unwrap_or(false);
171
172        let cache_key = format!("{} {}", req.method(), &url);
173
174        if !bypass_cache
175            && !bust_cache
176            && let Some(cache) = &self.cache
177        {
178            if cache.is_cached(&req).await {
179                tracing::debug!("Using cache for: {}", &cache_key);
180
181                return next.run(req, extensions).await;
182            } else {
183                tracing::debug!("No cache found for: {}", &cache_key);
184            }
185        }
186
187        // Use a custom throttle policy if provided, otherwise default to `self.policy`
188        let custom_policy = extensions.get::<ThrottlePolicy>().cloned();
189        let policy = custom_policy.unwrap_or_else(|| self.policy.clone()); // Use override if available
190
191        // Log if the permit is not immediately available
192        if self.semaphore.available_permits() == 0 {
193            tracing::debug!("Waiting for permit... ({} in use)", policy.max_concurrent);
194        }
195
196        // Acquire the permit and log when granted
197        let permit = self
198            .semaphore
199            .acquire()
200            .await
201            .map_err(|e| Error::Middleware(e.into()))?;
202
203        tracing::debug!(
204            "Permit granted: {} ({} permits left)",
205            cache_key,
206            self.semaphore.available_permits()
207        );
208
209        // Hold the permit until this function completes
210        let _permit_guard = permit;
211
212        sleep(Duration::from_millis(policy.base_delay_ms)).await;
213
214        let mut attempt = 0;
215
216        loop {
217            let req_clone = req.try_clone().expect("Request cloning failed");
218            let result = next.clone().run(req_clone, extensions).await;
219
220            match result {
221                Ok(resp) if resp.status().is_success() => return Ok(resp),
222                result if attempt >= policy.max_retries => return result,
223                _ => {
224                    attempt += 1;
225
226                    let backoff_duration = {
227                        let mut rng = rand::rng();
228                        Duration::from_millis(
229                            policy.base_delay_ms * 2u64.pow(attempt as u32)
230                                + rng.random_range(0..=policy.adaptive_jitter_ms),
231                        )
232                    };
233
234                    tracing::debug!(
235                        "Retry {}/{} for URL {} after {} ms",
236                        attempt,
237                        policy.max_retries,
238                        url,
239                        backoff_duration.as_millis()
240                    );
241
242                    sleep(backoff_duration).await;
243
244                    if attempt >= policy.max_retries {
245                        break;
246                    }
247                }
248            }
249        }
250
251        next.run(req, extensions).await
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258    use reqwest_middleware::ClientBuilder;
259
260    #[test]
261    fn throttle_policy_default_values_are_stable() {
262        let policy = ThrottlePolicy::default();
263
264        assert_eq!(policy.base_delay_ms, 500);
265        assert_eq!(policy.adaptive_jitter_ms, 250);
266        assert_eq!(policy.max_concurrent, 5);
267        assert_eq!(policy.max_retries, 3);
268    }
269
270    #[tokio::test]
271    async fn closed_semaphore_returns_middleware_error() {
272        let throttle = Arc::new(DriveThrottleBackoff::without_cache(ThrottlePolicy {
273            base_delay_ms: 1,
274            adaptive_jitter_ms: 0,
275            max_concurrent: 1,
276            max_retries: 0,
277        }));
278
279        // Force semaphore acquire to fail so we cover the map_err path.
280        throttle.semaphore.close();
281
282        let client = ClientBuilder::new(reqwest::Client::new())
283            .with_arc(throttle)
284            .build();
285
286        let error = client
287            .get("https://example.test/closed-semaphore")
288            .send()
289            .await
290            .expect_err("closed semaphore should return middleware error");
291
292        assert!(
293            matches!(error, reqwest_middleware::Error::Middleware(_)),
294            "expected middleware error variant, got: {:?}",
295            error
296        );
297    }
298}