Skip to main content

mcpr_core/proxy/
router.rs

1//! Pure `(Request, Config) -> Route` mapping. No I/O.
2//!
3//! Owns the `BufferPolicy` table — buffering is a routing decision,
4//! not an intrinsic method property, so it lives here rather than on
5//! the method enum.
6
7use super::pipeline::driver::Router;
8use super::pipeline::stubs::UrlMap;
9use super::pipeline::values::{BufferPolicy, Context, McpTransport, Request, Route};
10use crate::protocol::mcp::{
11    ClientKind, ClientMethod, LifecycleMethod, PromptsMethod, ResourcesMethod, ToolsMethod,
12};
13
14pub struct ProxyRouter;
15
16impl Router for ProxyRouter {
17    fn route(&self, req: &Request, cx: &Context) -> Route {
18        let upstream = cx.intake.proxy.mcp_upstream.clone();
19
20        match req {
21            Request::Mcp(mcp) => match mcp.transport {
22                McpTransport::StreamableHttpPost | McpTransport::StreamableHttpGet => {
23                    let method = match &mcp.kind {
24                        ClientKind::Request(m) => m.clone(),
25                        _ => ClientMethod::Unknown("notification-or-reply".into()),
26                    };
27                    let buffer_policy =
28                        buffer_policy_for(&method, cx.intake.proxy.max_response_body);
29                    Route::McpStreamableHttp {
30                        upstream,
31                        method,
32                        buffer_policy,
33                    }
34                }
35                McpTransport::SseLegacyGet => Route::McpSseLegacy { upstream },
36            },
37            Request::OAuth(_) => Route::Oauth {
38                upstream,
39                rewrite: UrlMap,
40            },
41            Request::Raw(_) => Route::Raw { upstream },
42        }
43    }
44}
45
46/// The buffer-policy table. These 7 methods get their responses parsed
47/// so response middlewares can mutate them (schema ingest, CSP rewrite).
48/// Every other method streams bytes through untouched.
49fn buffer_policy_for(method: &ClientMethod, max: usize) -> BufferPolicy {
50    match method {
51        ClientMethod::Lifecycle(LifecycleMethod::Initialize)
52        | ClientMethod::Tools(ToolsMethod::List)
53        | ClientMethod::Tools(ToolsMethod::Call)
54        | ClientMethod::Resources(ResourcesMethod::List)
55        | ClientMethod::Resources(ResourcesMethod::TemplatesList)
56        | ClientMethod::Resources(ResourcesMethod::Read)
57        | ClientMethod::Prompts(PromptsMethod::List) => BufferPolicy::Buffered { max },
58        _ => BufferPolicy::Streamed,
59    }
60}
61
62#[cfg(test)]
63#[allow(non_snake_case)]
64mod tests {
65    use super::*;
66
67    use axum::body::Body;
68    use axum::http::{HeaderMap, Method};
69    use serde_json::Value;
70
71    use crate::protocol::mcp::{CompletionMethod, LoggingMethod, TasksMethod};
72    use crate::proxy::pipeline::middlewares::test_support::{
73        mcp_request, test_context, test_proxy_state,
74    };
75    use crate::proxy::pipeline::values::{RawRequest, Request};
76
77    const DEFAULT_MAX: usize = 1 << 20;
78
79    #[test]
80    fn buffer_policy_for__seven_buffered_methods_match_legacy_table() {
81        let buffered = [
82            ClientMethod::Lifecycle(LifecycleMethod::Initialize),
83            ClientMethod::Tools(ToolsMethod::List),
84            ClientMethod::Tools(ToolsMethod::Call),
85            ClientMethod::Resources(ResourcesMethod::List),
86            ClientMethod::Resources(ResourcesMethod::TemplatesList),
87            ClientMethod::Resources(ResourcesMethod::Read),
88            ClientMethod::Prompts(PromptsMethod::List),
89        ];
90        for m in buffered {
91            assert!(
92                matches!(buffer_policy_for(&m, DEFAULT_MAX), BufferPolicy::Buffered { max } if max == DEFAULT_MAX),
93                "method {m:?} should buffer"
94            );
95        }
96    }
97
98    #[test]
99    fn buffer_policy_for__streamed_methods() {
100        let streamed = [
101            ClientMethod::Ping,
102            ClientMethod::Prompts(PromptsMethod::Get),
103            ClientMethod::Resources(ResourcesMethod::Subscribe),
104            ClientMethod::Resources(ResourcesMethod::Unsubscribe),
105            ClientMethod::Completion(CompletionMethod::Complete),
106            ClientMethod::Logging(LoggingMethod::SetLevel),
107            ClientMethod::Tasks(TasksMethod::List),
108            ClientMethod::Tasks(TasksMethod::Get),
109            ClientMethod::Tasks(TasksMethod::Result),
110            ClientMethod::Tasks(TasksMethod::Cancel),
111        ];
112        for m in streamed {
113            assert_eq!(
114                buffer_policy_for(&m, DEFAULT_MAX),
115                BufferPolicy::Streamed,
116                "method {m:?} should stream"
117            );
118        }
119    }
120
121    #[test]
122    fn buffer_policy_for__unknown_is_streamed() {
123        assert_eq!(
124            buffer_policy_for(&ClientMethod::Unknown("x".into()), DEFAULT_MAX),
125            BufferPolicy::Streamed,
126        );
127    }
128
129    #[tokio::test]
130    async fn route__mcp_post_tools_list_is_buffered_streamable_http() {
131        let proxy = test_proxy_state();
132        let cx = test_context(proxy);
133        let req = mcp_request("tools/list", Value::Null, None);
134        match ProxyRouter.route(&req, &cx) {
135            Route::McpStreamableHttp {
136                method,
137                buffer_policy,
138                ..
139            } => {
140                assert!(matches!(method, ClientMethod::Tools(ToolsMethod::List)));
141                assert!(matches!(buffer_policy, BufferPolicy::Buffered { .. }));
142            }
143            other => panic!("expected streamable http, got {other:?}"),
144        }
145    }
146
147    #[tokio::test]
148    async fn route__mcp_post_ping_is_streamed_streamable_http() {
149        let proxy = test_proxy_state();
150        let cx = test_context(proxy);
151        let req = mcp_request("ping", Value::Null, None);
152        match ProxyRouter.route(&req, &cx) {
153            Route::McpStreamableHttp { buffer_policy, .. } => {
154                assert_eq!(buffer_policy, BufferPolicy::Streamed);
155            }
156            other => panic!("expected streamable http, got {other:?}"),
157        }
158    }
159
160    #[tokio::test]
161    async fn route__mcp_notification_is_streamed() {
162        let proxy = test_proxy_state();
163        let cx = test_context(proxy);
164        // `mcp_request` hard-codes id=1, so build a notification inline.
165        use crate::protocol::jsonrpc::JsonRpcEnvelope;
166        use crate::protocol::mcp::ClientNotifMethod;
167        use crate::proxy::pipeline::values::{McpRequest, McpTransport};
168        let envelope =
169            JsonRpcEnvelope::parse(br#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#)
170                .unwrap();
171        let req = Request::Mcp(McpRequest {
172            transport: McpTransport::StreamableHttpPost,
173            envelope,
174            kind: ClientKind::Notification(ClientNotifMethod::Initialized),
175            headers: HeaderMap::new(),
176            session_hint: None,
177        });
178        match ProxyRouter.route(&req, &cx) {
179            Route::McpStreamableHttp { buffer_policy, .. } => {
180                assert_eq!(buffer_policy, BufferPolicy::Streamed);
181            }
182            other => panic!("expected streamable http, got {other:?}"),
183        }
184    }
185
186    #[tokio::test]
187    async fn route__sse_legacy_intake_is_sse_legacy_route() {
188        use crate::protocol::jsonrpc::JsonRpcEnvelope;
189        use crate::protocol::mcp::ClientNotifMethod;
190        use crate::proxy::pipeline::values::{McpRequest, McpTransport};
191        let proxy = test_proxy_state();
192        let cx = test_context(proxy);
193        let envelope = JsonRpcEnvelope::parse(br#"{"jsonrpc":"2.0","method":"ping"}"#).unwrap();
194        let req = Request::Mcp(McpRequest {
195            transport: McpTransport::SseLegacyGet,
196            envelope,
197            kind: ClientKind::Notification(ClientNotifMethod::Unknown("ping".into())),
198            headers: HeaderMap::new(),
199            session_hint: None,
200        });
201        assert!(matches!(
202            ProxyRouter.route(&req, &cx),
203            Route::McpSseLegacy { .. }
204        ));
205    }
206
207    #[tokio::test]
208    async fn route__raw_is_raw_route() {
209        let proxy = test_proxy_state();
210        let cx = test_context(proxy);
211        let req = Request::Raw(RawRequest {
212            method: Method::GET,
213            path: "/health".into(),
214            body: Body::empty(),
215            headers: HeaderMap::new(),
216        });
217        assert!(matches!(ProxyRouter.route(&req, &cx), Route::Raw { .. }));
218    }
219
220    #[tokio::test]
221    async fn route__propagates_upstream_string_from_state() {
222        let proxy = test_proxy_state();
223        let cx = test_context(proxy);
224        let req = mcp_request("tools/list", Value::Null, None);
225        match ProxyRouter.route(&req, &cx) {
226            Route::McpStreamableHttp { upstream, .. } => {
227                assert_eq!(upstream, "http://upstream.test");
228            }
229            _ => panic!("expected streamable http"),
230        }
231    }
232}