Skip to main content

brainwires_agent_network/middleware/
rate_limit.rs

1use async_trait::async_trait;
2use brainwires_mcp::{JsonRpcError, JsonRpcRequest};
3use std::collections::HashMap;
4use std::sync::Arc;
5use tokio::sync::Mutex;
6use tokio::time::Instant;
7
8use super::{Middleware, MiddlewareResult};
9use crate::connection::RequestContext;
10
11struct RateLimitBucket {
12    tokens: f64,
13    last_refill: Instant,
14}
15
16/// Token-bucket rate limiting middleware.
17pub struct RateLimitMiddleware {
18    max_requests_per_second: f64,
19    per_tool_limits: HashMap<String, f64>,
20    buckets: Arc<Mutex<HashMap<String, RateLimitBucket>>>,
21}
22
23impl RateLimitMiddleware {
24    /// Create a new rate limiter with a global requests-per-second limit.
25    pub fn new(max_requests_per_second: f64) -> Self {
26        Self {
27            max_requests_per_second,
28            per_tool_limits: HashMap::new(),
29            buckets: Arc::new(Mutex::new(HashMap::new())),
30        }
31    }
32
33    /// Set a per-tool rate limit override.
34    pub fn with_tool_limit(mut self, tool_name: &str, limit: f64) -> Self {
35        self.per_tool_limits.insert(tool_name.to_string(), limit);
36        self
37    }
38
39    fn get_limit(&self, key: &str) -> f64 {
40        self.per_tool_limits
41            .get(key)
42            .copied()
43            .unwrap_or(self.max_requests_per_second)
44    }
45}
46
47#[async_trait]
48impl Middleware for RateLimitMiddleware {
49    async fn process_request(
50        &self,
51        request: &JsonRpcRequest,
52        _ctx: &mut RequestContext,
53    ) -> MiddlewareResult {
54        // Only rate-limit tools/call
55        if request.method != "tools/call" {
56            return MiddlewareResult::Continue;
57        }
58
59        let tool_name = request
60            .params
61            .as_ref()
62            .and_then(|p| p.get("name"))
63            .and_then(|n| n.as_str())
64            .unwrap_or("unknown");
65
66        let limit = self.get_limit(tool_name);
67        let key = format!("tool:{tool_name}");
68
69        let mut buckets = self.buckets.lock().await;
70        let bucket = buckets.entry(key).or_insert(RateLimitBucket {
71            tokens: limit,
72            last_refill: Instant::now(),
73        });
74
75        // Token bucket refill
76        let now = Instant::now();
77        let elapsed = now.duration_since(bucket.last_refill).as_secs_f64();
78        bucket.tokens = (bucket.tokens + elapsed * limit).min(limit);
79        bucket.last_refill = now;
80
81        if bucket.tokens >= 1.0 {
82            bucket.tokens -= 1.0;
83            MiddlewareResult::Continue
84        } else {
85            MiddlewareResult::Reject(JsonRpcError {
86                code: -32002,
87                message: format!("Rate limited: too many requests for tool '{tool_name}'"),
88                data: None,
89            })
90        }
91    }
92}