reqwest_drive/throttle_middleware.rs
1use crate::DriveCache;
2use async_trait::async_trait;
3use http::Extensions;
4use rand::Rng;
5use reqwest::{Request, Response};
6use reqwest_middleware::{Error, Middleware, Next};
7use std::sync::Arc;
8use tokio::sync::Semaphore;
9use tokio::time::{sleep, Duration};
10
11/// Defines the throttling and backoff behavior for handling HTTP requests.
12///
13/// This policy determines the **rate limiting strategy** used for outgoing requests,
14/// including fixed delays, adaptive backoff, and retry settings.
15#[derive(Clone, Debug)]
16pub struct ThrottlePolicy {
17 /// The base delay (in milliseconds) applied before making a request.
18 ///
19 /// This ensures a **minimum delay** between consecutive requests.
20 pub base_delay_ms: u64,
21
22 /// The maximum random jitter (in milliseconds) added to the backoff delay.
23 ///
24 /// This prevents synchronization issues when multiple clients are making requests,
25 /// reducing the likelihood of rate-limit collisions.
26 pub adaptive_jitter_ms: u64,
27
28 /// The maximum number of concurrent requests allowed at any given time.
29 ///
30 /// This controls **parallel request execution**, ensuring that no more than
31 /// `max_concurrent` requests are in-flight simultaneously.
32 pub max_concurrent: usize,
33
34 /// The maximum number of retries allowed in case of failed requests.
35 ///
36 /// If a request fails (e.g., due to a **server error or rate limiting**),
37 /// it will be retried up to `max_retries` times with exponential backoff.
38 pub max_retries: usize,
39}
40
41impl Default for ThrottlePolicy {
42 /// Provides a sensible default throttling policy.
43 ///
44 /// This default configuration is suitable for most API use cases and includes:
45 /// - A **base delay** of 500ms between requests.
46 /// - A **random jitter** of up to 250ms to avoid synchronization issues.
47 /// - A **maximum of 5 concurrent requests** to prevent excessive load.
48 /// - A **maximum of 3 retries** for failed requests.
49 ///
50 /// # Returns
51 ///
52 /// A `ThrottlePolicy` instance with preconfigured defaults.
53 fn default() -> Self {
54 Self {
55 base_delay_ms: 500, // 500ms base delay between requests
56 adaptive_jitter_ms: 250, // Add up to 250ms random jitter
57 max_concurrent: 5, // Allow up to 5 concurrent requests
58 max_retries: 3, // Retry failed requests up to 3 times
59 }
60 }
61}
62
63/// Implements a throttling and exponential backoff middleware for HTTP requests.
64///
65/// This middleware **limits request concurrency** and applies **adaptive delays**
66/// between retries, helping to prevent rate-limiting issues when interacting
67/// with APIs that enforce request quotas.
68///
69/// Requests are throttled using a **semaphore-based** approach, ensuring that
70/// the maximum number of concurrent requests does not exceed `max_concurrent`.
71///
72/// If a request fails, it enters a **retry loop** where each retry is delayed
73/// according to an **exponential backoff strategy**.
74pub struct DriveThrottleBackoff {
75 /// Semaphore controlling the maximum number of concurrent requests.
76 semaphore: Arc<Semaphore>,
77
78 /// Defines the backoff and throttling behavior.
79 policy: ThrottlePolicy,
80
81 /// Cache layer for detecting previously cached responses.
82 cache: Arc<DriveCache>,
83}
84
85impl DriveThrottleBackoff {
86 /// Creates a new `DriveThrottleBackoff` middleware with the specified throttling policy.
87 ///
88 /// # Arguments
89 ///
90 /// * `policy` - The throttling configuration defining concurrency limits, delays, and retry behavior.
91 /// * `cache` - The shared cache instance used for **detecting previously cached requests**.
92 ///
93 /// # Returns
94 ///
95 /// A new instance of `DriveThrottleBackoff`.
96 pub fn new(policy: ThrottlePolicy, cache: Arc<DriveCache>) -> Self {
97 Self {
98 semaphore: Arc::new(Semaphore::new(policy.max_concurrent)),
99 policy,
100 cache,
101 }
102 }
103
104 #[cfg(any(test, debug_assertions))]
105 pub fn available_permits(&self) -> usize {
106 self.semaphore.available_permits()
107 }
108}
109
110#[async_trait]
111impl Middleware for DriveThrottleBackoff {
112 /// Handles throttling and retry logic for HTTP requests.
113 ///
114 /// This method:
115 /// 1. **Checks the cache**: If the request is already cached, it bypasses throttling.
116 /// 2. **Enforces concurrency limits**: Ensures no more than `max_concurrent` requests are in flight.
117 /// 3. **Applies an initial delay** before sending the request.
118 /// 4. **Retries failed requests**: Uses **exponential backoff** with jitter for failed requests.
119 ///
120 /// # Arguments
121 ///
122 /// * `req` - The incoming HTTP request.
123 /// * `extensions` - A mutable reference to request extensions, used for tracking metadata.
124 /// * `next` - The next middleware in the request chain.
125 ///
126 /// # Returns
127 ///
128 /// A `Result<Response, Error>` containing either:
129 /// - A successfully processed response.
130 /// - An error if the request failed after exhausting all retries.
131 ///
132 /// # Behavior
133 ///
134 /// - If the request is **already cached**, the middleware immediately forwards it.
135 /// - If **throttling is required**, it waits according to the configured delay.
136 /// - If a request fails, **exponential backoff** is applied before retrying.
137 async fn handle(
138 &self,
139 req: Request,
140 extensions: &mut Extensions,
141 next: Next<'_>,
142 ) -> Result<Response, Error> {
143 let url = req.url().to_string();
144
145 let cache_key = format!("{} {}", req.method(), &url);
146
147 if self.cache.is_cached(&req).await {
148 eprintln!("Using cache for: {}", &cache_key);
149
150 return next.run(req, extensions).await;
151 } else {
152 eprintln!("No cache found for: {}", &cache_key);
153 }
154
155 // Use a custom throttle policy if provided, otherwise default to `self.policy`
156 let custom_policy = extensions.get::<ThrottlePolicy>().cloned();
157 let policy = custom_policy.unwrap_or_else(|| self.policy.clone()); // Use override if available
158
159 // Log if the permit is not immediately available
160 if self.semaphore.available_permits() == 0 {
161 eprintln!("Waiting for permit... ({} in use)", policy.max_concurrent);
162 }
163
164 // Acquire the permit and log when granted
165 let permit = self
166 .semaphore
167 .acquire()
168 .await
169 .map_err(|e| Error::Middleware(e.into()))?;
170
171 eprintln!(
172 "Permit granted: {} ({} permits left)",
173 cache_key,
174 self.semaphore.available_permits()
175 );
176
177 // Hold the permit until this function completes
178 let _permit_guard = permit;
179
180 sleep(Duration::from_millis(policy.base_delay_ms)).await;
181
182 let mut attempt = 0;
183
184 loop {
185 let req_clone = req.try_clone().expect("Request cloning failed");
186 let result = next.clone().run(req_clone, extensions).await;
187
188 match result {
189 Ok(resp) if resp.status().is_success() => return Ok(resp),
190 result if attempt >= policy.max_retries => return result,
191 _ => {
192 attempt += 1;
193
194 let backoff_duration = {
195 let mut rng = rand::rng();
196 Duration::from_millis(
197 policy.base_delay_ms * 2u64.pow(attempt as u32)
198 + rng.random_range(0..=policy.adaptive_jitter_ms),
199 )
200 };
201
202 eprintln!(
203 "Retry {}/{} for URL {} after {} ms",
204 attempt,
205 policy.max_retries,
206 url,
207 backoff_duration.as_millis()
208 );
209
210 sleep(backoff_duration).await;
211
212 if attempt >= policy.max_retries {
213 break;
214 }
215 }
216 }
217 }
218
219 next.run(req, extensions).await
220 }
221}