route_ratelimit/middleware.rs
1//! The rate limiting middleware implementation.
2
3use async_trait::async_trait;
4use dashmap::DashMap;
5use http::Extensions;
6use rand::Rng;
7use reqwest::{Request, Response};
8use reqwest_middleware::{Middleware, Next, Result as MiddlewareResult};
9use std::sync::Arc;
10use std::sync::atomic::Ordering;
11use std::time::Instant;
12use tokio::time::sleep;
13
14use crate::builder::RateLimitBuilder;
15use crate::error::RateLimitError;
16use crate::gcra::GcraState;
17use crate::types::{Route, RouteKey, ThrottleBehavior};
18
19/// The rate limiting middleware.
20///
21/// This middleware tracks rate limits and either delays or rejects requests
22/// based on the configured routes.
23///
24/// # Thread Safety
25///
26/// `RateLimitMiddleware` is `Send + Sync` and can be safely shared across
27/// threads and async tasks. The internal state uses lock-free atomic operations
28/// (via [`DashMap`] and atomic integers) to ensure correct behavior under
29/// concurrent access. When cloned, clones share the same rate limit state,
30/// so limits are enforced across all clones.
31#[derive(Debug, Clone)]
32pub struct RateLimitMiddleware {
33 pub(crate) routes: Arc<Vec<Route>>,
34 pub(crate) state: Arc<DashMap<RouteKey, GcraState>>,
35 pub(crate) start_instant: Instant,
36}
37
38impl RateLimitMiddleware {
39 /// Create a new builder for configuring the middleware.
40 #[must_use]
41 pub fn builder() -> RateLimitBuilder {
42 RateLimitBuilder::new()
43 }
44
45 #[inline]
46 pub(crate) fn now_nanos(&self) -> u64 {
47 // Use saturating conversion to prevent overflow on very long-running processes
48 // (would require running for ~585 years to overflow)
49 self.start_instant
50 .elapsed()
51 .as_nanos()
52 .min(u64::MAX as u128) as u64
53 }
54
55 /// Remove stale rate limit state entries that haven't been accessed recently.
56 ///
57 /// An entry is considered stale when its theoretical arrival time (TAT) has
58 /// recovered past twice the limit window, meaning the burst capacity has been
59 /// fully recovered for an extended period.
60 ///
61 /// This method should be called periodically in long-running applications to
62 /// prevent unbounded memory growth from accumulated state entries.
63 ///
64 /// # Example
65 ///
66 /// ```rust,no_run
67 /// use route_ratelimit::RateLimitMiddleware;
68 /// use std::time::Duration;
69 ///
70 /// # async fn example() {
71 /// let middleware = RateLimitMiddleware::builder()
72 /// .route(|r| r.limit(100, Duration::from_secs(10)))
73 /// .build();
74 ///
75 /// // Call periodically to clean up stale entries
76 /// middleware.cleanup();
77 /// # }
78 /// ```
79 pub fn cleanup(&self) {
80 let now = self.now_nanos();
81 self.state.retain(|key, gcra_state| {
82 // Bounds check to handle edge cases
83 if key.route_index >= self.routes.len() {
84 return false;
85 }
86 let route = &self.routes[key.route_index];
87 if key.limit_index >= route.limits.len() {
88 return false;
89 }
90
91 let limit = &route.limits[key.limit_index];
92 let window_nanos = limit.window.as_nanos() as u64;
93 let tat = gcra_state.tat(Ordering::Acquire);
94
95 // Keep if TAT is within 2x window of now (recently active)
96 // An entry with TAT far in the past has fully recovered and can be removed
97 tat > now.saturating_sub(window_nanos.saturating_mul(2))
98 });
99 }
100
101 /// Returns the number of active rate limit state entries.
102 ///
103 /// This can be useful for monitoring memory usage.
104 #[must_use]
105 pub fn state_count(&self) -> usize {
106 self.state.len()
107 }
108
109 async fn check_and_apply_limits(&self, req: &Request) -> Result<(), RateLimitError> {
110 'outer: loop {
111 let now = self.now_nanos();
112
113 for (route_index, route) in self.routes.iter().enumerate() {
114 if !route.matches(req) {
115 continue;
116 }
117
118 for (limit_index, limit) in route.limits.iter().enumerate() {
119 let key = RouteKey {
120 route_index,
121 limit_index,
122 };
123
124 let emission_interval_nanos = limit.emission_interval().as_nanos() as u64;
125 let limit_nanos = limit.window.as_nanos() as u64;
126
127 // Get or create GCRA state for this route+limit
128 let state = self.state.entry(key).or_insert_with(GcraState::new);
129
130 match state.try_acquire(now, emission_interval_nanos, limit_nanos) {
131 Ok(()) => {}
132 Err(wait_duration) => {
133 match route.on_limit {
134 ThrottleBehavior::Delay => {
135 // Release the lock before sleeping
136 drop(state);
137 // Add jitter (0-50% of wait duration) to prevent thundering herd
138 let jitter_max_nanos = wait_duration.as_nanos() as u64 / 2;
139 let jitter_nanos = if jitter_max_nanos > 0 {
140 rand::rng().random_range(0..=jitter_max_nanos)
141 } else {
142 0
143 };
144 let sleep_duration = wait_duration
145 + std::time::Duration::from_nanos(jitter_nanos);
146 sleep(sleep_duration).await;
147 // After sleeping, restart the entire check with fresh timestamp
148 continue 'outer;
149 }
150 ThrottleBehavior::Error => {
151 return Err(RateLimitError::RateLimited(wait_duration));
152 }
153 }
154 }
155 }
156 }
157 }
158
159 // All limits passed, we can proceed
160 break Ok(());
161 }
162 }
163}
164
165#[async_trait]
166impl Middleware for RateLimitMiddleware {
167 async fn handle(
168 &self,
169 req: Request,
170 extensions: &mut Extensions,
171 next: Next<'_>,
172 ) -> MiddlewareResult<Response> {
173 // Check and apply rate limits
174 self.check_and_apply_limits(&req).await?;
175
176 // Proceed with the request
177 next.run(req, extensions).await
178 }
179}
180
181impl Default for RateLimitMiddleware {
182 /// Create a middleware with no routes configured.
183 ///
184 /// All requests will pass through without any rate limiting.
185 /// Use [`RateLimitMiddleware::builder()`] to configure routes.
186 fn default() -> Self {
187 Self::builder().build()
188 }
189}