reqwest_drive/
throttle_middleware.rs

1use crate::DriveCache;
2use async_trait::async_trait;
3use http::Extensions;
4use rand::Rng;
5use reqwest::{Request, Response};
6use reqwest_middleware::{Error, Middleware, Next};
7use std::sync::Arc;
8use tokio::sync::Semaphore;
9use tokio::time::{sleep, Duration};
10
11/// Defines the throttling and backoff behavior for handling HTTP requests.
12///
13/// This policy determines the **rate limiting strategy** used for outgoing requests,
14/// including fixed delays, adaptive backoff, and retry settings.
15#[derive(Clone, Debug)]
16pub struct ThrottlePolicy {
17    /// The base delay (in milliseconds) applied before making a request.
18    ///
19    /// This ensures a **minimum delay** between consecutive requests.
20    pub base_delay_ms: u64,
21
22    /// The maximum random jitter (in milliseconds) added to the backoff delay.
23    ///
24    /// This prevents synchronization issues when multiple clients are making requests,
25    /// reducing the likelihood of rate-limit collisions.
26    pub adaptive_jitter_ms: u64,
27
28    /// The maximum number of concurrent requests allowed at any given time.
29    ///
30    /// This controls **parallel request execution**, ensuring that no more than
31    /// `max_concurrent` requests are in-flight simultaneously.
32    pub max_concurrent: usize,
33
34    /// The maximum number of retries allowed in case of failed requests.
35    ///
36    /// If a request fails (e.g., due to a **server error or rate limiting**),
37    /// it will be retried up to `max_retries` times with exponential backoff.
38    pub max_retries: usize,
39}
40
41impl Default for ThrottlePolicy {
42    /// Provides a sensible default throttling policy.
43    ///
44    /// This default configuration is suitable for most API use cases and includes:
45    /// - A **base delay** of 500ms between requests.
46    /// - A **random jitter** of up to 250ms to avoid synchronization issues.
47    /// - A **maximum of 5 concurrent requests** to prevent excessive load.
48    /// - A **maximum of 3 retries** for failed requests.
49    ///
50    /// # Returns
51    ///
52    /// A `ThrottlePolicy` instance with preconfigured defaults.
53    fn default() -> Self {
54        Self {
55            base_delay_ms: 500,      // 500ms base delay between requests
56            adaptive_jitter_ms: 250, // Add up to 250ms random jitter
57            max_concurrent: 5,       // Allow up to 5 concurrent requests
58            max_retries: 3,          // Retry failed requests up to 3 times
59        }
60    }
61}
62
63/// Implements a throttling and exponential backoff middleware for HTTP requests.
64///
65/// This middleware **limits request concurrency** and applies **adaptive delays**
66/// between retries, helping to prevent rate-limiting issues when interacting
67/// with APIs that enforce request quotas.
68///
69/// Requests are throttled using a **semaphore-based** approach, ensuring that
70/// the maximum number of concurrent requests does not exceed `max_concurrent`.
71///
72/// If a request fails, it enters a **retry loop** where each retry is delayed
73/// according to an **exponential backoff strategy**.
74pub struct DriveThrottleBackoff {
75    /// Semaphore controlling the maximum number of concurrent requests.
76    semaphore: Arc<Semaphore>,
77
78    /// Defines the backoff and throttling behavior.
79    policy: ThrottlePolicy,
80
81    /// Cache layer for detecting previously cached responses.
82    cache: Arc<DriveCache>,
83}
84
85impl DriveThrottleBackoff {
86    /// Creates a new `DriveThrottleBackoff` middleware with the specified throttling policy.
87    ///
88    /// # Arguments
89    ///
90    /// * `policy` - The throttling configuration defining concurrency limits, delays, and retry behavior.
91    /// * `cache` - The shared cache instance used for **detecting previously cached requests**.
92    ///
93    /// # Returns
94    ///
95    /// A new instance of `DriveThrottleBackoff`.
96    pub fn new(policy: ThrottlePolicy, cache: Arc<DriveCache>) -> Self {
97        Self {
98            semaphore: Arc::new(Semaphore::new(policy.max_concurrent)),
99            policy,
100            cache,
101        }
102    }
103
104    #[cfg(any(test, debug_assertions))]
105    pub fn available_permits(&self) -> usize {
106        self.semaphore.available_permits()
107    }
108}
109
110#[async_trait]
111impl Middleware for DriveThrottleBackoff {
112    /// Handles throttling and retry logic for HTTP requests.
113    ///
114    /// This method:
115    /// 1. **Checks the cache**: If the request is already cached, it bypasses throttling.
116    /// 2. **Enforces concurrency limits**: Ensures no more than `max_concurrent` requests are in flight.
117    /// 3. **Applies an initial delay** before sending the request.
118    /// 4. **Retries failed requests**: Uses **exponential backoff** with jitter for failed requests.
119    ///
120    /// # Arguments
121    ///
122    /// * `req` - The incoming HTTP request.
123    /// * `extensions` - A mutable reference to request extensions, used for tracking metadata.
124    /// * `next` - The next middleware in the request chain.
125    ///
126    /// # Returns
127    ///
128    /// A `Result<Response, Error>` containing either:
129    /// - A successfully processed response.
130    /// - An error if the request failed after exhausting all retries.
131    ///
132    /// # Behavior
133    ///
134    /// - If the request is **already cached**, the middleware immediately forwards it.
135    /// - If **throttling is required**, it waits according to the configured delay.
136    /// - If a request fails, **exponential backoff** is applied before retrying.
137    async fn handle(
138        &self,
139        req: Request,
140        extensions: &mut Extensions,
141        next: Next<'_>,
142    ) -> Result<Response, Error> {
143        let url = req.url().to_string();
144
145        let cache_key = format!("{} {}", req.method(), &url);
146
147        if self.cache.is_cached(&req).await {
148            eprintln!("Using cache for: {}", &cache_key);
149
150            return next.run(req, extensions).await;
151        } else {
152            eprintln!("No cache found for: {}", &cache_key);
153        }
154
155        // Use a custom throttle policy if provided, otherwise default to `self.policy`
156        let custom_policy = extensions.get::<ThrottlePolicy>().cloned();
157        let policy = custom_policy.unwrap_or_else(|| self.policy.clone()); // Use override if available
158
159        // Log if the permit is not immediately available
160        if self.semaphore.available_permits() == 0 {
161            eprintln!("Waiting for permit... ({} in use)", policy.max_concurrent);
162        }
163
164        // Acquire the permit and log when granted
165        let permit = self
166            .semaphore
167            .acquire()
168            .await
169            .map_err(|e| Error::Middleware(e.into()))?;
170
171        eprintln!(
172            "Permit granted: {} ({} permits left)",
173            cache_key,
174            self.semaphore.available_permits()
175        );
176
177        // Hold the permit until this function completes
178        let _permit_guard = permit;
179
180        sleep(Duration::from_millis(policy.base_delay_ms)).await;
181
182        let mut attempt = 0;
183
184        loop {
185            let req_clone = req.try_clone().expect("Request cloning failed");
186            let result = next.clone().run(req_clone, extensions).await;
187
188            match result {
189                Ok(resp) if resp.status().is_success() => return Ok(resp),
190                result if attempt >= policy.max_retries => return result,
191                _ => {
192                    attempt += 1;
193
194                    let backoff_duration = {
195                        let mut rng = rand::rng();
196                        Duration::from_millis(
197                            policy.base_delay_ms * 2u64.pow(attempt as u32)
198                                + rng.random_range(0..=policy.adaptive_jitter_ms),
199                        )
200                    };
201
202                    eprintln!(
203                        "Retry {}/{} for URL {} after {} ms",
204                        attempt,
205                        policy.max_retries,
206                        url,
207                        backoff_duration.as_millis()
208                    );
209
210                    sleep(backoff_duration).await;
211
212                    if attempt >= policy.max_retries {
213                        break;
214                    }
215                }
216            }
217        }
218
219        next.run(req, extensions).await
220    }
221}