reqwest_drive/throttle_middleware.rs
1use crate::{
2 DriveCache,
3 cache_middleware::{CacheBust, CacheBypass},
4};
5use async_trait::async_trait;
6use http::Extensions;
7use rand::RngExt;
8use reqwest::{Request, Response};
9use reqwest_middleware::{Error, Middleware, Next};
10use std::sync::Arc;
11use tokio::sync::Semaphore;
12use tokio::time::{Duration, sleep};
13
14/// Defines the throttling and backoff behavior for handling HTTP requests.
15///
16/// This policy determines the **rate limiting strategy** used for outgoing requests,
17/// including fixed delays, adaptive backoff, and retry settings.
18#[derive(Clone, Debug)]
19pub struct ThrottlePolicy {
20 /// The base delay (in milliseconds) applied before making a request.
21 ///
22 /// This ensures a **minimum delay** between consecutive requests.
23 pub base_delay_ms: u64,
24
25 /// The maximum random jitter (in milliseconds) added to the backoff delay.
26 ///
27 /// This prevents synchronization issues when multiple clients are making requests,
28 /// reducing the likelihood of rate-limit collisions.
29 pub adaptive_jitter_ms: u64,
30
31 /// The maximum number of concurrent requests allowed at any given time.
32 ///
33 /// This controls **parallel request execution**, ensuring that no more than
34 /// `max_concurrent` requests are in-flight simultaneously.
35 pub max_concurrent: usize,
36
37 /// The maximum number of retries allowed in case of failed requests.
38 ///
39 /// If a request fails (e.g., due to a **server error or rate limiting**),
40 /// it will be retried up to `max_retries` times with exponential backoff.
41 pub max_retries: usize,
42}
43
44impl Default for ThrottlePolicy {
45 /// Provides a sensible default throttling policy.
46 ///
47 /// This default configuration is suitable for most API use cases and includes:
48 /// - A **base delay** of 500ms between requests.
49 /// - A **random jitter** of up to 250ms to avoid synchronization issues.
50 /// - A **maximum of 5 concurrent requests** to prevent excessive load.
51 /// - A **maximum of 3 retries** for failed requests.
52 ///
53 /// # Returns
54 ///
55 /// A `ThrottlePolicy` instance with preconfigured defaults.
56 fn default() -> Self {
57 Self {
58 base_delay_ms: 500, // 500ms base delay between requests
59 adaptive_jitter_ms: 250, // Add up to 250ms random jitter
60 max_concurrent: 5, // Allow up to 5 concurrent requests
61 max_retries: 3, // Retry failed requests up to 3 times
62 }
63 }
64}
65
66/// Implements a throttling and exponential backoff middleware for HTTP requests.
67///
68/// This middleware **limits request concurrency** and applies **adaptive delays**
69/// between retries, helping to prevent rate-limiting issues when interacting
70/// with APIs that enforce request quotas.
71///
72/// It can run in two modes:
73/// - **Cache-aware mode** via [`DriveThrottleBackoff::new`], where cached requests can bypass throttling.
74/// - **Throttle-only mode** via [`DriveThrottleBackoff::without_cache`], where all requests are throttled.
75///
76/// Requests are throttled using a **semaphore-based** approach, ensuring that
77/// the maximum number of concurrent requests does not exceed `max_concurrent`.
78///
79/// If a request fails, it enters a **retry loop** where each retry is delayed
80/// according to an **exponential backoff strategy**.
81pub struct DriveThrottleBackoff {
82 /// Semaphore controlling the maximum number of concurrent requests.
83 semaphore: Arc<Semaphore>,
84
85 /// Defines the backoff and throttling behavior.
86 policy: ThrottlePolicy,
87
88 /// Optional cache layer for detecting previously cached responses.
89 cache: Option<Arc<DriveCache>>,
90}
91
92impl DriveThrottleBackoff {
93 /// Creates a new `DriveThrottleBackoff` middleware with the specified throttling policy.
94 ///
95 /// # Arguments
96 ///
97 /// * `policy` - The throttling configuration defining concurrency limits, delays, and retry behavior.
98 /// * `cache` - The shared cache instance used for **detecting previously cached requests**.
99 ///
100 /// # Returns
101 ///
102 /// A new instance of `DriveThrottleBackoff`.
103 pub fn new(policy: ThrottlePolicy, cache: Arc<DriveCache>) -> Self {
104 Self {
105 semaphore: Arc::new(Semaphore::new(policy.max_concurrent)),
106 policy,
107 cache: Some(cache),
108 }
109 }
110
111 /// Creates a new `DriveThrottleBackoff` middleware without any cache integration.
112 ///
113 /// In this mode, every request is throttled based on the configured policy,
114 /// and no cache checks are performed.
115 pub fn without_cache(policy: ThrottlePolicy) -> Self {
116 Self {
117 semaphore: Arc::new(Semaphore::new(policy.max_concurrent)),
118 policy,
119 cache: None,
120 }
121 }
122
123 pub fn available_permits(&self) -> usize {
124 self.semaphore.available_permits()
125 }
126}
127
128#[async_trait]
129impl Middleware for DriveThrottleBackoff {
130 /// Handles throttling and retry logic for HTTP requests.
131 ///
132 /// This method:
133 /// 1. **Optionally checks the cache**: In cache-aware mode, cached requests bypass throttling.
134 /// 2. **Enforces concurrency limits**: Ensures no more than `max_concurrent` requests are in flight.
135 /// 3. **Applies an initial delay** before sending the request.
136 /// 4. **Retries failed requests**: Uses **exponential backoff** with jitter for failed requests.
137 ///
138 /// # Arguments
139 ///
140 /// * `req` - The incoming HTTP request.
141 /// * `extensions` - A mutable reference to request extensions, used for tracking metadata.
142 /// * `next` - The next middleware in the request chain.
143 ///
144 /// # Returns
145 ///
146 /// A `Result<Response, Error>` containing either:
147 /// - A successfully processed response.
148 /// - An error if the request failed after exhausting all retries.
149 ///
150 /// # Behavior
151 ///
152 /// - In cache-aware mode, if the request is **already cached**, the middleware immediately forwards it.
153 /// - In throttle-only mode, cache checks are skipped.
154 /// - If **throttling is required**, it waits according to the configured delay.
155 /// - If a request fails, **exponential backoff** is applied before retrying.
156 async fn handle(
157 &self,
158 req: Request,
159 extensions: &mut Extensions,
160 next: Next<'_>,
161 ) -> Result<Response, Error> {
162 let url = req.url().to_string();
163 let bypass_cache = extensions
164 .get::<CacheBypass>()
165 .map(|flag| flag.0)
166 .unwrap_or(false);
167 let bust_cache = extensions
168 .get::<CacheBust>()
169 .map(|flag| flag.0)
170 .unwrap_or(false);
171
172 let cache_key = format!("{} {}", req.method(), &url);
173
174 if !bypass_cache
175 && !bust_cache
176 && let Some(cache) = &self.cache
177 {
178 if cache.is_cached(&req).await {
179 tracing::debug!("Using cache for: {}", &cache_key);
180
181 return next.run(req, extensions).await;
182 } else {
183 tracing::debug!("No cache found for: {}", &cache_key);
184 }
185 }
186
187 // Use a custom throttle policy if provided, otherwise default to `self.policy`
188 let custom_policy = extensions.get::<ThrottlePolicy>().cloned();
189 let policy = custom_policy.unwrap_or_else(|| self.policy.clone()); // Use override if available
190
191 // Log if the permit is not immediately available
192 if self.semaphore.available_permits() == 0 {
193 tracing::debug!("Waiting for permit... ({} in use)", policy.max_concurrent);
194 }
195
196 // Acquire the permit and log when granted
197 let permit = self
198 .semaphore
199 .acquire()
200 .await
201 .map_err(|e| Error::Middleware(e.into()))?;
202
203 tracing::debug!(
204 "Permit granted: {} ({} permits left)",
205 cache_key,
206 self.semaphore.available_permits()
207 );
208
209 // Hold the permit until this function completes
210 let _permit_guard = permit;
211
212 sleep(Duration::from_millis(policy.base_delay_ms)).await;
213
214 let mut attempt = 0;
215
216 loop {
217 let req_clone = req.try_clone().expect("Request cloning failed");
218 let result = next.clone().run(req_clone, extensions).await;
219
220 match result {
221 Ok(resp) if resp.status().is_success() => return Ok(resp),
222 result if attempt >= policy.max_retries => return result,
223 _ => {
224 attempt += 1;
225
226 let backoff_duration = {
227 let mut rng = rand::rng();
228 Duration::from_millis(
229 policy.base_delay_ms * 2u64.pow(attempt as u32)
230 + rng.random_range(0..=policy.adaptive_jitter_ms),
231 )
232 };
233
234 tracing::debug!(
235 "Retry {}/{} for URL {} after {} ms",
236 attempt,
237 policy.max_retries,
238 url,
239 backoff_duration.as_millis()
240 );
241
242 sleep(backoff_duration).await;
243
244 if attempt >= policy.max_retries {
245 break;
246 }
247 }
248 }
249 }
250
251 next.run(req, extensions).await
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258 use reqwest_middleware::ClientBuilder;
259
260 #[test]
261 fn throttle_policy_default_values_are_stable() {
262 let policy = ThrottlePolicy::default();
263
264 assert_eq!(policy.base_delay_ms, 500);
265 assert_eq!(policy.adaptive_jitter_ms, 250);
266 assert_eq!(policy.max_concurrent, 5);
267 assert_eq!(policy.max_retries, 3);
268 }
269
270 #[tokio::test]
271 async fn closed_semaphore_returns_middleware_error() {
272 let throttle = Arc::new(DriveThrottleBackoff::without_cache(ThrottlePolicy {
273 base_delay_ms: 1,
274 adaptive_jitter_ms: 0,
275 max_concurrent: 1,
276 max_retries: 0,
277 }));
278
279 // Force semaphore acquire to fail so we cover the map_err path.
280 throttle.semaphore.close();
281
282 let client = ClientBuilder::new(reqwest::Client::new())
283 .with_arc(throttle)
284 .build();
285
286 let error = client
287 .get("https://example.test/closed-semaphore")
288 .send()
289 .await
290 .expect_err("closed semaphore should return middleware error");
291
292 assert!(
293 matches!(error, reqwest_middleware::Error::Middleware(_)),
294 "expected middleware error variant, got: {:?}",
295 error
296 );
297 }
298}