Skip to main content

mcpr_core/proxy/pipeline/
driver.rs

1//! Pipeline driver — the engine that runs middleware chains, the
2//! router, and the transport.
3//!
4//! See `PIPELINE.md` §Layers. A short explicit loop that
5//! owns an ordered `Vec<Box<dyn …>>` for each chain. No tower, no
6//! service combinators.
7
8use std::sync::OnceLock;
9use std::time::Instant;
10
11use async_trait::async_trait;
12
13use super::middleware::{Flow, RequestMiddleware, ResponseMiddleware};
14use super::values::{Context, Request, Response, Route, StageTiming};
15
16/// `true` if `MCPR_STAGE_TIMING` is set to `1` or `true`. Checked once
17/// per process and cached — the hot path pays one `OnceLock` load
18/// (~1ns) per middleware dispatch.
19pub fn stage_timing_enabled() -> bool {
20    static ENABLED: OnceLock<bool> = OnceLock::new();
21    *ENABLED.get_or_init(|| {
22        std::env::var("MCPR_STAGE_TIMING")
23            .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
24            .unwrap_or(false)
25    })
26}
27
28/// RAII timer — start one at the top of a block, get a
29/// [`StageTiming`] pushed onto `sink` when the guard drops. Honors
30/// [`stage_timing_enabled`] so disabled builds skip the push.
31/// Handles early returns / `?` / panics because `Drop` always runs.
32///
33/// Used by non-middleware sites (transport sub-stages, intake parse).
34/// Middlewares themselves are wrapped by the driver — their bodies
35/// never construct a `StageGuard`.
36pub struct StageGuard<'a> {
37    name: &'static str,
38    start: Instant,
39    sink: &'a mut Vec<StageTiming>,
40    enabled: bool,
41}
42
43impl<'a> StageGuard<'a> {
44    pub fn start(name: &'static str, sink: &'a mut Vec<StageTiming>) -> Self {
45        Self {
46            name,
47            start: Instant::now(),
48            sink,
49            enabled: stage_timing_enabled(),
50        }
51    }
52}
53
54impl Drop for StageGuard<'_> {
55    fn drop(&mut self) {
56        if self.enabled {
57            self.sink.push(StageTiming {
58                name: self.name,
59                elapsed_us: self.start.elapsed().as_micros() as u64,
60            });
61        }
62    }
63}
64
65/// Pure function: decide where a request is headed. No I/O.
66pub trait Router: Send + Sync {
67    fn route(&self, req: &Request, cx: &Context) -> Route;
68}
69
70/// The one layer that touches the network. Reqwest errors become
71/// `Response::Upstream502`. Takes `&mut Context` so the transport can
72/// write per-stage timings (`upstream_us`, `buffer_us`, `sse_unwrap_us`,
73/// `json_parse_us`) onto `cx.working`.
74#[async_trait]
75pub trait Transport: Send + Sync {
76    async fn dispatch(&self, req: Request, route: Route, cx: &mut Context) -> Response;
77}
78
79pub struct Pipeline<R: Router, T: Transport> {
80    request_chain: Vec<Box<dyn RequestMiddleware>>,
81    response_chain: Vec<Box<dyn ResponseMiddleware>>,
82    router: R,
83    transport: T,
84}
85
86impl<R: Router, T: Transport> Pipeline<R, T> {
87    pub fn new(
88        request_chain: Vec<Box<dyn RequestMiddleware>>,
89        response_chain: Vec<Box<dyn ResponseMiddleware>>,
90        router: R,
91        transport: T,
92    ) -> Self {
93        // Registration logging is handled by `build_default_pipeline`,
94        // which owns construction ordering and is the single site where
95        // operator-visible chain composition needs to be reported.
96        Self {
97            request_chain,
98            response_chain,
99            router,
100            transport,
101        }
102    }
103
104    pub fn request_chain_names(&self) -> Vec<&'static str> {
105        self.request_chain.iter().map(|mw| mw.name()).collect()
106    }
107
108    pub fn response_chain_names(&self) -> Vec<&'static str> {
109        self.response_chain.iter().map(|mw| mw.name()).collect()
110    }
111
112    pub async fn run(&self, req: Request, cx: &mut Context) -> Response {
113        let resp = match self.run_request_chain(req, cx).await {
114            Ok(req) => {
115                let route = self.router.route(&req, cx);
116                self.transport.dispatch(req, route, cx).await
117            }
118            Err(short) => short,
119        };
120        self.run_response_chain(resp, cx).await
121    }
122
123    async fn run_request_chain(
124        &self,
125        mut req: Request,
126        cx: &mut Context,
127    ) -> Result<Request, Response> {
128        let enabled = stage_timing_enabled();
129        for mw in &self.request_chain {
130            let started = enabled.then(Instant::now);
131            let flow = mw.on_request(req, cx).await;
132            if let Some(t) = started {
133                cx.working.timings.push(StageTiming {
134                    name: mw.name(),
135                    elapsed_us: t.elapsed().as_micros() as u64,
136                });
137            }
138            match flow {
139                Flow::Continue(r) => req = r,
140                Flow::ShortCircuit(r) => return Err(r),
141            }
142        }
143        Ok(req)
144    }
145
146    async fn run_response_chain(&self, mut resp: Response, cx: &mut Context) -> Response {
147        let enabled = stage_timing_enabled();
148        for mw in &self.response_chain {
149            let started = enabled.then(Instant::now);
150            resp = mw.on_response(resp, cx).await;
151            if let Some(t) = started {
152                cx.working.timings.push(StageTiming {
153                    name: mw.name(),
154                    elapsed_us: t.elapsed().as_micros() as u64,
155                });
156            }
157        }
158        resp
159    }
160}
161
162#[cfg(test)]
163#[allow(non_snake_case)]
164mod tests {
165    use std::sync::{Arc, Mutex};
166
167    use axum::http::{HeaderMap, StatusCode};
168    use serde_json::json;
169
170    use super::*;
171    use crate::protocol::jsonrpc::JsonRpcEnvelope;
172    use crate::protocol::mcp::{ClientKind, ClientMethod, McpMessage, MessageKind, ToolsMethod};
173    use crate::proxy::pipeline::middleware::{Flow, RequestMiddleware, ResponseMiddleware};
174    use crate::proxy::pipeline::middlewares::test_support::{test_context, test_proxy_state};
175    use crate::proxy::pipeline::values::{
176        BufferPolicy, Envelope, McpRequest, McpTransport, Request, Response, Route,
177    };
178
179    // ── Fakes ────────────────────────────────────────────────
180
181    enum FakeReqAction {
182        Continue,
183        AnnotateTag(&'static str),
184        ShortCircuit(&'static str),
185    }
186
187    struct FakeReqMw {
188        name: &'static str,
189        action: FakeReqAction,
190    }
191
192    #[async_trait]
193    impl RequestMiddleware for FakeReqMw {
194        fn name(&self) -> &'static str {
195            self.name
196        }
197        async fn on_request(&self, req: Request, cx: &mut Context) -> Flow {
198            match &self.action {
199                FakeReqAction::Continue => Flow::Continue(req),
200                FakeReqAction::AnnotateTag(t) => {
201                    cx.working.tags.push(t);
202                    Flow::Continue(req)
203                }
204                FakeReqAction::ShortCircuit(reason) => Flow::ShortCircuit(Response::Upstream502 {
205                    reason: (*reason).to_owned(),
206                }),
207            }
208        }
209    }
210
211    struct FakeRespMw {
212        name: &'static str,
213        annotate: &'static str,
214    }
215
216    #[async_trait]
217    impl ResponseMiddleware for FakeRespMw {
218        fn name(&self) -> &'static str {
219            self.name
220        }
221        async fn on_response(&self, resp: Response, cx: &mut Context) -> Response {
222            cx.working.tags.push(self.annotate);
223            resp
224        }
225    }
226
227    struct FakeRouter {
228        route: Mutex<Option<Route>>,
229        calls: Arc<Mutex<u32>>,
230    }
231
232    impl Router for FakeRouter {
233        fn route(&self, _req: &Request, _cx: &Context) -> Route {
234            *self.calls.lock().unwrap() += 1;
235            self.route
236                .lock()
237                .unwrap()
238                .take()
239                .expect("FakeRouter called more than once")
240        }
241    }
242
243    struct FakeTransport {
244        response: Mutex<Option<Response>>,
245        calls: Arc<Mutex<u32>>,
246    }
247
248    #[async_trait]
249    impl Transport for FakeTransport {
250        async fn dispatch(&self, _req: Request, _route: Route, _cx: &mut Context) -> Response {
251            *self.calls.lock().unwrap() += 1;
252            self.response
253                .lock()
254                .unwrap()
255                .take()
256                .expect("FakeTransport called more than once")
257        }
258    }
259
260    // ── Harness ─────────────────────────────────────────────
261
262    fn stub_mcp_request() -> Request {
263        let env =
264            JsonRpcEnvelope::parse(br#"{"jsonrpc":"2.0","id":1,"method":"tools/list"}"#).unwrap();
265        Request::Mcp(McpRequest {
266            transport: McpTransport::StreamableHttpPost,
267            envelope: env,
268            kind: ClientKind::Request(ClientMethod::Tools(ToolsMethod::List)),
269            headers: HeaderMap::new(),
270            session_hint: None,
271        })
272    }
273
274    fn stub_buffered_response() -> Response {
275        let env =
276            JsonRpcEnvelope::parse(br#"{"jsonrpc":"2.0","id":1,"result":{"tools":[]}}"#).unwrap();
277        let message = McpMessage {
278            envelope: env,
279            kind: MessageKind::Server(crate::protocol::mcp::ServerKind::Result),
280        };
281        Response::McpBuffered {
282            envelope: Envelope::Json,
283            message,
284            status: StatusCode::OK,
285            headers: HeaderMap::new(),
286        }
287    }
288
289    fn stub_route() -> Route {
290        Route::McpStreamableHttp {
291            upstream: "http://upstream.test/mcp".into(),
292            method: ClientMethod::Tools(ToolsMethod::List),
293            buffer_policy: BufferPolicy::Buffered { max: 4096 },
294        }
295    }
296
297    // ── Tests ───────────────────────────────────────────────
298
299    #[tokio::test]
300    async fn run__empty_chain_returns_transport_response() {
301        let proxy = test_proxy_state();
302        let mut cx = test_context(proxy);
303        let router_calls = Arc::new(Mutex::new(0));
304        let transport_calls = Arc::new(Mutex::new(0));
305        let pipeline = Pipeline::new(
306            Vec::<Box<dyn RequestMiddleware>>::new(),
307            Vec::<Box<dyn ResponseMiddleware>>::new(),
308            FakeRouter {
309                route: Mutex::new(Some(stub_route())),
310                calls: router_calls.clone(),
311            },
312            FakeTransport {
313                response: Mutex::new(Some(stub_buffered_response())),
314                calls: transport_calls.clone(),
315            },
316        );
317
318        let resp = pipeline.run(stub_mcp_request(), &mut cx).await;
319        assert!(matches!(resp, Response::McpBuffered { .. }));
320        assert_eq!(*router_calls.lock().unwrap(), 1);
321        assert_eq!(*transport_calls.lock().unwrap(), 1);
322    }
323
324    #[tokio::test]
325    async fn run__request_chain_fires_in_order() {
326        let proxy = test_proxy_state();
327        let mut cx = test_context(proxy);
328        let pipeline = Pipeline::new(
329            vec![
330                Box::new(FakeReqMw {
331                    name: "a",
332                    action: FakeReqAction::AnnotateTag("tag-a"),
333                }) as _,
334                Box::new(FakeReqMw {
335                    name: "b",
336                    action: FakeReqAction::AnnotateTag("tag-b"),
337                }) as _,
338                Box::new(FakeReqMw {
339                    name: "c",
340                    action: FakeReqAction::AnnotateTag("tag-c"),
341                }) as _,
342            ],
343            Vec::<Box<dyn ResponseMiddleware>>::new(),
344            FakeRouter {
345                route: Mutex::new(Some(stub_route())),
346                calls: Arc::new(Mutex::new(0)),
347            },
348            FakeTransport {
349                response: Mutex::new(Some(stub_buffered_response())),
350                calls: Arc::new(Mutex::new(0)),
351            },
352        );
353        pipeline.run(stub_mcp_request(), &mut cx).await;
354        assert_eq!(cx.working.tags.as_slice(), &["tag-a", "tag-b", "tag-c"]);
355    }
356
357    #[tokio::test]
358    async fn run__short_circuit_skips_router_transport_and_later_request_mws() {
359        let proxy = test_proxy_state();
360        let mut cx = test_context(proxy);
361        let router_calls = Arc::new(Mutex::new(0));
362        let transport_calls = Arc::new(Mutex::new(0));
363        let pipeline = Pipeline::new(
364            vec![
365                Box::new(FakeReqMw {
366                    name: "before",
367                    action: FakeReqAction::AnnotateTag("before"),
368                }) as _,
369                Box::new(FakeReqMw {
370                    name: "cut",
371                    action: FakeReqAction::ShortCircuit("cut"),
372                }) as _,
373                Box::new(FakeReqMw {
374                    name: "after",
375                    action: FakeReqAction::AnnotateTag("after"),
376                }) as _,
377            ],
378            Vec::<Box<dyn ResponseMiddleware>>::new(),
379            FakeRouter {
380                route: Mutex::new(Some(stub_route())),
381                calls: router_calls.clone(),
382            },
383            FakeTransport {
384                response: Mutex::new(Some(stub_buffered_response())),
385                calls: transport_calls.clone(),
386            },
387        );
388
389        let resp = pipeline.run(stub_mcp_request(), &mut cx).await;
390        assert!(matches!(resp, Response::Upstream502 { .. }));
391        assert_eq!(cx.working.tags.as_slice(), &["before"]);
392        assert_eq!(*router_calls.lock().unwrap(), 0);
393        assert_eq!(*transport_calls.lock().unwrap(), 0);
394    }
395
396    #[tokio::test]
397    async fn run__response_chain_runs_after_short_circuit() {
398        let proxy = test_proxy_state();
399        let mut cx = test_context(proxy);
400        let pipeline = Pipeline::new(
401            vec![Box::new(FakeReqMw {
402                name: "cut",
403                action: FakeReqAction::ShortCircuit("x"),
404            }) as _],
405            vec![
406                Box::new(FakeRespMw {
407                    name: "r1",
408                    annotate: "resp-1",
409                }) as _,
410                Box::new(FakeRespMw {
411                    name: "r2",
412                    annotate: "resp-2",
413                }) as _,
414            ],
415            FakeRouter {
416                route: Mutex::new(Some(stub_route())),
417                calls: Arc::new(Mutex::new(0)),
418            },
419            FakeTransport {
420                response: Mutex::new(Some(stub_buffered_response())),
421                calls: Arc::new(Mutex::new(0)),
422            },
423        );
424
425        pipeline.run(stub_mcp_request(), &mut cx).await;
426        assert_eq!(cx.working.tags.as_slice(), &["resp-1", "resp-2"]);
427    }
428
429    #[tokio::test]
430    async fn run__response_chain_folds_in_order() {
431        let proxy = test_proxy_state();
432        let mut cx = test_context(proxy);
433        let pipeline = Pipeline::new(
434            Vec::<Box<dyn RequestMiddleware>>::new(),
435            vec![
436                Box::new(FakeRespMw {
437                    name: "r1",
438                    annotate: "r1",
439                }) as _,
440                Box::new(FakeRespMw {
441                    name: "r2",
442                    annotate: "r2",
443                }) as _,
444                Box::new(FakeRespMw {
445                    name: "r3",
446                    annotate: "r3",
447                }) as _,
448            ],
449            FakeRouter {
450                route: Mutex::new(Some(stub_route())),
451                calls: Arc::new(Mutex::new(0)),
452            },
453            FakeTransport {
454                response: Mutex::new(Some(stub_buffered_response())),
455                calls: Arc::new(Mutex::new(0)),
456            },
457        );
458        pipeline.run(stub_mcp_request(), &mut cx).await;
459        assert_eq!(cx.working.tags.as_slice(), &["r1", "r2", "r3"]);
460    }
461
462    #[tokio::test]
463    async fn chain_names__reports_registered_middlewares() {
464        let pipeline = Pipeline::new(
465            vec![
466                Box::new(FakeReqMw {
467                    name: "session_touch",
468                    action: FakeReqAction::Continue,
469                }) as _,
470                Box::new(FakeReqMw {
471                    name: "client_info_inject",
472                    action: FakeReqAction::Continue,
473                }) as _,
474            ],
475            vec![
476                Box::new(FakeRespMw {
477                    name: "schema_ingest",
478                    annotate: "",
479                }) as _,
480                Box::new(FakeRespMw {
481                    name: "envelope_seal",
482                    annotate: "",
483                }) as _,
484            ],
485            FakeRouter {
486                route: Mutex::new(Some(stub_route())),
487                calls: Arc::new(Mutex::new(0)),
488            },
489            FakeTransport {
490                response: Mutex::new(Some(stub_buffered_response())),
491                calls: Arc::new(Mutex::new(0)),
492            },
493        );
494        assert_eq!(
495            pipeline.request_chain_names(),
496            vec!["session_touch", "client_info_inject"],
497        );
498        assert_eq!(
499            pipeline.response_chain_names(),
500            vec!["schema_ingest", "envelope_seal"],
501        );
502    }
503
504    // ── Smoke test — one request through a full stub chain ──
505
506    #[tokio::test]
507    async fn smoke__request_response_roundtrip_with_mutation() {
508        let proxy = test_proxy_state();
509        let mut cx = test_context(proxy);
510        let pipeline = Pipeline::new(
511            vec![Box::new(FakeReqMw {
512                name: "tag",
513                action: FakeReqAction::AnnotateTag("touched"),
514            }) as _],
515            vec![Box::new(FakeRespMw {
516                name: "tag_resp",
517                annotate: "sealed",
518            }) as _],
519            FakeRouter {
520                route: Mutex::new(Some(stub_route())),
521                calls: Arc::new(Mutex::new(0)),
522            },
523            FakeTransport {
524                response: Mutex::new(Some(Response::McpBuffered {
525                    envelope: Envelope::Json,
526                    message: McpMessage {
527                        envelope: JsonRpcEnvelope::parse(
528                            br#"{"jsonrpc":"2.0","id":42,"result":{"tools":[]}}"#,
529                        )
530                        .unwrap(),
531                        kind: MessageKind::Server(crate::protocol::mcp::ServerKind::Result),
532                    },
533                    status: StatusCode::OK,
534                    headers: HeaderMap::new(),
535                })),
536                calls: Arc::new(Mutex::new(0)),
537            },
538        );
539
540        let resp = pipeline.run(stub_mcp_request(), &mut cx).await;
541        match resp {
542            Response::McpBuffered {
543                status, message, ..
544            } => {
545                assert_eq!(status, StatusCode::OK);
546                let result: serde_json::Value = message
547                    .envelope
548                    .result_as()
549                    .expect("result should deserialize");
550                assert_eq!(result, json!({"tools": []}));
551            }
552            other => panic!("expected McpBuffered, got {other:?}"),
553        }
554        assert_eq!(cx.working.tags.as_slice(), &["touched", "sealed"]);
555    }
556}