1use std::sync::Arc;
8use std::time::Instant;
9
10use async_trait::async_trait;
11use axum::body::{Body, Bytes};
12use axum::http::{HeaderMap, Method, StatusCode, header};
13
14use super::ProxyState;
15use super::forwarding::{forward_request, read_body_capped};
16use super::pipeline::driver::{StageGuard, Transport};
17use super::pipeline::values::{
18 BufferPolicy, Context, Envelope, McpRequest, RawRequest, Request, Response, Route, Working,
19};
20use super::sse::{extract_json_from_sse, split_upstream};
21use crate::protocol::jsonrpc::JsonRpcEnvelope;
22use crate::protocol::mcp::{McpMessage, MessageKind, classify_server};
23
24pub struct ProxyTransport;
25
26#[async_trait]
27impl Transport for ProxyTransport {
28 async fn dispatch(&self, req: Request, route: Route, cx: &mut Context) -> Response {
29 let state = cx.intake.proxy.clone();
30 match (req, route) {
31 (
32 Request::Mcp(mcp),
33 Route::McpStreamableHttp {
34 upstream,
35 buffer_policy,
36 ..
37 },
38 ) => dispatch_mcp_post(state, mcp, upstream, buffer_policy, &mut cx.working).await,
39 (Request::Mcp(mcp), Route::McpSseLegacy { upstream }) => {
40 dispatch_sse_legacy(state, mcp, upstream, &mut cx.working).await
41 }
42 (Request::Raw(raw), Route::Raw { upstream }) => {
43 dispatch_raw(state, raw, upstream, &mut cx.working).await
44 }
45 (Request::OAuth(_), Route::Oauth { .. }) => Response::Upstream502 {
47 reason: "oauth dispatch not implemented".into(),
48 },
49 _ => Response::Upstream502 {
50 reason: "intake/router variant mismatch".into(),
51 },
52 }
53 }
54}
55
56async fn dispatch_mcp_post(
57 state: Arc<ProxyState>,
58 mcp: McpRequest,
59 upstream: String,
60 buffer_policy: BufferPolicy,
61 working: &mut Working,
62) -> Response {
63 let body_bytes = Bytes::from(mcp.envelope.to_bytes());
64 let is_streaming = matches!(buffer_policy, BufferPolicy::Streamed);
65 let upstream_started = Instant::now();
66 let resp = match forward_request(
67 &state.upstream,
68 &upstream,
69 Method::POST,
70 &mcp.headers,
71 &body_bytes,
72 is_streaming,
73 )
74 .await
75 {
76 Ok(r) => r,
77 Err(e) => {
78 working.upstream_us = Some(upstream_started.elapsed().as_micros() as u64);
79 return Response::Upstream502 {
80 reason: format!("{e}"),
81 };
82 }
83 };
84 working.upstream_us = Some(upstream_started.elapsed().as_micros() as u64);
85 let status = resp.status();
86 let headers = resp.headers().clone();
87
88 match buffer_policy {
89 BufferPolicy::Buffered { max } => {
90 buffer_and_parse(resp, max, status, headers, working).await
91 }
92 BufferPolicy::Streamed => Response::McpStreamed {
93 envelope: Envelope::Json,
94 body: Body::from_stream(resp.bytes_stream()),
95 status,
96 headers,
97 },
98 }
99}
100
101async fn buffer_and_parse(
102 resp: reqwest::Response,
103 max: usize,
104 status: StatusCode,
105 headers: HeaderMap,
106 working: &mut Working,
107) -> Response {
108 let raw = {
112 let _g = StageGuard::start("transport_buffer", &mut working.timings);
113 match read_body_capped(resp, max).await {
114 Ok(b) => b,
115 Err(e) => {
116 return Response::Upstream502 {
117 reason: e.to_string(),
118 };
119 }
120 }
121 };
122
123 let was_sse = headers
124 .get(header::CONTENT_TYPE)
125 .and_then(|v| v.to_str().ok())
126 .map(|ct| ct.contains("text/event-stream"))
127 .unwrap_or(false);
128 let json_bytes: Vec<u8> = if was_sse {
129 let _g = StageGuard::start("transport_sse_unwrap", &mut working.timings);
130 extract_json_from_sse(&raw).unwrap_or_else(|| raw.to_vec())
131 } else {
132 raw.to_vec()
133 };
134
135 let envelope = {
136 let _g = StageGuard::start("transport_json_parse", &mut working.timings);
137 match JsonRpcEnvelope::parse(&json_bytes) {
138 Ok(e) => e,
139 Err(_) => {
140 return Response::Raw {
141 body: Body::from(raw),
142 status,
143 headers,
144 };
145 }
146 }
147 };
148
149 let kind = MessageKind::Server(classify_server(&envelope));
150 let message = McpMessage { envelope, kind };
151 Response::McpBuffered {
152 envelope: if was_sse {
153 Envelope::Sse
154 } else {
155 Envelope::Json
156 },
157 message,
158 status,
159 headers,
160 }
161}
162
163async fn dispatch_sse_legacy(
164 state: Arc<ProxyState>,
165 mcp: McpRequest,
166 upstream: String,
167 working: &mut Working,
168) -> Response {
169 let empty = Bytes::new();
170 let upstream_started = Instant::now();
171 let resp = match forward_request(
172 &state.upstream,
173 &upstream,
174 Method::GET,
175 &mcp.headers,
176 &empty,
177 true,
178 )
179 .await
180 {
181 Ok(r) => r,
182 Err(e) => {
183 working.upstream_us = Some(upstream_started.elapsed().as_micros() as u64);
184 return Response::Upstream502 {
185 reason: format!("{e}"),
186 };
187 }
188 };
189 working.upstream_us = Some(upstream_started.elapsed().as_micros() as u64);
190 let status = resp.status();
191 let headers = resp.headers().clone();
192 Response::McpStreamed {
193 envelope: Envelope::Sse,
194 body: Body::from_stream(resp.bytes_stream()),
195 status,
196 headers,
197 }
198}
199
200async fn dispatch_raw(
201 state: Arc<ProxyState>,
202 raw: RawRequest,
203 upstream: String,
204 working: &mut Working,
205) -> Response {
206 let (base, _) = split_upstream(&upstream);
207 let url = format!("{}{}", base.trim_end_matches('/'), raw.path);
208 let body_bytes = axum::body::to_bytes(raw.body, usize::MAX)
212 .await
213 .unwrap_or_default();
214 let upstream_started = Instant::now();
215 let resp = match forward_request(
216 &state.upstream,
217 &url,
218 raw.method,
219 &raw.headers,
220 &body_bytes,
221 false,
222 )
223 .await
224 {
225 Ok(r) => r,
226 Err(e) => {
227 working.upstream_us = Some(upstream_started.elapsed().as_micros() as u64);
228 return Response::Upstream502 {
229 reason: format!("{e}"),
230 };
231 }
232 };
233 working.upstream_us = Some(upstream_started.elapsed().as_micros() as u64);
234 let status = resp.status();
235 let headers = resp.headers().clone();
236 let body_bytes = {
237 let _g = StageGuard::start("transport_buffer", &mut working.timings);
238 match read_body_capped(resp, state.max_response_body).await {
239 Ok(b) => b,
240 Err(e) => {
241 return Response::Upstream502 {
242 reason: e.to_string(),
243 };
244 }
245 }
246 };
247 Response::Raw {
248 body: Body::from(body_bytes),
249 status,
250 headers,
251 }
252}
253
254#[cfg(test)]
255#[allow(non_snake_case)]
256mod tests {
257 use super::*;
258
259 use std::sync::{Arc as StdArc, Mutex};
260 use std::time::Duration;
261
262 use axum::Router as AxumRouter;
263 use axum::extract::State;
264 use axum::http::{HeaderMap, HeaderValue, Request as AxumRequest, StatusCode};
265 use axum::response::IntoResponse;
266 use axum::routing::{any, post};
267 use serde_json::Value;
268 use tokio::net::TcpListener;
269
270 use crate::protocol::mcp::{ClientMethod, ServerKind, ToolsMethod};
271 use crate::proxy::pipeline::middlewares::test_support::{
272 test_context, test_proxy_state_upstream,
273 };
274 use crate::proxy::pipeline::values::{McpTransport, RawRequest, Request};
275
276 async fn spawn_upstream(app: AxumRouter) -> String {
277 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
278 let addr = listener.local_addr().unwrap();
279 tokio::spawn(async move { axum::serve(listener, app).await.unwrap() });
280 format!("http://{addr}")
281 }
282
283 fn mcp_request(method: &str, body: &str, session: Option<&str>) -> McpRequest {
284 let envelope = JsonRpcEnvelope::parse(body.as_bytes()).unwrap();
285 let mut headers = HeaderMap::new();
286 if let Some(sid) = session {
287 headers.insert("mcp-session-id", HeaderValue::from_str(sid).unwrap());
288 }
289 McpRequest {
290 transport: McpTransport::StreamableHttpPost,
291 envelope,
292 kind: crate::protocol::mcp::ClientKind::Request(ClientMethod::parse(method)),
293 headers,
294 session_hint: None,
295 }
296 }
297
298 #[tokio::test]
299 async fn dispatch__mcp_post_tools_list_buffered_returns_mcp_buffered_result() {
300 let app = AxumRouter::new().route(
301 "/mcp",
302 post(|| async {
303 (
304 StatusCode::OK,
305 [(axum::http::header::CONTENT_TYPE, "application/json")],
306 r#"{"jsonrpc":"2.0","id":1,"result":{"tools":[{"name":"a"}]}}"#,
307 )
308 }),
309 );
310 let url = format!("{}/mcp", spawn_upstream(app).await);
311 let proxy = test_proxy_state_upstream(url.clone());
312 let mut cx = test_context(proxy);
313 let req = mcp_request(
314 "tools/list",
315 r#"{"jsonrpc":"2.0","id":1,"method":"tools/list"}"#,
316 None,
317 );
318 let route = Route::McpStreamableHttp {
319 upstream: url,
320 method: ClientMethod::Tools(ToolsMethod::List),
321 buffer_policy: BufferPolicy::Buffered { max: 1 << 20 },
322 };
323
324 let out = ProxyTransport
325 .dispatch(Request::Mcp(req), route, &mut cx)
326 .await;
327 match out {
328 Response::McpBuffered {
329 envelope, message, ..
330 } => {
331 assert_eq!(envelope, Envelope::Json);
332 assert!(matches!(
333 message.kind,
334 MessageKind::Server(ServerKind::Result)
335 ));
336 let v: Value = message.envelope.result_as().unwrap();
337 assert_eq!(v["tools"][0]["name"], "a");
338 }
339 other => panic!("expected McpBuffered, got {other:?}"),
340 }
341 }
342
343 #[tokio::test]
344 async fn dispatch__mcp_post_buffered_sse_wrapped_response_unwraps_envelope_sse() {
345 let app = AxumRouter::new().route(
346 "/mcp",
347 post(|| async {
348 let body = format!(
349 "data: {}\n\n",
350 r#"{"jsonrpc":"2.0","id":1,"result":{"tools":[]}}"#
351 );
352 (
353 StatusCode::OK,
354 [(axum::http::header::CONTENT_TYPE, "text/event-stream")],
355 body,
356 )
357 }),
358 );
359 let url = format!("{}/mcp", spawn_upstream(app).await);
360 let proxy = test_proxy_state_upstream(url.clone());
361 let mut cx = test_context(proxy);
362 let req = mcp_request(
363 "tools/list",
364 r#"{"jsonrpc":"2.0","id":1,"method":"tools/list"}"#,
365 None,
366 );
367 let route = Route::McpStreamableHttp {
368 upstream: url,
369 method: ClientMethod::Tools(ToolsMethod::List),
370 buffer_policy: BufferPolicy::Buffered { max: 1 << 20 },
371 };
372
373 let out = ProxyTransport
374 .dispatch(Request::Mcp(req), route, &mut cx)
375 .await;
376 match out {
377 Response::McpBuffered { envelope, .. } => assert_eq!(envelope, Envelope::Sse),
378 other => panic!("expected McpBuffered, got {other:?}"),
379 }
380 }
381
382 #[tokio::test]
383 async fn dispatch__mcp_post_buffered_oversize_body_returns_502() {
384 let app = AxumRouter::new().route(
385 "/mcp",
386 post(|| async {
387 let body = vec![b'x'; 8 * 1024];
390 (
391 StatusCode::OK,
392 [(axum::http::header::CONTENT_TYPE, "application/json")],
393 body,
394 )
395 }),
396 );
397 let url = format!("{}/mcp", spawn_upstream(app).await);
398 let proxy = test_proxy_state_upstream(url.clone());
399 let mut cx = test_context(proxy);
400 let req = mcp_request(
401 "tools/list",
402 r#"{"jsonrpc":"2.0","id":1,"method":"tools/list"}"#,
403 None,
404 );
405 let route = Route::McpStreamableHttp {
406 upstream: url,
407 method: ClientMethod::Tools(ToolsMethod::List),
408 buffer_policy: BufferPolicy::Buffered { max: 1024 },
409 };
410
411 let out = ProxyTransport
412 .dispatch(Request::Mcp(req), route, &mut cx)
413 .await;
414 assert!(matches!(out, Response::Upstream502 { .. }));
415 }
416
417 #[tokio::test]
418 async fn dispatch__mcp_post_buffered_non_jsonrpc_falls_back_to_raw() {
419 let app = AxumRouter::new().route(
420 "/mcp",
421 post(|| async {
422 (
423 StatusCode::OK,
424 [(axum::http::header::CONTENT_TYPE, "text/html")],
425 "<!DOCTYPE html><html></html>",
426 )
427 }),
428 );
429 let url = format!("{}/mcp", spawn_upstream(app).await);
430 let proxy = test_proxy_state_upstream(url.clone());
431 let mut cx = test_context(proxy);
432 let req = mcp_request(
433 "tools/list",
434 r#"{"jsonrpc":"2.0","id":1,"method":"tools/list"}"#,
435 None,
436 );
437 let route = Route::McpStreamableHttp {
438 upstream: url,
439 method: ClientMethod::Tools(ToolsMethod::List),
440 buffer_policy: BufferPolicy::Buffered { max: 1 << 20 },
441 };
442
443 let out = ProxyTransport
444 .dispatch(Request::Mcp(req), route, &mut cx)
445 .await;
446 assert!(matches!(out, Response::Raw { .. }));
447 }
448
449 #[tokio::test]
450 async fn dispatch__mcp_post_streamed_forwards_body_unchanged() {
451 let app = AxumRouter::new().route(
452 "/mcp",
453 post(|| async {
454 (
455 StatusCode::OK,
456 [(axum::http::header::CONTENT_TYPE, "application/json")],
457 r#"{"jsonrpc":"2.0","id":1,"result":{"pong":true}}"#,
458 )
459 }),
460 );
461 let url = format!("{}/mcp", spawn_upstream(app).await);
462 let proxy = test_proxy_state_upstream(url.clone());
463 let mut cx = test_context(proxy);
464 let req = mcp_request("ping", r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#, None);
465 let route = Route::McpStreamableHttp {
466 upstream: url,
467 method: ClientMethod::Ping,
468 buffer_policy: BufferPolicy::Streamed,
469 };
470
471 let out = ProxyTransport
472 .dispatch(Request::Mcp(req), route, &mut cx)
473 .await;
474 match out {
475 Response::McpStreamed { body, status, .. } => {
476 assert_eq!(status, StatusCode::OK);
477 let bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap();
478 let s = std::str::from_utf8(&bytes).unwrap();
479 assert!(s.contains("\"pong\":true"), "got {s}");
480 }
481 other => panic!("expected McpStreamed, got {other:?}"),
482 }
483 }
484
485 #[tokio::test]
486 async fn dispatch__mcp_sse_legacy_returns_streamed_sse() {
487 let app = AxumRouter::new().route(
488 "/mcp",
489 any(|| async {
490 (
491 StatusCode::OK,
492 [(axum::http::header::CONTENT_TYPE, "text/event-stream")],
493 "data: hi\n\n",
494 )
495 }),
496 );
497 let url = format!("{}/mcp", spawn_upstream(app).await);
498 let proxy = test_proxy_state_upstream(url.clone());
499 let mut cx = test_context(proxy);
500 let req = McpRequest {
501 transport: McpTransport::SseLegacyGet,
502 envelope: JsonRpcEnvelope::parse(br#"{"jsonrpc":"2.0","method":"ping"}"#).unwrap(),
503 kind: crate::protocol::mcp::ClientKind::Notification(
504 crate::protocol::mcp::ClientNotifMethod::Unknown("ping".into()),
505 ),
506 headers: HeaderMap::new(),
507 session_hint: None,
508 };
509 let route = Route::McpSseLegacy { upstream: url };
510
511 let out = ProxyTransport
512 .dispatch(Request::Mcp(req), route, &mut cx)
513 .await;
514 assert!(matches!(
515 out,
516 Response::McpStreamed {
517 envelope: Envelope::Sse,
518 ..
519 }
520 ));
521 }
522
523 #[tokio::test]
524 async fn dispatch__raw_appends_path_to_upstream_base() {
525 #[derive(Clone)]
526 struct Shared(StdArc<Mutex<Option<String>>>);
527 let recorded = Shared(StdArc::new(Mutex::new(None)));
528 let app = AxumRouter::new()
529 .route(
530 "/token",
531 any(
532 |State(Shared(slot)): State<Shared>, req: AxumRequest<axum::body::Body>| async move {
533 *slot.lock().unwrap() = Some(req.uri().path().to_string());
534 (StatusCode::OK, "ok").into_response()
535 },
536 ),
537 )
538 .with_state(recorded.clone());
539 let base = spawn_upstream(app).await;
540 let proxy = test_proxy_state_upstream(format!("{base}/mcp"));
541 let mut cx = test_context(proxy);
542 let req = RawRequest {
543 method: Method::POST,
544 path: "/token".into(),
545 body: Body::from("grant_type=x"),
546 headers: HeaderMap::new(),
547 };
548 let route = Route::Raw {
549 upstream: format!("{base}/mcp"),
550 };
551
552 let out = ProxyTransport
553 .dispatch(Request::Raw(req), route, &mut cx)
554 .await;
555 assert!(matches!(out, Response::Raw { .. }));
556 assert_eq!(
557 recorded.0.lock().unwrap().as_deref(),
558 Some("/token"),
559 "upstream should have seen /token",
560 );
561 }
562
563 #[tokio::test]
564 async fn dispatch__upstream_unreachable_is_502() {
565 let url = "http://127.0.0.1:1".to_string();
567 let proxy = test_proxy_state_upstream(url.clone());
568 let mut cx = test_context(proxy);
569 let req = mcp_request(
570 "tools/list",
571 r#"{"jsonrpc":"2.0","id":1,"method":"tools/list"}"#,
572 None,
573 );
574 let route = Route::McpStreamableHttp {
575 upstream: url,
576 method: ClientMethod::Tools(ToolsMethod::List),
577 buffer_policy: BufferPolicy::Buffered { max: 1 << 20 },
578 };
579
580 let out = ProxyTransport
581 .dispatch(Request::Mcp(req), route, &mut cx)
582 .await;
583 assert!(matches!(out, Response::Upstream502 { .. }));
584 }
585
586 #[tokio::test]
587 async fn dispatch__variant_mismatch_is_502() {
588 let proxy = test_proxy_state_upstream("http://unused.test".to_string());
589 let mut cx = test_context(proxy);
590 let req = mcp_request(
591 "tools/list",
592 r#"{"jsonrpc":"2.0","id":1,"method":"tools/list"}"#,
593 None,
594 );
595 let route = Route::Raw {
596 upstream: "http://unused.test".into(),
597 };
598
599 let out = ProxyTransport
600 .dispatch(Request::Mcp(req), route, &mut cx)
601 .await;
602 assert!(matches!(out, Response::Upstream502 { reason } if reason.contains("mismatch")));
603 }
604
605 #[tokio::test]
606 async fn dispatch__session_header_is_forwarded() {
607 #[derive(Clone)]
608 struct Shared(StdArc<Mutex<Option<String>>>);
609 let recorded = Shared(StdArc::new(Mutex::new(None)));
610 let app = AxumRouter::new()
611 .route(
612 "/mcp",
613 post(
614 |State(Shared(slot)): State<Shared>, headers: HeaderMap| async move {
615 let sid = headers
616 .get("mcp-session-id")
617 .and_then(|v| v.to_str().ok())
618 .map(|s| s.to_string());
619 *slot.lock().unwrap() = sid;
620 (
621 StatusCode::OK,
622 [(axum::http::header::CONTENT_TYPE, "application/json")],
623 r#"{"jsonrpc":"2.0","id":1,"result":{}}"#,
624 )
625 .into_response()
626 },
627 ),
628 )
629 .with_state(recorded.clone());
630 let url = format!("{}/mcp", spawn_upstream(app).await);
631 let proxy = test_proxy_state_upstream(url.clone());
632 let mut cx = test_context(proxy);
633 let req = mcp_request(
634 "tools/list",
635 r#"{"jsonrpc":"2.0","id":1,"method":"tools/list"}"#,
636 Some("abc-123"),
637 );
638 let route = Route::McpStreamableHttp {
639 upstream: url,
640 method: ClientMethod::Tools(ToolsMethod::List),
641 buffer_policy: BufferPolicy::Buffered { max: 1 << 20 },
642 };
643
644 let _ = ProxyTransport
645 .dispatch(Request::Mcp(req), route, &mut cx)
646 .await;
647 tokio::time::sleep(Duration::from_millis(20)).await;
649 assert_eq!(
650 recorded.0.lock().unwrap().as_deref(),
651 Some("abc-123"),
652 "upstream should have seen the mcp-session-id header",
653 );
654 }
655}