1use dashmap::DashMap;
2use ntex::http::header::{HeaderName, HeaderValue};
3use ntex::{http::StatusCode, Middleware, ServiceCtx};
4use std::net::IpAddr;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::sync::Arc;
7
8use ntex::{web, Service};
9use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
10
11#[cfg(feature = "tokio")]
12use tokio::time::interval;
13
14#[cfg(feature = "async-std")]
15use async_std::task;
16
17#[cfg(feature = "json")]
18use serde::{Deserialize, Serialize};
19
20const HEADER_RATELIMIT_REMAINING: &str = "x-ratelimit-remaining";
21const HEADER_RATELIMIT_LIMIT: &str = "x-ratelimit-limit";
22const HEADER_RATELIMIT_RESET: &str = "x-ratelimit-reset";
23
24#[derive(Debug)]
26struct TokenBucket {
27 tokens: f64,
28 last_refill: Instant,
29}
30
31impl TokenBucket {
32 fn new(capacity: usize) -> Self {
33 Self {
34 tokens: capacity as f64,
35 last_refill: Instant::now(),
36 }
37 }
38
39 fn consume(&mut self, tokens: usize, now: Instant, config: &RateLimiterConfig) -> bool {
40 self.refill(now, config);
41 if self.tokens >= tokens as f64 {
42 self.tokens -= tokens as f64;
43 true
44 } else {
45 false
46 }
47 }
48
49 fn refill(&mut self, now: Instant, config: &RateLimiterConfig) {
50 let elapsed = now.duration_since(self.last_refill).as_secs_f64();
51 let refill_rate = config.capacity as f64 / config.window as f64;
52 let new_tokens = elapsed * refill_rate;
53 self.tokens = (self.tokens + new_tokens).min(config.capacity as f64);
54 self.last_refill = now;
55 }
56
57 fn remaining_tokens(&self) -> u32 {
58 self.tokens.floor() as u32
59 }
60
61 fn reset_time(&self, _now: Instant, config: &RateLimiterConfig) -> u64 {
62 let now_secs = SystemTime::now()
63 .duration_since(UNIX_EPOCH)
64 .unwrap_or_default()
65 .as_secs();
66
67 if self.tokens >= config.capacity as f64 {
68 return now_secs;
69 }
70
71 let missing_tokens = config.capacity as f64 - self.tokens;
72 let refill_rate = config.capacity as f64 / config.window as f64;
73 let seconds_to_refill = missing_tokens / refill_rate;
74
75 now_secs + seconds_to_refill.ceil() as u64
76 }
77
78 fn is_stale(&self, now: Instant, stale_threshold: Duration) -> bool {
80 now.duration_since(self.last_refill) > stale_threshold
81 }
82}
83
84#[derive(Debug, Clone)]
86pub struct RateLimiterConfig {
87 pub capacity: usize,
88 pub window: u64,
89 pub cleanup_interval: Duration,
90 pub stale_threshold: Duration,
91}
92
93impl Default for RateLimiterConfig {
94 fn default() -> Self {
95 Self {
96 capacity: 100,
97 window: 60,
98 cleanup_interval: Duration::from_secs(300), stale_threshold: Duration::from_secs(3600), }
101 }
102}
103
104pub struct RateLimiter {
106 map: DashMap<IpAddr, TokenBucket>,
107 config: RateLimiterConfig,
108 last_cleanup: AtomicU64,
109}
110
111impl RateLimiter {
112 pub fn new(capacity: usize, window: u64) -> Arc<Self> {
114 let config = RateLimiterConfig {
115 capacity,
116 window,
117 ..Default::default()
118 };
119 Self::with_config(config)
120 }
121
122 pub fn with_config(config: RateLimiterConfig) -> Arc<Self> {
124 assert!(config.window > 0, "RateLimiter window must be greater than zero");
125
126 let limiter = Arc::new(RateLimiter {
127 map: DashMap::new(),
128 config,
129 last_cleanup: AtomicU64::new(
130 SystemTime::now()
131 .duration_since(UNIX_EPOCH)
132 .unwrap_or_default()
133 .as_secs(),
134 ),
135 });
136
137 #[cfg(any(feature = "tokio", feature = "async-std"))]
139 Self::start_cleanup_task(Arc::clone(&limiter));
140
141 limiter
142 }
143
144 #[cfg(feature = "tokio")]
145 fn start_cleanup_task(limiter: Arc<RateLimiter>) {
146 tokio::spawn(async move {
147 let mut interval = interval(limiter.config.cleanup_interval);
148 loop {
149 interval.tick().await;
150 limiter.cleanup().await;
151 }
152 });
153 }
154
155 #[cfg(feature = "async-std")]
156 fn start_cleanup_task(limiter: Arc<RateLimiter>) {
157 let cleanup_interval = limiter.config.cleanup_interval;
158 task::spawn(async move {
159 loop {
160 task::sleep(cleanup_interval).await;
161 limiter.cleanup().await;
162 }
163 });
164 }
165
166 pub fn check_rate_limit(&self, identifier: IpAddr) -> RateLimitResult {
168 let now = Instant::now();
169 let mut bucket = self
170 .map
171 .entry(identifier)
172 .or_insert_with(|| TokenBucket::new(self.config.capacity));
173
174 let allowed = bucket.consume(1, now, &self.config);
175 let remaining = bucket.remaining_tokens();
176 let reset = bucket.reset_time(now, &self.config);
177
178 RateLimitResult {
179 allowed,
180 remaining,
181 reset,
182 limit: self.config.capacity,
183 }
184 }
185
186 async fn cleanup(&self) {
188 let now_secs = SystemTime::now()
189 .duration_since(UNIX_EPOCH)
190 .unwrap_or_default()
191 .as_secs();
192
193 let last_cleanup = self.last_cleanup.load(Ordering::Acquire);
194
195 if now_secs.saturating_sub(last_cleanup) < self.config.cleanup_interval.as_secs() {
197 return;
198 }
199
200 if self
202 .last_cleanup
203 .compare_exchange(last_cleanup, now_secs, Ordering::AcqRel, Ordering::Relaxed)
204 .is_err()
205 {
206 return;
208 }
209
210 let now = Instant::now();
211 let stale_threshold = self.config.stale_threshold;
212
213 let initial_size = self.map.len();
214 self.map
215 .retain(|_, bucket| !bucket.is_stale(now, stale_threshold));
216 let final_size = self.map.len();
217
218 if cfg!(debug_assertions) && initial_size > final_size {
219 eprintln!(
220 "Cleaned {} stale rate limit entries",
221 initial_size - final_size
222 );
223 }
224 }
225
226 pub fn stats(&self) -> RateLimiterStats {
228 RateLimiterStats {
229 active_entries: self.map.len(),
230 capacity: self.config.capacity,
231 window: self.config.window,
232 }
233 }
234}
235
236#[derive(Debug, Clone)]
238pub struct RateLimitResult {
239 pub allowed: bool,
240 pub remaining: u32,
241 pub reset: u64,
242 pub limit: usize,
243}
244
245#[derive(Debug, Clone)]
247pub struct RateLimiterStats {
248 pub active_entries: usize,
249 pub capacity: usize,
250 pub window: u64,
251}
252
253pub struct RateLimit {
255 pub limiter: Arc<RateLimiter>,
256}
257
258impl RateLimit {
259 pub fn new(limiter: Arc<RateLimiter>) -> Self {
260 Self { limiter }
261 }
262}
263
264impl<S> Middleware<S> for RateLimit {
265 type Service = RateLimitMiddlewareService<S>;
266
267 fn create(&self, service: S) -> Self::Service {
268 RateLimitMiddlewareService {
269 service,
270 limiter: Arc::clone(&self.limiter),
271 }
272 }
273}
274
275pub struct RateLimitMiddlewareService<S> {
276 service: S,
277 limiter: Arc<RateLimiter>,
278}
279
280impl<S, Err> Service<web::WebRequest<Err>> for RateLimitMiddlewareService<S>
281where
282 S: Service<web::WebRequest<Err>, Response = web::WebResponse, Error = web::Error> + 'static,
283 Err: web::ErrorRenderer,
284{
285 type Response = web::WebResponse;
286 type Error = web::Error;
287
288 async fn call(
289 &self,
290 req: web::WebRequest<Err>,
291 ctx: ServiceCtx<'_, Self>,
292 ) -> Result<Self::Response, Self::Error> {
293 let ip = extract_client_ip(&req);
294
295 let result = self.limiter.check_rate_limit(ip);
296
297 if !result.allowed {
298 return Err(RateLimitError::from(result).into());
299 }
300
301 let mut response = ctx.call(&self.service, req).await?;
302
303 add_rate_limit_headers(response.headers_mut(), &result);
305
306 Ok(response)
307 }
308}
309
310fn extract_client_ip<Err>(req: &web::WebRequest<Err>) -> IpAddr {
312 if let Some(forwarded) = req.headers().get("x-forwarded-for") {
314 if let Ok(forwarded_str) = forwarded.to_str() {
315 if let Some(ip) = forwarded_str.split(',').next() {
316 let ip = ip.trim();
317 if let Ok(parsed_ip) = ip.parse::<IpAddr>() {
318 return parsed_ip;
319 }
320 }
321 }
322 }
323
324 if let Some(real_ip) = req.headers().get("x-real-ip") {
326 if let Ok(ip_str) = real_ip.to_str() {
327 let ip = ip_str.trim();
328 if let Ok(parsed_ip) = ip.parse::<IpAddr>() {
329 return parsed_ip;
330 }
331 }
332 }
333
334 if let Some(addr_str) = req.connection_info().remote() {
336 if let Ok(sock_addr) = addr_str.parse::<std::net::SocketAddr>() {
337 return sock_addr.ip();
338 }
339 }
340
341 IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1))
343}
344
345fn add_rate_limit_headers(headers: &mut ntex::http::HeaderMap, result: &RateLimitResult) {
347 if let Ok(value) = HeaderValue::from_str(&result.remaining.to_string()) {
348 headers.insert(HeaderName::from_static(HEADER_RATELIMIT_REMAINING), value);
349 }
350 if let Ok(value) = HeaderValue::from_str(&result.limit.to_string()) {
351 headers.insert(HeaderName::from_static(HEADER_RATELIMIT_LIMIT), value);
352 }
353 if let Ok(value) = HeaderValue::from_str(&result.reset.to_string()) {
354 headers.insert(HeaderName::from_static(HEADER_RATELIMIT_RESET), value);
355 }
356}
357
358#[derive(Debug)]
360#[cfg_attr(feature = "json", derive(Serialize, Deserialize))]
361struct RateLimitErrorData {
362 remaining: u32,
363 reset: u64,
364 limit: usize,
365}
366
367#[derive(Debug)]
368#[cfg_attr(feature = "json", derive(Serialize, Deserialize))]
369struct RateLimitErrorResponse {
370 code: u32,
371 message: String,
372 data: RateLimitErrorData,
373}
374
375#[derive(Debug)]
376struct RateLimitError {
377 data: RateLimitErrorData,
378}
379
380impl From<RateLimitResult> for RateLimitError {
381 fn from(result: RateLimitResult) -> Self {
382 Self {
383 data: RateLimitErrorData {
384 remaining: result.remaining,
385 reset: result.reset,
386 limit: result.limit,
387 },
388 }
389 }
390}
391
392impl std::fmt::Display for RateLimitError {
393 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
394 write!(
395 f,
396 "Rate limit exceeded. Remaining: {}, Reset: {}, Limit: {}",
397 self.data.remaining, self.data.reset, self.data.limit
398 )
399 }
400}
401
402impl web::error::WebResponseError for RateLimitError {
403 fn error_response(&self, _: &ntex::web::HttpRequest) -> web::HttpResponse {
404 let error_response = RateLimitErrorResponse {
405 code: 429,
406 message: "Rate limit exceeded".to_string(),
407 data: RateLimitErrorData {
408 remaining: self.data.remaining,
409 reset: self.data.reset,
410 limit: self.data.limit,
411 },
412 };
413
414 #[cfg(feature = "json")]
415 let body = serde_json::to_string(&error_response)
416 .unwrap_or_else(|_| r#"{"code":429,"message":"Rate limit exceeded"}"#.to_string());
417
418 #[cfg(not(feature = "json"))]
419 let body = format!(
420 r#"{{"code":429,"message":"Rate limit exceeded","data":{{"remaining":{},"reset":{},"limit":{}}}}}"#,
421 self.data.remaining, self.data.reset, self.data.limit
422 );
423
424 web::HttpResponse::build(StatusCode::TOO_MANY_REQUESTS)
425 .set_header("content-type", "application/json")
426 .set_header(HEADER_RATELIMIT_REMAINING, self.data.remaining.to_string())
427 .set_header(HEADER_RATELIMIT_LIMIT, self.data.limit.to_string())
428 .set_header(HEADER_RATELIMIT_RESET, self.data.reset.to_string())
429 .body(body)
430 }
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436
437 #[test]
438 fn test_token_bucket_basic() {
439 let config = RateLimiterConfig {
440 capacity: 5,
441 window: 10,
442 ..Default::default()
443 };
444 let mut bucket = TokenBucket::new(5);
445 let now = Instant::now();
446
447 for _ in 0..5 {
449 assert!(bucket.consume(1, now, &config));
450 }
451
452 assert!(!bucket.consume(1, now, &config));
454 assert_eq!(bucket.remaining_tokens(), 0);
455 }
456
457 #[test]
458 fn test_token_bucket_refill() {
459 let config = RateLimiterConfig {
460 capacity: 10,
461 window: 10, ..Default::default()
463 };
464 let mut bucket = TokenBucket::new(10);
465 let now = Instant::now();
466
467 for _ in 0..10 {
469 assert!(bucket.consume(1, now, &config));
470 }
471 assert!(!bucket.consume(1, now, &config));
472
473 let later = now + Duration::from_secs(5);
475 bucket.refill(later, &config);
476 assert_eq!(bucket.remaining_tokens(), 5);
477
478 for _ in 0..5 {
480 assert!(bucket.consume(1, later, &config));
481 }
482 assert!(!bucket.consume(1, later, &config));
483 }
484
485 #[tokio::test]
486 async fn test_rate_limiter() {
487 let config = RateLimiterConfig {
488 capacity: 5,
489 window: 1,
490 ..Default::default()
491 };
492 let limiter = RateLimiter::with_config(config);
493 let ip = "192.168.1.1".parse::<IpAddr>().unwrap();
494
495 for i in 0..5 {
497 let result = limiter.check_rate_limit(ip);
498 assert!(result.allowed, "Request {} should be allowed", i + 1);
499 assert_eq!(result.remaining, 4 - i as u32);
500 }
501
502 let result = limiter.check_rate_limit(ip);
504 assert!(!result.allowed);
505 assert_eq!(result.remaining, 0);
506 }
507
508 #[tokio::test]
509 async fn test_rate_limiter_different_ips() {
510 let limiter = RateLimiter::new(2, 60);
511
512 let ip1 = "192.168.1.1".parse::<IpAddr>().unwrap();
514 let ip2 = "192.168.1.2".parse::<IpAddr>().unwrap();
515 let result1 = limiter.check_rate_limit(ip1);
516 let result2 = limiter.check_rate_limit(ip2);
517
518 assert!(result1.allowed);
519 assert!(result2.allowed);
520 assert_eq!(result1.remaining, 1);
521 assert_eq!(result2.remaining, 1);
522 }
523}