1use axum::body::{Body, Bytes};
20use axum::http::{HeaderMap, Method, Uri, header};
21
22use super::pipeline::stubs::SessionId;
23use super::pipeline::values::{McpRequest, McpTransport, RawRequest, Request};
24use crate::protocol::jsonrpc::JsonRpcEnvelope;
25use crate::protocol::mcp::{ClientKind, ClientNotifMethod, classify_client};
26
27pub fn from_axum_parts(method: Method, headers: HeaderMap, uri: Uri, body: Bytes) -> Request {
28 let path = uri.path().to_string();
29
30 if method == Method::POST
32 && let Ok(envelope) = JsonRpcEnvelope::parse(&body)
33 {
34 let kind = classify_client(&envelope);
35 let session_hint = session_hint_from_headers(&headers);
36 return Request::Mcp(McpRequest {
37 transport: McpTransport::StreamableHttpPost,
38 envelope,
39 kind,
40 headers,
41 session_hint,
42 });
43 }
44
45 if method == Method::GET && wants_sse(&headers) {
48 let envelope = JsonRpcEnvelope::parse(br#"{"jsonrpc":"2.0","method":"ping"}"#)
49 .expect("static synthetic envelope parses");
50 let kind = ClientKind::Notification(ClientNotifMethod::Unknown("ping".into()));
51 let session_hint = session_hint_from_headers(&headers);
52 return Request::Mcp(McpRequest {
53 transport: McpTransport::SseLegacyGet,
54 envelope,
55 kind,
56 headers,
57 session_hint,
58 });
59 }
60
61 if method == Method::DELETE
63 && let Some(sid_value) = headers.get("mcp-session-id").cloned()
64 {
65 let envelope = JsonRpcEnvelope::parse(br#"{"jsonrpc":"2.0","method":"delete"}"#)
66 .expect("static synthetic envelope parses");
67 let kind = ClientKind::Notification(ClientNotifMethod::Unknown("delete".into()));
68 let session_hint = sid_value
69 .to_str()
70 .ok()
71 .map(|s| SessionId::new(s.to_string()));
72 return Request::Mcp(McpRequest {
73 transport: McpTransport::StreamableHttpPost,
74 envelope,
75 kind,
76 headers,
77 session_hint,
78 });
79 }
80
81 Request::Raw(RawRequest {
82 method,
83 path,
84 body: Body::from(body),
85 headers,
86 })
87}
88
89fn wants_sse(headers: &HeaderMap) -> bool {
90 headers
91 .get(header::ACCEPT)
92 .and_then(|v| v.to_str().ok())
93 .map(|a| a.contains("text/event-stream"))
94 .unwrap_or(false)
95}
96
97fn session_hint_from_headers(headers: &HeaderMap) -> Option<SessionId> {
98 headers
99 .get("mcp-session-id")
100 .and_then(|v| v.to_str().ok())
101 .map(|s| SessionId::new(s.to_string()))
102}
103
104#[cfg(test)]
105#[allow(non_snake_case)]
106mod tests {
107 use super::*;
108
109 use crate::protocol::mcp::{ClientMethod, LifecycleMethod, ToolsMethod};
110
111 fn uri(path: &str) -> Uri {
112 path.parse().unwrap()
113 }
114
115 fn headers_with(pairs: &[(&str, &str)]) -> HeaderMap {
116 let mut h = HeaderMap::new();
117 for (k, v) in pairs {
118 h.insert(
119 axum::http::HeaderName::from_bytes(k.as_bytes()).unwrap(),
120 v.parse().unwrap(),
121 );
122 }
123 h
124 }
125
126 #[test]
127 fn from_axum_parts__post_tools_list_is_mcp_streamable() {
128 let req = from_axum_parts(
129 Method::POST,
130 HeaderMap::new(),
131 uri("/mcp"),
132 Bytes::from_static(br#"{"jsonrpc":"2.0","id":1,"method":"tools/list"}"#),
133 );
134 let Request::Mcp(mcp) = req else {
135 panic!("expected Mcp");
136 };
137 assert_eq!(mcp.transport, McpTransport::StreamableHttpPost);
138 assert_eq!(
139 mcp.kind,
140 ClientKind::Request(ClientMethod::Tools(ToolsMethod::List))
141 );
142 }
143
144 #[test]
145 fn from_axum_parts__session_header_populates_hint() {
146 let req = from_axum_parts(
147 Method::POST,
148 headers_with(&[("mcp-session-id", "abc")]),
149 uri("/mcp"),
150 Bytes::from_static(br#"{"jsonrpc":"2.0","id":1,"method":"tools/list"}"#),
151 );
152 let Request::Mcp(mcp) = req else {
153 panic!("expected Mcp");
154 };
155 assert_eq!(mcp.session_hint.map(|s| s.0), Some("abc".into()));
156 }
157
158 #[test]
159 fn from_axum_parts__post_invalid_json_falls_through_to_raw() {
160 let req = from_axum_parts(
161 Method::POST,
162 HeaderMap::new(),
163 uri("/mcp"),
164 Bytes::from_static(b"not json"),
165 );
166 assert!(matches!(req, Request::Raw(_)));
167 }
168
169 #[test]
170 fn from_axum_parts__post_valid_json_but_not_jsonrpc_falls_through_to_raw() {
171 let req = from_axum_parts(
172 Method::POST,
173 HeaderMap::new(),
174 uri("/"),
175 Bytes::from_static(br#"{"foo":"bar"}"#),
176 );
177 assert!(matches!(req, Request::Raw(_)));
178 }
179
180 #[test]
181 fn from_axum_parts__get_with_sse_accept_is_sse_legacy() {
182 let req = from_axum_parts(
183 Method::GET,
184 headers_with(&[("accept", "text/event-stream")]),
185 uri("/mcp"),
186 Bytes::new(),
187 );
188 let Request::Mcp(mcp) = req else {
189 panic!("expected Mcp");
190 };
191 assert_eq!(mcp.transport, McpTransport::SseLegacyGet);
192 }
193
194 #[test]
195 fn from_axum_parts__get_without_sse_is_raw() {
196 let req = from_axum_parts(Method::GET, HeaderMap::new(), uri("/health"), Bytes::new());
197 assert!(matches!(req, Request::Raw(_)));
198 }
199
200 #[test]
201 fn from_axum_parts__delete_with_session_id_is_mcp() {
202 let req = from_axum_parts(
203 Method::DELETE,
204 headers_with(&[("mcp-session-id", "abc")]),
205 uri("/mcp"),
206 Bytes::new(),
207 );
208 let Request::Mcp(mcp) = req else {
209 panic!("expected Mcp");
210 };
211 assert_eq!(mcp.session_hint.map(|s| s.0), Some("abc".into()));
212 assert_eq!(mcp.transport, McpTransport::StreamableHttpPost);
213 }
214
215 #[test]
216 fn from_axum_parts__delete_without_session_id_is_raw() {
217 let req = from_axum_parts(Method::DELETE, HeaderMap::new(), uri("/mcp"), Bytes::new());
218 assert!(matches!(req, Request::Raw(_)));
219 }
220
221 #[test]
222 fn from_axum_parts__notification_classified_as_notification() {
223 let req = from_axum_parts(
224 Method::POST,
225 HeaderMap::new(),
226 uri("/mcp"),
227 Bytes::from_static(br#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#),
228 );
229 let Request::Mcp(mcp) = req else {
230 panic!("expected Mcp");
231 };
232 assert_eq!(
233 mcp.kind,
234 ClientKind::Notification(ClientNotifMethod::Initialized)
235 );
236 }
237
238 #[test]
239 fn from_axum_parts__initialize_request_stays_mcp() {
240 let req = from_axum_parts(
241 Method::POST,
242 HeaderMap::new(),
243 uri("/mcp"),
244 Bytes::from_static(br#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{}}"#),
245 );
246 let Request::Mcp(mcp) = req else {
247 panic!("expected Mcp");
248 };
249 assert_eq!(
250 mcp.kind,
251 ClientKind::Request(ClientMethod::Lifecycle(LifecycleMethod::Initialize))
252 );
253 }
254}