1use actix_web::dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform};
29use actix_web::http::header::{HeaderName, HeaderValue};
30use actix_web::http::StatusCode;
31use actix_web::{Error, HttpResponse};
32use futures_util::future::{ok, LocalBoxFuture, Ready};
33use std::collections::HashMap;
34use std::sync::Arc;
35use std::time::{Duration, Instant};
36use tokio::sync::RwLock;
37
38#[derive(Debug, Clone)]
40pub struct RateLimitExceeded {
41 pub retry_after: u64,
43 pub message: String,
45}
46
47impl Default for RateLimitExceeded {
48 fn default() -> Self {
49 Self {
50 retry_after: 60,
51 message: "Too many requests. Please try again later.".to_string(),
52 }
53 }
54}
55
56pub type KeyExtractorFn = Arc<dyn Fn(&ServiceRequest) -> Option<String> + Send + Sync>;
58
59#[derive(Clone, Default)]
61pub enum KeyExtractor {
62 #[default]
64 IpAddress,
65 User,
67 Header(String),
69 IpAndEndpoint,
71 Custom(KeyExtractorFn),
73}
74
75impl std::fmt::Debug for KeyExtractor {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 match self {
78 KeyExtractor::IpAddress => write!(f, "IpAddress"),
79 KeyExtractor::User => write!(f, "User"),
80 KeyExtractor::Header(h) => write!(f, "Header({})", h),
81 KeyExtractor::IpAndEndpoint => write!(f, "IpAndEndpoint"),
82 KeyExtractor::Custom(_) => write!(f, "Custom(<fn>)"),
83 }
84 }
85}
86
87impl KeyExtractor {
88 pub fn extract(&self, req: &ServiceRequest) -> Option<String> {
90 match self {
91 KeyExtractor::IpAddress => req
92 .connection_info()
93 .realip_remote_addr()
94 .map(|s| s.to_string()),
95 KeyExtractor::User => req
96 .headers()
97 .get("Authorization")
98 .and_then(|h| h.to_str().ok())
99 .map(|s| s.to_string()),
100 KeyExtractor::Header(name) => req
101 .headers()
102 .get(name.as_str())
103 .and_then(|h| h.to_str().ok())
104 .map(|s| s.to_string()),
105 KeyExtractor::IpAndEndpoint => {
106 let ip = req.connection_info().realip_remote_addr()?.to_string();
107 let path = req.path().to_string();
108 Some(format!("{}:{}", ip, path))
109 }
110 KeyExtractor::Custom(f) => f(req),
111 }
112 }
113}
114
115#[derive(Debug, Clone, Default)]
117pub enum RateLimitAlgorithm {
118 #[default]
120 FixedWindow,
121 SlidingWindow,
123 TokenBucket,
125}
126
127#[derive(Clone)]
129pub struct RateLimitConfig {
130 pub max_requests: u64,
132 pub window: Duration,
134 pub burst_size: u64,
136 pub algorithm: RateLimitAlgorithm,
138 pub key_extractor: KeyExtractor,
140 pub excluded_paths: Vec<String>,
142 pub add_headers: bool,
144 pub error_response: Option<Arc<dyn Fn(RateLimitExceeded) -> HttpResponse + Send + Sync>>,
146}
147
148impl std::fmt::Debug for RateLimitConfig {
149 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150 f.debug_struct("RateLimitConfig")
151 .field("max_requests", &self.max_requests)
152 .field("window", &self.window)
153 .field("burst_size", &self.burst_size)
154 .field("algorithm", &self.algorithm)
155 .field("key_extractor", &self.key_extractor)
156 .field("excluded_paths", &self.excluded_paths)
157 .field("add_headers", &self.add_headers)
158 .field(
159 "error_response",
160 &self.error_response.as_ref().map(|_| "<fn>"),
161 )
162 .finish()
163 }
164}
165
166impl Default for RateLimitConfig {
167 fn default() -> Self {
168 Self {
169 max_requests: 100,
170 window: Duration::from_secs(60),
171 burst_size: 10,
172 algorithm: RateLimitAlgorithm::default(),
173 key_extractor: KeyExtractor::default(),
174 excluded_paths: vec![],
175 add_headers: true,
176 error_response: None,
177 }
178 }
179}
180
181impl RateLimitConfig {
182 pub fn new() -> Self {
184 Self::default()
185 }
186
187 pub fn max_requests(mut self, max: u64) -> Self {
189 self.max_requests = max;
190 self
191 }
192
193 pub fn requests_per_second(mut self, rps: u64) -> Self {
195 self.max_requests = rps;
196 self.window = Duration::from_secs(1);
197 self
198 }
199
200 pub fn requests_per_minute(mut self, rpm: u64) -> Self {
202 self.max_requests = rpm;
203 self.window = Duration::from_secs(60);
204 self
205 }
206
207 pub fn window(mut self, window: Duration) -> Self {
209 self.window = window;
210 self
211 }
212
213 pub fn burst_size(mut self, size: u64) -> Self {
215 self.burst_size = size;
216 self
217 }
218
219 pub fn algorithm(mut self, algo: RateLimitAlgorithm) -> Self {
221 self.algorithm = algo;
222 self
223 }
224
225 pub fn key_extractor(mut self, extractor: KeyExtractor) -> Self {
227 self.key_extractor = extractor;
228 self
229 }
230
231 pub fn exclude_paths(mut self, paths: Vec<&str>) -> Self {
233 self.excluded_paths = paths.into_iter().map(String::from).collect();
234 self
235 }
236
237 pub fn add_headers(mut self, add: bool) -> Self {
239 self.add_headers = add;
240 self
241 }
242
243 pub fn error_response<F>(mut self, handler: F) -> Self
245 where
246 F: Fn(RateLimitExceeded) -> HttpResponse + Send + Sync + 'static,
247 {
248 self.error_response = Some(Arc::new(handler));
249 self
250 }
251
252 pub fn strict_login() -> Self {
254 Self::new()
255 .requests_per_minute(5)
256 .burst_size(3)
257 .algorithm(RateLimitAlgorithm::SlidingWindow)
258 }
259
260 pub fn lenient_api() -> Self {
262 Self::new()
263 .requests_per_minute(1000)
264 .burst_size(100)
265 .algorithm(RateLimitAlgorithm::TokenBucket)
266 }
267}
268
269#[derive(Debug, Clone)]
271struct RateLimitEntry {
272 count: u64,
274 window_start: Instant,
276 timestamps: Vec<Instant>,
278 tokens: f64,
280 last_refill: Instant,
282}
283
284impl RateLimitEntry {
285 fn new(config: &RateLimitConfig) -> Self {
286 Self {
287 count: 0,
288 window_start: Instant::now(),
289 timestamps: Vec::new(),
290 tokens: config.burst_size as f64,
291 last_refill: Instant::now(),
292 }
293 }
294}
295
296#[derive(Clone)]
298pub struct RateLimiterState {
299 entries: Arc<RwLock<HashMap<String, RateLimitEntry>>>,
300 config: RateLimitConfig,
301}
302
303impl RateLimiterState {
304 pub fn new(config: RateLimitConfig) -> Self {
306 Self {
307 entries: Arc::new(RwLock::new(HashMap::new())),
308 config,
309 }
310 }
311
312 pub async fn check(&self, key: &str) -> Result<RateLimitInfo, RateLimitExceeded> {
314 let mut entries = self.entries.write().await;
315 let now = Instant::now();
316
317 let entry = entries
318 .entry(key.to_string())
319 .or_insert_with(|| RateLimitEntry::new(&self.config));
320
321 match self.config.algorithm {
322 RateLimitAlgorithm::FixedWindow => self.check_fixed_window(entry, now),
323 RateLimitAlgorithm::SlidingWindow => self.check_sliding_window(entry, now),
324 RateLimitAlgorithm::TokenBucket => self.check_token_bucket(entry, now),
325 }
326 }
327
328 fn check_fixed_window(
329 &self,
330 entry: &mut RateLimitEntry,
331 now: Instant,
332 ) -> Result<RateLimitInfo, RateLimitExceeded> {
333 if now.duration_since(entry.window_start) >= self.config.window {
335 entry.count = 0;
336 entry.window_start = now;
337 }
338
339 if entry.count >= self.config.max_requests {
340 let reset_time = entry.window_start + self.config.window;
341 let retry_after = reset_time.saturating_duration_since(now).as_secs();
342 return Err(RateLimitExceeded {
343 retry_after,
344 message: "Rate limit exceeded".to_string(),
345 });
346 }
347
348 entry.count += 1;
349
350 let reset_time = entry.window_start + self.config.window;
351 Ok(RateLimitInfo {
352 limit: self.config.max_requests,
353 remaining: self.config.max_requests.saturating_sub(entry.count),
354 reset: reset_time.saturating_duration_since(now).as_secs(),
355 })
356 }
357
358 fn check_sliding_window(
359 &self,
360 entry: &mut RateLimitEntry,
361 now: Instant,
362 ) -> Result<RateLimitInfo, RateLimitExceeded> {
363 let window_start = now - self.config.window;
365 entry.timestamps.retain(|&t| t > window_start);
366
367 if entry.timestamps.len() as u64 >= self.config.max_requests {
368 let oldest = entry.timestamps.first().copied().unwrap_or(now);
369 let retry_after = (oldest + self.config.window)
370 .saturating_duration_since(now)
371 .as_secs();
372 return Err(RateLimitExceeded {
373 retry_after,
374 message: "Rate limit exceeded".to_string(),
375 });
376 }
377
378 entry.timestamps.push(now);
379
380 Ok(RateLimitInfo {
381 limit: self.config.max_requests,
382 remaining: self
383 .config
384 .max_requests
385 .saturating_sub(entry.timestamps.len() as u64),
386 reset: self.config.window.as_secs(),
387 })
388 }
389
390 fn check_token_bucket(
391 &self,
392 entry: &mut RateLimitEntry,
393 now: Instant,
394 ) -> Result<RateLimitInfo, RateLimitExceeded> {
395 let elapsed = now.duration_since(entry.last_refill).as_secs_f64();
397 let refill_rate = self.config.max_requests as f64 / self.config.window.as_secs_f64();
398 let new_tokens = elapsed * refill_rate;
399
400 entry.tokens = (entry.tokens + new_tokens).min(self.config.burst_size as f64);
401 entry.last_refill = now;
402
403 if entry.tokens < 1.0 {
404 let tokens_needed = 1.0 - entry.tokens;
405 let retry_after = (tokens_needed / refill_rate).ceil() as u64;
406 return Err(RateLimitExceeded {
407 retry_after,
408 message: "Rate limit exceeded".to_string(),
409 });
410 }
411
412 entry.tokens -= 1.0;
413
414 Ok(RateLimitInfo {
415 limit: self.config.max_requests,
416 remaining: entry.tokens as u64,
417 reset: self.config.window.as_secs(),
418 })
419 }
420
421 pub async fn cleanup(&self) {
423 let mut entries = self.entries.write().await;
424 let now = Instant::now();
425 let window = self.config.window * 2; entries.retain(|_, entry| now.duration_since(entry.window_start) < window);
428 }
429}
430
431#[derive(Debug, Clone)]
433pub struct RateLimitInfo {
434 pub limit: u64,
436 pub remaining: u64,
438 pub reset: u64,
440}
441
442#[derive(Clone)]
444pub struct RateLimiter {
445 state: RateLimiterState,
446}
447
448impl RateLimiter {
449 pub fn new(config: RateLimitConfig) -> Self {
451 Self {
452 state: RateLimiterState::new(config),
453 }
454 }
455
456 pub fn for_login() -> Self {
458 Self::new(RateLimitConfig::strict_login())
459 }
460
461 pub fn for_api() -> Self {
463 Self::new(RateLimitConfig::lenient_api())
464 }
465
466 pub fn state(&self) -> &RateLimiterState {
468 &self.state
469 }
470}
471
472impl<S, B> Transform<S, ServiceRequest> for RateLimiter
473where
474 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
475 B: 'static,
476{
477 type Response = ServiceResponse<B>;
478 type Error = Error;
479 type Transform = RateLimiterMiddleware<S>;
480 type InitError = ();
481 type Future = Ready<Result<Self::Transform, Self::InitError>>;
482
483 fn new_transform(&self, service: S) -> Self::Future {
484 ok(RateLimiterMiddleware {
485 service,
486 state: self.state.clone(),
487 })
488 }
489}
490
491pub struct RateLimiterMiddleware<S> {
493 service: S,
494 state: RateLimiterState,
495}
496
497impl<S, B> Service<ServiceRequest> for RateLimiterMiddleware<S>
498where
499 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
500 B: 'static,
501{
502 type Response = ServiceResponse<B>;
503 type Error = Error;
504 type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
505
506 forward_ready!(service);
507
508 fn call(&self, req: ServiceRequest) -> Self::Future {
509 let state = self.state.clone();
510 let config = state.config.clone();
511
512 let path = req.path().to_string();
514 if config.excluded_paths.iter().any(|p| path.starts_with(p)) {
515 let fut = self.service.call(req);
516 return Box::pin(fut);
517 }
518
519 let key = match config.key_extractor.extract(&req) {
521 Some(k) => k,
522 None => {
523 let fut = self.service.call(req);
525 return Box::pin(fut);
526 }
527 };
528
529 let fut = self.service.call(req);
530 let add_headers = config.add_headers;
531 let error_handler = config.error_response.clone();
532
533 Box::pin(async move {
534 match state.check(&key).await {
536 Ok(info) => {
537 let mut resp = fut.await?;
538
539 if add_headers {
541 let headers = resp.headers_mut();
542 if let Ok(v) = HeaderValue::from_str(&info.limit.to_string()) {
543 headers.insert(HeaderName::from_static("x-ratelimit-limit"), v);
544 }
545 if let Ok(v) = HeaderValue::from_str(&info.remaining.to_string()) {
546 headers.insert(HeaderName::from_static("x-ratelimit-remaining"), v);
547 }
548 if let Ok(v) = HeaderValue::from_str(&info.reset.to_string()) {
549 headers.insert(HeaderName::from_static("x-ratelimit-reset"), v);
550 }
551 }
552
553 Ok(resp)
554 }
555 Err(exceeded) => {
556 let response = if let Some(handler) = error_handler {
558 handler(exceeded.clone())
559 } else {
560 HttpResponse::build(StatusCode::TOO_MANY_REQUESTS)
561 .insert_header(("Retry-After", exceeded.retry_after.to_string()))
562 .insert_header(("X-RateLimit-Limit", config.max_requests.to_string()))
563 .insert_header(("X-RateLimit-Remaining", "0"))
564 .body(exceeded.message)
565 };
566
567 Err(actix_web::error::InternalError::from_response(
570 std::io::Error::other("Rate limit exceeded"),
571 response,
572 )
573 .into())
574 }
575 }
576 })
577 }
578}
579
580#[derive(Clone, Default)]
582pub struct RateLimitBuilder {
583 rules: Vec<(String, RateLimitConfig)>,
584 default: Option<RateLimitConfig>,
585}
586
587impl RateLimitBuilder {
588 pub fn new() -> Self {
590 Self::default()
591 }
592
593 pub fn add_rule(mut self, pattern: &str, config: RateLimitConfig) -> Self {
595 self.rules.push((pattern.to_string(), config));
596 self
597 }
598
599 pub fn default_limit(mut self, config: RateLimitConfig) -> Self {
601 self.default = Some(config);
602 self
603 }
604
605 pub fn protect_login(self, path: &str) -> Self {
607 self.add_rule(path, RateLimitConfig::strict_login())
608 }
609
610 pub fn protect_api(self, path: &str) -> Self {
612 self.add_rule(path, RateLimitConfig::lenient_api())
613 }
614}
615
616#[cfg(test)]
617mod tests {
618 use super::*;
619
620 #[tokio::test]
621 async fn test_fixed_window_rate_limit() {
622 let config = RateLimitConfig::new()
623 .max_requests(3)
624 .window(Duration::from_secs(60));
625
626 let state = RateLimiterState::new(config);
627
628 assert!(state.check("test-key").await.is_ok());
630 assert!(state.check("test-key").await.is_ok());
631 assert!(state.check("test-key").await.is_ok());
632
633 assert!(state.check("test-key").await.is_err());
635 }
636
637 #[tokio::test]
638 async fn test_sliding_window_rate_limit() {
639 let config = RateLimitConfig::new()
640 .max_requests(3)
641 .window(Duration::from_secs(60))
642 .algorithm(RateLimitAlgorithm::SlidingWindow);
643
644 let state = RateLimiterState::new(config);
645
646 assert!(state.check("test-key").await.is_ok());
648 assert!(state.check("test-key").await.is_ok());
649 assert!(state.check("test-key").await.is_ok());
650
651 assert!(state.check("test-key").await.is_err());
653 }
654
655 #[tokio::test]
656 async fn test_token_bucket_rate_limit() {
657 let config = RateLimitConfig::new()
658 .max_requests(10)
659 .window(Duration::from_secs(1))
660 .burst_size(3)
661 .algorithm(RateLimitAlgorithm::TokenBucket);
662
663 let state = RateLimiterState::new(config);
664
665 assert!(state.check("test-key").await.is_ok());
667 assert!(state.check("test-key").await.is_ok());
668 assert!(state.check("test-key").await.is_ok());
669
670 assert!(state.check("test-key").await.is_err());
672 }
673
674 #[tokio::test]
675 async fn test_different_keys_independent() {
676 let config = RateLimitConfig::new()
677 .max_requests(2)
678 .window(Duration::from_secs(60));
679
680 let state = RateLimiterState::new(config);
681
682 assert!(state.check("key-a").await.is_ok());
684 assert!(state.check("key-a").await.is_ok());
685 assert!(state.check("key-a").await.is_err());
686
687 assert!(state.check("key-b").await.is_ok());
689 assert!(state.check("key-b").await.is_ok());
690 assert!(state.check("key-b").await.is_err());
691 }
692
693 #[test]
694 fn test_rate_limit_info() {
695 let info = RateLimitInfo {
696 limit: 100,
697 remaining: 50,
698 reset: 30,
699 };
700
701 assert_eq!(info.limit, 100);
702 assert_eq!(info.remaining, 50);
703 assert_eq!(info.reset, 30);
704 }
705
706 #[test]
707 fn test_config_builder() {
708 let config = RateLimitConfig::new()
709 .requests_per_minute(60)
710 .burst_size(10)
711 .add_headers(true)
712 .exclude_paths(vec!["/health", "/metrics"]);
713
714 assert_eq!(config.max_requests, 60);
715 assert_eq!(config.burst_size, 10);
716 assert!(config.add_headers);
717 assert_eq!(config.excluded_paths.len(), 2);
718 }
719
720 #[test]
721 fn test_strict_login_config() {
722 let config = RateLimitConfig::strict_login();
723 assert_eq!(config.max_requests, 5);
724 assert_eq!(config.window, Duration::from_secs(60));
725 }
726
727 #[test]
728 fn test_lenient_api_config() {
729 let config = RateLimitConfig::lenient_api();
730 assert_eq!(config.max_requests, 1000);
731 }
732}