1use axum::{
7 body::Body,
8 extract::{ConnectInfo, Request, State},
9 http::{HeaderValue, StatusCode},
10 middleware::Next,
11 response::{IntoResponse, Response},
12};
13use serde::{Deserialize, Serialize};
14use std::net::SocketAddr;
15use std::sync::Arc;
16use std::time::{SystemTime, UNIX_EPOCH};
17use tracing::{debug, warn};
18
19use crate::error::ErrorResponse;
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct RateLimitConfig {
24 pub max_requests: u32,
26
27 pub window_secs: u64,
29
30 pub enabled: bool,
32
33 pub by_ip: bool,
35
36 pub by_user: bool,
38
39 #[serde(skip_serializing_if = "Option::is_none")]
41 pub identifier_header: Option<String>,
42}
43
44impl Default for RateLimitConfig {
45 fn default() -> Self {
46 Self {
47 max_requests: 100,
48 window_secs: 60,
49 enabled: true,
50 by_ip: true,
51 by_user: true,
52 identifier_header: None,
53 }
54 }
55}
56
57impl RateLimitConfig {
58 pub fn new(max_requests: u32, window_secs: u64) -> Self {
60 Self {
61 max_requests,
62 window_secs,
63 ..Default::default()
64 }
65 }
66
67 pub fn disabled() -> Self {
69 Self {
70 enabled: false,
71 ..Default::default()
72 }
73 }
74
75 pub fn with_max_requests(mut self, max_requests: u32) -> Self {
77 self.max_requests = max_requests;
78 self
79 }
80
81 pub fn with_window_secs(mut self, window_secs: u64) -> Self {
83 self.window_secs = window_secs;
84 self
85 }
86
87 pub fn with_by_ip(mut self, by_ip: bool) -> Self {
89 self.by_ip = by_ip;
90 self
91 }
92
93 pub fn with_by_user(mut self, by_user: bool) -> Self {
95 self.by_user = by_user;
96 self
97 }
98
99 pub fn with_identifier_header(mut self, header: impl Into<String>) -> Self {
101 self.identifier_header = Some(header.into());
102 self
103 }
104}
105
106#[derive(Clone)]
108pub struct RateLimiterState {
109 config: Arc<RateLimitConfig>,
110 storage: Arc<tokio::sync::RwLock<std::collections::HashMap<String, TokenBucket>>>,
112}
113
114impl RateLimiterState {
115 pub fn new(config: RateLimitConfig) -> Self {
117 Self {
118 config: Arc::new(config),
119 storage: Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new())),
120 }
121 }
122
123 pub fn config(&self) -> &RateLimitConfig {
125 &self.config
126 }
127}
128
129#[derive(Debug, Clone)]
131struct TokenBucket {
132 tokens: f64,
134
135 last_refill: u64,
137
138 capacity: f64,
140
141 refill_rate: f64,
143}
144
145impl TokenBucket {
146 fn new(capacity: u32, window_secs: u64) -> Self {
148 let refill_rate = capacity as f64 / window_secs as f64;
149 Self {
150 tokens: capacity as f64,
151 last_refill: Self::current_time_secs(),
152 capacity: capacity as f64,
153 refill_rate,
154 }
155 }
156
157 fn current_time_secs() -> u64 {
159 SystemTime::now()
160 .duration_since(UNIX_EPOCH)
161 .unwrap()
162 .as_secs()
163 }
164
165 fn refill(&mut self) {
167 let now = Self::current_time_secs();
168 let elapsed = now - self.last_refill;
169
170 if elapsed > 0 {
171 let new_tokens = elapsed as f64 * self.refill_rate;
172 self.tokens = (self.tokens + new_tokens).min(self.capacity);
173 self.last_refill = now;
174 }
175 }
176
177 fn try_consume(&mut self, count: f64) -> bool {
179 self.refill();
180
181 if self.tokens >= count {
182 self.tokens -= count;
183 true
184 } else {
185 false
186 }
187 }
188
189 fn time_until_available(&self) -> u64 {
191 if self.tokens >= 1.0 {
192 return 0;
193 }
194
195 let tokens_needed = 1.0 - self.tokens;
196 (tokens_needed / self.refill_rate).ceil() as u64
197 }
198}
199
200pub async fn rate_limit(
221 State(limiter): State<RateLimiterState>,
222 request: Request,
223 next: Next,
224) -> Result<Response, RateLimitError> {
225 if !limiter.config.enabled {
227 return Ok(next.run(request).await);
228 }
229
230 let identifier = extract_identifier(&request, &limiter.config);
232
233 debug!("Rate limiting for identifier: {}", identifier);
234
235 let allowed = check_rate_limit(&limiter, &identifier).await;
237
238 if !allowed {
239 warn!("Rate limit exceeded for identifier: {}", identifier);
240 return Err(RateLimitError::LimitExceeded {
241 retry_after: limiter.config.window_secs,
242 });
243 }
244
245 let mut response = next.run(request).await;
247
248 add_rate_limit_headers(&mut response, &limiter.config);
250
251 Ok(response)
252}
253
254fn extract_identifier(request: &Request<Body>, config: &RateLimitConfig) -> String {
256 let mut parts = Vec::new();
257
258 if config.by_ip {
260 if let Some(ConnectInfo(addr)) = request.extensions().get::<ConnectInfo<SocketAddr>>() {
261 parts.push(format!("ip:{}", addr.ip()));
262 }
263 }
264
265 if config.by_user {
267 if let Some(user) = request.extensions().get::<crate::auth::AuthUser>() {
268 parts.push(format!("user:{}", user.user_id()));
269 }
270 }
271
272 if let Some(header_name) = &config.identifier_header {
274 if let Some(value) = request.headers().get(header_name) {
275 if let Ok(value_str) = value.to_str() {
276 parts.push(format!("custom:{}", value_str));
277 }
278 }
279 }
280
281 if parts.is_empty() {
283 parts.push("anonymous".to_string());
284 }
285
286 parts.join("|")
287}
288
289async fn check_rate_limit(limiter: &RateLimiterState, identifier: &str) -> bool {
291 let mut storage = limiter.storage.write().await;
292
293 let bucket = storage
294 .entry(identifier.to_string())
295 .or_insert_with(|| {
296 TokenBucket::new(limiter.config.max_requests, limiter.config.window_secs)
297 });
298
299 bucket.try_consume(1.0)
300}
301
302fn add_rate_limit_headers(response: &mut Response, config: &RateLimitConfig) {
304 response.headers_mut().insert(
306 "X-RateLimit-Limit",
307 HeaderValue::from_str(&config.max_requests.to_string()).unwrap(),
308 );
309
310 response.headers_mut().insert(
311 "X-RateLimit-Window",
312 HeaderValue::from_str(&config.window_secs.to_string()).unwrap(),
313 );
314}
315
316#[derive(Debug)]
318pub enum RateLimitError {
319 LimitExceeded {
321 retry_after: u64,
323 },
324}
325
326impl IntoResponse for RateLimitError {
327 fn into_response(self) -> Response {
328 match self {
329 RateLimitError::LimitExceeded { retry_after } => {
330 let error_response = ErrorResponse {
331 status: 429,
332 error: "Rate limit exceeded".to_string(),
333 code: Some("RATE_LIMIT_EXCEEDED".to_string()),
334 timestamp: chrono::Utc::now(),
335 };
336
337 let mut response = (StatusCode::TOO_MANY_REQUESTS, axum::Json(error_response))
338 .into_response();
339
340 response.headers_mut().insert(
342 "Retry-After",
343 HeaderValue::from_str(&retry_after.to_string()).unwrap(),
344 );
345
346 response
347 }
348 }
349 }
350}
351
352impl std::fmt::Display for RateLimitError {
353 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
354 match self {
355 RateLimitError::LimitExceeded { retry_after } => {
356 write!(f, "Rate limit exceeded. Retry after {} seconds", retry_after)
357 }
358 }
359 }
360}
361
362impl std::error::Error for RateLimitError {}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367
368 #[test]
369 fn test_rate_limit_config() {
370 let config = RateLimitConfig::new(100, 60);
371 assert_eq!(config.max_requests, 100);
372 assert_eq!(config.window_secs, 60);
373 assert!(config.enabled);
374 }
375
376 #[test]
377 fn test_rate_limit_config_builder() {
378 let config = RateLimitConfig::default()
379 .with_max_requests(200)
380 .with_window_secs(120)
381 .with_by_ip(false)
382 .with_identifier_header("X-API-Key");
383
384 assert_eq!(config.max_requests, 200);
385 assert_eq!(config.window_secs, 120);
386 assert!(!config.by_ip);
387 assert_eq!(
388 config.identifier_header,
389 Some("X-API-Key".to_string())
390 );
391 }
392
393 #[test]
394 fn test_token_bucket_creation() {
395 let bucket = TokenBucket::new(100, 60);
396 assert_eq!(bucket.capacity, 100.0);
397 assert_eq!(bucket.tokens, 100.0);
398 }
399
400 #[test]
401 fn test_token_bucket_consume() {
402 let mut bucket = TokenBucket::new(10, 60);
403
404 for _ in 0..10 {
406 assert!(bucket.try_consume(1.0));
407 }
408
409 assert!(!bucket.try_consume(1.0));
411 }
412
413 #[test]
414 fn test_token_bucket_refill() {
415 let mut bucket = TokenBucket::new(10, 10); for _ in 0..10 {
419 bucket.try_consume(1.0);
420 }
421
422 assert_eq!(bucket.tokens, 0.0);
423
424 bucket.last_refill -= 5; bucket.refill();
429 assert_eq!(bucket.tokens, 5.0);
430 }
431
432 #[tokio::test]
433 async fn test_rate_limiter_state() {
434 let config = RateLimitConfig::new(5, 60);
435 let limiter = RateLimiterState::new(config);
436
437 for _ in 0..5 {
439 assert!(check_rate_limit(&limiter, "test-user").await);
440 }
441
442 assert!(!check_rate_limit(&limiter, "test-user").await);
444
445 assert!(check_rate_limit(&limiter, "other-user").await);
447 }
448
449 #[test]
450 fn test_disabled_rate_limit() {
451 let config = RateLimitConfig::disabled();
452 assert!(!config.enabled);
453 }
454}