1use std::sync::OnceLock;
9use std::time::Instant;
10
11use async_trait::async_trait;
12
13use super::middleware::{Flow, RequestMiddleware, ResponseMiddleware};
14use super::values::{Context, Request, Response, Route, StageTiming};
15
16pub fn stage_timing_enabled() -> bool {
20 static ENABLED: OnceLock<bool> = OnceLock::new();
21 *ENABLED.get_or_init(|| {
22 std::env::var("MCPR_STAGE_TIMING")
23 .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
24 .unwrap_or(false)
25 })
26}
27
28pub struct StageGuard<'a> {
37 name: &'static str,
38 start: Instant,
39 sink: &'a mut Vec<StageTiming>,
40 enabled: bool,
41}
42
43impl<'a> StageGuard<'a> {
44 pub fn start(name: &'static str, sink: &'a mut Vec<StageTiming>) -> Self {
45 Self {
46 name,
47 start: Instant::now(),
48 sink,
49 enabled: stage_timing_enabled(),
50 }
51 }
52}
53
54impl Drop for StageGuard<'_> {
55 fn drop(&mut self) {
56 if self.enabled {
57 self.sink.push(StageTiming {
58 name: self.name,
59 elapsed_us: self.start.elapsed().as_micros() as u64,
60 });
61 }
62 }
63}
64
65pub trait Router: Send + Sync {
67 fn route(&self, req: &Request, cx: &Context) -> Route;
68}
69
70#[async_trait]
75pub trait Transport: Send + Sync {
76 async fn dispatch(&self, req: Request, route: Route, cx: &mut Context) -> Response;
77}
78
79pub struct Pipeline<R: Router, T: Transport> {
80 request_chain: Vec<Box<dyn RequestMiddleware>>,
81 response_chain: Vec<Box<dyn ResponseMiddleware>>,
82 router: R,
83 transport: T,
84}
85
86impl<R: Router, T: Transport> Pipeline<R, T> {
87 pub fn new(
88 request_chain: Vec<Box<dyn RequestMiddleware>>,
89 response_chain: Vec<Box<dyn ResponseMiddleware>>,
90 router: R,
91 transport: T,
92 ) -> Self {
93 Self {
97 request_chain,
98 response_chain,
99 router,
100 transport,
101 }
102 }
103
104 pub fn request_chain_names(&self) -> Vec<&'static str> {
105 self.request_chain.iter().map(|mw| mw.name()).collect()
106 }
107
108 pub fn response_chain_names(&self) -> Vec<&'static str> {
109 self.response_chain.iter().map(|mw| mw.name()).collect()
110 }
111
112 pub async fn run(&self, req: Request, cx: &mut Context) -> Response {
113 let resp = match self.run_request_chain(req, cx).await {
114 Ok(req) => {
115 let route = self.router.route(&req, cx);
116 self.transport.dispatch(req, route, cx).await
117 }
118 Err(short) => short,
119 };
120 self.run_response_chain(resp, cx).await
121 }
122
123 async fn run_request_chain(
124 &self,
125 mut req: Request,
126 cx: &mut Context,
127 ) -> Result<Request, Response> {
128 let enabled = stage_timing_enabled();
129 for mw in &self.request_chain {
130 let started = enabled.then(Instant::now);
131 let flow = mw.on_request(req, cx).await;
132 if let Some(t) = started {
133 cx.working.timings.push(StageTiming {
134 name: mw.name(),
135 elapsed_us: t.elapsed().as_micros() as u64,
136 });
137 }
138 match flow {
139 Flow::Continue(r) => req = r,
140 Flow::ShortCircuit(r) => return Err(r),
141 }
142 }
143 Ok(req)
144 }
145
146 async fn run_response_chain(&self, mut resp: Response, cx: &mut Context) -> Response {
147 let enabled = stage_timing_enabled();
148 for mw in &self.response_chain {
149 let started = enabled.then(Instant::now);
150 resp = mw.on_response(resp, cx).await;
151 if let Some(t) = started {
152 cx.working.timings.push(StageTiming {
153 name: mw.name(),
154 elapsed_us: t.elapsed().as_micros() as u64,
155 });
156 }
157 }
158 resp
159 }
160}
161
162#[cfg(test)]
163#[allow(non_snake_case)]
164mod tests {
165 use std::sync::{Arc, Mutex};
166
167 use axum::http::{HeaderMap, StatusCode};
168 use serde_json::json;
169
170 use super::*;
171 use crate::protocol::jsonrpc::JsonRpcEnvelope;
172 use crate::protocol::mcp::{ClientKind, ClientMethod, McpMessage, MessageKind, ToolsMethod};
173 use crate::proxy::pipeline::middleware::{Flow, RequestMiddleware, ResponseMiddleware};
174 use crate::proxy::pipeline::middlewares::test_support::{test_context, test_proxy_state};
175 use crate::proxy::pipeline::values::{
176 BufferPolicy, Envelope, McpRequest, McpTransport, Request, Response, Route,
177 };
178
179 enum FakeReqAction {
182 Continue,
183 AnnotateTag(&'static str),
184 ShortCircuit(&'static str),
185 }
186
187 struct FakeReqMw {
188 name: &'static str,
189 action: FakeReqAction,
190 }
191
192 #[async_trait]
193 impl RequestMiddleware for FakeReqMw {
194 fn name(&self) -> &'static str {
195 self.name
196 }
197 async fn on_request(&self, req: Request, cx: &mut Context) -> Flow {
198 match &self.action {
199 FakeReqAction::Continue => Flow::Continue(req),
200 FakeReqAction::AnnotateTag(t) => {
201 cx.working.tags.push(t);
202 Flow::Continue(req)
203 }
204 FakeReqAction::ShortCircuit(reason) => Flow::ShortCircuit(Response::Upstream502 {
205 reason: (*reason).to_owned(),
206 }),
207 }
208 }
209 }
210
211 struct FakeRespMw {
212 name: &'static str,
213 annotate: &'static str,
214 }
215
216 #[async_trait]
217 impl ResponseMiddleware for FakeRespMw {
218 fn name(&self) -> &'static str {
219 self.name
220 }
221 async fn on_response(&self, resp: Response, cx: &mut Context) -> Response {
222 cx.working.tags.push(self.annotate);
223 resp
224 }
225 }
226
227 struct FakeRouter {
228 route: Mutex<Option<Route>>,
229 calls: Arc<Mutex<u32>>,
230 }
231
232 impl Router for FakeRouter {
233 fn route(&self, _req: &Request, _cx: &Context) -> Route {
234 *self.calls.lock().unwrap() += 1;
235 self.route
236 .lock()
237 .unwrap()
238 .take()
239 .expect("FakeRouter called more than once")
240 }
241 }
242
243 struct FakeTransport {
244 response: Mutex<Option<Response>>,
245 calls: Arc<Mutex<u32>>,
246 }
247
248 #[async_trait]
249 impl Transport for FakeTransport {
250 async fn dispatch(&self, _req: Request, _route: Route, _cx: &mut Context) -> Response {
251 *self.calls.lock().unwrap() += 1;
252 self.response
253 .lock()
254 .unwrap()
255 .take()
256 .expect("FakeTransport called more than once")
257 }
258 }
259
260 fn stub_mcp_request() -> Request {
263 let env =
264 JsonRpcEnvelope::parse(br#"{"jsonrpc":"2.0","id":1,"method":"tools/list"}"#).unwrap();
265 Request::Mcp(McpRequest {
266 transport: McpTransport::StreamableHttpPost,
267 envelope: env,
268 kind: ClientKind::Request(ClientMethod::Tools(ToolsMethod::List)),
269 headers: HeaderMap::new(),
270 session_hint: None,
271 })
272 }
273
274 fn stub_buffered_response() -> Response {
275 let env =
276 JsonRpcEnvelope::parse(br#"{"jsonrpc":"2.0","id":1,"result":{"tools":[]}}"#).unwrap();
277 let message = McpMessage {
278 envelope: env,
279 kind: MessageKind::Server(crate::protocol::mcp::ServerKind::Result),
280 };
281 Response::McpBuffered {
282 envelope: Envelope::Json,
283 message,
284 status: StatusCode::OK,
285 headers: HeaderMap::new(),
286 }
287 }
288
289 fn stub_route() -> Route {
290 Route::McpStreamableHttp {
291 upstream: "http://upstream.test/mcp".into(),
292 method: ClientMethod::Tools(ToolsMethod::List),
293 buffer_policy: BufferPolicy::Buffered { max: 4096 },
294 }
295 }
296
297 #[tokio::test]
300 async fn run__empty_chain_returns_transport_response() {
301 let proxy = test_proxy_state();
302 let mut cx = test_context(proxy);
303 let router_calls = Arc::new(Mutex::new(0));
304 let transport_calls = Arc::new(Mutex::new(0));
305 let pipeline = Pipeline::new(
306 Vec::<Box<dyn RequestMiddleware>>::new(),
307 Vec::<Box<dyn ResponseMiddleware>>::new(),
308 FakeRouter {
309 route: Mutex::new(Some(stub_route())),
310 calls: router_calls.clone(),
311 },
312 FakeTransport {
313 response: Mutex::new(Some(stub_buffered_response())),
314 calls: transport_calls.clone(),
315 },
316 );
317
318 let resp = pipeline.run(stub_mcp_request(), &mut cx).await;
319 assert!(matches!(resp, Response::McpBuffered { .. }));
320 assert_eq!(*router_calls.lock().unwrap(), 1);
321 assert_eq!(*transport_calls.lock().unwrap(), 1);
322 }
323
324 #[tokio::test]
325 async fn run__request_chain_fires_in_order() {
326 let proxy = test_proxy_state();
327 let mut cx = test_context(proxy);
328 let pipeline = Pipeline::new(
329 vec![
330 Box::new(FakeReqMw {
331 name: "a",
332 action: FakeReqAction::AnnotateTag("tag-a"),
333 }) as _,
334 Box::new(FakeReqMw {
335 name: "b",
336 action: FakeReqAction::AnnotateTag("tag-b"),
337 }) as _,
338 Box::new(FakeReqMw {
339 name: "c",
340 action: FakeReqAction::AnnotateTag("tag-c"),
341 }) as _,
342 ],
343 Vec::<Box<dyn ResponseMiddleware>>::new(),
344 FakeRouter {
345 route: Mutex::new(Some(stub_route())),
346 calls: Arc::new(Mutex::new(0)),
347 },
348 FakeTransport {
349 response: Mutex::new(Some(stub_buffered_response())),
350 calls: Arc::new(Mutex::new(0)),
351 },
352 );
353 pipeline.run(stub_mcp_request(), &mut cx).await;
354 assert_eq!(cx.working.tags.as_slice(), &["tag-a", "tag-b", "tag-c"]);
355 }
356
357 #[tokio::test]
358 async fn run__short_circuit_skips_router_transport_and_later_request_mws() {
359 let proxy = test_proxy_state();
360 let mut cx = test_context(proxy);
361 let router_calls = Arc::new(Mutex::new(0));
362 let transport_calls = Arc::new(Mutex::new(0));
363 let pipeline = Pipeline::new(
364 vec![
365 Box::new(FakeReqMw {
366 name: "before",
367 action: FakeReqAction::AnnotateTag("before"),
368 }) as _,
369 Box::new(FakeReqMw {
370 name: "cut",
371 action: FakeReqAction::ShortCircuit("cut"),
372 }) as _,
373 Box::new(FakeReqMw {
374 name: "after",
375 action: FakeReqAction::AnnotateTag("after"),
376 }) as _,
377 ],
378 Vec::<Box<dyn ResponseMiddleware>>::new(),
379 FakeRouter {
380 route: Mutex::new(Some(stub_route())),
381 calls: router_calls.clone(),
382 },
383 FakeTransport {
384 response: Mutex::new(Some(stub_buffered_response())),
385 calls: transport_calls.clone(),
386 },
387 );
388
389 let resp = pipeline.run(stub_mcp_request(), &mut cx).await;
390 assert!(matches!(resp, Response::Upstream502 { .. }));
391 assert_eq!(cx.working.tags.as_slice(), &["before"]);
392 assert_eq!(*router_calls.lock().unwrap(), 0);
393 assert_eq!(*transport_calls.lock().unwrap(), 0);
394 }
395
396 #[tokio::test]
397 async fn run__response_chain_runs_after_short_circuit() {
398 let proxy = test_proxy_state();
399 let mut cx = test_context(proxy);
400 let pipeline = Pipeline::new(
401 vec![Box::new(FakeReqMw {
402 name: "cut",
403 action: FakeReqAction::ShortCircuit("x"),
404 }) as _],
405 vec![
406 Box::new(FakeRespMw {
407 name: "r1",
408 annotate: "resp-1",
409 }) as _,
410 Box::new(FakeRespMw {
411 name: "r2",
412 annotate: "resp-2",
413 }) as _,
414 ],
415 FakeRouter {
416 route: Mutex::new(Some(stub_route())),
417 calls: Arc::new(Mutex::new(0)),
418 },
419 FakeTransport {
420 response: Mutex::new(Some(stub_buffered_response())),
421 calls: Arc::new(Mutex::new(0)),
422 },
423 );
424
425 pipeline.run(stub_mcp_request(), &mut cx).await;
426 assert_eq!(cx.working.tags.as_slice(), &["resp-1", "resp-2"]);
427 }
428
429 #[tokio::test]
430 async fn run__response_chain_folds_in_order() {
431 let proxy = test_proxy_state();
432 let mut cx = test_context(proxy);
433 let pipeline = Pipeline::new(
434 Vec::<Box<dyn RequestMiddleware>>::new(),
435 vec![
436 Box::new(FakeRespMw {
437 name: "r1",
438 annotate: "r1",
439 }) as _,
440 Box::new(FakeRespMw {
441 name: "r2",
442 annotate: "r2",
443 }) as _,
444 Box::new(FakeRespMw {
445 name: "r3",
446 annotate: "r3",
447 }) as _,
448 ],
449 FakeRouter {
450 route: Mutex::new(Some(stub_route())),
451 calls: Arc::new(Mutex::new(0)),
452 },
453 FakeTransport {
454 response: Mutex::new(Some(stub_buffered_response())),
455 calls: Arc::new(Mutex::new(0)),
456 },
457 );
458 pipeline.run(stub_mcp_request(), &mut cx).await;
459 assert_eq!(cx.working.tags.as_slice(), &["r1", "r2", "r3"]);
460 }
461
462 #[tokio::test]
463 async fn chain_names__reports_registered_middlewares() {
464 let pipeline = Pipeline::new(
465 vec![
466 Box::new(FakeReqMw {
467 name: "session_touch",
468 action: FakeReqAction::Continue,
469 }) as _,
470 Box::new(FakeReqMw {
471 name: "client_info_inject",
472 action: FakeReqAction::Continue,
473 }) as _,
474 ],
475 vec![
476 Box::new(FakeRespMw {
477 name: "schema_ingest",
478 annotate: "",
479 }) as _,
480 Box::new(FakeRespMw {
481 name: "envelope_seal",
482 annotate: "",
483 }) as _,
484 ],
485 FakeRouter {
486 route: Mutex::new(Some(stub_route())),
487 calls: Arc::new(Mutex::new(0)),
488 },
489 FakeTransport {
490 response: Mutex::new(Some(stub_buffered_response())),
491 calls: Arc::new(Mutex::new(0)),
492 },
493 );
494 assert_eq!(
495 pipeline.request_chain_names(),
496 vec!["session_touch", "client_info_inject"],
497 );
498 assert_eq!(
499 pipeline.response_chain_names(),
500 vec!["schema_ingest", "envelope_seal"],
501 );
502 }
503
504 #[tokio::test]
507 async fn smoke__request_response_roundtrip_with_mutation() {
508 let proxy = test_proxy_state();
509 let mut cx = test_context(proxy);
510 let pipeline = Pipeline::new(
511 vec![Box::new(FakeReqMw {
512 name: "tag",
513 action: FakeReqAction::AnnotateTag("touched"),
514 }) as _],
515 vec![Box::new(FakeRespMw {
516 name: "tag_resp",
517 annotate: "sealed",
518 }) as _],
519 FakeRouter {
520 route: Mutex::new(Some(stub_route())),
521 calls: Arc::new(Mutex::new(0)),
522 },
523 FakeTransport {
524 response: Mutex::new(Some(Response::McpBuffered {
525 envelope: Envelope::Json,
526 message: McpMessage {
527 envelope: JsonRpcEnvelope::parse(
528 br#"{"jsonrpc":"2.0","id":42,"result":{"tools":[]}}"#,
529 )
530 .unwrap(),
531 kind: MessageKind::Server(crate::protocol::mcp::ServerKind::Result),
532 },
533 status: StatusCode::OK,
534 headers: HeaderMap::new(),
535 })),
536 calls: Arc::new(Mutex::new(0)),
537 },
538 );
539
540 let resp = pipeline.run(stub_mcp_request(), &mut cx).await;
541 match resp {
542 Response::McpBuffered {
543 status, message, ..
544 } => {
545 assert_eq!(status, StatusCode::OK);
546 let result: serde_json::Value = message
547 .envelope
548 .result_as()
549 .expect("result should deserialize");
550 assert_eq!(result, json!({"tools": []}));
551 }
552 other => panic!("expected McpBuffered, got {other:?}"),
553 }
554 assert_eq!(cx.working.tags.as_slice(), &["touched", "sealed"]);
555 }
556}