1use fastmcp_core::{McpContext, McpError, McpResult};
20use fastmcp_protocol::JsonRpcRequest;
21
22use std::sync::Arc;
23
24#[derive(Debug, Clone)]
26pub enum MiddlewareDecision {
27 Continue,
29 Respond(serde_json::Value),
31}
32
33pub trait Middleware: Send + Sync {
39 fn on_request(
43 &self,
44 _ctx: &McpContext,
45 _request: &JsonRpcRequest,
46 ) -> McpResult<MiddlewareDecision> {
47 Ok(MiddlewareDecision::Continue)
48 }
49
50 fn on_response(
54 &self,
55 _ctx: &McpContext,
56 _request: &JsonRpcRequest,
57 response: serde_json::Value,
58 ) -> McpResult<serde_json::Value> {
59 Ok(response)
60 }
61
62 fn on_error(&self, _ctx: &McpContext, _request: &JsonRpcRequest, error: McpError) -> McpError {
66 error
67 }
68}
69
70impl<T> Middleware for Arc<T>
71where
72 T: Middleware + ?Sized,
73{
74 fn on_request(
75 &self,
76 ctx: &McpContext,
77 request: &JsonRpcRequest,
78 ) -> McpResult<MiddlewareDecision> {
79 (**self).on_request(ctx, request)
80 }
81
82 fn on_response(
83 &self,
84 ctx: &McpContext,
85 request: &JsonRpcRequest,
86 response: serde_json::Value,
87 ) -> McpResult<serde_json::Value> {
88 (**self).on_response(ctx, request, response)
89 }
90
91 fn on_error(&self, ctx: &McpContext, request: &JsonRpcRequest, error: McpError) -> McpError {
92 (**self).on_error(ctx, request, error)
93 }
94}
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99 use asupersync::Cx;
100
101 fn make_ctx() -> McpContext {
102 McpContext::new(Cx::for_testing(), 1)
103 }
104
105 fn make_request() -> JsonRpcRequest {
106 JsonRpcRequest::new("tools/call", None, 1i64)
107 }
108
109 #[test]
112 fn middleware_decision_continue_debug() {
113 let d = MiddlewareDecision::Continue;
114 let debug = format!("{:?}", d);
115 assert!(debug.contains("Continue"));
116 }
117
118 #[test]
119 fn middleware_decision_respond_debug() {
120 let d = MiddlewareDecision::Respond(serde_json::json!({"ok": true}));
121 let debug = format!("{:?}", d);
122 assert!(debug.contains("Respond"));
123 }
124
125 #[test]
126 fn middleware_decision_clone() {
127 let d = MiddlewareDecision::Respond(serde_json::json!(42));
128 let cloned = d.clone();
129 match cloned {
130 MiddlewareDecision::Respond(v) => assert_eq!(v, 42),
131 _ => panic!("expected Respond"),
132 }
133 }
134
135 struct NoopMiddleware;
138 impl Middleware for NoopMiddleware {}
139
140 #[test]
141 fn default_on_request_returns_continue() {
142 let mw = NoopMiddleware;
143 let ctx = make_ctx();
144 let req = make_request();
145 let decision = mw.on_request(&ctx, &req).unwrap();
146 matches!(decision, MiddlewareDecision::Continue);
147 }
148
149 #[test]
150 fn default_on_response_passes_through() {
151 let mw = NoopMiddleware;
152 let ctx = make_ctx();
153 let req = make_request();
154 let input = serde_json::json!({"data": "hello"});
155 let output = mw.on_response(&ctx, &req, input.clone()).unwrap();
156 assert_eq!(output, input);
157 }
158
159 #[test]
160 fn default_on_error_passes_through() {
161 let mw = NoopMiddleware;
162 let ctx = make_ctx();
163 let req = make_request();
164 let err = McpError::internal_error("test error");
165 let result = mw.on_error(&ctx, &req, err);
166 assert!(result.message.contains("test error"));
167 }
168
169 struct BlockingMiddleware;
172 impl Middleware for BlockingMiddleware {
173 fn on_request(
174 &self,
175 _ctx: &McpContext,
176 _request: &JsonRpcRequest,
177 ) -> McpResult<MiddlewareDecision> {
178 Ok(MiddlewareDecision::Respond(
179 serde_json::json!({"blocked": true}),
180 ))
181 }
182 }
183
184 #[test]
185 fn custom_on_request_can_short_circuit() {
186 let mw = BlockingMiddleware;
187 let ctx = make_ctx();
188 let req = make_request();
189 let decision = mw.on_request(&ctx, &req).unwrap();
190 match decision {
191 MiddlewareDecision::Respond(v) => assert_eq!(v["blocked"], true),
192 _ => panic!("expected Respond"),
193 }
194 }
195
196 struct ErrorRewritingMiddleware;
197 impl Middleware for ErrorRewritingMiddleware {
198 fn on_error(
199 &self,
200 _ctx: &McpContext,
201 _request: &JsonRpcRequest,
202 _error: McpError,
203 ) -> McpError {
204 McpError::internal_error("rewritten")
205 }
206 }
207
208 #[test]
209 fn custom_on_error_can_rewrite() {
210 let mw = ErrorRewritingMiddleware;
211 let ctx = make_ctx();
212 let req = make_request();
213 let original = McpError::internal_error("original");
214 let rewritten = mw.on_error(&ctx, &req, original);
215 assert!(rewritten.message.contains("rewritten"));
216 }
217
218 #[test]
221 fn arc_middleware_delegates_on_request() {
222 let mw: Arc<dyn Middleware> = Arc::new(BlockingMiddleware);
223 let ctx = make_ctx();
224 let req = make_request();
225 let decision = mw.on_request(&ctx, &req).unwrap();
226 match decision {
227 MiddlewareDecision::Respond(v) => assert_eq!(v["blocked"], true),
228 _ => panic!("expected Respond"),
229 }
230 }
231
232 #[test]
233 fn arc_middleware_delegates_on_response() {
234 let mw: Arc<dyn Middleware> = Arc::new(NoopMiddleware);
235 let ctx = make_ctx();
236 let req = make_request();
237 let input = serde_json::json!("hello");
238 let output = mw.on_response(&ctx, &req, input.clone()).unwrap();
239 assert_eq!(output, input);
240 }
241
242 #[test]
243 fn arc_middleware_delegates_on_error() {
244 let mw: Arc<dyn Middleware> = Arc::new(ErrorRewritingMiddleware);
245 let ctx = make_ctx();
246 let req = make_request();
247 let err = McpError::internal_error("x");
248 let result = mw.on_error(&ctx, &req, err);
249 assert!(result.message.contains("rewritten"));
250 }
251
252 struct TransformResponseMiddleware;
255 impl Middleware for TransformResponseMiddleware {
256 fn on_response(
257 &self,
258 _ctx: &McpContext,
259 _request: &JsonRpcRequest,
260 mut response: serde_json::Value,
261 ) -> McpResult<serde_json::Value> {
262 response["transformed"] = serde_json::json!(true);
263 Ok(response)
264 }
265 }
266
267 #[test]
268 fn custom_on_response_can_transform() {
269 let mw = TransformResponseMiddleware;
270 let ctx = make_ctx();
271 let req = make_request();
272 let input = serde_json::json!({"data": 1});
273 let output = mw.on_response(&ctx, &req, input).unwrap();
274 assert_eq!(output["data"], 1);
275 assert_eq!(output["transformed"], true);
276 }
277
278 #[test]
279 fn on_request_can_return_error() {
280 struct RejectMiddleware;
281 impl Middleware for RejectMiddleware {
282 fn on_request(
283 &self,
284 _ctx: &McpContext,
285 _request: &JsonRpcRequest,
286 ) -> McpResult<MiddlewareDecision> {
287 Err(McpError::internal_error("rejected"))
288 }
289 }
290
291 let mw = RejectMiddleware;
292 let ctx = make_ctx();
293 let req = make_request();
294 let err = mw.on_request(&ctx, &req).unwrap_err();
295 assert!(err.message.contains("rejected"));
296 }
297
298 #[test]
299 fn on_response_can_return_error() {
300 struct FailResponseMiddleware;
301 impl Middleware for FailResponseMiddleware {
302 fn on_response(
303 &self,
304 _ctx: &McpContext,
305 _request: &JsonRpcRequest,
306 _response: serde_json::Value,
307 ) -> McpResult<serde_json::Value> {
308 Err(McpError::internal_error("response-fail"))
309 }
310 }
311
312 let mw = FailResponseMiddleware;
313 let ctx = make_ctx();
314 let req = make_request();
315 let err = mw
316 .on_response(&ctx, &req, serde_json::json!({}))
317 .unwrap_err();
318 assert!(err.message.contains("response-fail"));
319 }
320
321 #[test]
322 fn middleware_decision_continue_clone() {
323 let d = MiddlewareDecision::Continue;
324 let cloned = d.clone();
325 assert!(matches!(cloned, MiddlewareDecision::Continue));
326 }
327
328 #[test]
329 fn arc_middleware_delegates_transforming_on_response() {
330 let mw: Arc<dyn Middleware> = Arc::new(TransformResponseMiddleware);
331 let ctx = make_ctx();
332 let req = make_request();
333 let input = serde_json::json!({"x": 2});
334 let output = mw.on_response(&ctx, &req, input).unwrap();
335 assert_eq!(output["x"], 2);
336 assert_eq!(output["transformed"], true);
337 }
338}