Skip to main content

brainwires_proxy/middleware/
rate_limit.rs

1//! Token-bucket rate limiter middleware.
2
3use crate::error::ProxyResult;
4use crate::middleware::{LayerAction, ProxyLayer};
5use crate::types::{ProxyRequest, ProxyResponse};
6use http::StatusCode;
7use std::sync::Arc;
8use std::time::Instant;
9use tokio::sync::Mutex;
10
11/// Token-bucket rate limiter.
12pub struct RateLimitLayer {
13    bucket: Arc<Mutex<TokenBucket>>,
14}
15
16struct TokenBucket {
17    tokens: f64,
18    capacity: f64,
19    refill_rate: f64, // tokens per second
20    last_refill: Instant,
21}
22
23impl TokenBucket {
24    fn new(capacity: f64, refill_rate: f64) -> Self {
25        Self {
26            tokens: capacity,
27            capacity,
28            refill_rate,
29            last_refill: Instant::now(),
30        }
31    }
32
33    fn try_acquire(&mut self) -> bool {
34        let now = Instant::now();
35        let elapsed = now.duration_since(self.last_refill).as_secs_f64();
36        self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.capacity);
37        self.last_refill = now;
38
39        if self.tokens >= 1.0 {
40            self.tokens -= 1.0;
41            true
42        } else {
43            false
44        }
45    }
46}
47
48impl RateLimitLayer {
49    /// Create a rate limiter allowing `capacity` burst requests
50    /// with a sustained rate of `per_second` requests/sec.
51    pub fn new(capacity: f64, per_second: f64) -> Self {
52        Self {
53            bucket: Arc::new(Mutex::new(TokenBucket::new(capacity, per_second))),
54        }
55    }
56}
57
58#[async_trait::async_trait]
59impl ProxyLayer for RateLimitLayer {
60    async fn on_request(&self, request: ProxyRequest) -> ProxyResult<LayerAction> {
61        let mut bucket = self.bucket.lock().await;
62        if bucket.try_acquire() {
63            Ok(LayerAction::Forward(request))
64        } else {
65            tracing::warn!(request_id = %request.id, "rate limited");
66            Ok(LayerAction::Respond(
67                ProxyResponse::for_request(request.id, StatusCode::TOO_MANY_REQUESTS)
68                    .with_body("Rate limit exceeded"),
69            ))
70        }
71    }
72
73    fn name(&self) -> &str {
74        "rate_limit"
75    }
76}
77
78#[cfg(test)]
79mod tests {
80    use super::*;
81    use http::Method;
82
83    fn make_request() -> ProxyRequest {
84        ProxyRequest::new(Method::GET, "/test".parse().unwrap())
85    }
86
87    #[tokio::test]
88    async fn allows_within_capacity() {
89        let limiter = RateLimitLayer::new(3.0, 1.0);
90        // First 3 should pass (burst capacity)
91        for _ in 0..3 {
92            let result = limiter.on_request(make_request()).await.unwrap();
93            assert!(matches!(result, LayerAction::Forward(_)));
94        }
95    }
96
97    #[tokio::test]
98    async fn rejects_over_capacity() {
99        let limiter = RateLimitLayer::new(2.0, 0.0); // 2 burst, no refill
100        // Consume both tokens
101        limiter.on_request(make_request()).await.unwrap();
102        limiter.on_request(make_request()).await.unwrap();
103
104        // Third should be rejected
105        let result = limiter.on_request(make_request()).await.unwrap();
106        match result {
107            LayerAction::Respond(resp) => {
108                assert_eq!(resp.status, StatusCode::TOO_MANY_REQUESTS);
109            }
110            LayerAction::Forward(_) => panic!("should have been rate limited"),
111        }
112    }
113
114    #[tokio::test]
115    async fn refills_over_time() {
116        let limiter = RateLimitLayer::new(1.0, 100.0); // 1 burst, 100/sec refill
117        // Consume token
118        limiter.on_request(make_request()).await.unwrap();
119
120        // Wait for refill
121        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
122
123        // Should have refilled
124        let result = limiter.on_request(make_request()).await.unwrap();
125        assert!(matches!(result, LayerAction::Forward(_)));
126    }
127}