bolt_web/middleware/
ratelimter.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::Arc;
4use std::time::Instant;
5use tokio::sync::Mutex;
6
7use crate::request::RequestBody;
8use crate::response::ResponseWriter;
9use crate::types::Middleware;
10
11pub struct RateLimiterConfig {
12    pub requests: u32,
13    pub per_seconds: u64,
14}
15
16#[derive(Clone)]
17pub struct RateLimiter {
18    config: Arc<RateLimiterConfig>,
19    state: Arc<Mutex<HashMap<String, (u32, Instant)>>>,
20}
21
22#[allow(dead_code)]
23impl RateLimiter {
24    pub fn new(config: RateLimiterConfig) -> Self {
25        Self {
26            config: Arc::new(config),
27            state: Arc::new(Mutex::new(HashMap::new())),
28        }
29    }
30}
31
32#[async_trait]
33impl Middleware for RateLimiter {
34    async fn run(&self, req: &mut RequestBody, res: &mut ResponseWriter) {
35        let ip = req
36            .headers()
37            .get("x-forwarded-for")
38            .or_else(|| req.headers().get("host"))
39            .and_then(|h| h.to_str().ok())
40            .unwrap_or("unknown")
41            .to_string();
42
43        let mut state = self.state.lock().await;
44        let now = Instant::now();
45        let (count, last_reset) = state.entry(ip.clone()).or_insert((0, now));
46
47        if now.duration_since(*last_reset).as_secs() > self.config.per_seconds {
48            *count = 0;
49            *last_reset = now;
50        }
51
52        if *count >= self.config.requests {
53            res.status(429).send("Too Many Requests");
54        } else {
55            *count += 1;
56        }
57    }
58}