mcp_host/server/
middleware.rs

1//! Middleware system for request processing
2//!
3//! Provides a simple middleware chain for processing requests with rate limiting
4
5use std::sync::{Arc, RwLock};
6
7use throttle_machines::gcra;
8
9use crate::protocol::errors::McpError;
10use crate::server::handler::RequestContext;
11
12/// Middleware function type
13///
14/// Takes a RequestContext and returns a HandlerResult
15pub type MiddlewareFn =
16    Arc<dyn Fn(RequestContext) -> Result<RequestContext, McpError> + Send + Sync>;
17
18/// Middleware chain for processing requests
19#[derive(Clone)]
20pub struct MiddlewareChain {
21    middleware: Vec<MiddlewareFn>,
22}
23
24impl MiddlewareChain {
25    /// Create new empty middleware chain
26    pub fn new() -> Self {
27        Self {
28            middleware: Vec::new(),
29        }
30    }
31
32    /// Add middleware to the chain
33    pub fn add(&mut self, middleware: MiddlewareFn) {
34        self.middleware.push(middleware);
35    }
36
37    /// Process request through middleware chain
38    ///
39    /// Each middleware can:
40    /// - Modify the context and pass it to next middleware
41    /// - Return an error to short-circuit the chain
42    pub fn process(&self, mut ctx: RequestContext) -> Result<RequestContext, McpError> {
43        for middleware in &self.middleware {
44            ctx = middleware(ctx)?;
45        }
46        Ok(ctx)
47    }
48
49    /// Get the number of middleware in the chain
50    pub fn len(&self) -> usize {
51        self.middleware.len()
52    }
53
54    /// Check if chain is empty
55    pub fn is_empty(&self) -> bool {
56        self.middleware.is_empty()
57    }
58}
59
60impl Default for MiddlewareChain {
61    fn default() -> Self {
62        Self::new()
63    }
64}
65
66/// Helper to create logging middleware
67pub fn logging_middleware() -> MiddlewareFn {
68    Arc::new(|ctx: RequestContext| {
69        tracing::debug!(
70            session_id = %ctx.session.id,
71            method = %ctx.request.method,
72            "Processing request"
73        );
74        Ok(ctx)
75    })
76}
77
78/// Helper to create validation middleware
79pub fn validation_middleware() -> MiddlewareFn {
80    Arc::new(|ctx: RequestContext| {
81        // Basic validation - ensure method is not empty
82        if ctx.request.method.is_empty() {
83            return Err(McpError::validation(
84                "empty_method",
85                "Method name cannot be empty",
86            ));
87        }
88        Ok(ctx)
89    })
90}
91
92/// Rate limiter configuration using GCRA algorithm
93#[derive(Debug, Clone)]
94pub struct RateLimiterConfig {
95    /// Requests allowed per second
96    pub requests_per_second: f64,
97    /// Burst capacity (number of requests that can be made instantly)
98    pub burst_capacity: usize,
99}
100
101impl Default for RateLimiterConfig {
102    fn default() -> Self {
103        Self {
104            requests_per_second: 100.0,
105            burst_capacity: 10,
106        }
107    }
108}
109
110impl RateLimiterConfig {
111    /// Create a new rate limiter configuration
112    pub fn new(requests_per_second: f64, burst_capacity: usize) -> Self {
113        Self {
114            requests_per_second,
115            burst_capacity,
116        }
117    }
118
119    /// Calculate GCRA emission interval (time between requests)
120    pub fn emission_interval(&self) -> f64 {
121        1.0 / self.requests_per_second
122    }
123
124    /// Calculate GCRA delay tolerance (for burst support)
125    pub fn delay_tolerance(&self) -> f64 {
126        self.burst_capacity as f64 * self.emission_interval()
127    }
128}
129
130/// Thread-safe rate limiter using GCRA algorithm
131#[derive(Clone)]
132pub struct RateLimiter {
133    config: RateLimiterConfig,
134    /// Theoretical Arrival Time (TAT) - shared state
135    tat: Arc<RwLock<f64>>,
136}
137
138impl RateLimiter {
139    /// Create a new rate limiter with the given configuration
140    pub fn new(config: RateLimiterConfig) -> Self {
141        Self {
142            config,
143            tat: Arc::new(RwLock::new(0.0)),
144        }
145    }
146
147    /// Create rate limiter with defaults (100 req/s, burst of 10)
148    pub fn default_limiter() -> Self {
149        Self::new(RateLimiterConfig::default())
150    }
151
152    /// Check if a request is allowed
153    ///
154    /// Returns Ok(()) if allowed, Err(retry_after_secs) if rate limited
155    pub fn check(&self) -> Result<(), f64> {
156        let now = std::time::SystemTime::now()
157            .duration_since(std::time::UNIX_EPOCH)
158            .map(|d| d.as_secs_f64())
159            .unwrap_or(0.0);
160
161        let emission_interval = self.config.emission_interval();
162        let delay_tolerance = self.config.delay_tolerance();
163
164        let mut tat_guard = self.tat.write().map_err(|_| 1.0)?; // On lock error, ask to retry after 1 second
165
166        let result = gcra::check(*tat_guard, now, emission_interval, delay_tolerance);
167
168        if result.allowed {
169            *tat_guard = result.new_tat;
170            Ok(())
171        } else {
172            Err(result.retry_after)
173        }
174    }
175
176    /// Get current capacity remaining (approximate)
177    pub fn remaining_capacity(&self) -> usize {
178        let now = std::time::SystemTime::now()
179            .duration_since(std::time::UNIX_EPOCH)
180            .map(|d| d.as_secs_f64())
181            .unwrap_or(0.0);
182
183        let delay_tolerance = self.config.delay_tolerance();
184        let emission_interval = self.config.emission_interval();
185
186        let tat = self.tat.read().map(|t| *t).unwrap_or(0.0);
187
188        let result = gcra::peek(tat, now, delay_tolerance);
189
190        if result.allowed {
191            // Estimate remaining based on how much tolerance is left
192            let remaining_tolerance = delay_tolerance - (result.new_tat - now).max(0.0);
193            (remaining_tolerance / emission_interval).floor() as usize + 1
194        } else {
195            0
196        }
197    }
198}
199
200/// Helper to create rate limiter middleware
201///
202/// Creates middleware that enforces rate limits using GCRA algorithm.
203/// Returns `McpError::rate_limited()` when limit is exceeded.
204pub fn rate_limiter_middleware(limiter: Arc<RateLimiter>) -> MiddlewareFn {
205    Arc::new(move |ctx: RequestContext| match limiter.check() {
206        Ok(()) => {
207            tracing::trace!(
208                session_id = %ctx.session.id,
209                method = %ctx.request.method,
210                "Request allowed by rate limiter"
211            );
212            Ok(ctx)
213        }
214        Err(retry_after) => {
215            tracing::warn!(
216                session_id = %ctx.session.id,
217                method = %ctx.request.method,
218                retry_after_secs = %retry_after,
219                "Request rate limited"
220            );
221            Err(McpError::rate_limited(format!(
222                "Rate limit exceeded. Retry after {:.2} seconds",
223                retry_after
224            )))
225        }
226    })
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232    use crate::protocol::types::JsonRpcRequest;
233    use crate::server::session::Session;
234    use serde_json::Value;
235
236    #[test]
237    fn test_empty_chain() {
238        let chain = MiddlewareChain::new();
239        assert!(chain.is_empty());
240        assert_eq!(chain.len(), 0);
241    }
242
243    #[test]
244    fn test_add_middleware() {
245        let mut chain = MiddlewareChain::new();
246
247        let mw = Arc::new(|ctx: RequestContext| Ok(ctx));
248        chain.add(mw);
249
250        assert!(!chain.is_empty());
251        assert_eq!(chain.len(), 1);
252    }
253
254    #[test]
255    fn test_process_middleware() {
256        let mut chain = MiddlewareChain::new();
257
258        // Add middleware that modifies session state
259        let mw = Arc::new(|ctx: RequestContext| {
260            ctx.session.set_state("processed", Value::Bool(true));
261            Ok(ctx)
262        });
263        chain.add(mw);
264
265        let session = Session::new();
266        let request = JsonRpcRequest {
267            jsonrpc: "2.0".to_string(),
268            id: Some(Value::Number(1.into())),
269            method: "test".to_string(),
270            params: None,
271        };
272        let ctx = RequestContext::new(session, request);
273
274        let result = chain.process(ctx).unwrap();
275        assert_eq!(
276            result.session.get_state("processed"),
277            Some(Value::Bool(true))
278        );
279    }
280
281    #[test]
282    fn test_multiple_middleware() {
283        let mut chain = MiddlewareChain::new();
284
285        // Add multiple middleware
286        let mw1 = Arc::new(|ctx: RequestContext| {
287            ctx.session.set_state("step1", Value::Bool(true));
288            Ok(ctx)
289        });
290        let mw2 = Arc::new(|ctx: RequestContext| {
291            ctx.session.set_state("step2", Value::Bool(true));
292            Ok(ctx)
293        });
294        chain.add(mw1);
295        chain.add(mw2);
296
297        let session = Session::new();
298        let request = JsonRpcRequest {
299            jsonrpc: "2.0".to_string(),
300            id: Some(Value::Number(1.into())),
301            method: "test".to_string(),
302            params: None,
303        };
304        let ctx = RequestContext::new(session, request);
305
306        let result = chain.process(ctx).unwrap();
307        assert_eq!(result.session.get_state("step1"), Some(Value::Bool(true)));
308        assert_eq!(result.session.get_state("step2"), Some(Value::Bool(true)));
309    }
310
311    #[test]
312    fn test_middleware_error() {
313        let mut chain = MiddlewareChain::new();
314
315        // Add middleware that returns error
316        let mw =
317            Arc::new(|_ctx: RequestContext| Err(McpError::validation("test_error", "Test error")));
318        chain.add(mw);
319
320        let session = Session::new();
321        let request = JsonRpcRequest {
322            jsonrpc: "2.0".to_string(),
323            id: Some(Value::Number(1.into())),
324            method: "test".to_string(),
325            params: None,
326        };
327        let ctx = RequestContext::new(session, request);
328
329        let result = chain.process(ctx);
330        assert!(result.is_err());
331    }
332
333    #[test]
334    fn test_validation_middleware() {
335        let mw = validation_middleware();
336
337        // Valid request
338        let session = Session::new();
339        let request = JsonRpcRequest {
340            jsonrpc: "2.0".to_string(),
341            id: Some(Value::Number(1.into())),
342            method: "test".to_string(),
343            params: None,
344        };
345        let ctx = RequestContext::new(session, request);
346        assert!(mw(ctx).is_ok());
347
348        // Invalid request (empty method)
349        let session = Session::new();
350        let request = JsonRpcRequest {
351            jsonrpc: "2.0".to_string(),
352            id: Some(Value::Number(1.into())),
353            method: "".to_string(),
354            params: None,
355        };
356        let ctx = RequestContext::new(session, request);
357        assert!(mw(ctx).is_err());
358    }
359}