1#![forbid(unsafe_code)]
2use chrono::{DateTime, Duration as ChronoDuration, Utc};
114use serde::{Deserialize, Serialize};
115
116use std::collections::HashMap;
117use std::sync::Arc;
118use std::time::{Duration, Instant};
119use tokio::sync::RwLock;
120use warp::{
121 http::header::{self, HeaderMap, HeaderValue},
122 reject, Filter, Rejection
123};
124
125pub use chrono;
126pub use serde;
127
128#[derive(Clone, Debug, PartialEq)]
130pub struct RateLimitConfig {
131 pub max_requests: u32,
133 pub window: Duration,
135 pub retry_after_format: RetryAfterFormat,
137}
138
139#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
141pub enum RetryAfterFormat {
142 #[default]
144 HttpDate,
145 Seconds,
147}
148
149#[derive(Clone, Debug, Serialize, Deserialize)]
151pub struct RateLimitInfo {
152 pub retry_after: String,
154 pub limit: u32,
156 pub remaining: u32,
158 pub reset_timestamp: i64,
160 pub retry_after_format: RetryAfterFormat,
162}
163
164#[derive(Debug)]
166pub struct RateLimitRejection {
167 pub retry_after: Duration,
169 pub limit: u32,
171 pub reset_time: DateTime<Utc>,
173 pub retry_after_format: RetryAfterFormat,
175}
176
177impl warp::reject::Reject for RateLimitRejection {}
178
179impl Default for RateLimitConfig {
181 fn default() -> Self {
182 Self {
183 max_requests: 60, window: Duration::from_secs(60),
185 retry_after_format: RetryAfterFormat::HttpDate,
186 }
187 }
188}
189
190impl RateLimitConfig {
192 pub fn max_per_minute(max: u32) -> Self {
194 Self {
195 max_requests: max,
196 window: Duration::from_secs(60),
197 ..Default::default()
198 }
199 }
200
201 pub fn max_per_window(max_requests: u32, window_seconds: u64) -> Self {
203 Self {
204 max_requests,
205 window: Duration::from_secs(window_seconds),
206 ..Default::default()
207 }
208 }
209}
210
211#[derive(Debug)]
213pub enum RateLimitError {
214 HeaderError(warp::http::header::InvalidHeaderValue),
216 Other(Box<dyn std::error::Error + Send + Sync>),
218}
219
220impl std::fmt::Display for RateLimitError {
221 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
222 match self {
223 RateLimitError::HeaderError(e) => write!(f, "Failed to set rate limit header: {}", e),
224 RateLimitError::Other(e) => write!(f, "Rate limit error: {}", e),
225 }
226 }
227}
228
229impl std::error::Error for RateLimitError {
230 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
231 match self {
232 RateLimitError::HeaderError(e) => Some(e),
233 RateLimitError::Other(e) => Some(&**e),
234 }
235 }
236}
237
238#[derive(Clone)]
239struct RateLimiter {
240 state: Arc<RwLock<HashMap<String, (Instant, u32)>>>,
241 config: RateLimitConfig,
242}
243
244impl RateLimiter {
245 fn new(config: RateLimitConfig) -> Self {
246 Self {
247 state: Arc::new(RwLock::new(HashMap::new())),
248 config,
249 }
250 }
251
252 async fn check_rate_limit(&self, key: &str) -> Result<RateLimitInfo, Rejection> {
253 let mut state = self.state.write().await;
254 let now = Instant::now();
255 let current = state.get(key).copied();
256
257 match current {
258 Some((last_request, count)) => {
259 if now.duration_since(last_request) > self.config.window {
260 state.insert(key.to_string(), (now, 1));
262 Ok(self.create_info(self.config.max_requests - 1, now))
263 } else if count >= self.config.max_requests {
264 let retry_after = self.config.window - now.duration_since(last_request);
266 let reset_time = Utc::now() + ChronoDuration::from_std(retry_after).unwrap();
267
268 Err(reject::custom(RateLimitRejection {
269 retry_after,
270 limit: self.config.max_requests,
271 reset_time,
272 retry_after_format: self.config.retry_after_format.clone(),
273 }))
274 } else {
275 state.insert(key.to_string(), (last_request, count + 1));
277 Ok(self.create_info(
278 self.config.max_requests - (count + 1),
279 last_request,
280 ))
281 }
282 }
283 None => {
284 state.insert(key.to_string(), (now, 1));
286 Ok(self.create_info(self.config.max_requests - 1, now))
287 }
288 }
289 }
290
291 fn create_info(&self, remaining: u32, start: Instant) -> RateLimitInfo {
292 let reset_time = start + self.config.window;
293 let retry_after = match self.config.retry_after_format {
294 RetryAfterFormat::HttpDate => {
295 (Utc::now() + ChronoDuration::from_std(self.config.window).unwrap()).to_rfc2822()
296 }
297 RetryAfterFormat::Seconds => self.config.window.as_secs().to_string(),
298 };
299
300 RateLimitInfo {
301 retry_after,
302 limit: self.config.max_requests,
303 remaining,
304 reset_timestamp: (Utc::now() + ChronoDuration::from_std(reset_time.duration_since(start)).unwrap()).timestamp(),
305 retry_after_format: self.config.retry_after_format.clone(),
306 }
307 }
308}
309
310pub fn with_rate_limit(
312 config: RateLimitConfig,
313) -> impl Filter<Extract = (RateLimitInfo,), Error = Rejection> + Clone {
314 let rate_limiter = RateLimiter::new(config);
315
316 warp::filters::addr::remote()
317 .map(move |addr: Option<std::net::SocketAddr>| {
318 (
319 rate_limiter.clone(),
320 addr.map(|a| a.ip().to_string())
321 .unwrap_or_else(|| "unknown".to_string()),
322 )
323 })
324 .and_then(|(rate_limiter, ip): (RateLimiter, String)| async move {
325 rate_limiter.check_rate_limit(&ip).await
326 })
327}
328
329pub fn add_rate_limit_headers(
331 headers: &mut HeaderMap,
332 info: &RateLimitInfo,
333) -> Result<(), RateLimitError> {
334 headers.insert(header::RETRY_AFTER,
335 HeaderValue::from_str(&info.retry_after).map_err(RateLimitError::HeaderError)?);
336 headers.insert(
337 "X-RateLimit-Limit",
338 HeaderValue::from_str(&info.limit.to_string()).map_err(RateLimitError::HeaderError)?,
339 );
340 headers.insert(
341 "X-RateLimit-Remaining",
342 HeaderValue::from_str(&info.remaining.to_string()).map_err(RateLimitError::HeaderError)?,
343 );
344 headers.insert(
345 "X-RateLimit-Reset",
346 HeaderValue::from_str(&info.reset_timestamp.to_string()).map_err(RateLimitError::HeaderError)?,
347 );
348 Ok(())
349}
350
351pub fn get_rate_limit_info(rejection: &RateLimitRejection) -> RateLimitInfo {
353 let retry_after = match rejection.retry_after_format {
354 RetryAfterFormat::HttpDate => rejection.reset_time.to_rfc2822(),
355 RetryAfterFormat::Seconds => rejection.retry_after.as_secs().to_string(),
356 };
357
358 RateLimitInfo {
359 retry_after,
360 limit: rejection.limit,
361 remaining: 0,
362 reset_timestamp: rejection.reset_time.timestamp(),
363 retry_after_format: rejection.retry_after_format.clone(),
364 }
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370 use tokio::task::JoinSet;
371 use warp::Reply;
372 use warp::{
373 test::request,
374 http::StatusCode,
375 Filter,
376 };
377 use std::convert::Infallible;
378
379 async fn create_test_route(
381 config: RateLimitConfig,
382 ) -> impl Filter<Extract = impl Reply, Error = Infallible> + Clone {
383 with_rate_limit(config)
384 .map(|info: RateLimitInfo| info.remaining.to_string())
385 .recover(|rejection: Rejection| async move {
386 if let Some(rate_limit) = rejection.find::<RateLimitRejection>() {
387 let info = get_rate_limit_info(rate_limit);
388 let mut resp = warp::reply::with_status(
389 "Rate limit exceeded",
390 StatusCode::TOO_MANY_REQUESTS,
391 ).into_response();
392 add_rate_limit_headers(resp.headers_mut(), &info).unwrap();
393 Ok(resp)
394 } else {
395 Ok(warp::reply::with_status(
396 "Internal error",
397 StatusCode::INTERNAL_SERVER_ERROR,
398 ).into_response())
399 }
400 })
401 }
402
403 #[test]
404 fn test_config_builders() {
405 let per_minute = RateLimitConfig::max_per_minute(60);
407 assert_eq!(per_minute.window, Duration::from_secs(60));
408 assert_eq!(per_minute.max_requests, 60);
409 assert_eq!(per_minute.retry_after_format, RetryAfterFormat::HttpDate);
410
411 let custom = RateLimitConfig::max_per_window(30, 120);
413 assert_eq!(custom.window, Duration::from_secs(120));
414 assert_eq!(custom.max_requests, 30);
415 assert_eq!(custom.retry_after_format, RetryAfterFormat::HttpDate);
416
417 let default = RateLimitConfig::default();
419 assert_eq!(default.window, Duration::from_secs(60));
420 assert_eq!(default.max_requests, 60);
421 assert_eq!(default.retry_after_format, RetryAfterFormat::HttpDate);
422 }
423
424 #[tokio::test]
425 async fn test_comprehensive_rate_limit_rejection() {
426 let config = RateLimitConfig {
427 max_requests: 1,
428 window: Duration::from_secs(5),
429 retry_after_format: RetryAfterFormat::Seconds,
430 };
431
432 let route = create_test_route(config.clone()).await;
433
434 let resp1 = request()
436 .remote_addr("127.0.0.1:1234".parse().unwrap())
437 .reply(&route)
438 .await;
439 assert_eq!(resp1.status(), 200);
440 assert_eq!(resp1.body(), "0"); let resp2 = request()
444 .remote_addr("127.0.0.1:1234".parse().unwrap())
445 .reply(&route)
446 .await;
447
448 assert_eq!(resp2.status(), 429);
449
450 let headers = resp2.headers();
452 assert!(headers.contains_key(header::RETRY_AFTER));
453 assert!(headers.contains_key("X-RateLimit-Limit"));
454 assert!(headers.contains_key("X-RateLimit-Remaining"));
455 assert!(headers.contains_key("X-RateLimit-Reset"));
456
457 assert_eq!(headers.get("X-RateLimit-Limit").unwrap(), "1");
459 assert_eq!(headers.get("X-RateLimit-Remaining").unwrap(), "0");
460
461 let retry_after = headers.get(header::RETRY_AFTER).unwrap().to_str().unwrap();
463 assert!(retry_after.parse::<u64>().is_ok());
464 }
465
466 #[tokio::test]
467 async fn test_retry_after_formats() {
468 let http_date_config = RateLimitConfig {
470 max_requests: 1,
471 window: Duration::from_secs(15),
472 retry_after_format: RetryAfterFormat::HttpDate,
473 };
474
475 let http_date_route = create_test_route(http_date_config).await;
476
477 let _ = request()
479 .remote_addr("127.0.0.1:1234".parse().unwrap())
480 .reply(&http_date_route)
481 .await;
482
483 let resp_http = request()
484 .remote_addr("127.0.0.1:1234".parse().unwrap())
485 .reply(&http_date_route)
486 .await;
487
488 let retry_after_http = resp_http.headers().get(header::RETRY_AFTER).unwrap().to_str().unwrap();
490 assert!(!retry_after_http.is_empty()); let seconds_config = RateLimitConfig {
494 max_requests: 1,
495 window: Duration::from_secs(5),
496 retry_after_format: RetryAfterFormat::Seconds,
497 };
498
499 let seconds_route = create_test_route(seconds_config).await;
500
501 let _ = request()
503 .remote_addr("127.0.0.2:1234".parse().unwrap())
504 .reply(&seconds_route)
505 .await;
506
507 let resp_sec = request()
508 .remote_addr("127.0.0.2:1234".parse().unwrap())
509 .reply(&seconds_route)
510 .await;
511
512 let retry_after_sec = resp_sec.headers().get(header::RETRY_AFTER).unwrap().to_str().unwrap();
514 assert!(retry_after_sec.parse::<u64>().is_ok());
515 assert!(retry_after_sec.parse::<u64>().unwrap() <= 5);
516 }
517
518 #[test]
519 fn test_rate_limit_info_extraction() {
520 let now = Utc::now();
521 let rejection = RateLimitRejection {
522 retry_after: Duration::from_secs(60),
523 limit: 100,
524 reset_time: now,
525 retry_after_format: RetryAfterFormat::Seconds,
526 };
527
528 let info = get_rate_limit_info(&rejection);
529
530 assert_eq!(info.limit, 100);
531 assert_eq!(info.remaining, 0);
532 assert_eq!(info.reset_timestamp, now.timestamp());
533 assert_eq!(info.retry_after, "60");
534
535 let rejection_http = RateLimitRejection {
537 retry_after: Duration::from_secs(60),
538 limit: 100,
539 reset_time: now,
540 retry_after_format: RetryAfterFormat::HttpDate,
541 };
542
543 let info_http = get_rate_limit_info(&rejection_http);
544 assert!(!info_http.retry_after.is_empty()); }
546
547 #[tokio::test]
548 async fn test_concurrent_requests() {
549 let config = RateLimitConfig {
550 max_requests: 5,
551 window: Duration::from_secs(1),
552 retry_after_format: RetryAfterFormat::Seconds,
553 };
554
555 let route = create_test_route(config.clone()).await;
556 let mut set = JoinSet::new();
557
558 for _ in 0..10 {
560 let route = route.clone();
561 set.spawn(async move {
562 request()
563 .remote_addr("127.0.0.1:1234".parse().unwrap())
564 .reply(&route)
565 .await
566 });
567 }
568
569 let mut success_count = 0;
570 let mut rate_limited_count = 0;
571
572 while let Some(Ok(resp)) = set.join_next().await {
573 match resp.status() {
574 StatusCode::OK => success_count += 1,
575 StatusCode::TOO_MANY_REQUESTS => rate_limited_count += 1,
576 _ => panic!("Unexpected response status"),
577 }
578 }
579
580 assert_eq!(success_count, 5, "Expected exactly 5 successful requests");
581 assert_eq!(rate_limited_count, 5, "Expected exactly 5 rate-limited requests");
582 }
583
584 #[test]
585 fn test_invalid_header_value_handling() {
586 let mut headers = HeaderMap::new();
587 let invalid_info = RateLimitInfo {
588 retry_after: "invalid\u{0000}characters".to_string(),
589 limit: 100,
590 remaining: 50,
591 reset_timestamp: 1234567890,
592 retry_after_format: RetryAfterFormat::Seconds,
593 };
594
595 let result = add_rate_limit_headers(&mut headers, &invalid_info);
596 assert!(matches!(result, Err(RateLimitError::HeaderError(_))));
597 }
598}