Skip to main content

better_auth_core/middleware/
rate_limit.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::Mutex;
4use std::time::{Duration, Instant};
5
6use super::Middleware;
7use crate::error::AuthResult;
8use crate::types::{AuthRequest, AuthResponse};
9
10/// Configuration for the rate limiting middleware.
11#[derive(Debug, Clone)]
12pub struct RateLimitConfig {
13    /// Default rate limit applied to all endpoints.
14    pub default: EndpointRateLimit,
15
16    /// Per-endpoint overrides. Key is the path (e.g. "/sign-in/email").
17    pub per_endpoint: HashMap<String, EndpointRateLimit>,
18
19    /// Whether rate limiting is enabled.
20    pub enabled: bool,
21}
22
23/// Rate limit parameters for a single endpoint.
24#[derive(Debug, Clone)]
25pub struct EndpointRateLimit {
26    /// Sliding window duration.
27    pub window: Duration,
28
29    /// Maximum number of requests allowed within the window.
30    pub max_requests: u32,
31}
32
33impl Default for RateLimitConfig {
34    fn default() -> Self {
35        Self {
36            default: EndpointRateLimit {
37                window: Duration::from_secs(60),
38                max_requests: 100,
39            },
40            per_endpoint: HashMap::new(),
41            enabled: true,
42        }
43    }
44}
45
46impl RateLimitConfig {
47    pub fn new() -> Self {
48        Self::default()
49    }
50
51    pub fn default_limit(mut self, window: Duration, max_requests: u32) -> Self {
52        self.default = EndpointRateLimit {
53            window,
54            max_requests,
55        };
56        self
57    }
58
59    pub fn endpoint(
60        mut self,
61        path: impl Into<String>,
62        window: Duration,
63        max_requests: u32,
64    ) -> Self {
65        self.per_endpoint.insert(
66            path.into(),
67            EndpointRateLimit {
68                window,
69                max_requests,
70            },
71        );
72        self
73    }
74
75    pub fn enabled(mut self, enabled: bool) -> Self {
76        self.enabled = enabled;
77        self
78    }
79}
80
81/// In-memory sliding-window rate limiter.
82///
83/// For production use with multiple instances, a `CacheAdapter`-backed
84/// implementation should be used instead. This implementation is suitable
85/// for single-process deployments and testing.
86pub struct RateLimitMiddleware {
87    config: RateLimitConfig,
88    /// Keyed by (client_identifier, path) → list of request timestamps.
89    buckets: Mutex<HashMap<String, Vec<Instant>>>,
90}
91
92impl RateLimitMiddleware {
93    pub fn new(config: RateLimitConfig) -> Self {
94        Self {
95            config,
96            buckets: Mutex::new(HashMap::new()),
97        }
98    }
99
100    /// Derive a client key from the request. Uses X-Forwarded-For, then
101    /// falls back to a fixed key (single-bucket) when no IP is available.
102    fn client_key(req: &AuthRequest) -> String {
103        req.headers
104            .get("x-forwarded-for")
105            .or_else(|| req.headers.get("x-real-ip"))
106            .cloned()
107            .unwrap_or_else(|| "unknown".to_string())
108    }
109
110    fn limit_for_path(&self, path: &str) -> &EndpointRateLimit {
111        self.config
112            .per_endpoint
113            .get(path)
114            .unwrap_or(&self.config.default)
115    }
116}
117
118#[async_trait]
119impl Middleware for RateLimitMiddleware {
120    fn name(&self) -> &'static str {
121        "rate-limit"
122    }
123
124    async fn before_request(&self, req: &AuthRequest) -> AuthResult<Option<AuthResponse>> {
125        if !self.config.enabled {
126            return Ok(None);
127        }
128
129        let limit = self.limit_for_path(&req.path);
130        let key = format!("{}:{}", Self::client_key(req), req.path);
131        let now = Instant::now();
132        let window = limit.window;
133
134        let mut buckets = self.buckets.lock().unwrap();
135        let timestamps = buckets.entry(key).or_default();
136
137        // Remove timestamps outside the window
138        timestamps.retain(|&t| now.duration_since(t) < window);
139
140        if timestamps.len() as u32 >= limit.max_requests {
141            let retry_after = timestamps
142                .first()
143                .map(|&t| {
144                    window
145                        .as_secs()
146                        .saturating_sub(now.duration_since(t).as_secs())
147                })
148                .unwrap_or(window.as_secs());
149
150            return Ok(Some(
151                AuthResponse::json(
152                    429,
153                    &serde_json::json!({
154                        "code": "RATE_LIMIT_EXCEEDED",
155                        "message": "Too many requests",
156                        "retryAfter": retry_after,
157                    }),
158                )?
159                .with_header("Retry-After", retry_after.to_string()),
160            ));
161        }
162
163        timestamps.push(now);
164        Ok(None)
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171    use crate::types::HttpMethod;
172    use std::collections::HashMap as StdHashMap;
173
174    fn make_request(path: &str, ip: &str) -> AuthRequest {
175        let mut headers = StdHashMap::new();
176        headers.insert("x-forwarded-for".to_string(), ip.to_string());
177        AuthRequest {
178            method: HttpMethod::Post,
179            path: path.to_string(),
180            headers,
181            body: None,
182            query: StdHashMap::new(),
183        }
184    }
185
186    #[tokio::test]
187    async fn test_rate_limit_allows_within_limit() {
188        let config = RateLimitConfig::new().default_limit(Duration::from_secs(60), 5);
189        let mw = RateLimitMiddleware::new(config);
190        let req = make_request("/sign-in/email", "1.2.3.4");
191
192        for _ in 0..5 {
193            assert!(mw.before_request(&req).await.unwrap().is_none());
194        }
195    }
196
197    #[tokio::test]
198    async fn test_rate_limit_blocks_over_limit() {
199        let config = RateLimitConfig::new().default_limit(Duration::from_secs(60), 3);
200        let mw = RateLimitMiddleware::new(config);
201        let req = make_request("/sign-in/email", "1.2.3.4");
202
203        for _ in 0..3 {
204            assert!(mw.before_request(&req).await.unwrap().is_none());
205        }
206
207        let resp = mw.before_request(&req).await.unwrap();
208        assert!(resp.is_some());
209        assert_eq!(resp.unwrap().status, 429);
210    }
211
212    #[tokio::test]
213    async fn test_rate_limit_per_client() {
214        let config = RateLimitConfig::new().default_limit(Duration::from_secs(60), 2);
215        let mw = RateLimitMiddleware::new(config);
216
217        let req_a = make_request("/sign-in/email", "1.1.1.1");
218        let req_b = make_request("/sign-in/email", "2.2.2.2");
219
220        // Client A uses up its limit
221        for _ in 0..2 {
222            assert!(mw.before_request(&req_a).await.unwrap().is_none());
223        }
224        assert!(mw.before_request(&req_a).await.unwrap().is_some());
225
226        // Client B should still be allowed
227        assert!(mw.before_request(&req_b).await.unwrap().is_none());
228    }
229
230    #[tokio::test]
231    async fn test_rate_limit_per_endpoint_override() {
232        let config = RateLimitConfig::new()
233            .default_limit(Duration::from_secs(60), 100)
234            .endpoint("/sign-in/email", Duration::from_secs(60), 2);
235        let mw = RateLimitMiddleware::new(config);
236        let req = make_request("/sign-in/email", "1.2.3.4");
237
238        for _ in 0..2 {
239            assert!(mw.before_request(&req).await.unwrap().is_none());
240        }
241        assert!(mw.before_request(&req).await.unwrap().is_some());
242    }
243
244    #[tokio::test]
245    async fn test_rate_limit_disabled() {
246        let config = RateLimitConfig::new()
247            .default_limit(Duration::from_secs(60), 1)
248            .enabled(false);
249        let mw = RateLimitMiddleware::new(config);
250        let req = make_request("/sign-in/email", "1.2.3.4");
251
252        for _ in 0..10 {
253            assert!(mw.before_request(&req).await.unwrap().is_none());
254        }
255    }
256}