brainwires_proxy/middleware/
rate_limit.rs1use crate::error::ProxyResult;
4use crate::middleware::{LayerAction, ProxyLayer};
5use crate::types::{ProxyRequest, ProxyResponse};
6use http::StatusCode;
7use std::sync::Arc;
8use std::time::Instant;
9use tokio::sync::Mutex;
10
11pub struct RateLimitLayer {
13 bucket: Arc<Mutex<TokenBucket>>,
14}
15
16struct TokenBucket {
17 tokens: f64,
18 capacity: f64,
19 refill_rate: f64, last_refill: Instant,
21}
22
23impl TokenBucket {
24 fn new(capacity: f64, refill_rate: f64) -> Self {
25 Self {
26 tokens: capacity,
27 capacity,
28 refill_rate,
29 last_refill: Instant::now(),
30 }
31 }
32
33 fn try_acquire(&mut self) -> bool {
34 let now = Instant::now();
35 let elapsed = now.duration_since(self.last_refill).as_secs_f64();
36 self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.capacity);
37 self.last_refill = now;
38
39 if self.tokens >= 1.0 {
40 self.tokens -= 1.0;
41 true
42 } else {
43 false
44 }
45 }
46}
47
48impl RateLimitLayer {
49 pub fn new(capacity: f64, per_second: f64) -> Self {
52 Self {
53 bucket: Arc::new(Mutex::new(TokenBucket::new(capacity, per_second))),
54 }
55 }
56}
57
58#[async_trait::async_trait]
59impl ProxyLayer for RateLimitLayer {
60 async fn on_request(&self, request: ProxyRequest) -> ProxyResult<LayerAction> {
61 let mut bucket = self.bucket.lock().await;
62 if bucket.try_acquire() {
63 Ok(LayerAction::Forward(request))
64 } else {
65 tracing::warn!(request_id = %request.id, "rate limited");
66 Ok(LayerAction::Respond(
67 ProxyResponse::for_request(request.id, StatusCode::TOO_MANY_REQUESTS)
68 .with_body("Rate limit exceeded"),
69 ))
70 }
71 }
72
73 fn name(&self) -> &str {
74 "rate_limit"
75 }
76}
77
78#[cfg(test)]
79mod tests {
80 use super::*;
81 use http::Method;
82
83 fn make_request() -> ProxyRequest {
84 ProxyRequest::new(Method::GET, "/test".parse().unwrap())
85 }
86
87 #[tokio::test]
88 async fn allows_within_capacity() {
89 let limiter = RateLimitLayer::new(3.0, 1.0);
90 for _ in 0..3 {
92 let result = limiter.on_request(make_request()).await.unwrap();
93 assert!(matches!(result, LayerAction::Forward(_)));
94 }
95 }
96
97 #[tokio::test]
98 async fn rejects_over_capacity() {
99 let limiter = RateLimitLayer::new(2.0, 0.0); limiter.on_request(make_request()).await.unwrap();
102 limiter.on_request(make_request()).await.unwrap();
103
104 let result = limiter.on_request(make_request()).await.unwrap();
106 match result {
107 LayerAction::Respond(resp) => {
108 assert_eq!(resp.status, StatusCode::TOO_MANY_REQUESTS);
109 }
110 LayerAction::Forward(_) => panic!("should have been rate limited"),
111 }
112 }
113
114 #[tokio::test]
115 async fn refills_over_time() {
116 let limiter = RateLimitLayer::new(1.0, 100.0); limiter.on_request(make_request()).await.unwrap();
119
120 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
122
123 let result = limiter.on_request(make_request()).await.unwrap();
125 assert!(matches!(result, LayerAction::Forward(_)));
126 }
127}