Skip to main content

mcpr_core/proxy/
transport.rs

1//! The single I/O layer for the pipeline.
2//!
3//! Wraps `forward_request` from `proxy/forwarding.rs` and maps reqwest
4//! failures to [`Response::Upstream502`]. Buffer vs stream decision
5//! comes from the `Route` — no content-type sniffing at dispatch time.
6
7use std::sync::Arc;
8use std::time::Instant;
9
10use async_trait::async_trait;
11use axum::body::{Body, Bytes};
12use axum::http::{HeaderMap, Method, StatusCode, header};
13
14use super::ProxyState;
15use super::forwarding::{forward_request, read_body_capped};
16use super::pipeline::driver::{StageGuard, Transport};
17use super::pipeline::values::{
18    BufferPolicy, Context, Envelope, McpRequest, RawRequest, Request, Response, Route, Working,
19};
20use super::sse::{extract_json_from_sse, split_upstream};
21use crate::protocol::jsonrpc::JsonRpcEnvelope;
22use crate::protocol::mcp::{McpMessage, MessageKind, classify_server};
23
24pub struct ProxyTransport;
25
26#[async_trait]
27impl Transport for ProxyTransport {
28    async fn dispatch(&self, req: Request, route: Route, cx: &mut Context) -> Response {
29        let state = cx.intake.proxy.clone();
30        match (req, route) {
31            (
32                Request::Mcp(mcp),
33                Route::McpStreamableHttp {
34                    upstream,
35                    buffer_policy,
36                    ..
37                },
38            ) => dispatch_mcp_post(state, mcp, upstream, buffer_policy, &mut cx.working).await,
39            (Request::Mcp(mcp), Route::McpSseLegacy { upstream }) => {
40                dispatch_sse_legacy(state, mcp, upstream, &mut cx.working).await
41            }
42            (Request::Raw(raw), Route::Raw { upstream }) => {
43                dispatch_raw(state, raw, upstream, &mut cx.working).await
44            }
45            // Intake never produces `Request::OAuth` today; arm is defensive.
46            (Request::OAuth(_), Route::Oauth { .. }) => Response::Upstream502 {
47                reason: "oauth dispatch not implemented".into(),
48            },
49            _ => Response::Upstream502 {
50                reason: "intake/router variant mismatch".into(),
51            },
52        }
53    }
54}
55
56async fn dispatch_mcp_post(
57    state: Arc<ProxyState>,
58    mcp: McpRequest,
59    upstream: String,
60    buffer_policy: BufferPolicy,
61    working: &mut Working,
62) -> Response {
63    let body_bytes = Bytes::from(mcp.envelope.to_bytes());
64    let is_streaming = matches!(buffer_policy, BufferPolicy::Streamed);
65    let upstream_started = Instant::now();
66    let resp = match forward_request(
67        &state.upstream,
68        &upstream,
69        Method::POST,
70        &mcp.headers,
71        &body_bytes,
72        is_streaming,
73    )
74    .await
75    {
76        Ok(r) => r,
77        Err(e) => {
78            working.upstream_us = Some(upstream_started.elapsed().as_micros() as u64);
79            return Response::Upstream502 {
80                reason: format!("{e}"),
81            };
82        }
83    };
84    working.upstream_us = Some(upstream_started.elapsed().as_micros() as u64);
85    let status = resp.status();
86    let headers = resp.headers().clone();
87
88    match buffer_policy {
89        BufferPolicy::Buffered { max } => {
90            buffer_and_parse(resp, max, status, headers, working).await
91        }
92        BufferPolicy::Streamed => Response::McpStreamed {
93            envelope: Envelope::Json,
94            body: Body::from_stream(resp.bytes_stream()),
95            status,
96            headers,
97        },
98    }
99}
100
101async fn buffer_and_parse(
102    resp: reqwest::Response,
103    max: usize,
104    status: StatusCode,
105    headers: HeaderMap,
106    working: &mut Working,
107) -> Response {
108    // `StageGuard` pushes a named timing on drop. Each wrapping block
109    // scopes one phase; `?` / early returns still fire the guard's
110    // Drop, so failure paths are measured too.
111    let raw = {
112        let _g = StageGuard::start("transport_buffer", &mut working.timings);
113        match read_body_capped(resp, max).await {
114            Ok(b) => b,
115            Err(e) => {
116                return Response::Upstream502 {
117                    reason: e.to_string(),
118                };
119            }
120        }
121    };
122
123    let was_sse = headers
124        .get(header::CONTENT_TYPE)
125        .and_then(|v| v.to_str().ok())
126        .map(|ct| ct.contains("text/event-stream"))
127        .unwrap_or(false);
128    let json_bytes: Vec<u8> = if was_sse {
129        let _g = StageGuard::start("transport_sse_unwrap", &mut working.timings);
130        extract_json_from_sse(&raw).unwrap_or_else(|| raw.to_vec())
131    } else {
132        raw.to_vec()
133    };
134
135    let envelope = {
136        let _g = StageGuard::start("transport_json_parse", &mut working.timings);
137        match JsonRpcEnvelope::parse(&json_bytes) {
138            Ok(e) => e,
139            Err(_) => {
140                return Response::Raw {
141                    body: Body::from(raw),
142                    status,
143                    headers,
144                };
145            }
146        }
147    };
148
149    let kind = MessageKind::Server(classify_server(&envelope));
150    let message = McpMessage { envelope, kind };
151    Response::McpBuffered {
152        envelope: if was_sse {
153            Envelope::Sse
154        } else {
155            Envelope::Json
156        },
157        message,
158        status,
159        headers,
160    }
161}
162
163async fn dispatch_sse_legacy(
164    state: Arc<ProxyState>,
165    mcp: McpRequest,
166    upstream: String,
167    working: &mut Working,
168) -> Response {
169    let empty = Bytes::new();
170    let upstream_started = Instant::now();
171    let resp = match forward_request(
172        &state.upstream,
173        &upstream,
174        Method::GET,
175        &mcp.headers,
176        &empty,
177        true,
178    )
179    .await
180    {
181        Ok(r) => r,
182        Err(e) => {
183            working.upstream_us = Some(upstream_started.elapsed().as_micros() as u64);
184            return Response::Upstream502 {
185                reason: format!("{e}"),
186            };
187        }
188    };
189    working.upstream_us = Some(upstream_started.elapsed().as_micros() as u64);
190    let status = resp.status();
191    let headers = resp.headers().clone();
192    Response::McpStreamed {
193        envelope: Envelope::Sse,
194        body: Body::from_stream(resp.bytes_stream()),
195        status,
196        headers,
197    }
198}
199
200async fn dispatch_raw(
201    state: Arc<ProxyState>,
202    raw: RawRequest,
203    upstream: String,
204    working: &mut Working,
205) -> Response {
206    let (base, _) = split_upstream(&upstream);
207    let url = format!("{}{}", base.trim_end_matches('/'), raw.path);
208    // Passthrough does not cap the request body — `DefaultBodyLimit`
209    // at the axum edge already rejected oversize requests, so everything
210    // reaching here is within the configured limit.
211    let body_bytes = axum::body::to_bytes(raw.body, usize::MAX)
212        .await
213        .unwrap_or_default();
214    let upstream_started = Instant::now();
215    let resp = match forward_request(
216        &state.upstream,
217        &url,
218        raw.method,
219        &raw.headers,
220        &body_bytes,
221        false,
222    )
223    .await
224    {
225        Ok(r) => r,
226        Err(e) => {
227            working.upstream_us = Some(upstream_started.elapsed().as_micros() as u64);
228            return Response::Upstream502 {
229                reason: format!("{e}"),
230            };
231        }
232    };
233    working.upstream_us = Some(upstream_started.elapsed().as_micros() as u64);
234    let status = resp.status();
235    let headers = resp.headers().clone();
236    let body_bytes = {
237        let _g = StageGuard::start("transport_buffer", &mut working.timings);
238        match read_body_capped(resp, state.max_response_body).await {
239            Ok(b) => b,
240            Err(e) => {
241                return Response::Upstream502 {
242                    reason: e.to_string(),
243                };
244            }
245        }
246    };
247    Response::Raw {
248        body: Body::from(body_bytes),
249        status,
250        headers,
251    }
252}
253
254#[cfg(test)]
255#[allow(non_snake_case)]
256mod tests {
257    use super::*;
258
259    use std::sync::{Arc as StdArc, Mutex};
260    use std::time::Duration;
261
262    use axum::Router as AxumRouter;
263    use axum::extract::State;
264    use axum::http::{HeaderMap, HeaderValue, Request as AxumRequest, StatusCode};
265    use axum::response::IntoResponse;
266    use axum::routing::{any, post};
267    use serde_json::Value;
268    use tokio::net::TcpListener;
269
270    use crate::protocol::mcp::{ClientMethod, ServerKind, ToolsMethod};
271    use crate::proxy::pipeline::middlewares::test_support::{
272        test_context, test_proxy_state_upstream,
273    };
274    use crate::proxy::pipeline::values::{McpTransport, RawRequest, Request};
275
276    async fn spawn_upstream(app: AxumRouter) -> String {
277        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
278        let addr = listener.local_addr().unwrap();
279        tokio::spawn(async move { axum::serve(listener, app).await.unwrap() });
280        format!("http://{addr}")
281    }
282
283    fn mcp_request(method: &str, body: &str, session: Option<&str>) -> McpRequest {
284        let envelope = JsonRpcEnvelope::parse(body.as_bytes()).unwrap();
285        let mut headers = HeaderMap::new();
286        if let Some(sid) = session {
287            headers.insert("mcp-session-id", HeaderValue::from_str(sid).unwrap());
288        }
289        McpRequest {
290            transport: McpTransport::StreamableHttpPost,
291            envelope,
292            kind: crate::protocol::mcp::ClientKind::Request(ClientMethod::parse(method)),
293            headers,
294            session_hint: None,
295        }
296    }
297
298    #[tokio::test]
299    async fn dispatch__mcp_post_tools_list_buffered_returns_mcp_buffered_result() {
300        let app = AxumRouter::new().route(
301            "/mcp",
302            post(|| async {
303                (
304                    StatusCode::OK,
305                    [(axum::http::header::CONTENT_TYPE, "application/json")],
306                    r#"{"jsonrpc":"2.0","id":1,"result":{"tools":[{"name":"a"}]}}"#,
307                )
308            }),
309        );
310        let url = format!("{}/mcp", spawn_upstream(app).await);
311        let proxy = test_proxy_state_upstream(url.clone());
312        let mut cx = test_context(proxy);
313        let req = mcp_request(
314            "tools/list",
315            r#"{"jsonrpc":"2.0","id":1,"method":"tools/list"}"#,
316            None,
317        );
318        let route = Route::McpStreamableHttp {
319            upstream: url,
320            method: ClientMethod::Tools(ToolsMethod::List),
321            buffer_policy: BufferPolicy::Buffered { max: 1 << 20 },
322        };
323
324        let out = ProxyTransport
325            .dispatch(Request::Mcp(req), route, &mut cx)
326            .await;
327        match out {
328            Response::McpBuffered {
329                envelope, message, ..
330            } => {
331                assert_eq!(envelope, Envelope::Json);
332                assert!(matches!(
333                    message.kind,
334                    MessageKind::Server(ServerKind::Result)
335                ));
336                let v: Value = message.envelope.result_as().unwrap();
337                assert_eq!(v["tools"][0]["name"], "a");
338            }
339            other => panic!("expected McpBuffered, got {other:?}"),
340        }
341    }
342
343    #[tokio::test]
344    async fn dispatch__mcp_post_buffered_sse_wrapped_response_unwraps_envelope_sse() {
345        let app = AxumRouter::new().route(
346            "/mcp",
347            post(|| async {
348                let body = format!(
349                    "data: {}\n\n",
350                    r#"{"jsonrpc":"2.0","id":1,"result":{"tools":[]}}"#
351                );
352                (
353                    StatusCode::OK,
354                    [(axum::http::header::CONTENT_TYPE, "text/event-stream")],
355                    body,
356                )
357            }),
358        );
359        let url = format!("{}/mcp", spawn_upstream(app).await);
360        let proxy = test_proxy_state_upstream(url.clone());
361        let mut cx = test_context(proxy);
362        let req = mcp_request(
363            "tools/list",
364            r#"{"jsonrpc":"2.0","id":1,"method":"tools/list"}"#,
365            None,
366        );
367        let route = Route::McpStreamableHttp {
368            upstream: url,
369            method: ClientMethod::Tools(ToolsMethod::List),
370            buffer_policy: BufferPolicy::Buffered { max: 1 << 20 },
371        };
372
373        let out = ProxyTransport
374            .dispatch(Request::Mcp(req), route, &mut cx)
375            .await;
376        match out {
377            Response::McpBuffered { envelope, .. } => assert_eq!(envelope, Envelope::Sse),
378            other => panic!("expected McpBuffered, got {other:?}"),
379        }
380    }
381
382    #[tokio::test]
383    async fn dispatch__mcp_post_buffered_oversize_body_returns_502() {
384        let app = AxumRouter::new().route(
385            "/mcp",
386            post(|| async {
387                // 8 KiB body — content-length header is auto-populated,
388                // so `read_body_capped` short-circuits on size mismatch.
389                let body = vec![b'x'; 8 * 1024];
390                (
391                    StatusCode::OK,
392                    [(axum::http::header::CONTENT_TYPE, "application/json")],
393                    body,
394                )
395            }),
396        );
397        let url = format!("{}/mcp", spawn_upstream(app).await);
398        let proxy = test_proxy_state_upstream(url.clone());
399        let mut cx = test_context(proxy);
400        let req = mcp_request(
401            "tools/list",
402            r#"{"jsonrpc":"2.0","id":1,"method":"tools/list"}"#,
403            None,
404        );
405        let route = Route::McpStreamableHttp {
406            upstream: url,
407            method: ClientMethod::Tools(ToolsMethod::List),
408            buffer_policy: BufferPolicy::Buffered { max: 1024 },
409        };
410
411        let out = ProxyTransport
412            .dispatch(Request::Mcp(req), route, &mut cx)
413            .await;
414        assert!(matches!(out, Response::Upstream502 { .. }));
415    }
416
417    #[tokio::test]
418    async fn dispatch__mcp_post_buffered_non_jsonrpc_falls_back_to_raw() {
419        let app = AxumRouter::new().route(
420            "/mcp",
421            post(|| async {
422                (
423                    StatusCode::OK,
424                    [(axum::http::header::CONTENT_TYPE, "text/html")],
425                    "<!DOCTYPE html><html></html>",
426                )
427            }),
428        );
429        let url = format!("{}/mcp", spawn_upstream(app).await);
430        let proxy = test_proxy_state_upstream(url.clone());
431        let mut cx = test_context(proxy);
432        let req = mcp_request(
433            "tools/list",
434            r#"{"jsonrpc":"2.0","id":1,"method":"tools/list"}"#,
435            None,
436        );
437        let route = Route::McpStreamableHttp {
438            upstream: url,
439            method: ClientMethod::Tools(ToolsMethod::List),
440            buffer_policy: BufferPolicy::Buffered { max: 1 << 20 },
441        };
442
443        let out = ProxyTransport
444            .dispatch(Request::Mcp(req), route, &mut cx)
445            .await;
446        assert!(matches!(out, Response::Raw { .. }));
447    }
448
449    #[tokio::test]
450    async fn dispatch__mcp_post_streamed_forwards_body_unchanged() {
451        let app = AxumRouter::new().route(
452            "/mcp",
453            post(|| async {
454                (
455                    StatusCode::OK,
456                    [(axum::http::header::CONTENT_TYPE, "application/json")],
457                    r#"{"jsonrpc":"2.0","id":1,"result":{"pong":true}}"#,
458                )
459            }),
460        );
461        let url = format!("{}/mcp", spawn_upstream(app).await);
462        let proxy = test_proxy_state_upstream(url.clone());
463        let mut cx = test_context(proxy);
464        let req = mcp_request("ping", r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#, None);
465        let route = Route::McpStreamableHttp {
466            upstream: url,
467            method: ClientMethod::Ping,
468            buffer_policy: BufferPolicy::Streamed,
469        };
470
471        let out = ProxyTransport
472            .dispatch(Request::Mcp(req), route, &mut cx)
473            .await;
474        match out {
475            Response::McpStreamed { body, status, .. } => {
476                assert_eq!(status, StatusCode::OK);
477                let bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap();
478                let s = std::str::from_utf8(&bytes).unwrap();
479                assert!(s.contains("\"pong\":true"), "got {s}");
480            }
481            other => panic!("expected McpStreamed, got {other:?}"),
482        }
483    }
484
485    #[tokio::test]
486    async fn dispatch__mcp_sse_legacy_returns_streamed_sse() {
487        let app = AxumRouter::new().route(
488            "/mcp",
489            any(|| async {
490                (
491                    StatusCode::OK,
492                    [(axum::http::header::CONTENT_TYPE, "text/event-stream")],
493                    "data: hi\n\n",
494                )
495            }),
496        );
497        let url = format!("{}/mcp", spawn_upstream(app).await);
498        let proxy = test_proxy_state_upstream(url.clone());
499        let mut cx = test_context(proxy);
500        let req = McpRequest {
501            transport: McpTransport::SseLegacyGet,
502            envelope: JsonRpcEnvelope::parse(br#"{"jsonrpc":"2.0","method":"ping"}"#).unwrap(),
503            kind: crate::protocol::mcp::ClientKind::Notification(
504                crate::protocol::mcp::ClientNotifMethod::Unknown("ping".into()),
505            ),
506            headers: HeaderMap::new(),
507            session_hint: None,
508        };
509        let route = Route::McpSseLegacy { upstream: url };
510
511        let out = ProxyTransport
512            .dispatch(Request::Mcp(req), route, &mut cx)
513            .await;
514        assert!(matches!(
515            out,
516            Response::McpStreamed {
517                envelope: Envelope::Sse,
518                ..
519            }
520        ));
521    }
522
523    #[tokio::test]
524    async fn dispatch__raw_appends_path_to_upstream_base() {
525        #[derive(Clone)]
526        struct Shared(StdArc<Mutex<Option<String>>>);
527        let recorded = Shared(StdArc::new(Mutex::new(None)));
528        let app = AxumRouter::new()
529            .route(
530                "/token",
531                any(
532                    |State(Shared(slot)): State<Shared>, req: AxumRequest<axum::body::Body>| async move {
533                        *slot.lock().unwrap() = Some(req.uri().path().to_string());
534                        (StatusCode::OK, "ok").into_response()
535                    },
536                ),
537            )
538            .with_state(recorded.clone());
539        let base = spawn_upstream(app).await;
540        let proxy = test_proxy_state_upstream(format!("{base}/mcp"));
541        let mut cx = test_context(proxy);
542        let req = RawRequest {
543            method: Method::POST,
544            path: "/token".into(),
545            body: Body::from("grant_type=x"),
546            headers: HeaderMap::new(),
547        };
548        let route = Route::Raw {
549            upstream: format!("{base}/mcp"),
550        };
551
552        let out = ProxyTransport
553            .dispatch(Request::Raw(req), route, &mut cx)
554            .await;
555        assert!(matches!(out, Response::Raw { .. }));
556        assert_eq!(
557            recorded.0.lock().unwrap().as_deref(),
558            Some("/token"),
559            "upstream should have seen /token",
560        );
561    }
562
563    #[tokio::test]
564    async fn dispatch__upstream_unreachable_is_502() {
565        // Random unused port — nothing listening.
566        let url = "http://127.0.0.1:1".to_string();
567        let proxy = test_proxy_state_upstream(url.clone());
568        let mut cx = test_context(proxy);
569        let req = mcp_request(
570            "tools/list",
571            r#"{"jsonrpc":"2.0","id":1,"method":"tools/list"}"#,
572            None,
573        );
574        let route = Route::McpStreamableHttp {
575            upstream: url,
576            method: ClientMethod::Tools(ToolsMethod::List),
577            buffer_policy: BufferPolicy::Buffered { max: 1 << 20 },
578        };
579
580        let out = ProxyTransport
581            .dispatch(Request::Mcp(req), route, &mut cx)
582            .await;
583        assert!(matches!(out, Response::Upstream502 { .. }));
584    }
585
586    #[tokio::test]
587    async fn dispatch__variant_mismatch_is_502() {
588        let proxy = test_proxy_state_upstream("http://unused.test".to_string());
589        let mut cx = test_context(proxy);
590        let req = mcp_request(
591            "tools/list",
592            r#"{"jsonrpc":"2.0","id":1,"method":"tools/list"}"#,
593            None,
594        );
595        let route = Route::Raw {
596            upstream: "http://unused.test".into(),
597        };
598
599        let out = ProxyTransport
600            .dispatch(Request::Mcp(req), route, &mut cx)
601            .await;
602        assert!(matches!(out, Response::Upstream502 { reason } if reason.contains("mismatch")));
603    }
604
605    #[tokio::test]
606    async fn dispatch__session_header_is_forwarded() {
607        #[derive(Clone)]
608        struct Shared(StdArc<Mutex<Option<String>>>);
609        let recorded = Shared(StdArc::new(Mutex::new(None)));
610        let app = AxumRouter::new()
611            .route(
612                "/mcp",
613                post(
614                    |State(Shared(slot)): State<Shared>, headers: HeaderMap| async move {
615                        let sid = headers
616                            .get("mcp-session-id")
617                            .and_then(|v| v.to_str().ok())
618                            .map(|s| s.to_string());
619                        *slot.lock().unwrap() = sid;
620                        (
621                            StatusCode::OK,
622                            [(axum::http::header::CONTENT_TYPE, "application/json")],
623                            r#"{"jsonrpc":"2.0","id":1,"result":{}}"#,
624                        )
625                            .into_response()
626                    },
627                ),
628            )
629            .with_state(recorded.clone());
630        let url = format!("{}/mcp", spawn_upstream(app).await);
631        let proxy = test_proxy_state_upstream(url.clone());
632        let mut cx = test_context(proxy);
633        let req = mcp_request(
634            "tools/list",
635            r#"{"jsonrpc":"2.0","id":1,"method":"tools/list"}"#,
636            Some("abc-123"),
637        );
638        let route = Route::McpStreamableHttp {
639            upstream: url,
640            method: ClientMethod::Tools(ToolsMethod::List),
641            buffer_policy: BufferPolicy::Buffered { max: 1 << 20 },
642        };
643
644        let _ = ProxyTransport
645            .dispatch(Request::Mcp(req), route, &mut cx)
646            .await;
647        // Give the upstream task a moment to observe (serve completes before dispatch returns).
648        tokio::time::sleep(Duration::from_millis(20)).await;
649        assert_eq!(
650            recorded.0.lock().unwrap().as_deref(),
651            Some("abc-123"),
652            "upstream should have seen the mcp-session-id header",
653        );
654    }
655}