Skip to main content

better_auth_core/middleware/
rate_limit.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::Mutex;
4use std::time::{Duration, Instant};
5
6use super::Middleware;
7use crate::error::AuthResult;
8use crate::types::{AuthRequest, AuthResponse};
9
10/// Configuration for the rate limiting middleware.
11#[derive(Debug, Clone)]
12pub struct RateLimitConfig {
13    /// Default rate limit applied to all endpoints.
14    pub default: EndpointRateLimit,
15
16    /// Per-endpoint overrides. Key is the path (e.g. "/sign-in/email").
17    pub per_endpoint: HashMap<String, EndpointRateLimit>,
18
19    /// Whether rate limiting is enabled.
20    pub enabled: bool,
21}
22
23/// Rate limit parameters for a single endpoint.
24#[derive(Debug, Clone)]
25pub struct EndpointRateLimit {
26    /// Sliding window duration.
27    pub window: Duration,
28
29    /// Maximum number of requests allowed within the window.
30    pub max_requests: u32,
31}
32
33impl Default for RateLimitConfig {
34    fn default() -> Self {
35        Self {
36            default: EndpointRateLimit {
37                window: Duration::from_secs(60),
38                max_requests: 100,
39            },
40            per_endpoint: HashMap::new(),
41            enabled: true,
42        }
43    }
44}
45
46impl RateLimitConfig {
47    pub fn new() -> Self {
48        Self::default()
49    }
50
51    pub fn default_limit(mut self, window: Duration, max_requests: u32) -> Self {
52        self.default = EndpointRateLimit {
53            window,
54            max_requests,
55        };
56        self
57    }
58
59    pub fn endpoint(
60        mut self,
61        path: impl Into<String>,
62        window: Duration,
63        max_requests: u32,
64    ) -> Self {
65        self.per_endpoint.insert(
66            path.into(),
67            EndpointRateLimit {
68                window,
69                max_requests,
70            },
71        );
72        self
73    }
74
75    pub fn enabled(mut self, enabled: bool) -> Self {
76        self.enabled = enabled;
77        self
78    }
79}
80
81/// In-memory sliding-window rate limiter.
82///
83/// For production use with multiple instances, a `CacheAdapter`-backed
84/// implementation should be used instead. This implementation is suitable
85/// for single-process deployments and testing.
86pub struct RateLimitMiddleware {
87    config: RateLimitConfig,
88    /// Keyed by (client_identifier, path) → list of request timestamps.
89    buckets: Mutex<HashMap<String, Vec<Instant>>>,
90}
91
92impl RateLimitMiddleware {
93    pub fn new(config: RateLimitConfig) -> Self {
94        Self {
95            config,
96            buckets: Mutex::new(HashMap::new()),
97        }
98    }
99
100    /// Derive a client key from the request. Uses X-Forwarded-For, then
101    /// falls back to a fixed key (single-bucket) when no IP is available.
102    fn client_key(req: &AuthRequest) -> String {
103        req.headers
104            .get("x-forwarded-for")
105            .or_else(|| req.headers.get("x-real-ip"))
106            .cloned()
107            .unwrap_or_else(|| "unknown".to_string())
108    }
109
110    fn limit_for_path(&self, path: &str) -> &EndpointRateLimit {
111        self.config
112            .per_endpoint
113            .get(path)
114            .unwrap_or(&self.config.default)
115    }
116}
117
118#[async_trait]
119impl Middleware for RateLimitMiddleware {
120    fn name(&self) -> &'static str {
121        "rate-limit"
122    }
123
124    async fn before_request(&self, req: &AuthRequest) -> AuthResult<Option<AuthResponse>> {
125        if !self.config.enabled {
126            return Ok(None);
127        }
128
129        let limit = self.limit_for_path(&req.path);
130        let key = format!("{}:{}", Self::client_key(req), req.path);
131        let now = Instant::now();
132        let window = limit.window;
133
134        let mut buckets = self.buckets.lock().unwrap();
135        let timestamps = buckets.entry(key).or_default();
136
137        // Remove timestamps outside the window
138        timestamps.retain(|&t| now.duration_since(t) < window);
139
140        if timestamps.len() as u32 >= limit.max_requests {
141            let retry_after = timestamps
142                .first()
143                .map(|&t| {
144                    window
145                        .as_secs()
146                        .saturating_sub(now.duration_since(t).as_secs())
147                })
148                .unwrap_or(window.as_secs());
149
150            return Ok(Some(
151                AuthResponse::json(
152                    429,
153                    &crate::types::RateLimitErrorResponse {
154                        code: "RATE_LIMIT_EXCEEDED",
155                        message: "Too many requests",
156                        retry_after,
157                    },
158                )?
159                .with_header("Retry-After", retry_after.to_string()),
160            ));
161        }
162
163        timestamps.push(now);
164        Ok(None)
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171    use crate::types::HttpMethod;
172    use std::collections::HashMap as StdHashMap;
173
174    fn make_request(path: &str, ip: &str) -> AuthRequest {
175        let mut headers = StdHashMap::new();
176        headers.insert("x-forwarded-for".to_string(), ip.to_string());
177        AuthRequest {
178            method: HttpMethod::Post,
179            path: path.to_string(),
180            headers,
181            body: None,
182            query: StdHashMap::new(),
183            virtual_user_id: None,
184        }
185    }
186
187    #[tokio::test]
188    async fn test_rate_limit_allows_within_limit() {
189        let config = RateLimitConfig::new().default_limit(Duration::from_secs(60), 5);
190        let mw = RateLimitMiddleware::new(config);
191        let req = make_request("/sign-in/email", "1.2.3.4");
192
193        for _ in 0..5 {
194            assert!(mw.before_request(&req).await.unwrap().is_none());
195        }
196    }
197
198    #[tokio::test]
199    async fn test_rate_limit_blocks_over_limit() {
200        let config = RateLimitConfig::new().default_limit(Duration::from_secs(60), 3);
201        let mw = RateLimitMiddleware::new(config);
202        let req = make_request("/sign-in/email", "1.2.3.4");
203
204        for _ in 0..3 {
205            assert!(mw.before_request(&req).await.unwrap().is_none());
206        }
207
208        let resp = mw.before_request(&req).await.unwrap();
209        assert!(resp.is_some());
210        assert_eq!(resp.unwrap().status, 429);
211    }
212
213    #[tokio::test]
214    async fn test_rate_limit_per_client() {
215        let config = RateLimitConfig::new().default_limit(Duration::from_secs(60), 2);
216        let mw = RateLimitMiddleware::new(config);
217
218        let req_a = make_request("/sign-in/email", "1.1.1.1");
219        let req_b = make_request("/sign-in/email", "2.2.2.2");
220
221        // Client A uses up its limit
222        for _ in 0..2 {
223            assert!(mw.before_request(&req_a).await.unwrap().is_none());
224        }
225        assert!(mw.before_request(&req_a).await.unwrap().is_some());
226
227        // Client B should still be allowed
228        assert!(mw.before_request(&req_b).await.unwrap().is_none());
229    }
230
231    #[tokio::test]
232    async fn test_rate_limit_per_endpoint_override() {
233        let config = RateLimitConfig::new()
234            .default_limit(Duration::from_secs(60), 100)
235            .endpoint("/sign-in/email", Duration::from_secs(60), 2);
236        let mw = RateLimitMiddleware::new(config);
237        let req = make_request("/sign-in/email", "1.2.3.4");
238
239        for _ in 0..2 {
240            assert!(mw.before_request(&req).await.unwrap().is_none());
241        }
242        assert!(mw.before_request(&req).await.unwrap().is_some());
243    }
244
245    #[tokio::test]
246    async fn test_rate_limit_disabled() {
247        let config = RateLimitConfig::new()
248            .default_limit(Duration::from_secs(60), 1)
249            .enabled(false);
250        let mw = RateLimitMiddleware::new(config);
251        let req = make_request("/sign-in/email", "1.2.3.4");
252
253        for _ in 0..10 {
254            assert!(mw.before_request(&req).await.unwrap().is_none());
255        }
256    }
257}