bolt_web/middleware/
ratelimter.rs1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::Arc;
4use std::time::Instant;
5use tokio::sync::Mutex;
6
7use crate::request::RequestBody;
8use crate::response::ResponseWriter;
9use crate::types::Middleware;
10
11pub struct RateLimiterConfig {
12 pub requests: u32,
13 pub per_seconds: u64,
14}
15
16#[derive(Clone)]
17pub struct RateLimiter {
18 config: Arc<RateLimiterConfig>,
19 state: Arc<Mutex<HashMap<String, (u32, Instant)>>>,
20}
21
22impl RateLimiter {
23 pub fn new(config: RateLimiterConfig) -> Self {
24 Self {
25 config: Arc::new(config),
26 state: Arc::new(Mutex::new(HashMap::new())),
27 }
28 }
29}
30
31#[async_trait]
32impl Middleware for RateLimiter {
33 async fn run(&self, req: &mut RequestBody, res: &mut ResponseWriter) {
34 let ip = req
35 .headers()
36 .get("x-forwarded-for")
37 .or_else(|| req.headers().get("host"))
38 .and_then(|h| h.to_str().ok())
39 .unwrap_or("unknown")
40 .to_string();
41
42 let mut state = self.state.lock().await;
43 let now = Instant::now();
44 let (count, last_reset) = state.entry(ip.clone()).or_insert((0, now));
45
46 if now.duration_since(*last_reset).as_secs() > self.config.per_seconds {
47 *count = 0;
48 *last_reset = now;
49 }
50
51 if *count >= self.config.requests {
52 res.status(429).send("Too Many Requests");
53 } else {
54 *count += 1;
55 }
56 }
57}