mcp_host/server/
middleware.rs1use std::sync::{Arc, RwLock};
6
7use throttle_machines::gcra;
8
9use crate::protocol::errors::McpError;
10use crate::server::handler::RequestContext;
11
12pub type MiddlewareFn =
16 Arc<dyn Fn(RequestContext) -> Result<RequestContext, McpError> + Send + Sync>;
17
18#[derive(Clone)]
20pub struct MiddlewareChain {
21 middleware: Vec<MiddlewareFn>,
22}
23
24impl MiddlewareChain {
25 pub fn new() -> Self {
27 Self {
28 middleware: Vec::new(),
29 }
30 }
31
32 pub fn add(&mut self, middleware: MiddlewareFn) {
34 self.middleware.push(middleware);
35 }
36
37 pub fn process(&self, mut ctx: RequestContext) -> Result<RequestContext, McpError> {
43 for middleware in &self.middleware {
44 ctx = middleware(ctx)?;
45 }
46 Ok(ctx)
47 }
48
49 pub fn len(&self) -> usize {
51 self.middleware.len()
52 }
53
54 pub fn is_empty(&self) -> bool {
56 self.middleware.is_empty()
57 }
58}
59
60impl Default for MiddlewareChain {
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66pub fn logging_middleware() -> MiddlewareFn {
68 Arc::new(|ctx: RequestContext| {
69 tracing::debug!(
70 session_id = %ctx.session.id,
71 method = %ctx.request.method,
72 "Processing request"
73 );
74 Ok(ctx)
75 })
76}
77
78pub fn validation_middleware() -> MiddlewareFn {
80 Arc::new(|ctx: RequestContext| {
81 if ctx.request.method.is_empty() {
83 return Err(McpError::validation(
84 "empty_method",
85 "Method name cannot be empty",
86 ));
87 }
88 Ok(ctx)
89 })
90}
91
92#[derive(Debug, Clone)]
94pub struct RateLimiterConfig {
95 pub requests_per_second: f64,
97 pub burst_capacity: usize,
99}
100
101impl Default for RateLimiterConfig {
102 fn default() -> Self {
103 Self {
104 requests_per_second: 100.0,
105 burst_capacity: 10,
106 }
107 }
108}
109
110impl RateLimiterConfig {
111 pub fn new(requests_per_second: f64, burst_capacity: usize) -> Self {
113 Self {
114 requests_per_second,
115 burst_capacity,
116 }
117 }
118
119 pub fn emission_interval(&self) -> f64 {
121 1.0 / self.requests_per_second
122 }
123
124 pub fn delay_tolerance(&self) -> f64 {
126 self.burst_capacity as f64 * self.emission_interval()
127 }
128}
129
130#[derive(Clone)]
132pub struct RateLimiter {
133 config: RateLimiterConfig,
134 tat: Arc<RwLock<f64>>,
136}
137
138impl RateLimiter {
139 pub fn new(config: RateLimiterConfig) -> Self {
141 Self {
142 config,
143 tat: Arc::new(RwLock::new(0.0)),
144 }
145 }
146
147 pub fn default_limiter() -> Self {
149 Self::new(RateLimiterConfig::default())
150 }
151
152 pub fn check(&self) -> Result<(), f64> {
156 let now = std::time::SystemTime::now()
157 .duration_since(std::time::UNIX_EPOCH)
158 .map(|d| d.as_secs_f64())
159 .unwrap_or(0.0);
160
161 let emission_interval = self.config.emission_interval();
162 let delay_tolerance = self.config.delay_tolerance();
163
164 let mut tat_guard = self.tat.write().map_err(|_| 1.0)?; let result = gcra::check(*tat_guard, now, emission_interval, delay_tolerance);
167
168 if result.allowed {
169 *tat_guard = result.new_tat;
170 Ok(())
171 } else {
172 Err(result.retry_after)
173 }
174 }
175
176 pub fn remaining_capacity(&self) -> usize {
178 let now = std::time::SystemTime::now()
179 .duration_since(std::time::UNIX_EPOCH)
180 .map(|d| d.as_secs_f64())
181 .unwrap_or(0.0);
182
183 let delay_tolerance = self.config.delay_tolerance();
184 let emission_interval = self.config.emission_interval();
185
186 let tat = self.tat.read().map(|t| *t).unwrap_or(0.0);
187
188 let result = gcra::peek(tat, now, delay_tolerance);
189
190 if result.allowed {
191 let remaining_tolerance = delay_tolerance - (result.new_tat - now).max(0.0);
193 (remaining_tolerance / emission_interval).floor() as usize + 1
194 } else {
195 0
196 }
197 }
198}
199
200pub fn rate_limiter_middleware(limiter: Arc<RateLimiter>) -> MiddlewareFn {
205 Arc::new(move |ctx: RequestContext| match limiter.check() {
206 Ok(()) => {
207 tracing::trace!(
208 session_id = %ctx.session.id,
209 method = %ctx.request.method,
210 "Request allowed by rate limiter"
211 );
212 Ok(ctx)
213 }
214 Err(retry_after) => {
215 tracing::warn!(
216 session_id = %ctx.session.id,
217 method = %ctx.request.method,
218 retry_after_secs = %retry_after,
219 "Request rate limited"
220 );
221 Err(McpError::rate_limited(format!(
222 "Rate limit exceeded. Retry after {:.2} seconds",
223 retry_after
224 )))
225 }
226 })
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232 use crate::protocol::types::JsonRpcRequest;
233 use crate::server::session::Session;
234 use serde_json::Value;
235
236 #[test]
237 fn test_empty_chain() {
238 let chain = MiddlewareChain::new();
239 assert!(chain.is_empty());
240 assert_eq!(chain.len(), 0);
241 }
242
243 #[test]
244 fn test_add_middleware() {
245 let mut chain = MiddlewareChain::new();
246
247 let mw = Arc::new(|ctx: RequestContext| Ok(ctx));
248 chain.add(mw);
249
250 assert!(!chain.is_empty());
251 assert_eq!(chain.len(), 1);
252 }
253
254 #[test]
255 fn test_process_middleware() {
256 let mut chain = MiddlewareChain::new();
257
258 let mw = Arc::new(|ctx: RequestContext| {
260 ctx.session.set_state("processed", Value::Bool(true));
261 Ok(ctx)
262 });
263 chain.add(mw);
264
265 let session = Session::new();
266 let request = JsonRpcRequest {
267 jsonrpc: "2.0".to_string(),
268 id: Some(Value::Number(1.into())),
269 method: "test".to_string(),
270 params: None,
271 };
272 let ctx = RequestContext::new(session, request);
273
274 let result = chain.process(ctx).unwrap();
275 assert_eq!(
276 result.session.get_state("processed"),
277 Some(Value::Bool(true))
278 );
279 }
280
281 #[test]
282 fn test_multiple_middleware() {
283 let mut chain = MiddlewareChain::new();
284
285 let mw1 = Arc::new(|ctx: RequestContext| {
287 ctx.session.set_state("step1", Value::Bool(true));
288 Ok(ctx)
289 });
290 let mw2 = Arc::new(|ctx: RequestContext| {
291 ctx.session.set_state("step2", Value::Bool(true));
292 Ok(ctx)
293 });
294 chain.add(mw1);
295 chain.add(mw2);
296
297 let session = Session::new();
298 let request = JsonRpcRequest {
299 jsonrpc: "2.0".to_string(),
300 id: Some(Value::Number(1.into())),
301 method: "test".to_string(),
302 params: None,
303 };
304 let ctx = RequestContext::new(session, request);
305
306 let result = chain.process(ctx).unwrap();
307 assert_eq!(result.session.get_state("step1"), Some(Value::Bool(true)));
308 assert_eq!(result.session.get_state("step2"), Some(Value::Bool(true)));
309 }
310
311 #[test]
312 fn test_middleware_error() {
313 let mut chain = MiddlewareChain::new();
314
315 let mw =
317 Arc::new(|_ctx: RequestContext| Err(McpError::validation("test_error", "Test error")));
318 chain.add(mw);
319
320 let session = Session::new();
321 let request = JsonRpcRequest {
322 jsonrpc: "2.0".to_string(),
323 id: Some(Value::Number(1.into())),
324 method: "test".to_string(),
325 params: None,
326 };
327 let ctx = RequestContext::new(session, request);
328
329 let result = chain.process(ctx);
330 assert!(result.is_err());
331 }
332
333 #[test]
334 fn test_validation_middleware() {
335 let mw = validation_middleware();
336
337 let session = Session::new();
339 let request = JsonRpcRequest {
340 jsonrpc: "2.0".to_string(),
341 id: Some(Value::Number(1.into())),
342 method: "test".to_string(),
343 params: None,
344 };
345 let ctx = RequestContext::new(session, request);
346 assert!(mw(ctx).is_ok());
347
348 let session = Session::new();
350 let request = JsonRpcRequest {
351 jsonrpc: "2.0".to_string(),
352 id: Some(Value::Number(1.into())),
353 method: "".to_string(),
354 params: None,
355 };
356 let ctx = RequestContext::new(session, request);
357 assert!(mw(ctx).is_err());
358 }
359}