Skip to main content

fastmcp_server/
middleware.rs

1//! Middleware hooks for request/response interception.
2//!
3//! This provides a minimal, synchronous middleware system for MCP requests.
4//! Middleware can short-circuit requests, transform responses, and rewrite errors.
5//!
6//! # Ordering Semantics
7//!
8//! - `on_request` runs **in registration order** (first registered, first called).
9//! - `on_response` runs **in reverse order** for middleware whose `on_request` ran.
10//! - `on_error` runs **in reverse order** for middleware whose `on_request` ran.
11//!
12//! If a middleware returns `Respond` from `on_request`, the response is still
13//! passed through `on_response` for the already-entered middleware stack.
14//! If any `on_request` or `on_response` returns an error, `on_error` is invoked
15//! for the entered middleware stack to allow error rewriting.
16
17use fastmcp_core::{McpContext, McpError, McpResult};
18use fastmcp_protocol::JsonRpcRequest;
19
20use std::sync::Arc;
21
22/// Result of middleware request interception.
23#[derive(Debug, Clone)]
24pub enum MiddlewareDecision {
25    /// Continue normal dispatch.
26    Continue,
27    /// Short-circuit dispatch and return this JSON value as the result.
28    Respond(serde_json::Value),
29}
30
31/// Middleware hook trait for request/response interception.
32///
33/// This is intentionally minimal: synchronous hooks only, with simple
34/// short-circuit and response transform capabilities. See the module-level
35/// documentation for ordering semantics.
36pub trait Middleware: Send + Sync {
37    /// Invoked before routing the request.
38    ///
39    /// Return `Respond` to skip normal dispatch and return a custom result.
40    fn on_request(
41        &self,
42        _ctx: &McpContext,
43        _request: &JsonRpcRequest,
44    ) -> McpResult<MiddlewareDecision> {
45        Ok(MiddlewareDecision::Continue)
46    }
47
48    /// Invoked after a successful handler result is produced.
49    ///
50    /// Middleware can transform the response value (or return an error).
51    fn on_response(
52        &self,
53        _ctx: &McpContext,
54        _request: &JsonRpcRequest,
55        response: serde_json::Value,
56    ) -> McpResult<serde_json::Value> {
57        Ok(response)
58    }
59
60    /// Invoked when a handler or middleware returns an error.
61    ///
62    /// Middleware may rewrite the error before it is sent to the client.
63    fn on_error(&self, _ctx: &McpContext, _request: &JsonRpcRequest, error: McpError) -> McpError {
64        error
65    }
66}
67
68impl<T> Middleware for Arc<T>
69where
70    T: Middleware + ?Sized,
71{
72    fn on_request(
73        &self,
74        ctx: &McpContext,
75        request: &JsonRpcRequest,
76    ) -> McpResult<MiddlewareDecision> {
77        (**self).on_request(ctx, request)
78    }
79
80    fn on_response(
81        &self,
82        ctx: &McpContext,
83        request: &JsonRpcRequest,
84        response: serde_json::Value,
85    ) -> McpResult<serde_json::Value> {
86        (**self).on_response(ctx, request, response)
87    }
88
89    fn on_error(&self, ctx: &McpContext, request: &JsonRpcRequest, error: McpError) -> McpError {
90        (**self).on_error(ctx, request, error)
91    }
92}
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97    use asupersync::Cx;
98
99    fn make_ctx() -> McpContext {
100        McpContext::new(Cx::for_testing(), 1)
101    }
102
103    fn make_request() -> JsonRpcRequest {
104        JsonRpcRequest::new("tools/call", None, 1i64)
105    }
106
107    // ── MiddlewareDecision ───────────────────────────────────────────
108
109    #[test]
110    fn middleware_decision_continue_debug() {
111        let d = MiddlewareDecision::Continue;
112        let debug = format!("{:?}", d);
113        assert!(debug.contains("Continue"));
114    }
115
116    #[test]
117    fn middleware_decision_respond_debug() {
118        let d = MiddlewareDecision::Respond(serde_json::json!({"ok": true}));
119        let debug = format!("{:?}", d);
120        assert!(debug.contains("Respond"));
121    }
122
123    #[test]
124    fn middleware_decision_clone() {
125        let d = MiddlewareDecision::Respond(serde_json::json!(42));
126        let cloned = d.clone();
127        match cloned {
128            MiddlewareDecision::Respond(v) => assert_eq!(v, 42),
129            _ => panic!("expected Respond"),
130        }
131    }
132
133    // ── Default trait methods ────────────────────────────────────────
134
135    struct NoopMiddleware;
136    impl Middleware for NoopMiddleware {}
137
138    #[test]
139    fn default_on_request_returns_continue() {
140        let mw = NoopMiddleware;
141        let ctx = make_ctx();
142        let req = make_request();
143        let decision = mw.on_request(&ctx, &req).unwrap();
144        matches!(decision, MiddlewareDecision::Continue);
145    }
146
147    #[test]
148    fn default_on_response_passes_through() {
149        let mw = NoopMiddleware;
150        let ctx = make_ctx();
151        let req = make_request();
152        let input = serde_json::json!({"data": "hello"});
153        let output = mw.on_response(&ctx, &req, input.clone()).unwrap();
154        assert_eq!(output, input);
155    }
156
157    #[test]
158    fn default_on_error_passes_through() {
159        let mw = NoopMiddleware;
160        let ctx = make_ctx();
161        let req = make_request();
162        let err = McpError::internal_error("test error");
163        let result = mw.on_error(&ctx, &req, err);
164        assert!(result.message.contains("test error"));
165    }
166
167    // ── Custom middleware ─────────────────────────────────────────────
168
169    struct BlockingMiddleware;
170    impl Middleware for BlockingMiddleware {
171        fn on_request(
172            &self,
173            _ctx: &McpContext,
174            _request: &JsonRpcRequest,
175        ) -> McpResult<MiddlewareDecision> {
176            Ok(MiddlewareDecision::Respond(
177                serde_json::json!({"blocked": true}),
178            ))
179        }
180    }
181
182    #[test]
183    fn custom_on_request_can_short_circuit() {
184        let mw = BlockingMiddleware;
185        let ctx = make_ctx();
186        let req = make_request();
187        let decision = mw.on_request(&ctx, &req).unwrap();
188        match decision {
189            MiddlewareDecision::Respond(v) => assert_eq!(v["blocked"], true),
190            _ => panic!("expected Respond"),
191        }
192    }
193
194    struct ErrorRewritingMiddleware;
195    impl Middleware for ErrorRewritingMiddleware {
196        fn on_error(
197            &self,
198            _ctx: &McpContext,
199            _request: &JsonRpcRequest,
200            _error: McpError,
201        ) -> McpError {
202            McpError::internal_error("rewritten")
203        }
204    }
205
206    #[test]
207    fn custom_on_error_can_rewrite() {
208        let mw = ErrorRewritingMiddleware;
209        let ctx = make_ctx();
210        let req = make_request();
211        let original = McpError::internal_error("original");
212        let rewritten = mw.on_error(&ctx, &req, original);
213        assert!(rewritten.message.contains("rewritten"));
214    }
215
216    // ── Arc delegation ───────────────────────────────────────────────
217
218    #[test]
219    fn arc_middleware_delegates_on_request() {
220        let mw: Arc<dyn Middleware> = Arc::new(BlockingMiddleware);
221        let ctx = make_ctx();
222        let req = make_request();
223        let decision = mw.on_request(&ctx, &req).unwrap();
224        match decision {
225            MiddlewareDecision::Respond(v) => assert_eq!(v["blocked"], true),
226            _ => panic!("expected Respond"),
227        }
228    }
229
230    #[test]
231    fn arc_middleware_delegates_on_response() {
232        let mw: Arc<dyn Middleware> = Arc::new(NoopMiddleware);
233        let ctx = make_ctx();
234        let req = make_request();
235        let input = serde_json::json!("hello");
236        let output = mw.on_response(&ctx, &req, input.clone()).unwrap();
237        assert_eq!(output, input);
238    }
239
240    #[test]
241    fn arc_middleware_delegates_on_error() {
242        let mw: Arc<dyn Middleware> = Arc::new(ErrorRewritingMiddleware);
243        let ctx = make_ctx();
244        let req = make_request();
245        let err = McpError::internal_error("x");
246        let result = mw.on_error(&ctx, &req, err);
247        assert!(result.message.contains("rewritten"));
248    }
249
250    // ── Additional coverage ─────────────────────────────────────────
251
252    struct TransformResponseMiddleware;
253    impl Middleware for TransformResponseMiddleware {
254        fn on_response(
255            &self,
256            _ctx: &McpContext,
257            _request: &JsonRpcRequest,
258            mut response: serde_json::Value,
259        ) -> McpResult<serde_json::Value> {
260            response["transformed"] = serde_json::json!(true);
261            Ok(response)
262        }
263    }
264
265    #[test]
266    fn custom_on_response_can_transform() {
267        let mw = TransformResponseMiddleware;
268        let ctx = make_ctx();
269        let req = make_request();
270        let input = serde_json::json!({"data": 1});
271        let output = mw.on_response(&ctx, &req, input).unwrap();
272        assert_eq!(output["data"], 1);
273        assert_eq!(output["transformed"], true);
274    }
275
276    #[test]
277    fn on_request_can_return_error() {
278        struct RejectMiddleware;
279        impl Middleware for RejectMiddleware {
280            fn on_request(
281                &self,
282                _ctx: &McpContext,
283                _request: &JsonRpcRequest,
284            ) -> McpResult<MiddlewareDecision> {
285                Err(McpError::internal_error("rejected"))
286            }
287        }
288
289        let mw = RejectMiddleware;
290        let ctx = make_ctx();
291        let req = make_request();
292        let err = mw.on_request(&ctx, &req).unwrap_err();
293        assert!(err.message.contains("rejected"));
294    }
295
296    #[test]
297    fn on_response_can_return_error() {
298        struct FailResponseMiddleware;
299        impl Middleware for FailResponseMiddleware {
300            fn on_response(
301                &self,
302                _ctx: &McpContext,
303                _request: &JsonRpcRequest,
304                _response: serde_json::Value,
305            ) -> McpResult<serde_json::Value> {
306                Err(McpError::internal_error("response-fail"))
307            }
308        }
309
310        let mw = FailResponseMiddleware;
311        let ctx = make_ctx();
312        let req = make_request();
313        let err = mw
314            .on_response(&ctx, &req, serde_json::json!({}))
315            .unwrap_err();
316        assert!(err.message.contains("response-fail"));
317    }
318
319    #[test]
320    fn middleware_decision_continue_clone() {
321        let d = MiddlewareDecision::Continue;
322        let cloned = d.clone();
323        assert!(matches!(cloned, MiddlewareDecision::Continue));
324    }
325
326    #[test]
327    fn arc_middleware_delegates_transforming_on_response() {
328        let mw: Arc<dyn Middleware> = Arc::new(TransformResponseMiddleware);
329        let ctx = make_ctx();
330        let req = make_request();
331        let input = serde_json::json!({"x": 2});
332        let output = mw.on_response(&ctx, &req, input).unwrap();
333        assert_eq!(output["x"], 2);
334        assert_eq!(output["transformed"], true);
335    }
336}