bolt_web/middleware/
ratelimter.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::Arc;
4use std::time::Instant;
5use tokio::sync::Mutex;
6
7use crate::request::RequestBody;
8use crate::response::ResponseWriter;
9use crate::types::Middleware;
10
11pub struct RateLimiterConfig {
12    pub requests: u32,
13    pub per_seconds: u64,
14}
15
16#[derive(Clone)]
17pub struct RateLimiter {
18    config: Arc<RateLimiterConfig>,
19    state: Arc<Mutex<HashMap<String, (u32, Instant)>>>,
20}
21
22impl RateLimiter {
23    pub fn new(config: RateLimiterConfig) -> Self {
24        Self {
25            config: Arc::new(config),
26            state: Arc::new(Mutex::new(HashMap::new())),
27        }
28    }
29}
30
31#[async_trait]
32impl Middleware for RateLimiter {
33    async fn run(&self, req: &mut RequestBody, res: &mut ResponseWriter) {
34        let ip = req
35            .headers()
36            .get("x-forwarded-for")
37            .or_else(|| req.headers().get("host"))
38            .and_then(|h| h.to_str().ok())
39            .unwrap_or("unknown")
40            .to_string();
41
42        let mut state = self.state.lock().await;
43        let now = Instant::now();
44        let (count, last_reset) = state.entry(ip.clone()).or_insert((0, now));
45
46        if now.duration_since(*last_reset).as_secs() > self.config.per_seconds {
47            *count = 0;
48            *last_reset = now;
49        }
50
51        if *count >= self.config.requests {
52            res.status(429).send("Too Many Requests");
53        } else {
54            *count += 1;
55        }
56    }
57}