route_ratelimit/
middleware.rs

1//! The rate limiting middleware implementation.
2
3use async_trait::async_trait;
4use dashmap::DashMap;
5use http::Extensions;
6use rand::Rng;
7use reqwest::{Request, Response};
8use reqwest_middleware::{Middleware, Next, Result as MiddlewareResult};
9use std::sync::Arc;
10use std::sync::atomic::Ordering;
11use std::time::Instant;
12use tokio::time::sleep;
13
14use crate::builder::RateLimitBuilder;
15use crate::error::RateLimitError;
16use crate::gcra::GcraState;
17use crate::types::{Route, RouteKey, ThrottleBehavior};
18
19/// The rate limiting middleware.
20///
21/// This middleware tracks rate limits and either delays or rejects requests
22/// based on the configured routes.
23///
24/// # Thread Safety
25///
26/// `RateLimitMiddleware` is `Send + Sync` and can be safely shared across
27/// threads and async tasks. The internal state uses lock-free atomic operations
28/// (via [`DashMap`] and atomic integers) to ensure correct behavior under
29/// concurrent access. When cloned, clones share the same rate limit state,
30/// so limits are enforced across all clones.
31#[derive(Debug, Clone)]
32pub struct RateLimitMiddleware {
33    pub(crate) routes: Arc<Vec<Route>>,
34    pub(crate) state: Arc<DashMap<RouteKey, GcraState>>,
35    pub(crate) start_instant: Instant,
36}
37
38impl RateLimitMiddleware {
39    /// Create a new builder for configuring the middleware.
40    #[must_use]
41    pub fn builder() -> RateLimitBuilder {
42        RateLimitBuilder::new()
43    }
44
45    #[inline]
46    pub(crate) fn now_nanos(&self) -> u64 {
47        // Use saturating conversion to prevent overflow on very long-running processes
48        // (would require running for ~585 years to overflow)
49        self.start_instant
50            .elapsed()
51            .as_nanos()
52            .min(u64::MAX as u128) as u64
53    }
54
55    /// Remove stale rate limit state entries that haven't been accessed recently.
56    ///
57    /// An entry is considered stale when its theoretical arrival time (TAT) has
58    /// recovered past twice the limit window, meaning the burst capacity has been
59    /// fully recovered for an extended period.
60    ///
61    /// This method should be called periodically in long-running applications to
62    /// prevent unbounded memory growth from accumulated state entries.
63    ///
64    /// # Example
65    ///
66    /// ```rust,no_run
67    /// use route_ratelimit::RateLimitMiddleware;
68    /// use std::time::Duration;
69    ///
70    /// # async fn example() {
71    /// let middleware = RateLimitMiddleware::builder()
72    ///     .route(|r| r.limit(100, Duration::from_secs(10)))
73    ///     .build();
74    ///
75    /// // Call periodically to clean up stale entries
76    /// middleware.cleanup();
77    /// # }
78    /// ```
79    pub fn cleanup(&self) {
80        let now = self.now_nanos();
81        self.state.retain(|key, gcra_state| {
82            // Bounds check to handle edge cases
83            if key.route_index >= self.routes.len() {
84                return false;
85            }
86            let route = &self.routes[key.route_index];
87            if key.limit_index >= route.limits.len() {
88                return false;
89            }
90
91            let limit = &route.limits[key.limit_index];
92            let window_nanos = limit.window.as_nanos() as u64;
93            let tat = gcra_state.tat(Ordering::Acquire);
94
95            // Keep if TAT is within 2x window of now (recently active)
96            // An entry with TAT far in the past has fully recovered and can be removed
97            tat > now.saturating_sub(window_nanos.saturating_mul(2))
98        });
99    }
100
101    /// Returns the number of active rate limit state entries.
102    ///
103    /// This can be useful for monitoring memory usage.
104    #[must_use]
105    pub fn state_count(&self) -> usize {
106        self.state.len()
107    }
108
109    async fn check_and_apply_limits(&self, req: &Request) -> Result<(), RateLimitError> {
110        'outer: loop {
111            let now = self.now_nanos();
112
113            for (route_index, route) in self.routes.iter().enumerate() {
114                if !route.matches(req) {
115                    continue;
116                }
117
118                for (limit_index, limit) in route.limits.iter().enumerate() {
119                    let key = RouteKey {
120                        route_index,
121                        limit_index,
122                    };
123
124                    let emission_interval_nanos = limit.emission_interval().as_nanos() as u64;
125                    let limit_nanos = limit.window.as_nanos() as u64;
126
127                    // Get or create GCRA state for this route+limit
128                    let state = self.state.entry(key).or_insert_with(GcraState::new);
129
130                    match state.try_acquire(now, emission_interval_nanos, limit_nanos) {
131                        Ok(()) => {}
132                        Err(wait_duration) => {
133                            match route.on_limit {
134                                ThrottleBehavior::Delay => {
135                                    // Release the lock before sleeping
136                                    drop(state);
137                                    // Add jitter (0-50% of wait duration) to prevent thundering herd
138                                    let jitter_max_nanos = wait_duration.as_nanos() as u64 / 2;
139                                    let jitter_nanos = if jitter_max_nanos > 0 {
140                                        rand::rng().random_range(0..=jitter_max_nanos)
141                                    } else {
142                                        0
143                                    };
144                                    let sleep_duration = wait_duration
145                                        + std::time::Duration::from_nanos(jitter_nanos);
146                                    sleep(sleep_duration).await;
147                                    // After sleeping, restart the entire check with fresh timestamp
148                                    continue 'outer;
149                                }
150                                ThrottleBehavior::Error => {
151                                    return Err(RateLimitError::RateLimited(wait_duration));
152                                }
153                            }
154                        }
155                    }
156                }
157            }
158
159            // All limits passed, we can proceed
160            break Ok(());
161        }
162    }
163}
164
165#[async_trait]
166impl Middleware for RateLimitMiddleware {
167    async fn handle(
168        &self,
169        req: Request,
170        extensions: &mut Extensions,
171        next: Next<'_>,
172    ) -> MiddlewareResult<Response> {
173        // Check and apply rate limits
174        self.check_and_apply_limits(&req).await?;
175
176        // Proceed with the request
177        next.run(req, extensions).await
178    }
179}
180
181impl Default for RateLimitMiddleware {
182    /// Create a middleware with no routes configured.
183    ///
184    /// All requests will pass through without any rate limiting.
185    /// Use [`RateLimitMiddleware::builder()`] to configure routes.
186    fn default() -> Self {
187        Self::builder().build()
188    }
189}