1use fastmcp_core::{McpContext, McpError, McpResult};
18use fastmcp_protocol::JsonRpcRequest;
19
20use std::sync::Arc;
21
22#[derive(Debug, Clone)]
24pub enum MiddlewareDecision {
25 Continue,
27 Respond(serde_json::Value),
29}
30
31pub trait Middleware: Send + Sync {
37 fn on_request(
41 &self,
42 _ctx: &McpContext,
43 _request: &JsonRpcRequest,
44 ) -> McpResult<MiddlewareDecision> {
45 Ok(MiddlewareDecision::Continue)
46 }
47
48 fn on_response(
52 &self,
53 _ctx: &McpContext,
54 _request: &JsonRpcRequest,
55 response: serde_json::Value,
56 ) -> McpResult<serde_json::Value> {
57 Ok(response)
58 }
59
60 fn on_error(&self, _ctx: &McpContext, _request: &JsonRpcRequest, error: McpError) -> McpError {
64 error
65 }
66}
67
68impl<T> Middleware for Arc<T>
69where
70 T: Middleware + ?Sized,
71{
72 fn on_request(
73 &self,
74 ctx: &McpContext,
75 request: &JsonRpcRequest,
76 ) -> McpResult<MiddlewareDecision> {
77 (**self).on_request(ctx, request)
78 }
79
80 fn on_response(
81 &self,
82 ctx: &McpContext,
83 request: &JsonRpcRequest,
84 response: serde_json::Value,
85 ) -> McpResult<serde_json::Value> {
86 (**self).on_response(ctx, request, response)
87 }
88
89 fn on_error(&self, ctx: &McpContext, request: &JsonRpcRequest, error: McpError) -> McpError {
90 (**self).on_error(ctx, request, error)
91 }
92}
93
94#[cfg(test)]
95mod tests {
96 use super::*;
97 use asupersync::Cx;
98
99 fn make_ctx() -> McpContext {
100 McpContext::new(Cx::for_testing(), 1)
101 }
102
103 fn make_request() -> JsonRpcRequest {
104 JsonRpcRequest::new("tools/call", None, 1i64)
105 }
106
107 #[test]
110 fn middleware_decision_continue_debug() {
111 let d = MiddlewareDecision::Continue;
112 let debug = format!("{:?}", d);
113 assert!(debug.contains("Continue"));
114 }
115
116 #[test]
117 fn middleware_decision_respond_debug() {
118 let d = MiddlewareDecision::Respond(serde_json::json!({"ok": true}));
119 let debug = format!("{:?}", d);
120 assert!(debug.contains("Respond"));
121 }
122
123 #[test]
124 fn middleware_decision_clone() {
125 let d = MiddlewareDecision::Respond(serde_json::json!(42));
126 let cloned = d.clone();
127 match cloned {
128 MiddlewareDecision::Respond(v) => assert_eq!(v, 42),
129 _ => panic!("expected Respond"),
130 }
131 }
132
133 struct NoopMiddleware;
136 impl Middleware for NoopMiddleware {}
137
138 #[test]
139 fn default_on_request_returns_continue() {
140 let mw = NoopMiddleware;
141 let ctx = make_ctx();
142 let req = make_request();
143 let decision = mw.on_request(&ctx, &req).unwrap();
144 matches!(decision, MiddlewareDecision::Continue);
145 }
146
147 #[test]
148 fn default_on_response_passes_through() {
149 let mw = NoopMiddleware;
150 let ctx = make_ctx();
151 let req = make_request();
152 let input = serde_json::json!({"data": "hello"});
153 let output = mw.on_response(&ctx, &req, input.clone()).unwrap();
154 assert_eq!(output, input);
155 }
156
157 #[test]
158 fn default_on_error_passes_through() {
159 let mw = NoopMiddleware;
160 let ctx = make_ctx();
161 let req = make_request();
162 let err = McpError::internal_error("test error");
163 let result = mw.on_error(&ctx, &req, err);
164 assert!(result.message.contains("test error"));
165 }
166
167 struct BlockingMiddleware;
170 impl Middleware for BlockingMiddleware {
171 fn on_request(
172 &self,
173 _ctx: &McpContext,
174 _request: &JsonRpcRequest,
175 ) -> McpResult<MiddlewareDecision> {
176 Ok(MiddlewareDecision::Respond(
177 serde_json::json!({"blocked": true}),
178 ))
179 }
180 }
181
182 #[test]
183 fn custom_on_request_can_short_circuit() {
184 let mw = BlockingMiddleware;
185 let ctx = make_ctx();
186 let req = make_request();
187 let decision = mw.on_request(&ctx, &req).unwrap();
188 match decision {
189 MiddlewareDecision::Respond(v) => assert_eq!(v["blocked"], true),
190 _ => panic!("expected Respond"),
191 }
192 }
193
194 struct ErrorRewritingMiddleware;
195 impl Middleware for ErrorRewritingMiddleware {
196 fn on_error(
197 &self,
198 _ctx: &McpContext,
199 _request: &JsonRpcRequest,
200 _error: McpError,
201 ) -> McpError {
202 McpError::internal_error("rewritten")
203 }
204 }
205
206 #[test]
207 fn custom_on_error_can_rewrite() {
208 let mw = ErrorRewritingMiddleware;
209 let ctx = make_ctx();
210 let req = make_request();
211 let original = McpError::internal_error("original");
212 let rewritten = mw.on_error(&ctx, &req, original);
213 assert!(rewritten.message.contains("rewritten"));
214 }
215
216 #[test]
219 fn arc_middleware_delegates_on_request() {
220 let mw: Arc<dyn Middleware> = Arc::new(BlockingMiddleware);
221 let ctx = make_ctx();
222 let req = make_request();
223 let decision = mw.on_request(&ctx, &req).unwrap();
224 match decision {
225 MiddlewareDecision::Respond(v) => assert_eq!(v["blocked"], true),
226 _ => panic!("expected Respond"),
227 }
228 }
229
230 #[test]
231 fn arc_middleware_delegates_on_response() {
232 let mw: Arc<dyn Middleware> = Arc::new(NoopMiddleware);
233 let ctx = make_ctx();
234 let req = make_request();
235 let input = serde_json::json!("hello");
236 let output = mw.on_response(&ctx, &req, input.clone()).unwrap();
237 assert_eq!(output, input);
238 }
239
240 #[test]
241 fn arc_middleware_delegates_on_error() {
242 let mw: Arc<dyn Middleware> = Arc::new(ErrorRewritingMiddleware);
243 let ctx = make_ctx();
244 let req = make_request();
245 let err = McpError::internal_error("x");
246 let result = mw.on_error(&ctx, &req, err);
247 assert!(result.message.contains("rewritten"));
248 }
249
250 struct TransformResponseMiddleware;
253 impl Middleware for TransformResponseMiddleware {
254 fn on_response(
255 &self,
256 _ctx: &McpContext,
257 _request: &JsonRpcRequest,
258 mut response: serde_json::Value,
259 ) -> McpResult<serde_json::Value> {
260 response["transformed"] = serde_json::json!(true);
261 Ok(response)
262 }
263 }
264
265 #[test]
266 fn custom_on_response_can_transform() {
267 let mw = TransformResponseMiddleware;
268 let ctx = make_ctx();
269 let req = make_request();
270 let input = serde_json::json!({"data": 1});
271 let output = mw.on_response(&ctx, &req, input).unwrap();
272 assert_eq!(output["data"], 1);
273 assert_eq!(output["transformed"], true);
274 }
275
276 #[test]
277 fn on_request_can_return_error() {
278 struct RejectMiddleware;
279 impl Middleware for RejectMiddleware {
280 fn on_request(
281 &self,
282 _ctx: &McpContext,
283 _request: &JsonRpcRequest,
284 ) -> McpResult<MiddlewareDecision> {
285 Err(McpError::internal_error("rejected"))
286 }
287 }
288
289 let mw = RejectMiddleware;
290 let ctx = make_ctx();
291 let req = make_request();
292 let err = mw.on_request(&ctx, &req).unwrap_err();
293 assert!(err.message.contains("rejected"));
294 }
295
296 #[test]
297 fn on_response_can_return_error() {
298 struct FailResponseMiddleware;
299 impl Middleware for FailResponseMiddleware {
300 fn on_response(
301 &self,
302 _ctx: &McpContext,
303 _request: &JsonRpcRequest,
304 _response: serde_json::Value,
305 ) -> McpResult<serde_json::Value> {
306 Err(McpError::internal_error("response-fail"))
307 }
308 }
309
310 let mw = FailResponseMiddleware;
311 let ctx = make_ctx();
312 let req = make_request();
313 let err = mw
314 .on_response(&ctx, &req, serde_json::json!({}))
315 .unwrap_err();
316 assert!(err.message.contains("response-fail"));
317 }
318
319 #[test]
320 fn middleware_decision_continue_clone() {
321 let d = MiddlewareDecision::Continue;
322 let cloned = d.clone();
323 assert!(matches!(cloned, MiddlewareDecision::Continue));
324 }
325
326 #[test]
327 fn arc_middleware_delegates_transforming_on_response() {
328 let mw: Arc<dyn Middleware> = Arc::new(TransformResponseMiddleware);
329 let ctx = make_ctx();
330 let req = make_request();
331 let input = serde_json::json!({"x": 2});
332 let output = mw.on_response(&ctx, &req, input).unwrap();
333 assert_eq!(output["x"], 2);
334 assert_eq!(output["transformed"], true);
335 }
336}