reqwest_drive/throttle_middleware.rs
1use crate::DriveCache;
2use async_trait::async_trait;
3use http::Extensions;
4use rand::Rng;
5use reqwest::{Request, Response};
6use reqwest_middleware::{Error, Middleware, Next};
7use std::sync::Arc;
8use tokio::sync::Semaphore;
9use tokio::time::{sleep, Duration};
10
11/// Defines the throttling and backoff behavior for handling HTTP requests.
12///
13/// This policy determines the **rate limiting strategy** used for outgoing requests,
14/// including fixed delays, adaptive backoff, and retry settings.
15#[derive(Clone, Debug)]
16pub struct ThrottlePolicy {
17 /// The base delay (in milliseconds) applied before making a request.
18 ///
19 /// This ensures a **minimum delay** between consecutive requests.
20 pub base_delay_ms: u64,
21
22 /// The maximum random jitter (in milliseconds) added to the backoff delay.
23 ///
24 /// This prevents synchronization issues when multiple clients are making requests,
25 /// reducing the likelihood of rate-limit collisions.
26 pub adaptive_jitter_ms: u64,
27
28 /// The maximum number of concurrent requests allowed at any given time.
29 ///
30 /// This controls **parallel request execution**, ensuring that no more than
31 /// `max_concurrent` requests are in-flight simultaneously.
32 pub max_concurrent: usize,
33
34 /// The maximum number of retries allowed in case of failed requests.
35 ///
36 /// If a request fails (e.g., due to a **server error or rate limiting**),
37 /// it will be retried up to `max_retries` times with exponential backoff.
38 pub max_retries: usize,
39}
40
41/// Implements a throttling and exponential backoff middleware for HTTP requests.
42///
43/// This middleware **limits request concurrency** and applies **adaptive delays**
44/// between retries, helping to prevent rate-limiting issues when interacting
45/// with APIs that enforce request quotas.
46///
47/// Requests are throttled using a **semaphore-based** approach, ensuring that
48/// the maximum number of concurrent requests does not exceed `max_concurrent`.
49///
50/// If a request fails, it enters a **retry loop** where each retry is delayed
51/// according to an **exponential backoff strategy**.
52pub struct DriveThrottleBackoff {
53 /// Semaphore controlling the maximum number of concurrent requests.
54 semaphore: Arc<Semaphore>,
55
56 /// Defines the backoff and throttling behavior.
57 policy: ThrottlePolicy,
58
59 /// Cache layer for detecting previously cached responses.
60 cache: Arc<DriveCache>,
61}
62
63impl DriveThrottleBackoff {
64 /// Creates a new `DriveThrottleBackoff` middleware with the specified throttling policy.
65 ///
66 /// # Arguments
67 ///
68 /// * `policy` - The throttling configuration defining concurrency limits, delays, and retry behavior.
69 /// * `cache` - The shared cache instance used for **detecting previously cached requests**.
70 ///
71 /// # Returns
72 ///
73 /// A new instance of `DriveThrottleBackoff`.
74 pub fn new(policy: ThrottlePolicy, cache: Arc<DriveCache>) -> Self {
75 Self {
76 semaphore: Arc::new(Semaphore::new(policy.max_concurrent)),
77 policy,
78 cache,
79 }
80 }
81
82 #[cfg(any(test, debug_assertions))]
83 pub fn available_permits(&self) -> usize {
84 self.semaphore.available_permits()
85 }
86}
87
88#[async_trait]
89impl Middleware for DriveThrottleBackoff {
90 /// Handles throttling and retry logic for HTTP requests.
91 ///
92 /// This method:
93 /// 1. **Checks the cache**: If the request is already cached, it bypasses throttling.
94 /// 2. **Enforces concurrency limits**: Ensures no more than `max_concurrent` requests are in flight.
95 /// 3. **Applies an initial delay** before sending the request.
96 /// 4. **Retries failed requests**: Uses **exponential backoff** with jitter for failed requests.
97 ///
98 /// # Arguments
99 ///
100 /// * `req` - The incoming HTTP request.
101 /// * `extensions` - A mutable reference to request extensions, used for tracking metadata.
102 /// * `next` - The next middleware in the request chain.
103 ///
104 /// # Returns
105 ///
106 /// A `Result<Response, Error>` containing either:
107 /// - A successfully processed response.
108 /// - An error if the request failed after exhausting all retries.
109 ///
110 /// # Behavior
111 ///
112 /// - If the request is **already cached**, the middleware immediately forwards it.
113 /// - If **throttling is required**, it waits according to the configured delay.
114 /// - If a request fails, **exponential backoff** is applied before retrying.
115 async fn handle(
116 &self,
117 req: Request,
118 extensions: &mut Extensions,
119 next: Next<'_>,
120 ) -> Result<Response, Error> {
121 let url = req.url().to_string();
122
123 let cache_key = format!("{} {}", req.method(), &url);
124
125 if self.cache.is_cached(&req).await {
126 eprintln!("Using cache for: {}", &cache_key);
127
128 return next.run(req, extensions).await;
129 } else {
130 eprintln!("No cache found for: {}", &cache_key);
131 }
132
133 // Log if the permit is not immediately available
134 if self.semaphore.available_permits() == 0 {
135 eprintln!(
136 "Waiting for permit... ({} in use)",
137 self.policy.max_concurrent
138 );
139 }
140
141 // Acquire the permit and log when granted
142 let permit = self
143 .semaphore
144 .acquire()
145 .await
146 .map_err(|e| Error::Middleware(e.into()))?;
147
148 eprintln!(
149 "Permit granted: {} ({} permits left)",
150 cache_key,
151 self.semaphore.available_permits()
152 );
153
154 // Hold the permit until this function completes
155 let _permit_guard = permit;
156
157 sleep(Duration::from_millis(self.policy.base_delay_ms)).await;
158
159 let mut attempt = 0;
160
161 loop {
162 let req_clone = req.try_clone().expect("Request cloning failed");
163 let result = next.clone().run(req_clone, extensions).await;
164
165 match result {
166 Ok(resp) if resp.status().is_success() => return Ok(resp),
167 result if attempt >= self.policy.max_retries => return result,
168 _ => {
169 attempt += 1;
170
171 let backoff_duration = {
172 let mut rng = rand::rng();
173 Duration::from_millis(
174 self.policy.base_delay_ms * 2u64.pow(attempt as u32)
175 + rng.random_range(0..=self.policy.adaptive_jitter_ms),
176 )
177 };
178
179 eprintln!(
180 "Retry {}/{} for URL {} after {} ms",
181 attempt,
182 self.policy.max_retries,
183 url,
184 backoff_duration.as_millis()
185 );
186
187 sleep(backoff_duration).await;
188
189 if attempt >= self.policy.max_retries {
190 break;
191 }
192 }
193 }
194 }
195
196 next.run(req, extensions).await
197 }
198}