1use async_trait::async_trait;
17use log::{debug, info};
18use moka::future::Cache;
19use rand::distributions::{Distribution, Uniform};
20use std::hash::Hash;
21use std::sync::Arc;
22use std::time::Duration;
23use tokio::sync::Mutex;
24use tokio::time::{Instant, sleep};
25
26use governor::clock::DefaultClock;
27use governor::state::{InMemoryState, NotKeyed};
28use governor::{Quota, RateLimiter as GovernorRateLimiter};
29use std::num::NonZeroU32;
30
31use crate::middleware::{Middleware, MiddlewareAction};
32use spider_util::constants::{
33 MIDDLEWARE_CACHE_CAPACITY, MIDDLEWARE_CACHE_TTL_SECS, RATE_LIMIT_INITIAL_DELAY_MS,
34 RATE_LIMIT_MAX_DELAY_MS, RATE_LIMIT_MAX_JITTER_MS, RATE_LIMIT_MIN_DELAY_MS,
35};
36use spider_util::error::SpiderError;
37use spider_util::request::Request;
38use spider_util::response::Response;
39
40#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
42pub enum Scope {
43 Global,
45 Domain,
47}
48
49#[async_trait]
51pub trait RateLimiter: Send + Sync {
52 async fn acquire(&self);
54 async fn adjust(&self, response: &Response);
56 async fn current_delay(&self) -> Duration;
58}
59
60const INITIAL_DELAY: Duration = Duration::from_millis(RATE_LIMIT_INITIAL_DELAY_MS);
61const MIN_DELAY: Duration = Duration::from_millis(RATE_LIMIT_MIN_DELAY_MS);
62const MAX_DELAY: Duration = Duration::from_millis(RATE_LIMIT_MAX_DELAY_MS);
63
64const ERROR_PENALTY_MULTIPLIER: f64 = 1.5;
65const SUCCESS_DECAY_MULTIPLIER: f64 = 0.95;
66const FORBIDDEN_PENALTY_MULTIPLIER: f64 = 1.2;
67
68struct AdaptiveState {
69 delay: Duration,
70 next_allowed_at: Instant,
71}
72
73pub struct AdaptiveLimiter {
75 state: Mutex<AdaptiveState>,
76 jitter: bool,
77}
78
79impl AdaptiveLimiter {
80 pub fn new(initial_delay: Duration, jitter: bool) -> Self {
82 Self {
83 state: Mutex::new(AdaptiveState {
84 delay: initial_delay,
85 next_allowed_at: Instant::now(),
86 }),
87 jitter,
88 }
89 }
90
91 fn apply_jitter(&self, delay: Duration) -> Duration {
92 if !self.jitter || delay.is_zero() {
93 return delay;
94 }
95
96 let max_jitter = Duration::from_millis(RATE_LIMIT_MAX_JITTER_MS);
97 let jitter_window = delay.mul_f64(0.25).min(max_jitter);
98
99 let low = delay.saturating_sub(jitter_window);
100 let high = delay + jitter_window;
101
102 let mut rng = rand::thread_rng();
103 let uniform = Uniform::new_inclusive(low, high);
104 uniform.sample(&mut rng)
105 }
106}
107
108#[async_trait]
109impl RateLimiter for AdaptiveLimiter {
110 async fn acquire(&self) {
111 let sleep_duration = {
112 let mut state = self.state.lock().await;
113 let now = Instant::now();
114
115 let delay = state.delay;
116 if now < state.next_allowed_at {
117 let wait = state.next_allowed_at - now;
118 state.next_allowed_at += delay;
119 wait
120 } else {
121 state.next_allowed_at = now + delay;
122 Duration::ZERO
123 }
124 };
125
126 let sleep_duration = self.apply_jitter(sleep_duration);
127 if !sleep_duration.is_zero() {
128 debug!("Rate limiting: sleeping for {:?}", sleep_duration);
129 sleep(sleep_duration).await;
130 }
131 }
132
133 async fn adjust(&self, response: &Response) {
134 let mut state = self.state.lock().await;
135
136 let old_delay = state.delay;
137 let status = response.status.as_u16();
138 let new_delay = match status {
139 200..=399 => state.delay.mul_f64(SUCCESS_DECAY_MULTIPLIER),
140 403 => state.delay.mul_f64(FORBIDDEN_PENALTY_MULTIPLIER),
141 429 | 500..=599 => state.delay.mul_f64(ERROR_PENALTY_MULTIPLIER),
142 _ => state.delay,
143 };
144
145 state.delay = new_delay.clamp(MIN_DELAY, MAX_DELAY);
146
147 if old_delay != state.delay {
148 debug!(
149 "Adjusting delay for status {}: {:?} -> {:?}",
150 status, old_delay, state.delay
151 );
152 }
153 }
154
155 async fn current_delay(&self) -> Duration {
156 self.state.lock().await.delay
157 }
158}
159
160pub struct TokenBucketLimiter {
162 limiter: Arc<GovernorRateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
163}
164
165impl TokenBucketLimiter {
166 pub fn new(requests_per_second: u32) -> Self {
172 let requests_per_second = match NonZeroU32::new(requests_per_second) {
173 Some(rps) => rps,
174 None => panic!("requests_per_second must be non-zero"),
175 };
176 let quota = Quota::per_second(requests_per_second);
177 TokenBucketLimiter {
178 limiter: Arc::new(GovernorRateLimiter::direct_with_clock(
179 quota,
180 &DefaultClock::default(),
181 )),
182 }
183 }
184}
185
186#[async_trait]
187impl RateLimiter for TokenBucketLimiter {
188 async fn acquire(&self) {
189 self.limiter.until_ready().await;
190 }
191
192 async fn adjust(&self, _response: &Response) {}
194
195 async fn current_delay(&self) -> Duration {
196 Duration::ZERO
197 }
198}
199
200pub struct RateLimitMiddleware {
202 scope: Scope,
203 limiters: Cache<String, Arc<dyn RateLimiter>>,
204 limiter_factory: Arc<dyn Fn() -> Arc<dyn RateLimiter> + Send + Sync>,
205}
206
207impl RateLimitMiddleware {
208 pub fn builder() -> RateLimitMiddlewareBuilder {
210 RateLimitMiddlewareBuilder::default()
211 }
212
213 fn scope_key(&self, request: &Request) -> String {
214 match self.scope {
215 Scope::Global => "global".to_string(),
216 Scope::Domain => spider_util::util::normalize_origin(request),
217 }
218 }
219}
220
221#[async_trait]
222impl<C: Send + Sync> Middleware<C> for RateLimitMiddleware {
223 fn name(&self) -> &str {
224 "RateLimitMiddleware"
225 }
226
227 async fn process_request(
228 &mut self,
229 _client: &C,
230 request: Request,
231 ) -> Result<MiddlewareAction<Request>, SpiderError> {
232 let key = self.scope_key(&request);
233
234 let limiter = self
235 .limiters
236 .get_with(key.clone(), async { (self.limiter_factory)() })
237 .await;
238
239 let current_delay = limiter.current_delay().await;
240 debug!(
241 "Acquiring lock for key '{}' (delay: {:?})",
242 key, current_delay
243 );
244
245 limiter.acquire().await;
246 Ok(MiddlewareAction::Continue(request))
247 }
248
249 async fn process_response(
250 &mut self,
251 response: Response,
252 ) -> Result<MiddlewareAction<Response>, SpiderError> {
253 let key = self.scope_key(&response.request_from_response());
254
255 if let Some(limiter) = self.limiters.get(&key).await {
256 let old_delay = limiter.current_delay().await;
257 limiter.adjust(&response).await;
258 let new_delay = limiter.current_delay().await;
259 if old_delay != new_delay {
260 debug!(
261 "Adjusted rate limit for key '{}': {:?} -> {:?}",
262 key, old_delay, new_delay
263 );
264 }
265 }
266
267 Ok(MiddlewareAction::Continue(response))
268 }
269}
270
271pub struct RateLimitMiddlewareBuilder {
273 scope: Scope,
274 cache_ttl: Duration,
275 cache_capacity: u64,
276 limiter_factory: Box<dyn Fn() -> Arc<dyn RateLimiter> + Send + Sync>,
277}
278
279impl Default for RateLimitMiddlewareBuilder {
280 fn default() -> Self {
281 Self {
282 scope: Scope::Domain,
283 cache_ttl: Duration::from_secs(MIDDLEWARE_CACHE_TTL_SECS),
284 cache_capacity: MIDDLEWARE_CACHE_CAPACITY,
285 limiter_factory: Box::new(|| Arc::new(AdaptiveLimiter::new(INITIAL_DELAY, true))),
286 }
287 }
288}
289
290impl RateLimitMiddlewareBuilder {
291 pub fn scope(mut self, scope: Scope) -> Self {
293 self.scope = scope;
294 self
295 }
296
297 pub fn use_token_bucket_limiter(mut self, requests_per_second: u32) -> Self {
299 self.limiter_factory =
300 Box::new(move || Arc::new(TokenBucketLimiter::new(requests_per_second)));
301 self
302 }
303
304 pub fn limiter(mut self, limiter: impl RateLimiter + 'static) -> Self {
306 let arc = Arc::new(limiter);
307 self.limiter_factory = Box::new(move || arc.clone());
308 self
309 }
310
311 pub fn limiter_factory(
313 mut self,
314 factory: impl Fn() -> Arc<dyn RateLimiter> + Send + Sync + 'static,
315 ) -> Self {
316 self.limiter_factory = Box::new(factory);
317 self
318 }
319
320 pub fn build(self) -> RateLimitMiddleware {
322 info!(
323 "Initializing RateLimitMiddleware with config: scope={:?}, cache_ttl={:?}, cache_capacity={}",
324 self.scope, self.cache_ttl, self.cache_capacity
325 );
326 RateLimitMiddleware {
327 scope: self.scope,
328 limiters: Cache::builder()
329 .time_to_idle(self.cache_ttl)
330 .max_capacity(self.cache_capacity)
331 .build(),
332 limiter_factory: self.limiter_factory.into(),
333 }
334 }
335}