Skip to main content

mcp_protocol/
client.rs

1//! # MCP Client
2//!
3//! MCP client implementation for connecting to MCP servers over either
4//! Streamable HTTP or the older direct SSE feature path.
5
6use crate::{
7    AuthHandler, CallToolResult, ClientCapabilities, ClientInfo, InitializeRequest, JsonRpcError,
8    JsonRpcRequest, JsonRpcResponse, ListToolsResult, MCP_PROTOCOL_VERSION, Tool, ToolCapabilities,
9};
10use protocol_transport_core::{ProtocolError, TransportError};
11use serde_json::json;
12use std::collections::HashMap;
13use std::sync::Mutex;
14
15#[cfg(feature = "sse-client")]
16use crate::ToolProvider;
17#[cfg(feature = "sse-client")]
18use protocol_transport_core::{SseTransport, Transport, TransportFactory, UniversalRequest};
19
20const CONTENT_TYPE_JSON: &str = "application/json";
21const CONTENT_TYPE_EVENT_STREAM: &str = "text/event-stream";
22const HEADER_ACCEPT: &str = "Accept";
23const HEADER_AUTHORIZATION: &str = "Authorization";
24const HEADER_CONTENT_TYPE: &str = "Content-Type";
25const HEADER_MCP_SESSION_ID: &str = "Mcp-Session-Id";
26
27enum ClientTransport {
28    StreamableHttp(StreamableHttpClientTransport),
29
30    #[cfg(feature = "sse-client")]
31    Sse {
32        transport: SseTransport,
33    },
34}
35
36struct StreamableHttpClientTransport {
37    endpoint: String,
38    auth_token: Option<String>,
39    extra_headers: HashMap<String, String>,
40    client_info: ClientInfo,
41    initialized: Mutex<bool>,
42    protocol_version: Mutex<Option<String>>,
43    session_id: Mutex<Option<String>>,
44    next_id: Mutex<u64>,
45}
46
47impl StreamableHttpClientTransport {
48    fn new(endpoint: impl Into<String>) -> Self {
49        Self {
50            endpoint: endpoint.into(),
51            auth_token: None,
52            extra_headers: HashMap::new(),
53            client_info: ClientInfo {
54                name: "promptfleet-mcp-client".to_string(),
55                version: env!("CARGO_PKG_VERSION").to_string(),
56                description: Some("PromptFleet Streamable HTTP MCP client".to_string()),
57            },
58            initialized: Mutex::new(false),
59            protocol_version: Mutex::new(None),
60            session_id: Mutex::new(None),
61            next_id: Mutex::new(0),
62        }
63    }
64
65    fn with_auth_token(mut self, token: impl Into<String>) -> Self {
66        self.auth_token = Some(token.into());
67        self
68    }
69
70    fn with_headers(mut self, headers: HashMap<String, String>) -> Self {
71        self.extra_headers = headers;
72        self
73    }
74
75    fn with_client_info(mut self, client_info: ClientInfo) -> Self {
76        self.client_info = client_info;
77        self
78    }
79
80    async fn initialize_if_needed(&self) -> Result<(), ProtocolError> {
81        let already_initialized = *self
82            .initialized
83            .lock()
84            .map_err(|_| ProtocolError::internal_error("streamable client init mutex poisoned"))?;
85        if already_initialized {
86            return Ok(());
87        }
88
89        let init_request = InitializeRequest {
90            protocol_version: MCP_PROTOCOL_VERSION.to_string(),
91            capabilities: ClientCapabilities {
92                tools: Some(ToolCapabilities { supported: true }),
93            },
94            client_info: self.client_info.clone(),
95        };
96
97        let result = self
98            .send_jsonrpc_raw(
99                "initialize",
100                Some(
101                    serde_json::to_value(init_request)
102                        .map_err(|e| ProtocolError::Parsing(format!("init serialize: {e}")))?,
103                ),
104            )
105            .await?;
106
107        let negotiated_protocol_version = result
108            .get("protocolVersion")
109            .or_else(|| result.get("protocol_version"))
110            .and_then(|value| value.as_str())
111            .map(ToString::to_string);
112        if negotiated_protocol_version.is_none() {
113            return Err(ProtocolError::Parsing(
114                "invalid initialize result: missing protocolVersion".to_string(),
115            ));
116        }
117        *self.protocol_version.lock().map_err(|_| {
118            ProtocolError::internal_error("streamable client protocol-version mutex poisoned")
119        })? = negotiated_protocol_version;
120
121        self.send_notification_raw("notifications/initialized", None)
122            .await?;
123
124        let mut initialized = self
125            .initialized
126            .lock()
127            .map_err(|_| ProtocolError::internal_error("streamable client init mutex poisoned"))?;
128        *initialized = true;
129        Ok(())
130    }
131
132    async fn list_tools(&self) -> Result<Vec<Tool>, ProtocolError> {
133        let result = self
134            .send_jsonrpc("tools/list", Some(json!({})), true)
135            .await?;
136        let list_result: ListToolsResult = serde_json::from_value(result)
137            .map_err(|e| ProtocolError::Parsing(format!("invalid tools list format: {e}")))?;
138        Ok(list_result.tools)
139    }
140
141    async fn call_tool(
142        &self,
143        name: &str,
144        arguments: Option<serde_json::Value>,
145        meta: Option<serde_json::Value>,
146    ) -> Result<CallToolResult, ProtocolError> {
147        let result = self
148            .send_jsonrpc(
149                "tools/call",
150                Some(json!({
151                    "name": name,
152                    "arguments": arguments,
153                    "_meta": meta,
154                })),
155                true,
156            )
157            .await?;
158
159        serde_json::from_value(result)
160            .map_err(|e| ProtocolError::Parsing(format!("invalid tool call result format: {e}")))
161    }
162
163    async fn send_jsonrpc(
164        &self,
165        method: &str,
166        params: Option<serde_json::Value>,
167        require_initialized: bool,
168    ) -> Result<serde_json::Value, ProtocolError> {
169        if require_initialized {
170            self.initialize_if_needed().await?;
171        }
172        self.send_jsonrpc_raw(method, params).await
173    }
174
175    async fn send_jsonrpc_raw(
176        &self,
177        method: &str,
178        params: Option<serde_json::Value>,
179    ) -> Result<serde_json::Value, ProtocolError> {
180        let id = {
181            let mut next_id = self.next_id.lock().map_err(|_| {
182                ProtocolError::internal_error("streamable client request-id mutex poisoned")
183            })?;
184            *next_id += 1;
185            *next_id
186        };
187
188        let request = JsonRpcRequest {
189            jsonrpc: "2.0".to_string(),
190            id: Some(json!(id)),
191            method: method.to_string(),
192            params,
193        };
194
195        let body = serde_json::to_vec(&request)?;
196        let response = self.http_post(body).await?;
197
198        if let Some(session_id) = find_header(&response.headers, HEADER_MCP_SESSION_ID) {
199            let mut stored = self.session_id.lock().map_err(|_| {
200                ProtocolError::internal_error("streamable client session mutex poisoned")
201            })?;
202            *stored = Some(session_id.to_string());
203        }
204
205        let content_type = find_header(&response.headers, HEADER_CONTENT_TYPE)
206            .map(|value| value.to_ascii_lowercase())
207            .unwrap_or_else(|| CONTENT_TYPE_JSON.to_string());
208
209        let rpc_response = if content_type.contains(CONTENT_TYPE_JSON) {
210            serde_json::from_slice::<JsonRpcResponse>(&response.body).map_err(|e| {
211                ProtocolError::Parsing(format!("invalid JSON-RPC response body: {e}"))
212            })?
213        } else if content_type.contains(CONTENT_TYPE_EVENT_STREAM) {
214            parse_sse_jsonrpc_response(&response.body)?
215        } else {
216            return Err(ProtocolError::Parsing(format!(
217                "unsupported response content-type '{content_type}'"
218            )));
219        };
220
221        if let Some(error) = rpc_response.error {
222            return Err(protocol_error_from_jsonrpc(error));
223        }
224
225        rpc_response
226            .result
227            .ok_or_else(|| ProtocolError::Parsing("missing JSON-RPC result field".to_string()))
228    }
229
230    async fn send_notification_raw(
231        &self,
232        method: &str,
233        params: Option<serde_json::Value>,
234    ) -> Result<(), ProtocolError> {
235        let request = JsonRpcRequest {
236            jsonrpc: "2.0".to_string(),
237            id: None,
238            method: method.to_string(),
239            params,
240        };
241
242        let body = serde_json::to_vec(&request)
243            .map_err(|e| ProtocolError::Parsing(format!("request serialize: {e}")))?;
244        let _ = self.http_post(body).await?;
245        Ok(())
246    }
247
248    async fn http_post(&self, body: Vec<u8>) -> Result<HttpResponse, ProtocolError> {
249        let mut headers = HashMap::new();
250        for (key, value) in &self.extra_headers {
251            if !key.eq_ignore_ascii_case(HEADER_MCP_SESSION_ID) {
252                headers.insert(key.clone(), value.clone());
253            }
254        }
255        headers.insert(
256            HEADER_ACCEPT.to_string(),
257            format!("{CONTENT_TYPE_JSON}, {CONTENT_TYPE_EVENT_STREAM}"),
258        );
259        headers.insert(
260            HEADER_CONTENT_TYPE.to_string(),
261            CONTENT_TYPE_JSON.to_string(),
262        );
263        if let Some(protocol_version) = self
264            .protocol_version
265            .lock()
266            .map_err(|_| {
267                ProtocolError::internal_error("streamable client protocol-version mutex poisoned")
268            })?
269            .clone()
270        {
271            headers.insert("MCP-Protocol-Version".to_string(), protocol_version);
272        }
273
274        if let Some(token) = &self.auth_token {
275            headers.insert(HEADER_AUTHORIZATION.to_string(), format!("Bearer {token}"));
276        }
277
278        if let Some(session_id) = self
279            .session_id
280            .lock()
281            .map_err(|_| ProtocolError::internal_error("streamable client session mutex poisoned"))?
282            .clone()
283        {
284            headers.insert(HEADER_MCP_SESSION_ID.to_string(), session_id);
285        }
286
287        #[cfg(target_arch = "wasm32")]
288        {
289            use spin_sdk::http::{Method, Request as SpinRequest, Response as SpinResponse, send};
290
291            let mut builder = SpinRequest::builder();
292            builder.method(Method::Post);
293            builder.uri(&self.endpoint);
294            for (key, value) in &headers {
295                builder.header(key, value);
296            }
297            let request = builder.body(body).build();
298            let response: SpinResponse = send(request).await.map_err(|e| {
299                ProtocolError::Transport(TransportError::Network(format!(
300                    "Spin HTTP send failed: {e}"
301                )))
302            })?;
303
304            let response_headers = response
305                .headers()
306                .filter_map(|(name, value)| {
307                    value
308                        .as_str()
309                        .map(|value| (name.to_string(), value.to_string()))
310                })
311                .collect::<HashMap<_, _>>();
312            let status = *response.status();
313            let body = response.body().to_vec();
314
315            if !(200..300).contains(&status) {
316                return Err(ProtocolError::Transport(TransportError::Http {
317                    status,
318                    message: format!("streamable HTTP request failed with status {status}"),
319                    body: Some(body),
320                    headers: Some(response_headers),
321                }));
322            }
323
324            Ok(HttpResponse {
325                headers: response_headers,
326                body,
327            })
328        }
329
330        #[cfg(not(target_arch = "wasm32"))]
331        {
332            let client = reqwest::Client::builder()
333                .use_rustls_tls()
334                .build()
335                .map_err(|e| {
336                    ProtocolError::Transport(TransportError::Network(format!(
337                        "streamable HTTP client build failed: {e}; debug={e:?}"
338                    )))
339                })?;
340            let mut request = client.post(&self.endpoint);
341            for (key, value) in &headers {
342                request = request.header(key, value);
343            }
344            let response = request.body(body).send().await.map_err(|e| {
345                ProtocolError::Transport(TransportError::Network(format!(
346                    "streamable HTTP request failed: {e}; debug={e:?}"
347                )))
348            })?;
349
350            let status = response.status().as_u16();
351            let response_headers = response
352                .headers()
353                .iter()
354                .filter_map(|(name, value)| {
355                    value
356                        .to_str()
357                        .ok()
358                        .map(|value| (name.to_string(), value.to_string()))
359                })
360                .collect::<HashMap<_, _>>();
361            let body = response.bytes().await.map_err(|e| {
362                ProtocolError::Transport(TransportError::Network(format!(
363                    "streamable HTTP response read failed: {e}"
364                )))
365            })?;
366            let body = body.to_vec();
367
368            if !(200..300).contains(&status) {
369                return Err(ProtocolError::Transport(TransportError::Http {
370                    status,
371                    message: format!("streamable HTTP request failed with status {status}"),
372                    body: Some(body),
373                    headers: Some(response_headers),
374                }));
375            }
376
377            Ok(HttpResponse {
378                headers: response_headers,
379                body,
380            })
381        }
382    }
383}
384
385struct HttpResponse {
386    headers: HashMap<String, String>,
387    body: Vec<u8>,
388}
389
390fn find_header<'a>(headers: &'a HashMap<String, String>, name: &str) -> Option<&'a str> {
391    headers
392        .iter()
393        .find(|(header_name, _)| header_name.eq_ignore_ascii_case(name))
394        .map(|(_, value)| value.as_str())
395}
396
397fn protocol_error_from_jsonrpc(error: JsonRpcError) -> ProtocolError {
398    let details = error
399        .data
400        .map(|value| format!(" data={value}"))
401        .unwrap_or_default();
402    ProtocolError::Validation(format!(
403        "JSON-RPC error {}: {}{}",
404        error.code, error.message, details
405    ))
406}
407
408fn parse_sse_jsonrpc_response(body: &[u8]) -> Result<JsonRpcResponse, ProtocolError> {
409    let text = std::str::from_utf8(body)
410        .map_err(|e| ProtocolError::Parsing(format!("invalid UTF-8 event-stream body: {e}")))?;
411    let mut data_lines = Vec::new();
412
413    for line in text.lines() {
414        if let Some(rest) = line.strip_prefix("data:") {
415            data_lines.push(rest.trim_start().to_string());
416            continue;
417        }
418
419        if line.trim().is_empty() && !data_lines.is_empty() {
420            let payload = data_lines.join("\n");
421            if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(&payload) {
422                return Ok(response);
423            }
424            data_lines.clear();
425        }
426    }
427
428    if !data_lines.is_empty() {
429        let payload = data_lines.join("\n");
430        if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(&payload) {
431            return Ok(response);
432        }
433    }
434
435    Err(ProtocolError::Parsing(format!(
436        "event-stream response did not contain an MCP JSON-RPC payload; legacy SSE-only endpoints are unsupported; body={text:?}"
437    )))
438}
439
440/// **MCP Client** - Connect to MCP servers
441pub struct McpClient {
442    auth_handler: Option<Box<dyn AuthHandler>>,
443    transport: Option<ClientTransport>,
444}
445
446impl McpClient {
447    /// Create new MCP client
448    pub fn new() -> Self {
449        Self {
450            auth_handler: None,
451            transport: None,
452        }
453    }
454
455    /// Configure authentication
456    pub fn with_auth_handler<H: AuthHandler + 'static>(mut self, handler: H) -> Self {
457        self.auth_handler = Some(Box::new(handler));
458        self
459    }
460
461    /// Connect to MCP server via Streamable HTTP.
462    pub fn with_streamable_http_server(mut self, endpoint: &str) -> Self {
463        self.transport = Some(ClientTransport::StreamableHttp(
464            StreamableHttpClientTransport::new(endpoint),
465        ));
466        self
467    }
468
469    /// Connect to MCP server via Streamable HTTP with bearer authentication.
470    pub fn with_streamable_http_server_auth(mut self, endpoint: &str, auth_token: &str) -> Self {
471        self.transport = Some(ClientTransport::StreamableHttp(
472            StreamableHttpClientTransport::new(endpoint).with_auth_token(auth_token),
473        ));
474        self
475    }
476
477    /// Add extra HTTP headers to the Streamable HTTP transport.
478    pub fn with_streamable_http_headers(mut self, headers: HashMap<String, String>) -> Self {
479        let transport = match self.transport.take() {
480            Some(ClientTransport::StreamableHttp(transport)) => {
481                ClientTransport::StreamableHttp(transport.with_headers(headers))
482            }
483            other => {
484                self.transport = other;
485                return self;
486            }
487        };
488        self.transport = Some(transport);
489        self
490    }
491
492    /// Override the Streamable HTTP client identity used during initialize.
493    pub fn with_streamable_http_client_info(mut self, client_info: ClientInfo) -> Self {
494        let transport = match self.transport.take() {
495            Some(ClientTransport::StreamableHttp(transport)) => {
496                ClientTransport::StreamableHttp(transport.with_client_info(client_info))
497            }
498            other => {
499                self.transport = other;
500                return self;
501            }
502        };
503        self.transport = Some(transport);
504        self
505    }
506
507    /// Connect to MCP server via SSE (feature: "sse-client")
508    #[cfg(feature = "sse-client")]
509    pub fn with_sse_server(mut self, endpoint: &str) -> Self {
510        self.transport = Some(ClientTransport::Sse {
511            transport: TransportFactory::mcp_sse(endpoint),
512        });
513        self
514    }
515
516    /// Connect to MCP server via SSE with authentication (feature: "sse-client")
517    #[cfg(feature = "sse-client")]
518    pub fn with_sse_server_auth(mut self, endpoint: &str, auth_token: &str) -> Self {
519        self.transport = Some(ClientTransport::Sse {
520            transport: TransportFactory::mcp_sse_auth(endpoint, auth_token),
521        });
522        self
523    }
524
525    /// List tools from server.
526    pub async fn list_tools_async(&self) -> Result<Vec<Tool>, ProtocolError> {
527        match self
528            .transport
529            .as_ref()
530            .ok_or_else(|| ProtocolError::internal_error("no MCP transport configured"))?
531        {
532            ClientTransport::StreamableHttp(transport) => transport.list_tools().await,
533
534            #[cfg(feature = "sse-client")]
535            ClientTransport::Sse { transport } => {
536                let result = send_sse_request(transport, "tools/list", json!({})).await?;
537                let list_result: ListToolsResult = serde_json::from_value(result).map_err(|e| {
538                    ProtocolError::Parsing(format!("invalid tools list format: {e}"))
539                })?;
540                Ok(list_result.tools)
541            }
542        }
543    }
544
545    /// Call tool on server.
546    pub async fn call_tool_async(
547        &self,
548        name: &str,
549        arguments: Option<serde_json::Value>,
550    ) -> Result<CallToolResult, ProtocolError> {
551        self.call_tool_with_meta_async(name, arguments, None).await
552    }
553
554    pub async fn call_tool_with_meta_async(
555        &self,
556        name: &str,
557        arguments: Option<serde_json::Value>,
558        meta: Option<serde_json::Value>,
559    ) -> Result<CallToolResult, ProtocolError> {
560        match self
561            .transport
562            .as_ref()
563            .ok_or_else(|| ProtocolError::internal_error("no MCP transport configured"))?
564        {
565            ClientTransport::StreamableHttp(transport) => {
566                transport.call_tool(name, arguments, meta).await
567            }
568
569            #[cfg(feature = "sse-client")]
570            ClientTransport::Sse { transport } => {
571                let result = send_sse_request(
572                    transport,
573                    "tools/call",
574                    json!({
575                        "name": name,
576                        "arguments": arguments,
577                        "_meta": meta,
578                    }),
579                )
580                .await?;
581
582                serde_json::from_value(result).map_err(|e| {
583                    ProtocolError::Parsing(format!("invalid tool call result format: {e}"))
584                })
585            }
586        }
587    }
588
589    /// Initialize the current Streamable HTTP transport explicitly.
590    pub async fn initialize_async(&self) -> Result<(), ProtocolError> {
591        match self
592            .transport
593            .as_ref()
594            .ok_or_else(|| ProtocolError::internal_error("no MCP transport configured"))?
595        {
596            ClientTransport::StreamableHttp(transport) => transport.initialize_if_needed().await,
597
598            #[cfg(feature = "sse-client")]
599            ClientTransport::Sse { .. } => Ok(()),
600        }
601    }
602
603    /// Check server health.
604    pub async fn health_check(&self) -> Result<(), ProtocolError> {
605        match self
606            .transport
607            .as_ref()
608            .ok_or_else(|| ProtocolError::internal_error("no MCP transport configured"))?
609        {
610            ClientTransport::StreamableHttp(transport) => transport.initialize_if_needed().await,
611
612            #[cfg(feature = "sse-client")]
613            ClientTransport::Sse { transport } => transport.health_check().await.map_err(|e| {
614                ProtocolError::internal_error(&format!("health check failed: {:?}", e))
615            }),
616        }
617    }
618}
619
620#[cfg(feature = "sse-client")]
621async fn send_sse_request(
622    transport: &SseTransport,
623    method: &str,
624    params: serde_json::Value,
625) -> Result<serde_json::Value, ProtocolError> {
626    let request = UniversalRequest {
627        method: method.to_string(),
628        uri: "/".to_string(),
629        headers: HashMap::new(),
630        body: json!({
631            "jsonrpc": "2.0",
632            "method": method,
633            "params": params,
634            "id": 1,
635        })
636        .to_string()
637        .into_bytes(),
638        protocol: "MCP".to_string(),
639        correlation_id: format!("mcp-client-{}", method.replace('/', "-")),
640    };
641
642    let response = transport
643        .send(request)
644        .await
645        .map_err(|e| ProtocolError::internal_error(&format!("transport error: {e:?}")))?;
646
647    let response_json: serde_json::Value = serde_json::from_slice(&response.body)
648        .map_err(|e| ProtocolError::Parsing(format!("invalid JSON response: {e}")))?;
649
650    response_json
651        .get("result")
652        .cloned()
653        .ok_or_else(|| ProtocolError::Parsing("missing 'result' field".to_string()))
654}
655
656#[cfg(feature = "sse-client")]
657#[async_trait::async_trait]
658impl ToolProvider for McpClient {
659    fn list_tools(&self) -> Result<Vec<Tool>, ProtocolError> {
660        Err(ProtocolError::internal_error(
661            "async tool listing not supported in sync context. Use list_tools_async().",
662        ))
663    }
664
665    async fn call_tool(
666        &self,
667        name: &str,
668        _arguments: Option<serde_json::Value>,
669    ) -> Result<CallToolResult, ProtocolError> {
670        Err(ProtocolError::internal_error(&format!(
671            "async tool calls not supported in sync context. Use call_tool_async() for tool '{name}'.",
672        )))
673    }
674}
675
676impl Default for McpClient {
677    fn default() -> Self {
678        Self::new()
679    }
680}
681
682/// **MCP Client Builder** - Convenient client configuration
683pub struct McpClientBuilder {
684    auth_handler: Option<Box<dyn AuthHandler>>,
685    streamable_http_endpoint: Option<String>,
686    streamable_http_auth_token: Option<String>,
687
688    #[cfg(feature = "sse-client")]
689    sse_endpoint: Option<String>,
690
691    #[cfg(feature = "sse-client")]
692    sse_auth_token: Option<String>,
693}
694
695impl McpClientBuilder {
696    /// Create new client builder
697    pub fn new() -> Self {
698        Self {
699            auth_handler: None,
700            streamable_http_endpoint: None,
701            streamable_http_auth_token: None,
702
703            #[cfg(feature = "sse-client")]
704            sse_endpoint: None,
705
706            #[cfg(feature = "sse-client")]
707            sse_auth_token: None,
708        }
709    }
710
711    /// Set authentication handler
712    pub fn with_auth_handler<H: AuthHandler + 'static>(mut self, handler: H) -> Self {
713        self.auth_handler = Some(Box::new(handler));
714        self
715    }
716
717    /// Configure Streamable HTTP transport.
718    pub fn with_streamable_http_server(mut self, endpoint: &str) -> Self {
719        self.streamable_http_endpoint = Some(endpoint.to_string());
720        self
721    }
722
723    /// Set bearer token for Streamable HTTP transport.
724    pub fn with_streamable_http_auth_token(mut self, token: &str) -> Self {
725        self.streamable_http_auth_token = Some(token.to_string());
726        self
727    }
728
729    /// Connect to MCP server via SSE (feature: "sse-client")
730    #[cfg(feature = "sse-client")]
731    pub fn with_sse_server(mut self, endpoint: &str) -> Self {
732        self.sse_endpoint = Some(endpoint.to_string());
733        self
734    }
735
736    /// Set authentication token for SSE server (feature: "sse-client")
737    #[cfg(feature = "sse-client")]
738    pub fn with_auth_token(mut self, token: &str) -> Self {
739        self.sse_auth_token = Some(token.to_string());
740        self
741    }
742
743    /// Build the MCP client
744    pub fn build(self) -> McpClient {
745        let mut client = McpClient::new();
746
747        if let Some(handler) = self.auth_handler {
748            client.auth_handler = Some(handler);
749        }
750
751        if let Some(endpoint) = self.streamable_http_endpoint {
752            client = if let Some(token) = self.streamable_http_auth_token {
753                client.with_streamable_http_server_auth(&endpoint, &token)
754            } else {
755                client.with_streamable_http_server(&endpoint)
756            };
757        }
758
759        #[cfg(feature = "sse-client")]
760        {
761            if let Some(endpoint) = self.sse_endpoint {
762                client = if let Some(token) = self.sse_auth_token {
763                    client.with_sse_server_auth(&endpoint, &token)
764                } else {
765                    client.with_sse_server(&endpoint)
766                };
767            }
768        }
769
770        client
771    }
772}
773
774#[cfg(all(test, not(target_arch = "wasm32")))]
775mod tests {
776    use super::*;
777    use axum::{
778        Json, Router,
779        body::Bytes,
780        extract::State,
781        http::{HeaderMap, HeaderValue, StatusCode},
782        response::IntoResponse,
783        routing::post,
784    };
785    use std::sync::{
786        Arc,
787        atomic::{AtomicUsize, Ordering},
788    };
789    use tokio::net::TcpListener;
790
791    #[derive(Clone)]
792    struct TestState {
793        session_seen: Arc<AtomicUsize>,
794        initialized_seen: Arc<AtomicUsize>,
795    }
796
797    async fn json_handler(
798        State(state): State<TestState>,
799        headers: HeaderMap,
800        body: Bytes,
801    ) -> impl IntoResponse {
802        let request: serde_json::Value = serde_json::from_slice(&body).expect("json body");
803        let method = request["method"].as_str().expect("method");
804        match method {
805            "initialize" => {
806                let mut response_headers = HeaderMap::new();
807                response_headers.insert(
808                    HEADER_MCP_SESSION_ID,
809                    HeaderValue::from_static("session-123"),
810                );
811                (
812                    response_headers,
813                    Json(json!({
814                        "jsonrpc": "2.0",
815                        "id": request["id"].clone(),
816                        "result": {
817                            "protocolVersion": MCP_PROTOCOL_VERSION,
818                            "capabilities": { "tools": { "supported": true } },
819                            "serverInfo": {
820                                "name": "test-server",
821                                "version": "0.1.0"
822                            }
823                        }
824                    })),
825                )
826                    .into_response()
827            }
828            "notifications/initialized" => {
829                assert!(
830                    request.get("id").is_none(),
831                    "initialized notification must not carry an id"
832                );
833                state.initialized_seen.fetch_add(1, Ordering::SeqCst);
834                StatusCode::ACCEPTED.into_response()
835            }
836            "tools/list" => {
837                assert_eq!(
838                    state.initialized_seen.load(Ordering::SeqCst),
839                    1,
840                    "tools/list should only be called after notifications/initialized"
841                );
842                if headers
843                    .get(HEADER_MCP_SESSION_ID)
844                    .and_then(|value| value.to_str().ok())
845                    == Some("session-123")
846                {
847                    state.session_seen.fetch_add(1, Ordering::SeqCst);
848                }
849                Json(json!({
850                    "jsonrpc": "2.0",
851                    "id": request["id"].clone(),
852                    "result": {
853                        "tools": [{
854                            "name": "search_agents",
855                            "description": "Search directory",
856                            "inputSchema": { "type": "object", "properties": {} }
857                        }]
858                    }
859                }))
860                .into_response()
861            }
862            "tools/call" => {
863                let body = format!(
864                    "event: message\ndata: {}\n\n",
865                    json!({
866                        "jsonrpc": "2.0",
867                        "id": request["id"].clone(),
868                        "result": {
869                            "content": [{ "type": "text", "text": "{\"ok\":true}" }],
870                            "isError": false
871                        }
872                    })
873                );
874                ([(HEADER_CONTENT_TYPE, CONTENT_TYPE_EVENT_STREAM)], body).into_response()
875            }
876            _ => StatusCode::NOT_FOUND.into_response(),
877        }
878    }
879
880    async fn error_handler(body: Bytes) -> impl IntoResponse {
881        let request: serde_json::Value = serde_json::from_slice(&body).expect("json body");
882        Json(json!({
883            "jsonrpc": "2.0",
884            "id": request["id"].clone(),
885            "error": {
886                "code": -32602,
887                "message": "bad input"
888            }
889        }))
890    }
891
892    async fn legacy_sse_handler(_body: Bytes) -> impl IntoResponse {
893        (
894            [(HEADER_CONTENT_TYPE, CONTENT_TYPE_EVENT_STREAM)],
895            "event: endpoint\ndata: /messages?session=abc\n\n".to_string(),
896        )
897    }
898
899    async fn start_server(app: Router) -> (String, tokio::task::JoinHandle<()>) {
900        let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
901        let addr = listener.local_addr().expect("local addr");
902        let handle = tokio::spawn(async move {
903            axum::serve(listener, app).await.expect("server");
904        });
905        (format!("http://{addr}/mcp"), handle)
906    }
907
908    #[tokio::test]
909    async fn streamable_http_initializes_and_replays_session_header() {
910        let state = TestState {
911            session_seen: Arc::new(AtomicUsize::new(0)),
912            initialized_seen: Arc::new(AtomicUsize::new(0)),
913        };
914        let session_seen = state.session_seen.clone();
915        let initialized_seen = state.initialized_seen.clone();
916        let app = Router::new()
917            .route("/mcp", post(json_handler))
918            .with_state(state);
919        let (url, handle) = start_server(app).await;
920
921        let client = McpClient::new().with_streamable_http_server(&url);
922        let tools = client.list_tools_async().await.expect("list tools");
923
924        assert_eq!(tools.len(), 1);
925        assert_eq!(tools[0].name, "search_agents");
926        assert_eq!(session_seen.load(Ordering::SeqCst), 1);
927        assert_eq!(initialized_seen.load(Ordering::SeqCst), 1);
928
929        handle.abort();
930    }
931
932    #[tokio::test]
933    async fn streamable_http_parses_event_stream_tool_results() {
934        let app = Router::new()
935            .route("/mcp", post(json_handler))
936            .with_state(TestState {
937                session_seen: Arc::new(AtomicUsize::new(0)),
938                initialized_seen: Arc::new(AtomicUsize::new(0)),
939            });
940        let (url, handle) = start_server(app).await;
941
942        let client = McpClient::new().with_streamable_http_server(&url);
943        let result = client
944            .call_tool_async("search_agents", Some(json!({"q": "planner"})))
945            .await
946            .expect("tool call");
947
948        assert_eq!(result.is_error, Some(false));
949        assert_eq!(result.content.len(), 1);
950
951        handle.abort();
952    }
953
954    #[tokio::test]
955    async fn streamable_http_surfaces_jsonrpc_errors() {
956        let app = Router::new().route("/mcp", post(error_handler));
957        let (url, handle) = start_server(app).await;
958
959        let client = McpClient::new().with_streamable_http_server(&url);
960        let error = client.list_tools_async().await.expect_err("should fail");
961
962        assert!(error.to_string().contains("JSON-RPC error -32602"));
963
964        handle.abort();
965    }
966
967    #[tokio::test]
968    async fn streamable_http_rejects_legacy_sse_only_responses() {
969        let app = Router::new().route("/mcp", post(legacy_sse_handler));
970        let (url, handle) = start_server(app).await;
971
972        let client = McpClient::new().with_streamable_http_server(&url);
973        let error = client.list_tools_async().await.expect_err("should fail");
974
975        assert!(
976            error
977                .to_string()
978                .contains("legacy SSE-only endpoints are unsupported")
979        );
980
981        handle.abort();
982    }
983}