bolt_web/middleware/
ratelimter.rs1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::Arc;
4use std::time::Instant;
5use tokio::sync::Mutex;
6
7use crate::request::RequestBody;
8use crate::response::ResponseWriter;
9use crate::types::Middleware;
10
11pub struct RateLimiterConfig {
12 pub requests: u32,
13 pub per_seconds: u64,
14}
15
16#[derive(Clone)]
17pub struct RateLimiter {
18 config: Arc<RateLimiterConfig>,
19 state: Arc<Mutex<HashMap<String, (u32, Instant)>>>,
20}
21
22#[allow(dead_code)]
23impl RateLimiter {
24 pub fn new(config: RateLimiterConfig) -> Self {
25 Self {
26 config: Arc::new(config),
27 state: Arc::new(Mutex::new(HashMap::new())),
28 }
29 }
30}
31
32#[async_trait]
33impl Middleware for RateLimiter {
34 async fn run(&self, req: &mut RequestBody, res: &mut ResponseWriter) {
35 let ip = req
36 .headers()
37 .get("x-forwarded-for")
38 .or_else(|| req.headers().get("host"))
39 .and_then(|h| h.to_str().ok())
40 .unwrap_or("unknown")
41 .to_string();
42
43 let mut state = self.state.lock().await;
44 let now = Instant::now();
45 let (count, last_reset) = state.entry(ip.clone()).or_insert((0, now));
46
47 if now.duration_since(*last_reset).as_secs() > self.config.per_seconds {
48 *count = 0;
49 *last_reset = now;
50 }
51
52 if *count >= self.config.requests {
53 res.status(429).send("Too Many Requests");
54 } else {
55 *count += 1;
56 }
57 }
58}