1use crate::{
7 AuthHandler, CallToolResult, ClientCapabilities, ClientInfo, InitializeRequest, JsonRpcError,
8 JsonRpcRequest, JsonRpcResponse, ListToolsResult, MCP_PROTOCOL_VERSION, Tool, ToolCapabilities,
9};
10use protocol_transport_core::{ProtocolError, TransportError};
11use serde_json::json;
12use std::collections::HashMap;
13use std::sync::Mutex;
14
15#[cfg(feature = "sse-client")]
16use crate::ToolProvider;
17#[cfg(feature = "sse-client")]
18use protocol_transport_core::{SseTransport, Transport, TransportFactory, UniversalRequest};
19
20const CONTENT_TYPE_JSON: &str = "application/json";
21const CONTENT_TYPE_EVENT_STREAM: &str = "text/event-stream";
22const HEADER_ACCEPT: &str = "Accept";
23const HEADER_AUTHORIZATION: &str = "Authorization";
24const HEADER_CONTENT_TYPE: &str = "Content-Type";
25const HEADER_MCP_SESSION_ID: &str = "Mcp-Session-Id";
26
27enum ClientTransport {
28 StreamableHttp(StreamableHttpClientTransport),
29
30 #[cfg(feature = "sse-client")]
31 Sse {
32 transport: SseTransport,
33 },
34}
35
36struct StreamableHttpClientTransport {
37 endpoint: String,
38 auth_token: Option<String>,
39 extra_headers: HashMap<String, String>,
40 client_info: ClientInfo,
41 initialized: Mutex<bool>,
42 protocol_version: Mutex<Option<String>>,
43 session_id: Mutex<Option<String>>,
44 next_id: Mutex<u64>,
45}
46
47impl StreamableHttpClientTransport {
48 fn new(endpoint: impl Into<String>) -> Self {
49 Self {
50 endpoint: endpoint.into(),
51 auth_token: None,
52 extra_headers: HashMap::new(),
53 client_info: ClientInfo {
54 name: "promptfleet-mcp-client".to_string(),
55 version: env!("CARGO_PKG_VERSION").to_string(),
56 description: Some("PromptFleet Streamable HTTP MCP client".to_string()),
57 },
58 initialized: Mutex::new(false),
59 protocol_version: Mutex::new(None),
60 session_id: Mutex::new(None),
61 next_id: Mutex::new(0),
62 }
63 }
64
65 fn with_auth_token(mut self, token: impl Into<String>) -> Self {
66 self.auth_token = Some(token.into());
67 self
68 }
69
70 fn with_headers(mut self, headers: HashMap<String, String>) -> Self {
71 self.extra_headers = headers;
72 self
73 }
74
75 fn with_client_info(mut self, client_info: ClientInfo) -> Self {
76 self.client_info = client_info;
77 self
78 }
79
80 async fn initialize_if_needed(&self) -> Result<(), ProtocolError> {
81 let already_initialized = *self
82 .initialized
83 .lock()
84 .map_err(|_| ProtocolError::internal_error("streamable client init mutex poisoned"))?;
85 if already_initialized {
86 return Ok(());
87 }
88
89 let init_request = InitializeRequest {
90 protocol_version: MCP_PROTOCOL_VERSION.to_string(),
91 capabilities: ClientCapabilities {
92 tools: Some(ToolCapabilities { supported: true }),
93 },
94 client_info: self.client_info.clone(),
95 };
96
97 let result = self
98 .send_jsonrpc_raw(
99 "initialize",
100 Some(
101 serde_json::to_value(init_request)
102 .map_err(|e| ProtocolError::Parsing(format!("init serialize: {e}")))?,
103 ),
104 )
105 .await?;
106
107 let negotiated_protocol_version = result
108 .get("protocolVersion")
109 .or_else(|| result.get("protocol_version"))
110 .and_then(|value| value.as_str())
111 .map(ToString::to_string);
112 if negotiated_protocol_version.is_none() {
113 return Err(ProtocolError::Parsing(
114 "invalid initialize result: missing protocolVersion".to_string(),
115 ));
116 }
117 *self.protocol_version.lock().map_err(|_| {
118 ProtocolError::internal_error("streamable client protocol-version mutex poisoned")
119 })? = negotiated_protocol_version;
120
121 self.send_notification_raw("notifications/initialized", None)
122 .await?;
123
124 let mut initialized = self
125 .initialized
126 .lock()
127 .map_err(|_| ProtocolError::internal_error("streamable client init mutex poisoned"))?;
128 *initialized = true;
129 Ok(())
130 }
131
132 async fn list_tools(&self) -> Result<Vec<Tool>, ProtocolError> {
133 let result = self
134 .send_jsonrpc("tools/list", Some(json!({})), true)
135 .await?;
136 let list_result: ListToolsResult = serde_json::from_value(result)
137 .map_err(|e| ProtocolError::Parsing(format!("invalid tools list format: {e}")))?;
138 Ok(list_result.tools)
139 }
140
141 async fn call_tool(
142 &self,
143 name: &str,
144 arguments: Option<serde_json::Value>,
145 meta: Option<serde_json::Value>,
146 ) -> Result<CallToolResult, ProtocolError> {
147 let result = self
148 .send_jsonrpc(
149 "tools/call",
150 Some(json!({
151 "name": name,
152 "arguments": arguments,
153 "_meta": meta,
154 })),
155 true,
156 )
157 .await?;
158
159 serde_json::from_value(result)
160 .map_err(|e| ProtocolError::Parsing(format!("invalid tool call result format: {e}")))
161 }
162
163 async fn send_jsonrpc(
164 &self,
165 method: &str,
166 params: Option<serde_json::Value>,
167 require_initialized: bool,
168 ) -> Result<serde_json::Value, ProtocolError> {
169 if require_initialized {
170 self.initialize_if_needed().await?;
171 }
172 self.send_jsonrpc_raw(method, params).await
173 }
174
175 async fn send_jsonrpc_raw(
176 &self,
177 method: &str,
178 params: Option<serde_json::Value>,
179 ) -> Result<serde_json::Value, ProtocolError> {
180 let id = {
181 let mut next_id = self.next_id.lock().map_err(|_| {
182 ProtocolError::internal_error("streamable client request-id mutex poisoned")
183 })?;
184 *next_id += 1;
185 *next_id
186 };
187
188 let request = JsonRpcRequest {
189 jsonrpc: "2.0".to_string(),
190 id: Some(json!(id)),
191 method: method.to_string(),
192 params,
193 };
194
195 let body = serde_json::to_vec(&request)?;
196 let response = self.http_post(body).await?;
197
198 if let Some(session_id) = find_header(&response.headers, HEADER_MCP_SESSION_ID) {
199 let mut stored = self.session_id.lock().map_err(|_| {
200 ProtocolError::internal_error("streamable client session mutex poisoned")
201 })?;
202 *stored = Some(session_id.to_string());
203 }
204
205 let content_type = find_header(&response.headers, HEADER_CONTENT_TYPE)
206 .map(|value| value.to_ascii_lowercase())
207 .unwrap_or_else(|| CONTENT_TYPE_JSON.to_string());
208
209 let rpc_response = if content_type.contains(CONTENT_TYPE_JSON) {
210 serde_json::from_slice::<JsonRpcResponse>(&response.body).map_err(|e| {
211 ProtocolError::Parsing(format!("invalid JSON-RPC response body: {e}"))
212 })?
213 } else if content_type.contains(CONTENT_TYPE_EVENT_STREAM) {
214 parse_sse_jsonrpc_response(&response.body)?
215 } else {
216 return Err(ProtocolError::Parsing(format!(
217 "unsupported response content-type '{content_type}'"
218 )));
219 };
220
221 if let Some(error) = rpc_response.error {
222 return Err(protocol_error_from_jsonrpc(error));
223 }
224
225 rpc_response
226 .result
227 .ok_or_else(|| ProtocolError::Parsing("missing JSON-RPC result field".to_string()))
228 }
229
230 async fn send_notification_raw(
231 &self,
232 method: &str,
233 params: Option<serde_json::Value>,
234 ) -> Result<(), ProtocolError> {
235 let request = JsonRpcRequest {
236 jsonrpc: "2.0".to_string(),
237 id: None,
238 method: method.to_string(),
239 params,
240 };
241
242 let body = serde_json::to_vec(&request)
243 .map_err(|e| ProtocolError::Parsing(format!("request serialize: {e}")))?;
244 let _ = self.http_post(body).await?;
245 Ok(())
246 }
247
248 async fn http_post(&self, body: Vec<u8>) -> Result<HttpResponse, ProtocolError> {
249 let mut headers = HashMap::new();
250 for (key, value) in &self.extra_headers {
251 if !key.eq_ignore_ascii_case(HEADER_MCP_SESSION_ID) {
252 headers.insert(key.clone(), value.clone());
253 }
254 }
255 headers.insert(
256 HEADER_ACCEPT.to_string(),
257 format!("{CONTENT_TYPE_JSON}, {CONTENT_TYPE_EVENT_STREAM}"),
258 );
259 headers.insert(
260 HEADER_CONTENT_TYPE.to_string(),
261 CONTENT_TYPE_JSON.to_string(),
262 );
263 if let Some(protocol_version) = self
264 .protocol_version
265 .lock()
266 .map_err(|_| {
267 ProtocolError::internal_error("streamable client protocol-version mutex poisoned")
268 })?
269 .clone()
270 {
271 headers.insert("MCP-Protocol-Version".to_string(), protocol_version);
272 }
273
274 if let Some(token) = &self.auth_token {
275 headers.insert(HEADER_AUTHORIZATION.to_string(), format!("Bearer {token}"));
276 }
277
278 if let Some(session_id) = self
279 .session_id
280 .lock()
281 .map_err(|_| ProtocolError::internal_error("streamable client session mutex poisoned"))?
282 .clone()
283 {
284 headers.insert(HEADER_MCP_SESSION_ID.to_string(), session_id);
285 }
286
287 #[cfg(target_arch = "wasm32")]
288 {
289 use spin_sdk::http::{Method, Request as SpinRequest, Response as SpinResponse, send};
290
291 let mut builder = SpinRequest::builder();
292 builder.method(Method::Post);
293 builder.uri(&self.endpoint);
294 for (key, value) in &headers {
295 builder.header(key, value);
296 }
297 let request = builder.body(body).build();
298 let response: SpinResponse = send(request).await.map_err(|e| {
299 ProtocolError::Transport(TransportError::Network(format!(
300 "Spin HTTP send failed: {e}"
301 )))
302 })?;
303
304 let response_headers = response
305 .headers()
306 .filter_map(|(name, value)| {
307 value
308 .as_str()
309 .map(|value| (name.to_string(), value.to_string()))
310 })
311 .collect::<HashMap<_, _>>();
312 let status = *response.status();
313 let body = response.body().to_vec();
314
315 if !(200..300).contains(&status) {
316 return Err(ProtocolError::Transport(TransportError::Http {
317 status,
318 message: format!("streamable HTTP request failed with status {status}"),
319 body: Some(body),
320 headers: Some(response_headers),
321 }));
322 }
323
324 Ok(HttpResponse {
325 headers: response_headers,
326 body,
327 })
328 }
329
330 #[cfg(not(target_arch = "wasm32"))]
331 {
332 let client = reqwest::Client::new();
333 let mut request = client.post(&self.endpoint);
334 for (key, value) in &headers {
335 request = request.header(key, value);
336 }
337 let response = request.body(body).send().await.map_err(|e| {
338 ProtocolError::Transport(TransportError::Network(format!(
339 "streamable HTTP request failed: {e}"
340 )))
341 })?;
342
343 let status = response.status().as_u16();
344 let response_headers = response
345 .headers()
346 .iter()
347 .filter_map(|(name, value)| {
348 value
349 .to_str()
350 .ok()
351 .map(|value| (name.to_string(), value.to_string()))
352 })
353 .collect::<HashMap<_, _>>();
354 let body = response.bytes().await.map_err(|e| {
355 ProtocolError::Transport(TransportError::Network(format!(
356 "streamable HTTP response read failed: {e}"
357 )))
358 })?;
359 let body = body.to_vec();
360
361 if !(200..300).contains(&status) {
362 return Err(ProtocolError::Transport(TransportError::Http {
363 status,
364 message: format!("streamable HTTP request failed with status {status}"),
365 body: Some(body),
366 headers: Some(response_headers),
367 }));
368 }
369
370 Ok(HttpResponse {
371 headers: response_headers,
372 body,
373 })
374 }
375 }
376}
377
378struct HttpResponse {
379 headers: HashMap<String, String>,
380 body: Vec<u8>,
381}
382
383fn find_header<'a>(headers: &'a HashMap<String, String>, name: &str) -> Option<&'a str> {
384 headers
385 .iter()
386 .find(|(header_name, _)| header_name.eq_ignore_ascii_case(name))
387 .map(|(_, value)| value.as_str())
388}
389
390fn protocol_error_from_jsonrpc(error: JsonRpcError) -> ProtocolError {
391 let details = error
392 .data
393 .map(|value| format!(" data={value}"))
394 .unwrap_or_default();
395 ProtocolError::Validation(format!(
396 "JSON-RPC error {}: {}{}",
397 error.code, error.message, details
398 ))
399}
400
401fn parse_sse_jsonrpc_response(body: &[u8]) -> Result<JsonRpcResponse, ProtocolError> {
402 let text = std::str::from_utf8(body)
403 .map_err(|e| ProtocolError::Parsing(format!("invalid UTF-8 event-stream body: {e}")))?;
404 let mut data_lines = Vec::new();
405
406 for line in text.lines() {
407 if let Some(rest) = line.strip_prefix("data:") {
408 data_lines.push(rest.trim_start().to_string());
409 continue;
410 }
411
412 if line.trim().is_empty() && !data_lines.is_empty() {
413 let payload = data_lines.join("\n");
414 if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(&payload) {
415 return Ok(response);
416 }
417 data_lines.clear();
418 }
419 }
420
421 if !data_lines.is_empty() {
422 let payload = data_lines.join("\n");
423 if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(&payload) {
424 return Ok(response);
425 }
426 }
427
428 Err(ProtocolError::Parsing(format!(
429 "event-stream response did not contain an MCP JSON-RPC payload; legacy SSE-only endpoints are unsupported; body={text:?}"
430 )))
431}
432
433pub struct McpClient {
435 auth_handler: Option<Box<dyn AuthHandler>>,
436 transport: Option<ClientTransport>,
437}
438
439impl McpClient {
440 pub fn new() -> Self {
442 Self {
443 auth_handler: None,
444 transport: None,
445 }
446 }
447
448 pub fn with_auth_handler<H: AuthHandler + 'static>(mut self, handler: H) -> Self {
450 self.auth_handler = Some(Box::new(handler));
451 self
452 }
453
454 pub fn with_streamable_http_server(mut self, endpoint: &str) -> Self {
456 self.transport = Some(ClientTransport::StreamableHttp(
457 StreamableHttpClientTransport::new(endpoint),
458 ));
459 self
460 }
461
462 pub fn with_streamable_http_server_auth(mut self, endpoint: &str, auth_token: &str) -> Self {
464 self.transport = Some(ClientTransport::StreamableHttp(
465 StreamableHttpClientTransport::new(endpoint).with_auth_token(auth_token),
466 ));
467 self
468 }
469
470 pub fn with_streamable_http_headers(mut self, headers: HashMap<String, String>) -> Self {
472 let transport = match self.transport.take() {
473 Some(ClientTransport::StreamableHttp(transport)) => {
474 ClientTransport::StreamableHttp(transport.with_headers(headers))
475 }
476 other => {
477 self.transport = other;
478 return self;
479 }
480 };
481 self.transport = Some(transport);
482 self
483 }
484
485 pub fn with_streamable_http_client_info(mut self, client_info: ClientInfo) -> Self {
487 let transport = match self.transport.take() {
488 Some(ClientTransport::StreamableHttp(transport)) => {
489 ClientTransport::StreamableHttp(transport.with_client_info(client_info))
490 }
491 other => {
492 self.transport = other;
493 return self;
494 }
495 };
496 self.transport = Some(transport);
497 self
498 }
499
500 #[cfg(feature = "sse-client")]
502 pub fn with_sse_server(mut self, endpoint: &str) -> Self {
503 self.transport = Some(ClientTransport::Sse {
504 transport: TransportFactory::mcp_sse(endpoint),
505 });
506 self
507 }
508
509 #[cfg(feature = "sse-client")]
511 pub fn with_sse_server_auth(mut self, endpoint: &str, auth_token: &str) -> Self {
512 self.transport = Some(ClientTransport::Sse {
513 transport: TransportFactory::mcp_sse_auth(endpoint, auth_token),
514 });
515 self
516 }
517
518 pub async fn list_tools_async(&self) -> Result<Vec<Tool>, ProtocolError> {
520 match self
521 .transport
522 .as_ref()
523 .ok_or_else(|| ProtocolError::internal_error("no MCP transport configured"))?
524 {
525 ClientTransport::StreamableHttp(transport) => transport.list_tools().await,
526
527 #[cfg(feature = "sse-client")]
528 ClientTransport::Sse { transport } => {
529 let result = send_sse_request(transport, "tools/list", json!({})).await?;
530 let list_result: ListToolsResult = serde_json::from_value(result).map_err(|e| {
531 ProtocolError::Parsing(format!("invalid tools list format: {e}"))
532 })?;
533 Ok(list_result.tools)
534 }
535 }
536 }
537
538 pub async fn call_tool_async(
540 &self,
541 name: &str,
542 arguments: Option<serde_json::Value>,
543 ) -> Result<CallToolResult, ProtocolError> {
544 self.call_tool_with_meta_async(name, arguments, None).await
545 }
546
547 pub async fn call_tool_with_meta_async(
548 &self,
549 name: &str,
550 arguments: Option<serde_json::Value>,
551 meta: Option<serde_json::Value>,
552 ) -> Result<CallToolResult, ProtocolError> {
553 match self
554 .transport
555 .as_ref()
556 .ok_or_else(|| ProtocolError::internal_error("no MCP transport configured"))?
557 {
558 ClientTransport::StreamableHttp(transport) => {
559 transport.call_tool(name, arguments, meta).await
560 }
561
562 #[cfg(feature = "sse-client")]
563 ClientTransport::Sse { transport } => {
564 let result = send_sse_request(
565 transport,
566 "tools/call",
567 json!({
568 "name": name,
569 "arguments": arguments,
570 "_meta": meta,
571 }),
572 )
573 .await?;
574
575 serde_json::from_value(result).map_err(|e| {
576 ProtocolError::Parsing(format!("invalid tool call result format: {e}"))
577 })
578 }
579 }
580 }
581
582 pub async fn initialize_async(&self) -> Result<(), ProtocolError> {
584 match self
585 .transport
586 .as_ref()
587 .ok_or_else(|| ProtocolError::internal_error("no MCP transport configured"))?
588 {
589 ClientTransport::StreamableHttp(transport) => transport.initialize_if_needed().await,
590
591 #[cfg(feature = "sse-client")]
592 ClientTransport::Sse { .. } => Ok(()),
593 }
594 }
595
596 pub async fn health_check(&self) -> Result<(), ProtocolError> {
598 match self
599 .transport
600 .as_ref()
601 .ok_or_else(|| ProtocolError::internal_error("no MCP transport configured"))?
602 {
603 ClientTransport::StreamableHttp(transport) => transport.initialize_if_needed().await,
604
605 #[cfg(feature = "sse-client")]
606 ClientTransport::Sse { transport } => transport.health_check().await.map_err(|e| {
607 ProtocolError::internal_error(&format!("health check failed: {:?}", e))
608 }),
609 }
610 }
611}
612
613#[cfg(feature = "sse-client")]
614async fn send_sse_request(
615 transport: &SseTransport,
616 method: &str,
617 params: serde_json::Value,
618) -> Result<serde_json::Value, ProtocolError> {
619 let request = UniversalRequest {
620 method: method.to_string(),
621 uri: "/".to_string(),
622 headers: HashMap::new(),
623 body: json!({
624 "jsonrpc": "2.0",
625 "method": method,
626 "params": params,
627 "id": 1,
628 })
629 .to_string()
630 .into_bytes(),
631 protocol: "MCP".to_string(),
632 correlation_id: format!("mcp-client-{}", method.replace('/', "-")),
633 };
634
635 let response = transport
636 .send(request)
637 .await
638 .map_err(|e| ProtocolError::internal_error(&format!("transport error: {e:?}")))?;
639
640 let response_json: serde_json::Value = serde_json::from_slice(&response.body)
641 .map_err(|e| ProtocolError::Parsing(format!("invalid JSON response: {e}")))?;
642
643 response_json
644 .get("result")
645 .cloned()
646 .ok_or_else(|| ProtocolError::Parsing("missing 'result' field".to_string()))
647}
648
649#[cfg(feature = "sse-client")]
650#[async_trait::async_trait]
651impl ToolProvider for McpClient {
652 fn list_tools(&self) -> Result<Vec<Tool>, ProtocolError> {
653 Err(ProtocolError::internal_error(
654 "async tool listing not supported in sync context. Use list_tools_async().",
655 ))
656 }
657
658 async fn call_tool(
659 &self,
660 name: &str,
661 _arguments: Option<serde_json::Value>,
662 ) -> Result<CallToolResult, ProtocolError> {
663 Err(ProtocolError::internal_error(&format!(
664 "async tool calls not supported in sync context. Use call_tool_async() for tool '{name}'.",
665 )))
666 }
667}
668
669impl Default for McpClient {
670 fn default() -> Self {
671 Self::new()
672 }
673}
674
675pub struct McpClientBuilder {
677 auth_handler: Option<Box<dyn AuthHandler>>,
678 streamable_http_endpoint: Option<String>,
679 streamable_http_auth_token: Option<String>,
680
681 #[cfg(feature = "sse-client")]
682 sse_endpoint: Option<String>,
683
684 #[cfg(feature = "sse-client")]
685 sse_auth_token: Option<String>,
686}
687
688impl McpClientBuilder {
689 pub fn new() -> Self {
691 Self {
692 auth_handler: None,
693 streamable_http_endpoint: None,
694 streamable_http_auth_token: None,
695
696 #[cfg(feature = "sse-client")]
697 sse_endpoint: None,
698
699 #[cfg(feature = "sse-client")]
700 sse_auth_token: None,
701 }
702 }
703
704 pub fn with_auth_handler<H: AuthHandler + 'static>(mut self, handler: H) -> Self {
706 self.auth_handler = Some(Box::new(handler));
707 self
708 }
709
710 pub fn with_streamable_http_server(mut self, endpoint: &str) -> Self {
712 self.streamable_http_endpoint = Some(endpoint.to_string());
713 self
714 }
715
716 pub fn with_streamable_http_auth_token(mut self, token: &str) -> Self {
718 self.streamable_http_auth_token = Some(token.to_string());
719 self
720 }
721
722 #[cfg(feature = "sse-client")]
724 pub fn with_sse_server(mut self, endpoint: &str) -> Self {
725 self.sse_endpoint = Some(endpoint.to_string());
726 self
727 }
728
729 #[cfg(feature = "sse-client")]
731 pub fn with_auth_token(mut self, token: &str) -> Self {
732 self.sse_auth_token = Some(token.to_string());
733 self
734 }
735
736 pub fn build(self) -> McpClient {
738 let mut client = McpClient::new();
739
740 if let Some(handler) = self.auth_handler {
741 client.auth_handler = Some(handler);
742 }
743
744 if let Some(endpoint) = self.streamable_http_endpoint {
745 client = if let Some(token) = self.streamable_http_auth_token {
746 client.with_streamable_http_server_auth(&endpoint, &token)
747 } else {
748 client.with_streamable_http_server(&endpoint)
749 };
750 }
751
752 #[cfg(feature = "sse-client")]
753 {
754 if let Some(endpoint) = self.sse_endpoint {
755 client = if let Some(token) = self.sse_auth_token {
756 client.with_sse_server_auth(&endpoint, &token)
757 } else {
758 client.with_sse_server(&endpoint)
759 };
760 }
761 }
762
763 client
764 }
765}
766
767#[cfg(all(test, not(target_arch = "wasm32")))]
768mod tests {
769 use super::*;
770 use axum::{
771 Json, Router,
772 body::Bytes,
773 extract::State,
774 http::{HeaderMap, HeaderValue, StatusCode},
775 response::IntoResponse,
776 routing::post,
777 };
778 use std::sync::{
779 Arc,
780 atomic::{AtomicUsize, Ordering},
781 };
782 use tokio::net::TcpListener;
783
784 #[derive(Clone)]
785 struct TestState {
786 session_seen: Arc<AtomicUsize>,
787 initialized_seen: Arc<AtomicUsize>,
788 }
789
790 async fn json_handler(
791 State(state): State<TestState>,
792 headers: HeaderMap,
793 body: Bytes,
794 ) -> impl IntoResponse {
795 let request: serde_json::Value = serde_json::from_slice(&body).expect("json body");
796 let method = request["method"].as_str().expect("method");
797 match method {
798 "initialize" => {
799 let mut response_headers = HeaderMap::new();
800 response_headers.insert(
801 HEADER_MCP_SESSION_ID,
802 HeaderValue::from_static("session-123"),
803 );
804 (
805 response_headers,
806 Json(json!({
807 "jsonrpc": "2.0",
808 "id": request["id"].clone(),
809 "result": {
810 "protocolVersion": MCP_PROTOCOL_VERSION,
811 "capabilities": { "tools": { "supported": true } },
812 "serverInfo": {
813 "name": "test-server",
814 "version": "0.1.0"
815 }
816 }
817 })),
818 )
819 .into_response()
820 }
821 "notifications/initialized" => {
822 assert!(
823 request.get("id").is_none(),
824 "initialized notification must not carry an id"
825 );
826 state.initialized_seen.fetch_add(1, Ordering::SeqCst);
827 StatusCode::ACCEPTED.into_response()
828 }
829 "tools/list" => {
830 assert_eq!(
831 state.initialized_seen.load(Ordering::SeqCst),
832 1,
833 "tools/list should only be called after notifications/initialized"
834 );
835 if headers
836 .get(HEADER_MCP_SESSION_ID)
837 .and_then(|value| value.to_str().ok())
838 == Some("session-123")
839 {
840 state.session_seen.fetch_add(1, Ordering::SeqCst);
841 }
842 Json(json!({
843 "jsonrpc": "2.0",
844 "id": request["id"].clone(),
845 "result": {
846 "tools": [{
847 "name": "search_agents",
848 "description": "Search directory",
849 "inputSchema": { "type": "object", "properties": {} }
850 }]
851 }
852 }))
853 .into_response()
854 }
855 "tools/call" => {
856 let body = format!(
857 "event: message\ndata: {}\n\n",
858 json!({
859 "jsonrpc": "2.0",
860 "id": request["id"].clone(),
861 "result": {
862 "content": [{ "type": "text", "text": "{\"ok\":true}" }],
863 "isError": false
864 }
865 })
866 );
867 ([(HEADER_CONTENT_TYPE, CONTENT_TYPE_EVENT_STREAM)], body).into_response()
868 }
869 _ => StatusCode::NOT_FOUND.into_response(),
870 }
871 }
872
873 async fn error_handler(body: Bytes) -> impl IntoResponse {
874 let request: serde_json::Value = serde_json::from_slice(&body).expect("json body");
875 Json(json!({
876 "jsonrpc": "2.0",
877 "id": request["id"].clone(),
878 "error": {
879 "code": -32602,
880 "message": "bad input"
881 }
882 }))
883 }
884
885 async fn legacy_sse_handler(_body: Bytes) -> impl IntoResponse {
886 (
887 [(HEADER_CONTENT_TYPE, CONTENT_TYPE_EVENT_STREAM)],
888 "event: endpoint\ndata: /messages?session=abc\n\n".to_string(),
889 )
890 }
891
892 async fn start_server(app: Router) -> (String, tokio::task::JoinHandle<()>) {
893 let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
894 let addr = listener.local_addr().expect("local addr");
895 let handle = tokio::spawn(async move {
896 axum::serve(listener, app).await.expect("server");
897 });
898 (format!("http://{addr}/mcp"), handle)
899 }
900
901 #[tokio::test]
902 async fn streamable_http_initializes_and_replays_session_header() {
903 let state = TestState {
904 session_seen: Arc::new(AtomicUsize::new(0)),
905 initialized_seen: Arc::new(AtomicUsize::new(0)),
906 };
907 let session_seen = state.session_seen.clone();
908 let initialized_seen = state.initialized_seen.clone();
909 let app = Router::new()
910 .route("/mcp", post(json_handler))
911 .with_state(state);
912 let (url, handle) = start_server(app).await;
913
914 let client = McpClient::new().with_streamable_http_server(&url);
915 let tools = client.list_tools_async().await.expect("list tools");
916
917 assert_eq!(tools.len(), 1);
918 assert_eq!(tools[0].name, "search_agents");
919 assert_eq!(session_seen.load(Ordering::SeqCst), 1);
920 assert_eq!(initialized_seen.load(Ordering::SeqCst), 1);
921
922 handle.abort();
923 }
924
925 #[tokio::test]
926 async fn streamable_http_parses_event_stream_tool_results() {
927 let app = Router::new()
928 .route("/mcp", post(json_handler))
929 .with_state(TestState {
930 session_seen: Arc::new(AtomicUsize::new(0)),
931 initialized_seen: Arc::new(AtomicUsize::new(0)),
932 });
933 let (url, handle) = start_server(app).await;
934
935 let client = McpClient::new().with_streamable_http_server(&url);
936 let result = client
937 .call_tool_async("search_agents", Some(json!({"q": "planner"})))
938 .await
939 .expect("tool call");
940
941 assert_eq!(result.is_error, Some(false));
942 assert_eq!(result.content.len(), 1);
943
944 handle.abort();
945 }
946
947 #[tokio::test]
948 async fn streamable_http_surfaces_jsonrpc_errors() {
949 let app = Router::new().route("/mcp", post(error_handler));
950 let (url, handle) = start_server(app).await;
951
952 let client = McpClient::new().with_streamable_http_server(&url);
953 let error = client.list_tools_async().await.expect_err("should fail");
954
955 assert!(error.to_string().contains("JSON-RPC error -32602"));
956
957 handle.abort();
958 }
959
960 #[tokio::test]
961 async fn streamable_http_rejects_legacy_sse_only_responses() {
962 let app = Router::new().route("/mcp", post(legacy_sse_handler));
963 let (url, handle) = start_server(app).await;
964
965 let client = McpClient::new().with_streamable_http_server(&url);
966 let error = client.list_tools_async().await.expect_err("should fail");
967
968 assert!(
969 error
970 .to_string()
971 .contains("legacy SSE-only endpoints are unsupported")
972 );
973
974 handle.abort();
975 }
976}