1use super::pipeline::driver::Router;
8use super::pipeline::stubs::UrlMap;
9use super::pipeline::values::{BufferPolicy, Context, McpTransport, Request, Route};
10use crate::protocol::mcp::{
11 ClientKind, ClientMethod, LifecycleMethod, PromptsMethod, ResourcesMethod, ToolsMethod,
12};
13
14pub struct ProxyRouter;
15
16impl Router for ProxyRouter {
17 fn route(&self, req: &Request, cx: &Context) -> Route {
18 let upstream = cx.intake.proxy.mcp_upstream.clone();
19
20 match req {
21 Request::Mcp(mcp) => match mcp.transport {
22 McpTransport::StreamableHttpPost | McpTransport::StreamableHttpGet => {
23 let method = match &mcp.kind {
24 ClientKind::Request(m) => m.clone(),
25 _ => ClientMethod::Unknown("notification-or-reply".into()),
26 };
27 let buffer_policy =
28 buffer_policy_for(&method, cx.intake.proxy.max_response_body);
29 Route::McpStreamableHttp {
30 upstream,
31 method,
32 buffer_policy,
33 }
34 }
35 McpTransport::SseLegacyGet => Route::McpSseLegacy { upstream },
36 },
37 Request::OAuth(_) => Route::Oauth {
38 upstream,
39 rewrite: UrlMap,
40 },
41 Request::Raw(_) => Route::Raw { upstream },
42 }
43 }
44}
45
46fn buffer_policy_for(method: &ClientMethod, max: usize) -> BufferPolicy {
50 match method {
51 ClientMethod::Lifecycle(LifecycleMethod::Initialize)
52 | ClientMethod::Tools(ToolsMethod::List)
53 | ClientMethod::Tools(ToolsMethod::Call)
54 | ClientMethod::Resources(ResourcesMethod::List)
55 | ClientMethod::Resources(ResourcesMethod::TemplatesList)
56 | ClientMethod::Resources(ResourcesMethod::Read)
57 | ClientMethod::Prompts(PromptsMethod::List) => BufferPolicy::Buffered { max },
58 _ => BufferPolicy::Streamed,
59 }
60}
61
62#[cfg(test)]
63#[allow(non_snake_case)]
64mod tests {
65 use super::*;
66
67 use axum::body::Body;
68 use axum::http::{HeaderMap, Method};
69 use serde_json::Value;
70
71 use crate::protocol::mcp::{CompletionMethod, LoggingMethod, TasksMethod};
72 use crate::proxy::pipeline::middlewares::test_support::{
73 mcp_request, test_context, test_proxy_state,
74 };
75 use crate::proxy::pipeline::values::{RawRequest, Request};
76
77 const DEFAULT_MAX: usize = 1 << 20;
78
79 #[test]
80 fn buffer_policy_for__seven_buffered_methods_match_legacy_table() {
81 let buffered = [
82 ClientMethod::Lifecycle(LifecycleMethod::Initialize),
83 ClientMethod::Tools(ToolsMethod::List),
84 ClientMethod::Tools(ToolsMethod::Call),
85 ClientMethod::Resources(ResourcesMethod::List),
86 ClientMethod::Resources(ResourcesMethod::TemplatesList),
87 ClientMethod::Resources(ResourcesMethod::Read),
88 ClientMethod::Prompts(PromptsMethod::List),
89 ];
90 for m in buffered {
91 assert!(
92 matches!(buffer_policy_for(&m, DEFAULT_MAX), BufferPolicy::Buffered { max } if max == DEFAULT_MAX),
93 "method {m:?} should buffer"
94 );
95 }
96 }
97
98 #[test]
99 fn buffer_policy_for__streamed_methods() {
100 let streamed = [
101 ClientMethod::Ping,
102 ClientMethod::Prompts(PromptsMethod::Get),
103 ClientMethod::Resources(ResourcesMethod::Subscribe),
104 ClientMethod::Resources(ResourcesMethod::Unsubscribe),
105 ClientMethod::Completion(CompletionMethod::Complete),
106 ClientMethod::Logging(LoggingMethod::SetLevel),
107 ClientMethod::Tasks(TasksMethod::List),
108 ClientMethod::Tasks(TasksMethod::Get),
109 ClientMethod::Tasks(TasksMethod::Result),
110 ClientMethod::Tasks(TasksMethod::Cancel),
111 ];
112 for m in streamed {
113 assert_eq!(
114 buffer_policy_for(&m, DEFAULT_MAX),
115 BufferPolicy::Streamed,
116 "method {m:?} should stream"
117 );
118 }
119 }
120
121 #[test]
122 fn buffer_policy_for__unknown_is_streamed() {
123 assert_eq!(
124 buffer_policy_for(&ClientMethod::Unknown("x".into()), DEFAULT_MAX),
125 BufferPolicy::Streamed,
126 );
127 }
128
129 #[tokio::test]
130 async fn route__mcp_post_tools_list_is_buffered_streamable_http() {
131 let proxy = test_proxy_state();
132 let cx = test_context(proxy);
133 let req = mcp_request("tools/list", Value::Null, None);
134 match ProxyRouter.route(&req, &cx) {
135 Route::McpStreamableHttp {
136 method,
137 buffer_policy,
138 ..
139 } => {
140 assert!(matches!(method, ClientMethod::Tools(ToolsMethod::List)));
141 assert!(matches!(buffer_policy, BufferPolicy::Buffered { .. }));
142 }
143 other => panic!("expected streamable http, got {other:?}"),
144 }
145 }
146
147 #[tokio::test]
148 async fn route__mcp_post_ping_is_streamed_streamable_http() {
149 let proxy = test_proxy_state();
150 let cx = test_context(proxy);
151 let req = mcp_request("ping", Value::Null, None);
152 match ProxyRouter.route(&req, &cx) {
153 Route::McpStreamableHttp { buffer_policy, .. } => {
154 assert_eq!(buffer_policy, BufferPolicy::Streamed);
155 }
156 other => panic!("expected streamable http, got {other:?}"),
157 }
158 }
159
160 #[tokio::test]
161 async fn route__mcp_notification_is_streamed() {
162 let proxy = test_proxy_state();
163 let cx = test_context(proxy);
164 use crate::protocol::jsonrpc::JsonRpcEnvelope;
166 use crate::protocol::mcp::ClientNotifMethod;
167 use crate::proxy::pipeline::values::{McpRequest, McpTransport};
168 let envelope =
169 JsonRpcEnvelope::parse(br#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#)
170 .unwrap();
171 let req = Request::Mcp(McpRequest {
172 transport: McpTransport::StreamableHttpPost,
173 envelope,
174 kind: ClientKind::Notification(ClientNotifMethod::Initialized),
175 headers: HeaderMap::new(),
176 session_hint: None,
177 });
178 match ProxyRouter.route(&req, &cx) {
179 Route::McpStreamableHttp { buffer_policy, .. } => {
180 assert_eq!(buffer_policy, BufferPolicy::Streamed);
181 }
182 other => panic!("expected streamable http, got {other:?}"),
183 }
184 }
185
186 #[tokio::test]
187 async fn route__sse_legacy_intake_is_sse_legacy_route() {
188 use crate::protocol::jsonrpc::JsonRpcEnvelope;
189 use crate::protocol::mcp::ClientNotifMethod;
190 use crate::proxy::pipeline::values::{McpRequest, McpTransport};
191 let proxy = test_proxy_state();
192 let cx = test_context(proxy);
193 let envelope = JsonRpcEnvelope::parse(br#"{"jsonrpc":"2.0","method":"ping"}"#).unwrap();
194 let req = Request::Mcp(McpRequest {
195 transport: McpTransport::SseLegacyGet,
196 envelope,
197 kind: ClientKind::Notification(ClientNotifMethod::Unknown("ping".into())),
198 headers: HeaderMap::new(),
199 session_hint: None,
200 });
201 assert!(matches!(
202 ProxyRouter.route(&req, &cx),
203 Route::McpSseLegacy { .. }
204 ));
205 }
206
207 #[tokio::test]
208 async fn route__raw_is_raw_route() {
209 let proxy = test_proxy_state();
210 let cx = test_context(proxy);
211 let req = Request::Raw(RawRequest {
212 method: Method::GET,
213 path: "/health".into(),
214 body: Body::empty(),
215 headers: HeaderMap::new(),
216 });
217 assert!(matches!(ProxyRouter.route(&req, &cx), Route::Raw { .. }));
218 }
219
220 #[tokio::test]
221 async fn route__propagates_upstream_string_from_state() {
222 let proxy = test_proxy_state();
223 let cx = test_context(proxy);
224 let req = mcp_request("tools/list", Value::Null, None);
225 match ProxyRouter.route(&req, &cx) {
226 Route::McpStreamableHttp { upstream, .. } => {
227 assert_eq!(upstream, "http://upstream.test");
228 }
229 _ => panic!("expected streamable http"),
230 }
231 }
232}