1use actix_web::dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform};
29use actix_web::http::header::{HeaderName, HeaderValue};
30use actix_web::http::StatusCode;
31use actix_web::{Error, HttpResponse};
32use futures_util::future::{ok, LocalBoxFuture, Ready};
33use std::collections::HashMap;
34use std::sync::Arc;
35use std::time::{Duration, Instant};
36use tokio::sync::RwLock;
37
38#[derive(Debug, Clone)]
40pub struct RateLimitExceeded {
41 pub retry_after: u64,
43 pub message: String,
45}
46
47impl Default for RateLimitExceeded {
48 fn default() -> Self {
49 Self {
50 retry_after: 60,
51 message: "Too many requests. Please try again later.".to_string(),
52 }
53 }
54}
55
56pub type KeyExtractorFn = Arc<dyn Fn(&ServiceRequest) -> Option<String> + Send + Sync>;
58
59#[derive(Clone, Default)]
61pub enum KeyExtractor {
62 #[default]
64 IpAddress,
65 User,
67 Header(String),
69 IpAndEndpoint,
71 Custom(KeyExtractorFn),
73}
74
75impl std::fmt::Debug for KeyExtractor {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 match self {
78 KeyExtractor::IpAddress => write!(f, "IpAddress"),
79 KeyExtractor::User => write!(f, "User"),
80 KeyExtractor::Header(h) => write!(f, "Header({})", h),
81 KeyExtractor::IpAndEndpoint => write!(f, "IpAndEndpoint"),
82 KeyExtractor::Custom(_) => write!(f, "Custom(<fn>)"),
83 }
84 }
85}
86
87impl KeyExtractor {
88 pub fn extract(&self, req: &ServiceRequest) -> Option<String> {
90 match self {
91 KeyExtractor::IpAddress => req
92 .connection_info()
93 .realip_remote_addr()
94 .map(|s| s.to_string()),
95 KeyExtractor::User => req
96 .headers()
97 .get("Authorization")
98 .and_then(|h| h.to_str().ok())
99 .map(|s| s.to_string()),
100 KeyExtractor::Header(name) => req
101 .headers()
102 .get(name.as_str())
103 .and_then(|h| h.to_str().ok())
104 .map(|s| s.to_string()),
105 KeyExtractor::IpAndEndpoint => {
106 let ip = req.connection_info().realip_remote_addr()?.to_string();
107 let path = req.path().to_string();
108 Some(format!("{}:{}", ip, path))
109 }
110 KeyExtractor::Custom(f) => f(req),
111 }
112 }
113}
114
115#[derive(Debug, Clone, Default)]
117pub enum RateLimitAlgorithm {
118 #[default]
120 FixedWindow,
121 SlidingWindow,
123 TokenBucket,
125}
126
127#[derive(Clone)]
129pub struct RateLimitConfig {
130 pub max_requests: u64,
132 pub window: Duration,
134 pub burst_size: u64,
136 pub algorithm: RateLimitAlgorithm,
138 pub key_extractor: KeyExtractor,
140 pub excluded_paths: Vec<String>,
142 pub add_headers: bool,
144 pub error_response: Option<Arc<dyn Fn(RateLimitExceeded) -> HttpResponse + Send + Sync>>,
146}
147
148impl std::fmt::Debug for RateLimitConfig {
149 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150 f.debug_struct("RateLimitConfig")
151 .field("max_requests", &self.max_requests)
152 .field("window", &self.window)
153 .field("burst_size", &self.burst_size)
154 .field("algorithm", &self.algorithm)
155 .field("key_extractor", &self.key_extractor)
156 .field("excluded_paths", &self.excluded_paths)
157 .field("add_headers", &self.add_headers)
158 .field("error_response", &self.error_response.as_ref().map(|_| "<fn>"))
159 .finish()
160 }
161}
162
163impl Default for RateLimitConfig {
164 fn default() -> Self {
165 Self {
166 max_requests: 100,
167 window: Duration::from_secs(60),
168 burst_size: 10,
169 algorithm: RateLimitAlgorithm::default(),
170 key_extractor: KeyExtractor::default(),
171 excluded_paths: vec![],
172 add_headers: true,
173 error_response: None,
174 }
175 }
176}
177
178impl RateLimitConfig {
179 pub fn new() -> Self {
181 Self::default()
182 }
183
184 pub fn max_requests(mut self, max: u64) -> Self {
186 self.max_requests = max;
187 self
188 }
189
190 pub fn requests_per_second(mut self, rps: u64) -> Self {
192 self.max_requests = rps;
193 self.window = Duration::from_secs(1);
194 self
195 }
196
197 pub fn requests_per_minute(mut self, rpm: u64) -> Self {
199 self.max_requests = rpm;
200 self.window = Duration::from_secs(60);
201 self
202 }
203
204 pub fn window(mut self, window: Duration) -> Self {
206 self.window = window;
207 self
208 }
209
210 pub fn burst_size(mut self, size: u64) -> Self {
212 self.burst_size = size;
213 self
214 }
215
216 pub fn algorithm(mut self, algo: RateLimitAlgorithm) -> Self {
218 self.algorithm = algo;
219 self
220 }
221
222 pub fn key_extractor(mut self, extractor: KeyExtractor) -> Self {
224 self.key_extractor = extractor;
225 self
226 }
227
228 pub fn exclude_paths(mut self, paths: Vec<&str>) -> Self {
230 self.excluded_paths = paths.into_iter().map(String::from).collect();
231 self
232 }
233
234 pub fn add_headers(mut self, add: bool) -> Self {
236 self.add_headers = add;
237 self
238 }
239
240 pub fn error_response<F>(mut self, handler: F) -> Self
242 where
243 F: Fn(RateLimitExceeded) -> HttpResponse + Send + Sync + 'static,
244 {
245 self.error_response = Some(Arc::new(handler));
246 self
247 }
248
249 pub fn strict_login() -> Self {
251 Self::new()
252 .requests_per_minute(5)
253 .burst_size(3)
254 .algorithm(RateLimitAlgorithm::SlidingWindow)
255 }
256
257 pub fn lenient_api() -> Self {
259 Self::new()
260 .requests_per_minute(1000)
261 .burst_size(100)
262 .algorithm(RateLimitAlgorithm::TokenBucket)
263 }
264}
265
266#[derive(Debug, Clone)]
268struct RateLimitEntry {
269 count: u64,
271 window_start: Instant,
273 timestamps: Vec<Instant>,
275 tokens: f64,
277 last_refill: Instant,
279}
280
281impl RateLimitEntry {
282 fn new(config: &RateLimitConfig) -> Self {
283 Self {
284 count: 0,
285 window_start: Instant::now(),
286 timestamps: Vec::new(),
287 tokens: config.burst_size as f64,
288 last_refill: Instant::now(),
289 }
290 }
291}
292
293#[derive(Clone)]
295pub struct RateLimiterState {
296 entries: Arc<RwLock<HashMap<String, RateLimitEntry>>>,
297 config: RateLimitConfig,
298}
299
300impl RateLimiterState {
301 pub fn new(config: RateLimitConfig) -> Self {
303 Self {
304 entries: Arc::new(RwLock::new(HashMap::new())),
305 config,
306 }
307 }
308
309 pub async fn check(&self, key: &str) -> Result<RateLimitInfo, RateLimitExceeded> {
311 let mut entries = self.entries.write().await;
312 let now = Instant::now();
313
314 let entry = entries
315 .entry(key.to_string())
316 .or_insert_with(|| RateLimitEntry::new(&self.config));
317
318 match self.config.algorithm {
319 RateLimitAlgorithm::FixedWindow => self.check_fixed_window(entry, now),
320 RateLimitAlgorithm::SlidingWindow => self.check_sliding_window(entry, now),
321 RateLimitAlgorithm::TokenBucket => self.check_token_bucket(entry, now),
322 }
323 }
324
325 fn check_fixed_window(
326 &self,
327 entry: &mut RateLimitEntry,
328 now: Instant,
329 ) -> Result<RateLimitInfo, RateLimitExceeded> {
330 if now.duration_since(entry.window_start) >= self.config.window {
332 entry.count = 0;
333 entry.window_start = now;
334 }
335
336 if entry.count >= self.config.max_requests {
337 let reset_time = entry.window_start + self.config.window;
338 let retry_after = reset_time.saturating_duration_since(now).as_secs();
339 return Err(RateLimitExceeded {
340 retry_after,
341 message: "Rate limit exceeded".to_string(),
342 });
343 }
344
345 entry.count += 1;
346
347 let reset_time = entry.window_start + self.config.window;
348 Ok(RateLimitInfo {
349 limit: self.config.max_requests,
350 remaining: self.config.max_requests.saturating_sub(entry.count),
351 reset: reset_time.saturating_duration_since(now).as_secs(),
352 })
353 }
354
355 fn check_sliding_window(
356 &self,
357 entry: &mut RateLimitEntry,
358 now: Instant,
359 ) -> Result<RateLimitInfo, RateLimitExceeded> {
360 let window_start = now - self.config.window;
362 entry.timestamps.retain(|&t| t > window_start);
363
364 if entry.timestamps.len() as u64 >= self.config.max_requests {
365 let oldest = entry.timestamps.first().copied().unwrap_or(now);
366 let retry_after = (oldest + self.config.window)
367 .saturating_duration_since(now)
368 .as_secs();
369 return Err(RateLimitExceeded {
370 retry_after,
371 message: "Rate limit exceeded".to_string(),
372 });
373 }
374
375 entry.timestamps.push(now);
376
377 Ok(RateLimitInfo {
378 limit: self.config.max_requests,
379 remaining: self.config.max_requests.saturating_sub(entry.timestamps.len() as u64),
380 reset: self.config.window.as_secs(),
381 })
382 }
383
384 fn check_token_bucket(
385 &self,
386 entry: &mut RateLimitEntry,
387 now: Instant,
388 ) -> Result<RateLimitInfo, RateLimitExceeded> {
389 let elapsed = now.duration_since(entry.last_refill).as_secs_f64();
391 let refill_rate = self.config.max_requests as f64 / self.config.window.as_secs_f64();
392 let new_tokens = elapsed * refill_rate;
393
394 entry.tokens = (entry.tokens + new_tokens).min(self.config.burst_size as f64);
395 entry.last_refill = now;
396
397 if entry.tokens < 1.0 {
398 let tokens_needed = 1.0 - entry.tokens;
399 let retry_after = (tokens_needed / refill_rate).ceil() as u64;
400 return Err(RateLimitExceeded {
401 retry_after,
402 message: "Rate limit exceeded".to_string(),
403 });
404 }
405
406 entry.tokens -= 1.0;
407
408 Ok(RateLimitInfo {
409 limit: self.config.max_requests,
410 remaining: entry.tokens as u64,
411 reset: self.config.window.as_secs(),
412 })
413 }
414
415 pub async fn cleanup(&self) {
417 let mut entries = self.entries.write().await;
418 let now = Instant::now();
419 let window = self.config.window * 2; entries.retain(|_, entry| now.duration_since(entry.window_start) < window);
422 }
423}
424
425#[derive(Debug, Clone)]
427pub struct RateLimitInfo {
428 pub limit: u64,
430 pub remaining: u64,
432 pub reset: u64,
434}
435
436#[derive(Clone)]
438pub struct RateLimiter {
439 state: RateLimiterState,
440}
441
442impl RateLimiter {
443 pub fn new(config: RateLimitConfig) -> Self {
445 Self {
446 state: RateLimiterState::new(config),
447 }
448 }
449
450 pub fn for_login() -> Self {
452 Self::new(RateLimitConfig::strict_login())
453 }
454
455 pub fn for_api() -> Self {
457 Self::new(RateLimitConfig::lenient_api())
458 }
459
460 pub fn state(&self) -> &RateLimiterState {
462 &self.state
463 }
464}
465
466impl<S, B> Transform<S, ServiceRequest> for RateLimiter
467where
468 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
469 B: 'static,
470{
471 type Response = ServiceResponse<B>;
472 type Error = Error;
473 type Transform = RateLimiterMiddleware<S>;
474 type InitError = ();
475 type Future = Ready<Result<Self::Transform, Self::InitError>>;
476
477 fn new_transform(&self, service: S) -> Self::Future {
478 ok(RateLimiterMiddleware {
479 service,
480 state: self.state.clone(),
481 })
482 }
483}
484
485pub struct RateLimiterMiddleware<S> {
487 service: S,
488 state: RateLimiterState,
489}
490
491impl<S, B> Service<ServiceRequest> for RateLimiterMiddleware<S>
492where
493 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
494 B: 'static,
495{
496 type Response = ServiceResponse<B>;
497 type Error = Error;
498 type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
499
500 forward_ready!(service);
501
502 fn call(&self, req: ServiceRequest) -> Self::Future {
503 let state = self.state.clone();
504 let config = state.config.clone();
505
506 let path = req.path().to_string();
508 if config.excluded_paths.iter().any(|p| path.starts_with(p)) {
509 let fut = self.service.call(req);
510 return Box::pin(fut);
511 }
512
513 let key = match config.key_extractor.extract(&req) {
515 Some(k) => k,
516 None => {
517 let fut = self.service.call(req);
519 return Box::pin(fut);
520 }
521 };
522
523 let fut = self.service.call(req);
524 let add_headers = config.add_headers;
525 let error_handler = config.error_response.clone();
526
527 Box::pin(async move {
528 match state.check(&key).await {
530 Ok(info) => {
531 let mut resp = fut.await?;
532
533 if add_headers {
535 let headers = resp.headers_mut();
536 if let Ok(v) = HeaderValue::from_str(&info.limit.to_string()) {
537 headers.insert(
538 HeaderName::from_static("x-ratelimit-limit"),
539 v,
540 );
541 }
542 if let Ok(v) = HeaderValue::from_str(&info.remaining.to_string()) {
543 headers.insert(
544 HeaderName::from_static("x-ratelimit-remaining"),
545 v,
546 );
547 }
548 if let Ok(v) = HeaderValue::from_str(&info.reset.to_string()) {
549 headers.insert(
550 HeaderName::from_static("x-ratelimit-reset"),
551 v,
552 );
553 }
554 }
555
556 Ok(resp)
557 }
558 Err(exceeded) => {
559 let response = if let Some(handler) = error_handler {
561 handler(exceeded.clone())
562 } else {
563 HttpResponse::build(StatusCode::TOO_MANY_REQUESTS)
564 .insert_header(("Retry-After", exceeded.retry_after.to_string()))
565 .insert_header(("X-RateLimit-Limit", config.max_requests.to_string()))
566 .insert_header(("X-RateLimit-Remaining", "0"))
567 .body(exceeded.message)
568 };
569
570 Err(actix_web::error::InternalError::from_response(
573 std::io::Error::new(std::io::ErrorKind::Other, "Rate limit exceeded"),
574 response,
575 )
576 .into())
577 }
578 }
579 })
580 }
581}
582
583#[derive(Clone, Default)]
585pub struct RateLimitBuilder {
586 rules: Vec<(String, RateLimitConfig)>,
587 default: Option<RateLimitConfig>,
588}
589
590impl RateLimitBuilder {
591 pub fn new() -> Self {
593 Self::default()
594 }
595
596 pub fn add_rule(mut self, pattern: &str, config: RateLimitConfig) -> Self {
598 self.rules.push((pattern.to_string(), config));
599 self
600 }
601
602 pub fn default_limit(mut self, config: RateLimitConfig) -> Self {
604 self.default = Some(config);
605 self
606 }
607
608 pub fn protect_login(self, path: &str) -> Self {
610 self.add_rule(path, RateLimitConfig::strict_login())
611 }
612
613 pub fn protect_api(self, path: &str) -> Self {
615 self.add_rule(path, RateLimitConfig::lenient_api())
616 }
617}
618
619#[cfg(test)]
620mod tests {
621 use super::*;
622
623 #[tokio::test]
624 async fn test_fixed_window_rate_limit() {
625 let config = RateLimitConfig::new().max_requests(3).window(Duration::from_secs(60));
626
627 let state = RateLimiterState::new(config);
628
629 assert!(state.check("test-key").await.is_ok());
631 assert!(state.check("test-key").await.is_ok());
632 assert!(state.check("test-key").await.is_ok());
633
634 assert!(state.check("test-key").await.is_err());
636 }
637
638 #[tokio::test]
639 async fn test_sliding_window_rate_limit() {
640 let config = RateLimitConfig::new()
641 .max_requests(3)
642 .window(Duration::from_secs(60))
643 .algorithm(RateLimitAlgorithm::SlidingWindow);
644
645 let state = RateLimiterState::new(config);
646
647 assert!(state.check("test-key").await.is_ok());
649 assert!(state.check("test-key").await.is_ok());
650 assert!(state.check("test-key").await.is_ok());
651
652 assert!(state.check("test-key").await.is_err());
654 }
655
656 #[tokio::test]
657 async fn test_token_bucket_rate_limit() {
658 let config = RateLimitConfig::new()
659 .max_requests(10)
660 .window(Duration::from_secs(1))
661 .burst_size(3)
662 .algorithm(RateLimitAlgorithm::TokenBucket);
663
664 let state = RateLimiterState::new(config);
665
666 assert!(state.check("test-key").await.is_ok());
668 assert!(state.check("test-key").await.is_ok());
669 assert!(state.check("test-key").await.is_ok());
670
671 assert!(state.check("test-key").await.is_err());
673 }
674
675 #[tokio::test]
676 async fn test_different_keys_independent() {
677 let config = RateLimitConfig::new().max_requests(2).window(Duration::from_secs(60));
678
679 let state = RateLimiterState::new(config);
680
681 assert!(state.check("key-a").await.is_ok());
683 assert!(state.check("key-a").await.is_ok());
684 assert!(state.check("key-a").await.is_err());
685
686 assert!(state.check("key-b").await.is_ok());
688 assert!(state.check("key-b").await.is_ok());
689 assert!(state.check("key-b").await.is_err());
690 }
691
692 #[test]
693 fn test_rate_limit_info() {
694 let info = RateLimitInfo {
695 limit: 100,
696 remaining: 50,
697 reset: 30,
698 };
699
700 assert_eq!(info.limit, 100);
701 assert_eq!(info.remaining, 50);
702 assert_eq!(info.reset, 30);
703 }
704
705 #[test]
706 fn test_config_builder() {
707 let config = RateLimitConfig::new()
708 .requests_per_minute(60)
709 .burst_size(10)
710 .add_headers(true)
711 .exclude_paths(vec!["/health", "/metrics"]);
712
713 assert_eq!(config.max_requests, 60);
714 assert_eq!(config.burst_size, 10);
715 assert!(config.add_headers);
716 assert_eq!(config.excluded_paths.len(), 2);
717 }
718
719 #[test]
720 fn test_strict_login_config() {
721 let config = RateLimitConfig::strict_login();
722 assert_eq!(config.max_requests, 5);
723 assert_eq!(config.window, Duration::from_secs(60));
724 }
725
726 #[test]
727 fn test_lenient_api_config() {
728 let config = RateLimitConfig::lenient_api();
729 assert_eq!(config.max_requests, 1000);
730 }
731}