1use crate::{
7 AuthHandler, CallToolResult, ClientCapabilities, ClientInfo, InitializeRequest, JsonRpcError,
8 JsonRpcRequest, JsonRpcResponse, ListToolsResult, MCP_PROTOCOL_VERSION, Tool, ToolCapabilities,
9};
10use protocol_transport_core::{ProtocolError, TransportError};
11use serde_json::json;
12use std::collections::HashMap;
13use std::sync::Mutex;
14
15#[cfg(feature = "sse-client")]
16use crate::ToolProvider;
17#[cfg(feature = "sse-client")]
18use protocol_transport_core::{SseTransport, Transport, TransportFactory, UniversalRequest};
19
20const CONTENT_TYPE_JSON: &str = "application/json";
21const CONTENT_TYPE_EVENT_STREAM: &str = "text/event-stream";
22const HEADER_ACCEPT: &str = "Accept";
23const HEADER_AUTHORIZATION: &str = "Authorization";
24const HEADER_CONTENT_TYPE: &str = "Content-Type";
25const HEADER_MCP_SESSION_ID: &str = "Mcp-Session-Id";
26
27enum ClientTransport {
28 StreamableHttp(StreamableHttpClientTransport),
29
30 #[cfg(feature = "sse-client")]
31 Sse {
32 transport: SseTransport,
33 },
34}
35
36struct StreamableHttpClientTransport {
37 endpoint: String,
38 auth_token: Option<String>,
39 extra_headers: HashMap<String, String>,
40 client_info: ClientInfo,
41 initialized: Mutex<bool>,
42 protocol_version: Mutex<Option<String>>,
43 session_id: Mutex<Option<String>>,
44 next_id: Mutex<u64>,
45}
46
47impl StreamableHttpClientTransport {
48 fn new(endpoint: impl Into<String>) -> Self {
49 Self {
50 endpoint: endpoint.into(),
51 auth_token: None,
52 extra_headers: HashMap::new(),
53 client_info: ClientInfo {
54 name: "promptfleet-mcp-client".to_string(),
55 version: env!("CARGO_PKG_VERSION").to_string(),
56 description: Some("PromptFleet Streamable HTTP MCP client".to_string()),
57 },
58 initialized: Mutex::new(false),
59 protocol_version: Mutex::new(None),
60 session_id: Mutex::new(None),
61 next_id: Mutex::new(0),
62 }
63 }
64
65 fn with_auth_token(mut self, token: impl Into<String>) -> Self {
66 self.auth_token = Some(token.into());
67 self
68 }
69
70 fn with_headers(mut self, headers: HashMap<String, String>) -> Self {
71 self.extra_headers = headers;
72 self
73 }
74
75 fn with_client_info(mut self, client_info: ClientInfo) -> Self {
76 self.client_info = client_info;
77 self
78 }
79
80 async fn initialize_if_needed(&self) -> Result<(), ProtocolError> {
81 let already_initialized = *self
82 .initialized
83 .lock()
84 .map_err(|_| ProtocolError::internal_error("streamable client init mutex poisoned"))?;
85 if already_initialized {
86 return Ok(());
87 }
88
89 let init_request = InitializeRequest {
90 protocol_version: MCP_PROTOCOL_VERSION.to_string(),
91 capabilities: ClientCapabilities {
92 tools: Some(ToolCapabilities { supported: true }),
93 },
94 client_info: self.client_info.clone(),
95 };
96
97 let result = self
98 .send_jsonrpc_raw(
99 "initialize",
100 Some(
101 serde_json::to_value(init_request)
102 .map_err(|e| ProtocolError::Parsing(format!("init serialize: {e}")))?,
103 ),
104 )
105 .await?;
106
107 let negotiated_protocol_version = result
108 .get("protocolVersion")
109 .or_else(|| result.get("protocol_version"))
110 .and_then(|value| value.as_str())
111 .map(ToString::to_string);
112 if negotiated_protocol_version.is_none() {
113 return Err(ProtocolError::Parsing(
114 "invalid initialize result: missing protocolVersion".to_string(),
115 ));
116 }
117 *self.protocol_version.lock().map_err(|_| {
118 ProtocolError::internal_error("streamable client protocol-version mutex poisoned")
119 })? = negotiated_protocol_version;
120
121 self.send_notification_raw("notifications/initialized", None)
122 .await?;
123
124 let mut initialized = self
125 .initialized
126 .lock()
127 .map_err(|_| ProtocolError::internal_error("streamable client init mutex poisoned"))?;
128 *initialized = true;
129 Ok(())
130 }
131
132 async fn list_tools(&self) -> Result<Vec<Tool>, ProtocolError> {
133 let result = self
134 .send_jsonrpc("tools/list", Some(json!({})), true)
135 .await?;
136 let list_result: ListToolsResult = serde_json::from_value(result)
137 .map_err(|e| ProtocolError::Parsing(format!("invalid tools list format: {e}")))?;
138 Ok(list_result.tools)
139 }
140
141 async fn call_tool(
142 &self,
143 name: &str,
144 arguments: Option<serde_json::Value>,
145 meta: Option<serde_json::Value>,
146 ) -> Result<CallToolResult, ProtocolError> {
147 let result = self
148 .send_jsonrpc(
149 "tools/call",
150 Some(json!({
151 "name": name,
152 "arguments": arguments,
153 "_meta": meta,
154 })),
155 true,
156 )
157 .await?;
158
159 serde_json::from_value(result)
160 .map_err(|e| ProtocolError::Parsing(format!("invalid tool call result format: {e}")))
161 }
162
163 async fn send_jsonrpc(
164 &self,
165 method: &str,
166 params: Option<serde_json::Value>,
167 require_initialized: bool,
168 ) -> Result<serde_json::Value, ProtocolError> {
169 if require_initialized {
170 self.initialize_if_needed().await?;
171 }
172 self.send_jsonrpc_raw(method, params).await
173 }
174
175 async fn send_jsonrpc_raw(
176 &self,
177 method: &str,
178 params: Option<serde_json::Value>,
179 ) -> Result<serde_json::Value, ProtocolError> {
180 let id = {
181 let mut next_id = self.next_id.lock().map_err(|_| {
182 ProtocolError::internal_error("streamable client request-id mutex poisoned")
183 })?;
184 *next_id += 1;
185 *next_id
186 };
187
188 let request = JsonRpcRequest {
189 jsonrpc: "2.0".to_string(),
190 id: Some(json!(id)),
191 method: method.to_string(),
192 params,
193 };
194
195 let body = serde_json::to_vec(&request)?;
196 let response = self.http_post(body).await?;
197
198 if let Some(session_id) = find_header(&response.headers, HEADER_MCP_SESSION_ID) {
199 let mut stored = self.session_id.lock().map_err(|_| {
200 ProtocolError::internal_error("streamable client session mutex poisoned")
201 })?;
202 *stored = Some(session_id.to_string());
203 }
204
205 let content_type = find_header(&response.headers, HEADER_CONTENT_TYPE)
206 .map(|value| value.to_ascii_lowercase())
207 .unwrap_or_else(|| CONTENT_TYPE_JSON.to_string());
208
209 let rpc_response = if content_type.contains(CONTENT_TYPE_JSON) {
210 serde_json::from_slice::<JsonRpcResponse>(&response.body).map_err(|e| {
211 ProtocolError::Parsing(format!("invalid JSON-RPC response body: {e}"))
212 })?
213 } else if content_type.contains(CONTENT_TYPE_EVENT_STREAM) {
214 parse_sse_jsonrpc_response(&response.body)?
215 } else {
216 return Err(ProtocolError::Parsing(format!(
217 "unsupported response content-type '{content_type}'"
218 )));
219 };
220
221 if let Some(error) = rpc_response.error {
222 return Err(protocol_error_from_jsonrpc(error));
223 }
224
225 rpc_response
226 .result
227 .ok_or_else(|| ProtocolError::Parsing("missing JSON-RPC result field".to_string()))
228 }
229
230 async fn send_notification_raw(
231 &self,
232 method: &str,
233 params: Option<serde_json::Value>,
234 ) -> Result<(), ProtocolError> {
235 let request = JsonRpcRequest {
236 jsonrpc: "2.0".to_string(),
237 id: None,
238 method: method.to_string(),
239 params,
240 };
241
242 let body = serde_json::to_vec(&request)
243 .map_err(|e| ProtocolError::Parsing(format!("request serialize: {e}")))?;
244 let _ = self.http_post(body).await?;
245 Ok(())
246 }
247
248 async fn http_post(&self, body: Vec<u8>) -> Result<HttpResponse, ProtocolError> {
249 let mut headers = HashMap::new();
250 for (key, value) in &self.extra_headers {
251 if !key.eq_ignore_ascii_case(HEADER_MCP_SESSION_ID) {
252 headers.insert(key.clone(), value.clone());
253 }
254 }
255 headers.insert(
256 HEADER_ACCEPT.to_string(),
257 format!("{CONTENT_TYPE_JSON}, {CONTENT_TYPE_EVENT_STREAM}"),
258 );
259 headers.insert(
260 HEADER_CONTENT_TYPE.to_string(),
261 CONTENT_TYPE_JSON.to_string(),
262 );
263 if let Some(protocol_version) = self
264 .protocol_version
265 .lock()
266 .map_err(|_| {
267 ProtocolError::internal_error("streamable client protocol-version mutex poisoned")
268 })?
269 .clone()
270 {
271 headers.insert("MCP-Protocol-Version".to_string(), protocol_version);
272 }
273
274 if let Some(token) = &self.auth_token {
275 headers.insert(HEADER_AUTHORIZATION.to_string(), format!("Bearer {token}"));
276 }
277
278 if let Some(session_id) = self
279 .session_id
280 .lock()
281 .map_err(|_| ProtocolError::internal_error("streamable client session mutex poisoned"))?
282 .clone()
283 {
284 headers.insert(HEADER_MCP_SESSION_ID.to_string(), session_id);
285 }
286
287 #[cfg(target_arch = "wasm32")]
288 {
289 use spin_sdk::http::{Method, Request as SpinRequest, Response as SpinResponse, send};
290
291 let mut builder = SpinRequest::builder();
292 builder.method(Method::Post);
293 builder.uri(&self.endpoint);
294 for (key, value) in &headers {
295 builder.header(key, value);
296 }
297 let request = builder.body(body).build();
298 let response: SpinResponse = send(request).await.map_err(|e| {
299 ProtocolError::Transport(TransportError::Network(format!(
300 "Spin HTTP send failed: {e}"
301 )))
302 })?;
303
304 let response_headers = response
305 .headers()
306 .filter_map(|(name, value)| {
307 value
308 .as_str()
309 .map(|value| (name.to_string(), value.to_string()))
310 })
311 .collect::<HashMap<_, _>>();
312 let status = *response.status();
313 let body = response.body().to_vec();
314
315 if !(200..300).contains(&status) {
316 return Err(ProtocolError::Transport(TransportError::Http {
317 status,
318 message: format!("streamable HTTP request failed with status {status}"),
319 body: Some(body),
320 headers: Some(response_headers),
321 }));
322 }
323
324 Ok(HttpResponse {
325 headers: response_headers,
326 body,
327 })
328 }
329
330 #[cfg(not(target_arch = "wasm32"))]
331 {
332 let client = reqwest::Client::builder()
333 .use_rustls_tls()
334 .build()
335 .map_err(|e| {
336 ProtocolError::Transport(TransportError::Network(format!(
337 "streamable HTTP client build failed: {e}; debug={e:?}"
338 )))
339 })?;
340 let mut request = client.post(&self.endpoint);
341 for (key, value) in &headers {
342 request = request.header(key, value);
343 }
344 let response = request.body(body).send().await.map_err(|e| {
345 ProtocolError::Transport(TransportError::Network(format!(
346 "streamable HTTP request failed: {e}; debug={e:?}"
347 )))
348 })?;
349
350 let status = response.status().as_u16();
351 let response_headers = response
352 .headers()
353 .iter()
354 .filter_map(|(name, value)| {
355 value
356 .to_str()
357 .ok()
358 .map(|value| (name.to_string(), value.to_string()))
359 })
360 .collect::<HashMap<_, _>>();
361 let body = response.bytes().await.map_err(|e| {
362 ProtocolError::Transport(TransportError::Network(format!(
363 "streamable HTTP response read failed: {e}"
364 )))
365 })?;
366 let body = body.to_vec();
367
368 if !(200..300).contains(&status) {
369 return Err(ProtocolError::Transport(TransportError::Http {
370 status,
371 message: format!("streamable HTTP request failed with status {status}"),
372 body: Some(body),
373 headers: Some(response_headers),
374 }));
375 }
376
377 Ok(HttpResponse {
378 headers: response_headers,
379 body,
380 })
381 }
382 }
383}
384
385struct HttpResponse {
386 headers: HashMap<String, String>,
387 body: Vec<u8>,
388}
389
390fn find_header<'a>(headers: &'a HashMap<String, String>, name: &str) -> Option<&'a str> {
391 headers
392 .iter()
393 .find(|(header_name, _)| header_name.eq_ignore_ascii_case(name))
394 .map(|(_, value)| value.as_str())
395}
396
397fn protocol_error_from_jsonrpc(error: JsonRpcError) -> ProtocolError {
398 let details = error
399 .data
400 .map(|value| format!(" data={value}"))
401 .unwrap_or_default();
402 ProtocolError::Validation(format!(
403 "JSON-RPC error {}: {}{}",
404 error.code, error.message, details
405 ))
406}
407
408fn parse_sse_jsonrpc_response(body: &[u8]) -> Result<JsonRpcResponse, ProtocolError> {
409 let text = std::str::from_utf8(body)
410 .map_err(|e| ProtocolError::Parsing(format!("invalid UTF-8 event-stream body: {e}")))?;
411 let mut data_lines = Vec::new();
412
413 for line in text.lines() {
414 if let Some(rest) = line.strip_prefix("data:") {
415 data_lines.push(rest.trim_start().to_string());
416 continue;
417 }
418
419 if line.trim().is_empty() && !data_lines.is_empty() {
420 let payload = data_lines.join("\n");
421 if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(&payload) {
422 return Ok(response);
423 }
424 data_lines.clear();
425 }
426 }
427
428 if !data_lines.is_empty() {
429 let payload = data_lines.join("\n");
430 if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(&payload) {
431 return Ok(response);
432 }
433 }
434
435 Err(ProtocolError::Parsing(format!(
436 "event-stream response did not contain an MCP JSON-RPC payload; legacy SSE-only endpoints are unsupported; body={text:?}"
437 )))
438}
439
440pub struct McpClient {
442 auth_handler: Option<Box<dyn AuthHandler>>,
443 transport: Option<ClientTransport>,
444}
445
446impl McpClient {
447 pub fn new() -> Self {
449 Self {
450 auth_handler: None,
451 transport: None,
452 }
453 }
454
455 pub fn with_auth_handler<H: AuthHandler + 'static>(mut self, handler: H) -> Self {
457 self.auth_handler = Some(Box::new(handler));
458 self
459 }
460
461 pub fn with_streamable_http_server(mut self, endpoint: &str) -> Self {
463 self.transport = Some(ClientTransport::StreamableHttp(
464 StreamableHttpClientTransport::new(endpoint),
465 ));
466 self
467 }
468
469 pub fn with_streamable_http_server_auth(mut self, endpoint: &str, auth_token: &str) -> Self {
471 self.transport = Some(ClientTransport::StreamableHttp(
472 StreamableHttpClientTransport::new(endpoint).with_auth_token(auth_token),
473 ));
474 self
475 }
476
477 pub fn with_streamable_http_headers(mut self, headers: HashMap<String, String>) -> Self {
479 let transport = match self.transport.take() {
480 Some(ClientTransport::StreamableHttp(transport)) => {
481 ClientTransport::StreamableHttp(transport.with_headers(headers))
482 }
483 other => {
484 self.transport = other;
485 return self;
486 }
487 };
488 self.transport = Some(transport);
489 self
490 }
491
492 pub fn with_streamable_http_client_info(mut self, client_info: ClientInfo) -> Self {
494 let transport = match self.transport.take() {
495 Some(ClientTransport::StreamableHttp(transport)) => {
496 ClientTransport::StreamableHttp(transport.with_client_info(client_info))
497 }
498 other => {
499 self.transport = other;
500 return self;
501 }
502 };
503 self.transport = Some(transport);
504 self
505 }
506
507 #[cfg(feature = "sse-client")]
509 pub fn with_sse_server(mut self, endpoint: &str) -> Self {
510 self.transport = Some(ClientTransport::Sse {
511 transport: TransportFactory::mcp_sse(endpoint),
512 });
513 self
514 }
515
516 #[cfg(feature = "sse-client")]
518 pub fn with_sse_server_auth(mut self, endpoint: &str, auth_token: &str) -> Self {
519 self.transport = Some(ClientTransport::Sse {
520 transport: TransportFactory::mcp_sse_auth(endpoint, auth_token),
521 });
522 self
523 }
524
525 pub async fn list_tools_async(&self) -> Result<Vec<Tool>, ProtocolError> {
527 match self
528 .transport
529 .as_ref()
530 .ok_or_else(|| ProtocolError::internal_error("no MCP transport configured"))?
531 {
532 ClientTransport::StreamableHttp(transport) => transport.list_tools().await,
533
534 #[cfg(feature = "sse-client")]
535 ClientTransport::Sse { transport } => {
536 let result = send_sse_request(transport, "tools/list", json!({})).await?;
537 let list_result: ListToolsResult = serde_json::from_value(result).map_err(|e| {
538 ProtocolError::Parsing(format!("invalid tools list format: {e}"))
539 })?;
540 Ok(list_result.tools)
541 }
542 }
543 }
544
545 pub async fn call_tool_async(
547 &self,
548 name: &str,
549 arguments: Option<serde_json::Value>,
550 ) -> Result<CallToolResult, ProtocolError> {
551 self.call_tool_with_meta_async(name, arguments, None).await
552 }
553
554 pub async fn call_tool_with_meta_async(
555 &self,
556 name: &str,
557 arguments: Option<serde_json::Value>,
558 meta: Option<serde_json::Value>,
559 ) -> Result<CallToolResult, ProtocolError> {
560 match self
561 .transport
562 .as_ref()
563 .ok_or_else(|| ProtocolError::internal_error("no MCP transport configured"))?
564 {
565 ClientTransport::StreamableHttp(transport) => {
566 transport.call_tool(name, arguments, meta).await
567 }
568
569 #[cfg(feature = "sse-client")]
570 ClientTransport::Sse { transport } => {
571 let result = send_sse_request(
572 transport,
573 "tools/call",
574 json!({
575 "name": name,
576 "arguments": arguments,
577 "_meta": meta,
578 }),
579 )
580 .await?;
581
582 serde_json::from_value(result).map_err(|e| {
583 ProtocolError::Parsing(format!("invalid tool call result format: {e}"))
584 })
585 }
586 }
587 }
588
589 pub async fn initialize_async(&self) -> Result<(), ProtocolError> {
591 match self
592 .transport
593 .as_ref()
594 .ok_or_else(|| ProtocolError::internal_error("no MCP transport configured"))?
595 {
596 ClientTransport::StreamableHttp(transport) => transport.initialize_if_needed().await,
597
598 #[cfg(feature = "sse-client")]
599 ClientTransport::Sse { .. } => Ok(()),
600 }
601 }
602
603 pub async fn health_check(&self) -> Result<(), ProtocolError> {
605 match self
606 .transport
607 .as_ref()
608 .ok_or_else(|| ProtocolError::internal_error("no MCP transport configured"))?
609 {
610 ClientTransport::StreamableHttp(transport) => transport.initialize_if_needed().await,
611
612 #[cfg(feature = "sse-client")]
613 ClientTransport::Sse { transport } => transport.health_check().await.map_err(|e| {
614 ProtocolError::internal_error(&format!("health check failed: {:?}", e))
615 }),
616 }
617 }
618}
619
620#[cfg(feature = "sse-client")]
621async fn send_sse_request(
622 transport: &SseTransport,
623 method: &str,
624 params: serde_json::Value,
625) -> Result<serde_json::Value, ProtocolError> {
626 let request = UniversalRequest {
627 method: method.to_string(),
628 uri: "/".to_string(),
629 headers: HashMap::new(),
630 body: json!({
631 "jsonrpc": "2.0",
632 "method": method,
633 "params": params,
634 "id": 1,
635 })
636 .to_string()
637 .into_bytes(),
638 protocol: "MCP".to_string(),
639 correlation_id: format!("mcp-client-{}", method.replace('/', "-")),
640 };
641
642 let response = transport
643 .send(request)
644 .await
645 .map_err(|e| ProtocolError::internal_error(&format!("transport error: {e:?}")))?;
646
647 let response_json: serde_json::Value = serde_json::from_slice(&response.body)
648 .map_err(|e| ProtocolError::Parsing(format!("invalid JSON response: {e}")))?;
649
650 response_json
651 .get("result")
652 .cloned()
653 .ok_or_else(|| ProtocolError::Parsing("missing 'result' field".to_string()))
654}
655
656#[cfg(feature = "sse-client")]
657#[async_trait::async_trait]
658impl ToolProvider for McpClient {
659 fn list_tools(&self) -> Result<Vec<Tool>, ProtocolError> {
660 Err(ProtocolError::internal_error(
661 "async tool listing not supported in sync context. Use list_tools_async().",
662 ))
663 }
664
665 async fn call_tool(
666 &self,
667 name: &str,
668 _arguments: Option<serde_json::Value>,
669 ) -> Result<CallToolResult, ProtocolError> {
670 Err(ProtocolError::internal_error(&format!(
671 "async tool calls not supported in sync context. Use call_tool_async() for tool '{name}'.",
672 )))
673 }
674}
675
676impl Default for McpClient {
677 fn default() -> Self {
678 Self::new()
679 }
680}
681
682pub struct McpClientBuilder {
684 auth_handler: Option<Box<dyn AuthHandler>>,
685 streamable_http_endpoint: Option<String>,
686 streamable_http_auth_token: Option<String>,
687
688 #[cfg(feature = "sse-client")]
689 sse_endpoint: Option<String>,
690
691 #[cfg(feature = "sse-client")]
692 sse_auth_token: Option<String>,
693}
694
695impl McpClientBuilder {
696 pub fn new() -> Self {
698 Self {
699 auth_handler: None,
700 streamable_http_endpoint: None,
701 streamable_http_auth_token: None,
702
703 #[cfg(feature = "sse-client")]
704 sse_endpoint: None,
705
706 #[cfg(feature = "sse-client")]
707 sse_auth_token: None,
708 }
709 }
710
711 pub fn with_auth_handler<H: AuthHandler + 'static>(mut self, handler: H) -> Self {
713 self.auth_handler = Some(Box::new(handler));
714 self
715 }
716
717 pub fn with_streamable_http_server(mut self, endpoint: &str) -> Self {
719 self.streamable_http_endpoint = Some(endpoint.to_string());
720 self
721 }
722
723 pub fn with_streamable_http_auth_token(mut self, token: &str) -> Self {
725 self.streamable_http_auth_token = Some(token.to_string());
726 self
727 }
728
729 #[cfg(feature = "sse-client")]
731 pub fn with_sse_server(mut self, endpoint: &str) -> Self {
732 self.sse_endpoint = Some(endpoint.to_string());
733 self
734 }
735
736 #[cfg(feature = "sse-client")]
738 pub fn with_auth_token(mut self, token: &str) -> Self {
739 self.sse_auth_token = Some(token.to_string());
740 self
741 }
742
743 pub fn build(self) -> McpClient {
745 let mut client = McpClient::new();
746
747 if let Some(handler) = self.auth_handler {
748 client.auth_handler = Some(handler);
749 }
750
751 if let Some(endpoint) = self.streamable_http_endpoint {
752 client = if let Some(token) = self.streamable_http_auth_token {
753 client.with_streamable_http_server_auth(&endpoint, &token)
754 } else {
755 client.with_streamable_http_server(&endpoint)
756 };
757 }
758
759 #[cfg(feature = "sse-client")]
760 {
761 if let Some(endpoint) = self.sse_endpoint {
762 client = if let Some(token) = self.sse_auth_token {
763 client.with_sse_server_auth(&endpoint, &token)
764 } else {
765 client.with_sse_server(&endpoint)
766 };
767 }
768 }
769
770 client
771 }
772}
773
774#[cfg(all(test, not(target_arch = "wasm32")))]
775mod tests {
776 use super::*;
777 use axum::{
778 Json, Router,
779 body::Bytes,
780 extract::State,
781 http::{HeaderMap, HeaderValue, StatusCode},
782 response::IntoResponse,
783 routing::post,
784 };
785 use std::sync::{
786 Arc,
787 atomic::{AtomicUsize, Ordering},
788 };
789 use tokio::net::TcpListener;
790
791 #[derive(Clone)]
792 struct TestState {
793 session_seen: Arc<AtomicUsize>,
794 initialized_seen: Arc<AtomicUsize>,
795 }
796
797 async fn json_handler(
798 State(state): State<TestState>,
799 headers: HeaderMap,
800 body: Bytes,
801 ) -> impl IntoResponse {
802 let request: serde_json::Value = serde_json::from_slice(&body).expect("json body");
803 let method = request["method"].as_str().expect("method");
804 match method {
805 "initialize" => {
806 let mut response_headers = HeaderMap::new();
807 response_headers.insert(
808 HEADER_MCP_SESSION_ID,
809 HeaderValue::from_static("session-123"),
810 );
811 (
812 response_headers,
813 Json(json!({
814 "jsonrpc": "2.0",
815 "id": request["id"].clone(),
816 "result": {
817 "protocolVersion": MCP_PROTOCOL_VERSION,
818 "capabilities": { "tools": { "supported": true } },
819 "serverInfo": {
820 "name": "test-server",
821 "version": "0.1.0"
822 }
823 }
824 })),
825 )
826 .into_response()
827 }
828 "notifications/initialized" => {
829 assert!(
830 request.get("id").is_none(),
831 "initialized notification must not carry an id"
832 );
833 state.initialized_seen.fetch_add(1, Ordering::SeqCst);
834 StatusCode::ACCEPTED.into_response()
835 }
836 "tools/list" => {
837 assert_eq!(
838 state.initialized_seen.load(Ordering::SeqCst),
839 1,
840 "tools/list should only be called after notifications/initialized"
841 );
842 if headers
843 .get(HEADER_MCP_SESSION_ID)
844 .and_then(|value| value.to_str().ok())
845 == Some("session-123")
846 {
847 state.session_seen.fetch_add(1, Ordering::SeqCst);
848 }
849 Json(json!({
850 "jsonrpc": "2.0",
851 "id": request["id"].clone(),
852 "result": {
853 "tools": [{
854 "name": "search_agents",
855 "description": "Search directory",
856 "inputSchema": { "type": "object", "properties": {} }
857 }]
858 }
859 }))
860 .into_response()
861 }
862 "tools/call" => {
863 let body = format!(
864 "event: message\ndata: {}\n\n",
865 json!({
866 "jsonrpc": "2.0",
867 "id": request["id"].clone(),
868 "result": {
869 "content": [{ "type": "text", "text": "{\"ok\":true}" }],
870 "isError": false
871 }
872 })
873 );
874 ([(HEADER_CONTENT_TYPE, CONTENT_TYPE_EVENT_STREAM)], body).into_response()
875 }
876 _ => StatusCode::NOT_FOUND.into_response(),
877 }
878 }
879
880 async fn error_handler(body: Bytes) -> impl IntoResponse {
881 let request: serde_json::Value = serde_json::from_slice(&body).expect("json body");
882 Json(json!({
883 "jsonrpc": "2.0",
884 "id": request["id"].clone(),
885 "error": {
886 "code": -32602,
887 "message": "bad input"
888 }
889 }))
890 }
891
892 async fn legacy_sse_handler(_body: Bytes) -> impl IntoResponse {
893 (
894 [(HEADER_CONTENT_TYPE, CONTENT_TYPE_EVENT_STREAM)],
895 "event: endpoint\ndata: /messages?session=abc\n\n".to_string(),
896 )
897 }
898
899 async fn start_server(app: Router) -> (String, tokio::task::JoinHandle<()>) {
900 let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
901 let addr = listener.local_addr().expect("local addr");
902 let handle = tokio::spawn(async move {
903 axum::serve(listener, app).await.expect("server");
904 });
905 (format!("http://{addr}/mcp"), handle)
906 }
907
908 #[tokio::test]
909 async fn streamable_http_initializes_and_replays_session_header() {
910 let state = TestState {
911 session_seen: Arc::new(AtomicUsize::new(0)),
912 initialized_seen: Arc::new(AtomicUsize::new(0)),
913 };
914 let session_seen = state.session_seen.clone();
915 let initialized_seen = state.initialized_seen.clone();
916 let app = Router::new()
917 .route("/mcp", post(json_handler))
918 .with_state(state);
919 let (url, handle) = start_server(app).await;
920
921 let client = McpClient::new().with_streamable_http_server(&url);
922 let tools = client.list_tools_async().await.expect("list tools");
923
924 assert_eq!(tools.len(), 1);
925 assert_eq!(tools[0].name, "search_agents");
926 assert_eq!(session_seen.load(Ordering::SeqCst), 1);
927 assert_eq!(initialized_seen.load(Ordering::SeqCst), 1);
928
929 handle.abort();
930 }
931
932 #[tokio::test]
933 async fn streamable_http_parses_event_stream_tool_results() {
934 let app = Router::new()
935 .route("/mcp", post(json_handler))
936 .with_state(TestState {
937 session_seen: Arc::new(AtomicUsize::new(0)),
938 initialized_seen: Arc::new(AtomicUsize::new(0)),
939 });
940 let (url, handle) = start_server(app).await;
941
942 let client = McpClient::new().with_streamable_http_server(&url);
943 let result = client
944 .call_tool_async("search_agents", Some(json!({"q": "planner"})))
945 .await
946 .expect("tool call");
947
948 assert_eq!(result.is_error, Some(false));
949 assert_eq!(result.content.len(), 1);
950
951 handle.abort();
952 }
953
954 #[tokio::test]
955 async fn streamable_http_surfaces_jsonrpc_errors() {
956 let app = Router::new().route("/mcp", post(error_handler));
957 let (url, handle) = start_server(app).await;
958
959 let client = McpClient::new().with_streamable_http_server(&url);
960 let error = client.list_tools_async().await.expect_err("should fail");
961
962 assert!(error.to_string().contains("JSON-RPC error -32602"));
963
964 handle.abort();
965 }
966
967 #[tokio::test]
968 async fn streamable_http_rejects_legacy_sse_only_responses() {
969 let app = Router::new().route("/mcp", post(legacy_sse_handler));
970 let (url, handle) = start_server(app).await;
971
972 let client = McpClient::new().with_streamable_http_server(&url);
973 let error = client.list_tools_async().await.expect_err("should fail");
974
975 assert!(
976 error
977 .to_string()
978 .contains("legacy SSE-only endpoints are unsupported")
979 );
980
981 handle.abort();
982 }
983}