Skip to main content

fastmcp_server/
middleware.rs

1//! Middleware hooks for request/response interception.
2//!
3//! This provides a minimal, synchronous middleware system for MCP requests.
4//! Middleware can short-circuit requests, transform responses, and rewrite errors.
5//!
6//! # Ordering Semantics
7//!
8//! - `on_request` runs **in registration order** (first registered, first called).
9//! - `on_response` runs **in reverse order** for middleware whose `on_request` ran.
10//! - `on_error` runs **in reverse order** for middleware whose `on_request` ran.
11//! - Server-generated failures before `on_request` entry (for example auth errors)
12//!   may still be passed through `on_error` for the registered middleware stack.
13//!
14//! If a middleware returns `Respond` from `on_request`, the response is still
15//! passed through `on_response` for the already-entered middleware stack.
16//! If any `on_request` or `on_response` returns an error, `on_error` is invoked
17//! for the entered middleware stack to allow error rewriting.
18
19use fastmcp_core::{McpContext, McpError, McpResult};
20use fastmcp_protocol::JsonRpcRequest;
21
22use std::sync::Arc;
23
24/// Result of middleware request interception.
25#[derive(Debug, Clone)]
26pub enum MiddlewareDecision {
27    /// Continue normal dispatch.
28    Continue,
29    /// Short-circuit dispatch and return this JSON value as the result.
30    Respond(serde_json::Value),
31}
32
33/// Middleware hook trait for request/response interception.
34///
35/// This is intentionally minimal: synchronous hooks only, with simple
36/// short-circuit and response transform capabilities. See the module-level
37/// documentation for ordering semantics.
38pub trait Middleware: Send + Sync {
39    /// Invoked before routing the request.
40    ///
41    /// Return `Respond` to skip normal dispatch and return a custom result.
42    fn on_request(
43        &self,
44        _ctx: &McpContext,
45        _request: &JsonRpcRequest,
46    ) -> McpResult<MiddlewareDecision> {
47        Ok(MiddlewareDecision::Continue)
48    }
49
50    /// Invoked after a successful handler result is produced.
51    ///
52    /// Middleware can transform the response value (or return an error).
53    fn on_response(
54        &self,
55        _ctx: &McpContext,
56        _request: &JsonRpcRequest,
57        response: serde_json::Value,
58    ) -> McpResult<serde_json::Value> {
59        Ok(response)
60    }
61
62    /// Invoked when a handler or middleware returns an error.
63    ///
64    /// Middleware may rewrite the error before it is sent to the client.
65    fn on_error(&self, _ctx: &McpContext, _request: &JsonRpcRequest, error: McpError) -> McpError {
66        error
67    }
68}
69
70impl<T> Middleware for Arc<T>
71where
72    T: Middleware + ?Sized,
73{
74    fn on_request(
75        &self,
76        ctx: &McpContext,
77        request: &JsonRpcRequest,
78    ) -> McpResult<MiddlewareDecision> {
79        (**self).on_request(ctx, request)
80    }
81
82    fn on_response(
83        &self,
84        ctx: &McpContext,
85        request: &JsonRpcRequest,
86        response: serde_json::Value,
87    ) -> McpResult<serde_json::Value> {
88        (**self).on_response(ctx, request, response)
89    }
90
91    fn on_error(&self, ctx: &McpContext, request: &JsonRpcRequest, error: McpError) -> McpError {
92        (**self).on_error(ctx, request, error)
93    }
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99    use asupersync::Cx;
100
101    fn make_ctx() -> McpContext {
102        McpContext::new(Cx::for_testing(), 1)
103    }
104
105    fn make_request() -> JsonRpcRequest {
106        JsonRpcRequest::new("tools/call", None, 1i64)
107    }
108
109    // ── MiddlewareDecision ───────────────────────────────────────────
110
111    #[test]
112    fn middleware_decision_continue_debug() {
113        let d = MiddlewareDecision::Continue;
114        let debug = format!("{:?}", d);
115        assert!(debug.contains("Continue"));
116    }
117
118    #[test]
119    fn middleware_decision_respond_debug() {
120        let d = MiddlewareDecision::Respond(serde_json::json!({"ok": true}));
121        let debug = format!("{:?}", d);
122        assert!(debug.contains("Respond"));
123    }
124
125    #[test]
126    fn middleware_decision_clone() {
127        let d = MiddlewareDecision::Respond(serde_json::json!(42));
128        let cloned = d.clone();
129        match cloned {
130            MiddlewareDecision::Respond(v) => assert_eq!(v, 42),
131            _ => panic!("expected Respond"),
132        }
133    }
134
135    // ── Default trait methods ────────────────────────────────────────
136
137    struct NoopMiddleware;
138    impl Middleware for NoopMiddleware {}
139
140    #[test]
141    fn default_on_request_returns_continue() {
142        let mw = NoopMiddleware;
143        let ctx = make_ctx();
144        let req = make_request();
145        let decision = mw.on_request(&ctx, &req).unwrap();
146        matches!(decision, MiddlewareDecision::Continue);
147    }
148
149    #[test]
150    fn default_on_response_passes_through() {
151        let mw = NoopMiddleware;
152        let ctx = make_ctx();
153        let req = make_request();
154        let input = serde_json::json!({"data": "hello"});
155        let output = mw.on_response(&ctx, &req, input.clone()).unwrap();
156        assert_eq!(output, input);
157    }
158
159    #[test]
160    fn default_on_error_passes_through() {
161        let mw = NoopMiddleware;
162        let ctx = make_ctx();
163        let req = make_request();
164        let err = McpError::internal_error("test error");
165        let result = mw.on_error(&ctx, &req, err);
166        assert!(result.message.contains("test error"));
167    }
168
169    // ── Custom middleware ─────────────────────────────────────────────
170
171    struct BlockingMiddleware;
172    impl Middleware for BlockingMiddleware {
173        fn on_request(
174            &self,
175            _ctx: &McpContext,
176            _request: &JsonRpcRequest,
177        ) -> McpResult<MiddlewareDecision> {
178            Ok(MiddlewareDecision::Respond(
179                serde_json::json!({"blocked": true}),
180            ))
181        }
182    }
183
184    #[test]
185    fn custom_on_request_can_short_circuit() {
186        let mw = BlockingMiddleware;
187        let ctx = make_ctx();
188        let req = make_request();
189        let decision = mw.on_request(&ctx, &req).unwrap();
190        match decision {
191            MiddlewareDecision::Respond(v) => assert_eq!(v["blocked"], true),
192            _ => panic!("expected Respond"),
193        }
194    }
195
196    struct ErrorRewritingMiddleware;
197    impl Middleware for ErrorRewritingMiddleware {
198        fn on_error(
199            &self,
200            _ctx: &McpContext,
201            _request: &JsonRpcRequest,
202            _error: McpError,
203        ) -> McpError {
204            McpError::internal_error("rewritten")
205        }
206    }
207
208    #[test]
209    fn custom_on_error_can_rewrite() {
210        let mw = ErrorRewritingMiddleware;
211        let ctx = make_ctx();
212        let req = make_request();
213        let original = McpError::internal_error("original");
214        let rewritten = mw.on_error(&ctx, &req, original);
215        assert!(rewritten.message.contains("rewritten"));
216    }
217
218    // ── Arc delegation ───────────────────────────────────────────────
219
220    #[test]
221    fn arc_middleware_delegates_on_request() {
222        let mw: Arc<dyn Middleware> = Arc::new(BlockingMiddleware);
223        let ctx = make_ctx();
224        let req = make_request();
225        let decision = mw.on_request(&ctx, &req).unwrap();
226        match decision {
227            MiddlewareDecision::Respond(v) => assert_eq!(v["blocked"], true),
228            _ => panic!("expected Respond"),
229        }
230    }
231
232    #[test]
233    fn arc_middleware_delegates_on_response() {
234        let mw: Arc<dyn Middleware> = Arc::new(NoopMiddleware);
235        let ctx = make_ctx();
236        let req = make_request();
237        let input = serde_json::json!("hello");
238        let output = mw.on_response(&ctx, &req, input.clone()).unwrap();
239        assert_eq!(output, input);
240    }
241
242    #[test]
243    fn arc_middleware_delegates_on_error() {
244        let mw: Arc<dyn Middleware> = Arc::new(ErrorRewritingMiddleware);
245        let ctx = make_ctx();
246        let req = make_request();
247        let err = McpError::internal_error("x");
248        let result = mw.on_error(&ctx, &req, err);
249        assert!(result.message.contains("rewritten"));
250    }
251
252    // ── Additional coverage ─────────────────────────────────────────
253
254    struct TransformResponseMiddleware;
255    impl Middleware for TransformResponseMiddleware {
256        fn on_response(
257            &self,
258            _ctx: &McpContext,
259            _request: &JsonRpcRequest,
260            mut response: serde_json::Value,
261        ) -> McpResult<serde_json::Value> {
262            response["transformed"] = serde_json::json!(true);
263            Ok(response)
264        }
265    }
266
267    #[test]
268    fn custom_on_response_can_transform() {
269        let mw = TransformResponseMiddleware;
270        let ctx = make_ctx();
271        let req = make_request();
272        let input = serde_json::json!({"data": 1});
273        let output = mw.on_response(&ctx, &req, input).unwrap();
274        assert_eq!(output["data"], 1);
275        assert_eq!(output["transformed"], true);
276    }
277
278    #[test]
279    fn on_request_can_return_error() {
280        struct RejectMiddleware;
281        impl Middleware for RejectMiddleware {
282            fn on_request(
283                &self,
284                _ctx: &McpContext,
285                _request: &JsonRpcRequest,
286            ) -> McpResult<MiddlewareDecision> {
287                Err(McpError::internal_error("rejected"))
288            }
289        }
290
291        let mw = RejectMiddleware;
292        let ctx = make_ctx();
293        let req = make_request();
294        let err = mw.on_request(&ctx, &req).unwrap_err();
295        assert!(err.message.contains("rejected"));
296    }
297
298    #[test]
299    fn on_response_can_return_error() {
300        struct FailResponseMiddleware;
301        impl Middleware for FailResponseMiddleware {
302            fn on_response(
303                &self,
304                _ctx: &McpContext,
305                _request: &JsonRpcRequest,
306                _response: serde_json::Value,
307            ) -> McpResult<serde_json::Value> {
308                Err(McpError::internal_error("response-fail"))
309            }
310        }
311
312        let mw = FailResponseMiddleware;
313        let ctx = make_ctx();
314        let req = make_request();
315        let err = mw
316            .on_response(&ctx, &req, serde_json::json!({}))
317            .unwrap_err();
318        assert!(err.message.contains("response-fail"));
319    }
320
321    #[test]
322    fn middleware_decision_continue_clone() {
323        let d = MiddlewareDecision::Continue;
324        let cloned = d.clone();
325        assert!(matches!(cloned, MiddlewareDecision::Continue));
326    }
327
328    #[test]
329    fn arc_middleware_delegates_transforming_on_response() {
330        let mw: Arc<dyn Middleware> = Arc::new(TransformResponseMiddleware);
331        let ctx = make_ctx();
332        let req = make_request();
333        let input = serde_json::json!({"x": 2});
334        let output = mw.on_response(&ctx, &req, input).unwrap();
335        assert_eq!(output["x"], 2);
336        assert_eq!(output["transformed"], true);
337    }
338}