karbon_framework/http/middleware/
rate_limit.rs1use axum::{
2 extract::Request,
3 http::StatusCode,
4 response::{IntoResponse, Response},
5};
6use std::collections::HashMap;
7use std::net::IpAddr;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10use tokio::sync::Mutex;
11use tower::{Layer, Service};
12
13#[derive(Clone)]
30pub struct RateLimitLayer {
31 max_requests: u32,
32 window: Duration,
33 store: Arc<Mutex<HashMap<IpAddr, (u32, Instant)>>>,
34}
35
36impl RateLimitLayer {
37 pub fn new(max_requests: u32, window: Duration) -> Self {
38 Self {
39 max_requests,
40 window,
41 store: Arc::new(Mutex::new(HashMap::new())),
42 }
43 }
44
45 pub fn per_minute(max: u32) -> Self {
47 Self::new(max, Duration::from_secs(60))
48 }
49}
50
51impl<S> Layer<S> for RateLimitLayer {
52 type Service = RateLimitService<S>;
53
54 fn layer(&self, inner: S) -> Self::Service {
55 RateLimitService {
56 inner,
57 max_requests: self.max_requests,
58 window: self.window,
59 store: self.store.clone(),
60 }
61 }
62}
63
64#[derive(Clone)]
65pub struct RateLimitService<S> {
66 inner: S,
67 max_requests: u32,
68 window: Duration,
69 store: Arc<Mutex<HashMap<IpAddr, (u32, Instant)>>>,
70}
71
72impl<S> Service<Request> for RateLimitService<S>
73where
74 S: Service<Request, Response = Response> + Clone + Send + 'static,
75 S::Future: Send,
76{
77 type Response = Response;
78 type Error = S::Error;
79 type Future = std::pin::Pin<
80 Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
81 >;
82
83 fn poll_ready(
84 &mut self,
85 cx: &mut std::task::Context<'_>,
86 ) -> std::task::Poll<Result<(), Self::Error>> {
87 self.inner.poll_ready(cx)
88 }
89
90 fn call(&mut self, request: Request) -> Self::Future {
91 let max = self.max_requests;
92 let window = self.window;
93 let store = self.store.clone();
94 let mut inner = self.inner.clone();
95
96 Box::pin(async move {
97 let ip: IpAddr = crate::util::HttpHelper::client_ip(
99 request.headers(),
100 "127.0.0.1".parse().unwrap(),
101 );
102
103 let mut map = store.lock().await;
104 let now = Instant::now();
105
106 let (count, started) = map.entry(ip).or_insert((0, now));
107
108 if now.duration_since(*started) > window {
110 *count = 0;
111 *started = now;
112 }
113
114 *count += 1;
115
116 if *count > max {
117 drop(map);
118 return Ok((
119 StatusCode::TOO_MANY_REQUESTS,
120 "Rate limit exceeded",
121 )
122 .into_response());
123 }
124 drop(map);
125
126 inner.call(request).await
127 })
128 }
129}