Skip to main content

mcpr_core/proxy/
intake.rs

1//! Content-based classification of axum request parts into the
2//! [`Request`] sum type.
3//!
4//! Rules:
5//!
6//! 1. `POST` + body parses as JSON-RPC → `Request::Mcp` with
7//!    `transport: StreamableHttpPost`.
8//! 2. `GET` + `Accept: text/event-stream` → `Request::Mcp` with
9//!    `transport: SseLegacyGet` (synthetic envelope; downstream matches
10//!    on transport variant, not envelope contents).
11//! 3. `DELETE` + `mcp-session-id` header → `Request::Mcp` with a
12//!    synthetic envelope so `SessionDeleteMiddleware` can pattern-match
13//!    on `Request::Mcp`.
14//! 4. Everything else → `Request::Raw`.
15//!
16//! OAuth classification is deferred. Non-MCP traffic becomes
17//! `Request::Raw`; `UrlMapMiddleware` rewrites JSON Raw bodies.
18
19use axum::body::{Body, Bytes};
20use axum::http::{HeaderMap, Method, Uri, header};
21
22use super::pipeline::stubs::SessionId;
23use super::pipeline::values::{McpRequest, McpTransport, RawRequest, Request};
24use crate::protocol::jsonrpc::JsonRpcEnvelope;
25use crate::protocol::mcp::{ClientKind, ClientNotifMethod, classify_client};
26
27pub fn from_axum_parts(method: Method, headers: HeaderMap, uri: Uri, body: Bytes) -> Request {
28    let path = uri.path().to_string();
29
30    // MCP POST — JSON-RPC body parse succeeds.
31    if method == Method::POST
32        && let Ok(envelope) = JsonRpcEnvelope::parse(&body)
33    {
34        let kind = classify_client(&envelope);
35        let session_hint = session_hint_from_headers(&headers);
36        return Request::Mcp(McpRequest {
37            transport: McpTransport::StreamableHttpPost,
38            envelope,
39            kind,
40            headers,
41            session_hint,
42        });
43    }
44
45    // Legacy SSE GET — `Accept: text/event-stream` opens a server-push
46    // stream. Envelope is synthetic; downstream matches on transport.
47    if method == Method::GET && wants_sse(&headers) {
48        let envelope = JsonRpcEnvelope::parse(br#"{"jsonrpc":"2.0","method":"ping"}"#)
49            .expect("static synthetic envelope parses");
50        let kind = ClientKind::Notification(ClientNotifMethod::Unknown("ping".into()));
51        let session_hint = session_hint_from_headers(&headers);
52        return Request::Mcp(McpRequest {
53            transport: McpTransport::SseLegacyGet,
54            envelope,
55            kind,
56            headers,
57            session_hint,
58        });
59    }
60
61    // Session DELETE — empty-body DELETE + `mcp-session-id`.
62    if method == Method::DELETE
63        && let Some(sid_value) = headers.get("mcp-session-id").cloned()
64    {
65        let envelope = JsonRpcEnvelope::parse(br#"{"jsonrpc":"2.0","method":"delete"}"#)
66            .expect("static synthetic envelope parses");
67        let kind = ClientKind::Notification(ClientNotifMethod::Unknown("delete".into()));
68        let session_hint = sid_value
69            .to_str()
70            .ok()
71            .map(|s| SessionId::new(s.to_string()));
72        return Request::Mcp(McpRequest {
73            transport: McpTransport::StreamableHttpPost,
74            envelope,
75            kind,
76            headers,
77            session_hint,
78        });
79    }
80
81    Request::Raw(RawRequest {
82        method,
83        path,
84        body: Body::from(body),
85        headers,
86    })
87}
88
89fn wants_sse(headers: &HeaderMap) -> bool {
90    headers
91        .get(header::ACCEPT)
92        .and_then(|v| v.to_str().ok())
93        .map(|a| a.contains("text/event-stream"))
94        .unwrap_or(false)
95}
96
97fn session_hint_from_headers(headers: &HeaderMap) -> Option<SessionId> {
98    headers
99        .get("mcp-session-id")
100        .and_then(|v| v.to_str().ok())
101        .map(|s| SessionId::new(s.to_string()))
102}
103
104#[cfg(test)]
105#[allow(non_snake_case)]
106mod tests {
107    use super::*;
108
109    use crate::protocol::mcp::{ClientMethod, LifecycleMethod, ToolsMethod};
110
111    fn uri(path: &str) -> Uri {
112        path.parse().unwrap()
113    }
114
115    fn headers_with(pairs: &[(&str, &str)]) -> HeaderMap {
116        let mut h = HeaderMap::new();
117        for (k, v) in pairs {
118            h.insert(
119                axum::http::HeaderName::from_bytes(k.as_bytes()).unwrap(),
120                v.parse().unwrap(),
121            );
122        }
123        h
124    }
125
126    #[test]
127    fn from_axum_parts__post_tools_list_is_mcp_streamable() {
128        let req = from_axum_parts(
129            Method::POST,
130            HeaderMap::new(),
131            uri("/mcp"),
132            Bytes::from_static(br#"{"jsonrpc":"2.0","id":1,"method":"tools/list"}"#),
133        );
134        let Request::Mcp(mcp) = req else {
135            panic!("expected Mcp");
136        };
137        assert_eq!(mcp.transport, McpTransport::StreamableHttpPost);
138        assert_eq!(
139            mcp.kind,
140            ClientKind::Request(ClientMethod::Tools(ToolsMethod::List))
141        );
142    }
143
144    #[test]
145    fn from_axum_parts__session_header_populates_hint() {
146        let req = from_axum_parts(
147            Method::POST,
148            headers_with(&[("mcp-session-id", "abc")]),
149            uri("/mcp"),
150            Bytes::from_static(br#"{"jsonrpc":"2.0","id":1,"method":"tools/list"}"#),
151        );
152        let Request::Mcp(mcp) = req else {
153            panic!("expected Mcp");
154        };
155        assert_eq!(mcp.session_hint.map(|s| s.0), Some("abc".into()));
156    }
157
158    #[test]
159    fn from_axum_parts__post_invalid_json_falls_through_to_raw() {
160        let req = from_axum_parts(
161            Method::POST,
162            HeaderMap::new(),
163            uri("/mcp"),
164            Bytes::from_static(b"not json"),
165        );
166        assert!(matches!(req, Request::Raw(_)));
167    }
168
169    #[test]
170    fn from_axum_parts__post_valid_json_but_not_jsonrpc_falls_through_to_raw() {
171        let req = from_axum_parts(
172            Method::POST,
173            HeaderMap::new(),
174            uri("/"),
175            Bytes::from_static(br#"{"foo":"bar"}"#),
176        );
177        assert!(matches!(req, Request::Raw(_)));
178    }
179
180    #[test]
181    fn from_axum_parts__get_with_sse_accept_is_sse_legacy() {
182        let req = from_axum_parts(
183            Method::GET,
184            headers_with(&[("accept", "text/event-stream")]),
185            uri("/mcp"),
186            Bytes::new(),
187        );
188        let Request::Mcp(mcp) = req else {
189            panic!("expected Mcp");
190        };
191        assert_eq!(mcp.transport, McpTransport::SseLegacyGet);
192    }
193
194    #[test]
195    fn from_axum_parts__get_without_sse_is_raw() {
196        let req = from_axum_parts(Method::GET, HeaderMap::new(), uri("/health"), Bytes::new());
197        assert!(matches!(req, Request::Raw(_)));
198    }
199
200    #[test]
201    fn from_axum_parts__delete_with_session_id_is_mcp() {
202        let req = from_axum_parts(
203            Method::DELETE,
204            headers_with(&[("mcp-session-id", "abc")]),
205            uri("/mcp"),
206            Bytes::new(),
207        );
208        let Request::Mcp(mcp) = req else {
209            panic!("expected Mcp");
210        };
211        assert_eq!(mcp.session_hint.map(|s| s.0), Some("abc".into()));
212        assert_eq!(mcp.transport, McpTransport::StreamableHttpPost);
213    }
214
215    #[test]
216    fn from_axum_parts__delete_without_session_id_is_raw() {
217        let req = from_axum_parts(Method::DELETE, HeaderMap::new(), uri("/mcp"), Bytes::new());
218        assert!(matches!(req, Request::Raw(_)));
219    }
220
221    #[test]
222    fn from_axum_parts__notification_classified_as_notification() {
223        let req = from_axum_parts(
224            Method::POST,
225            HeaderMap::new(),
226            uri("/mcp"),
227            Bytes::from_static(br#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#),
228        );
229        let Request::Mcp(mcp) = req else {
230            panic!("expected Mcp");
231        };
232        assert_eq!(
233            mcp.kind,
234            ClientKind::Notification(ClientNotifMethod::Initialized)
235        );
236    }
237
238    #[test]
239    fn from_axum_parts__initialize_request_stays_mcp() {
240        let req = from_axum_parts(
241            Method::POST,
242            HeaderMap::new(),
243            uri("/mcp"),
244            Bytes::from_static(br#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{}}"#),
245        );
246        let Request::Mcp(mcp) = req else {
247            panic!("expected Mcp");
248        };
249        assert_eq!(
250            mcp.kind,
251            ClientKind::Request(ClientMethod::Lifecycle(LifecycleMethod::Initialize))
252        );
253    }
254}