better_auth_core/middleware/
rate_limit.rs1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::Mutex;
4use std::time::{Duration, Instant};
5
6use super::Middleware;
7use crate::error::AuthResult;
8use crate::types::{AuthRequest, AuthResponse};
9
10#[derive(Debug, Clone)]
12pub struct RateLimitConfig {
13 pub default: EndpointRateLimit,
15
16 pub per_endpoint: HashMap<String, EndpointRateLimit>,
18
19 pub enabled: bool,
21}
22
23#[derive(Debug, Clone)]
25pub struct EndpointRateLimit {
26 pub window: Duration,
28
29 pub max_requests: u32,
31}
32
33impl Default for RateLimitConfig {
34 fn default() -> Self {
35 Self {
36 default: EndpointRateLimit {
37 window: Duration::from_secs(60),
38 max_requests: 100,
39 },
40 per_endpoint: HashMap::new(),
41 enabled: true,
42 }
43 }
44}
45
46impl RateLimitConfig {
47 pub fn new() -> Self {
48 Self::default()
49 }
50
51 pub fn default_limit(mut self, window: Duration, max_requests: u32) -> Self {
52 self.default = EndpointRateLimit {
53 window,
54 max_requests,
55 };
56 self
57 }
58
59 pub fn endpoint(
60 mut self,
61 path: impl Into<String>,
62 window: Duration,
63 max_requests: u32,
64 ) -> Self {
65 self.per_endpoint.insert(
66 path.into(),
67 EndpointRateLimit {
68 window,
69 max_requests,
70 },
71 );
72 self
73 }
74
75 pub fn enabled(mut self, enabled: bool) -> Self {
76 self.enabled = enabled;
77 self
78 }
79}
80
81pub struct RateLimitMiddleware {
87 config: RateLimitConfig,
88 buckets: Mutex<HashMap<String, Vec<Instant>>>,
90}
91
92impl RateLimitMiddleware {
93 pub fn new(config: RateLimitConfig) -> Self {
94 Self {
95 config,
96 buckets: Mutex::new(HashMap::new()),
97 }
98 }
99
100 fn client_key(req: &AuthRequest) -> String {
103 req.headers
104 .get("x-forwarded-for")
105 .or_else(|| req.headers.get("x-real-ip"))
106 .cloned()
107 .unwrap_or_else(|| "unknown".to_string())
108 }
109
110 fn limit_for_path(&self, path: &str) -> &EndpointRateLimit {
111 self.config
112 .per_endpoint
113 .get(path)
114 .unwrap_or(&self.config.default)
115 }
116}
117
118#[async_trait]
119impl Middleware for RateLimitMiddleware {
120 fn name(&self) -> &'static str {
121 "rate-limit"
122 }
123
124 async fn before_request(&self, req: &AuthRequest) -> AuthResult<Option<AuthResponse>> {
125 if !self.config.enabled {
126 return Ok(None);
127 }
128
129 let limit = self.limit_for_path(&req.path);
130 let key = format!("{}:{}", Self::client_key(req), req.path);
131 let now = Instant::now();
132 let window = limit.window;
133
134 let mut buckets = self.buckets.lock().unwrap();
135 let timestamps = buckets.entry(key).or_default();
136
137 timestamps.retain(|&t| now.duration_since(t) < window);
139
140 if timestamps.len() as u32 >= limit.max_requests {
141 let retry_after = timestamps
142 .first()
143 .map(|&t| {
144 window
145 .as_secs()
146 .saturating_sub(now.duration_since(t).as_secs())
147 })
148 .unwrap_or(window.as_secs());
149
150 return Ok(Some(
151 AuthResponse::json(
152 429,
153 &crate::types::RateLimitErrorResponse {
154 code: "RATE_LIMIT_EXCEEDED",
155 message: "Too many requests",
156 retry_after,
157 },
158 )?
159 .with_header("Retry-After", retry_after.to_string()),
160 ));
161 }
162
163 timestamps.push(now);
164 Ok(None)
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171 use crate::types::HttpMethod;
172 use std::collections::HashMap as StdHashMap;
173
174 fn make_request(path: &str, ip: &str) -> AuthRequest {
175 let mut headers = StdHashMap::new();
176 headers.insert("x-forwarded-for".to_string(), ip.to_string());
177 AuthRequest {
178 method: HttpMethod::Post,
179 path: path.to_string(),
180 headers,
181 body: None,
182 query: StdHashMap::new(),
183 virtual_user_id: None,
184 }
185 }
186
187 #[tokio::test]
188 async fn test_rate_limit_allows_within_limit() {
189 let config = RateLimitConfig::new().default_limit(Duration::from_secs(60), 5);
190 let mw = RateLimitMiddleware::new(config);
191 let req = make_request("/sign-in/email", "1.2.3.4");
192
193 for _ in 0..5 {
194 assert!(mw.before_request(&req).await.unwrap().is_none());
195 }
196 }
197
198 #[tokio::test]
199 async fn test_rate_limit_blocks_over_limit() {
200 let config = RateLimitConfig::new().default_limit(Duration::from_secs(60), 3);
201 let mw = RateLimitMiddleware::new(config);
202 let req = make_request("/sign-in/email", "1.2.3.4");
203
204 for _ in 0..3 {
205 assert!(mw.before_request(&req).await.unwrap().is_none());
206 }
207
208 let resp = mw.before_request(&req).await.unwrap();
209 assert!(resp.is_some());
210 assert_eq!(resp.unwrap().status, 429);
211 }
212
213 #[tokio::test]
214 async fn test_rate_limit_per_client() {
215 let config = RateLimitConfig::new().default_limit(Duration::from_secs(60), 2);
216 let mw = RateLimitMiddleware::new(config);
217
218 let req_a = make_request("/sign-in/email", "1.1.1.1");
219 let req_b = make_request("/sign-in/email", "2.2.2.2");
220
221 for _ in 0..2 {
223 assert!(mw.before_request(&req_a).await.unwrap().is_none());
224 }
225 assert!(mw.before_request(&req_a).await.unwrap().is_some());
226
227 assert!(mw.before_request(&req_b).await.unwrap().is_none());
229 }
230
231 #[tokio::test]
232 async fn test_rate_limit_per_endpoint_override() {
233 let config = RateLimitConfig::new()
234 .default_limit(Duration::from_secs(60), 100)
235 .endpoint("/sign-in/email", Duration::from_secs(60), 2);
236 let mw = RateLimitMiddleware::new(config);
237 let req = make_request("/sign-in/email", "1.2.3.4");
238
239 for _ in 0..2 {
240 assert!(mw.before_request(&req).await.unwrap().is_none());
241 }
242 assert!(mw.before_request(&req).await.unwrap().is_some());
243 }
244
245 #[tokio::test]
246 async fn test_rate_limit_disabled() {
247 let config = RateLimitConfig::new()
248 .default_limit(Duration::from_secs(60), 1)
249 .enabled(false);
250 let mw = RateLimitMiddleware::new(config);
251 let req = make_request("/sign-in/email", "1.2.3.4");
252
253 for _ in 0..10 {
254 assert!(mw.before_request(&req).await.unwrap().is_none());
255 }
256 }
257}