brainwires_agent_network/middleware/
rate_limit.rs1use async_trait::async_trait;
2use brainwires_mcp::{JsonRpcError, JsonRpcRequest};
3use std::collections::HashMap;
4use std::sync::Arc;
5use tokio::sync::Mutex;
6use tokio::time::Instant;
7
8use super::{Middleware, MiddlewareResult};
9use crate::connection::RequestContext;
10
11struct RateLimitBucket {
12 tokens: f64,
13 last_refill: Instant,
14}
15
16pub struct RateLimitMiddleware {
18 max_requests_per_second: f64,
19 per_tool_limits: HashMap<String, f64>,
20 buckets: Arc<Mutex<HashMap<String, RateLimitBucket>>>,
21}
22
23impl RateLimitMiddleware {
24 pub fn new(max_requests_per_second: f64) -> Self {
26 Self {
27 max_requests_per_second,
28 per_tool_limits: HashMap::new(),
29 buckets: Arc::new(Mutex::new(HashMap::new())),
30 }
31 }
32
33 pub fn with_tool_limit(mut self, tool_name: &str, limit: f64) -> Self {
35 self.per_tool_limits.insert(tool_name.to_string(), limit);
36 self
37 }
38
39 fn get_limit(&self, key: &str) -> f64 {
40 self.per_tool_limits
41 .get(key)
42 .copied()
43 .unwrap_or(self.max_requests_per_second)
44 }
45}
46
47#[async_trait]
48impl Middleware for RateLimitMiddleware {
49 async fn process_request(
50 &self,
51 request: &JsonRpcRequest,
52 _ctx: &mut RequestContext,
53 ) -> MiddlewareResult {
54 if request.method != "tools/call" {
56 return MiddlewareResult::Continue;
57 }
58
59 let tool_name = request
60 .params
61 .as_ref()
62 .and_then(|p| p.get("name"))
63 .and_then(|n| n.as_str())
64 .unwrap_or("unknown");
65
66 let limit = self.get_limit(tool_name);
67 let key = format!("tool:{tool_name}");
68
69 let mut buckets = self.buckets.lock().await;
70 let bucket = buckets.entry(key).or_insert(RateLimitBucket {
71 tokens: limit,
72 last_refill: Instant::now(),
73 });
74
75 let now = Instant::now();
77 let elapsed = now.duration_since(bucket.last_refill).as_secs_f64();
78 bucket.tokens = (bucket.tokens + elapsed * limit).min(limit);
79 bucket.last_refill = now;
80
81 if bucket.tokens >= 1.0 {
82 bucket.tokens -= 1.0;
83 MiddlewareResult::Continue
84 } else {
85 MiddlewareResult::Reject(JsonRpcError {
86 code: -32002,
87 message: format!("Rate limited: too many requests for tool '{tool_name}'"),
88 data: None,
89 })
90 }
91 }
92}