ggen_api/middleware/
rate_limit.rs1use axum::{
4 extract::Request,
5 http::StatusCode,
6 middleware::Next,
7 response::{IntoResponse, Response},
8};
9use moka::future::Cache;
10use std::sync::Arc;
11use std::time::Duration;
12
13pub struct RateLimiter {
15 request_counts: Arc<Cache<String, u32>>,
17 max_requests_per_minute: u32,
18}
19
20impl RateLimiter {
21 pub fn new(max_requests_per_minute: u32) -> Self {
22 Self {
23 request_counts: Arc::new(
24 Cache::builder()
25 .time_to_live(Duration::from_secs(60))
26 .build(),
27 ),
28 max_requests_per_minute,
29 }
30 }
31
32 pub async fn check_rate_limit(&self, ip: &str) -> Result<(), RateLimitError> {
33 let current = self
34 .request_counts
35 .try_get_with(ip.to_string(), async { Ok::<u32, std::io::Error>(0) })
36 .await
37 .unwrap_or(0);
38
39 if current >= self.max_requests_per_minute {
40 return Err(RateLimitError::TooManyRequests);
41 }
42
43 self.request_counts
44 .insert(ip.to_string(), current + 1)
45 .await;
46
47 Ok(())
48 }
49}
50
51#[derive(Debug)]
52pub enum RateLimitError {
53 TooManyRequests,
54}
55
56impl IntoResponse for RateLimitError {
57 fn into_response(self) -> Response {
58 match self {
59 RateLimitError::TooManyRequests => (
60 StatusCode::TOO_MANY_REQUESTS,
61 "Rate limit exceeded",
62 )
63 .into_response(),
64 }
65 }
66}
67
68pub async fn rate_limit(
70 request: Request,
71 next: Next,
72) -> Response {
73 next.run(request).await
78}