Skip to main content

mcp_protocol/
client.rs

1//! # MCP Client
2//!
3//! MCP client implementation for connecting to MCP servers over either
4//! Streamable HTTP or the older direct SSE feature path.
5
6use crate::{
7    AuthHandler, CallToolResult, ClientCapabilities, ClientInfo, InitializeRequest, JsonRpcError,
8    JsonRpcRequest, JsonRpcResponse, ListToolsResult, MCP_PROTOCOL_VERSION, Tool, ToolCapabilities,
9};
10use protocol_transport_core::{ProtocolError, TransportError};
11use serde_json::json;
12use std::collections::HashMap;
13use std::sync::Mutex;
14
15#[cfg(feature = "sse-client")]
16use crate::ToolProvider;
17#[cfg(feature = "sse-client")]
18use protocol_transport_core::{SseTransport, Transport, TransportFactory, UniversalRequest};
19
20const CONTENT_TYPE_JSON: &str = "application/json";
21const CONTENT_TYPE_EVENT_STREAM: &str = "text/event-stream";
22const HEADER_ACCEPT: &str = "Accept";
23const HEADER_AUTHORIZATION: &str = "Authorization";
24const HEADER_CONTENT_TYPE: &str = "Content-Type";
25const HEADER_MCP_SESSION_ID: &str = "Mcp-Session-Id";
26
27enum ClientTransport {
28    StreamableHttp(StreamableHttpClientTransport),
29
30    #[cfg(feature = "sse-client")]
31    Sse {
32        transport: SseTransport,
33    },
34}
35
36struct StreamableHttpClientTransport {
37    endpoint: String,
38    auth_token: Option<String>,
39    extra_headers: HashMap<String, String>,
40    client_info: ClientInfo,
41    initialized: Mutex<bool>,
42    protocol_version: Mutex<Option<String>>,
43    session_id: Mutex<Option<String>>,
44    next_id: Mutex<u64>,
45}
46
47impl StreamableHttpClientTransport {
48    fn new(endpoint: impl Into<String>) -> Self {
49        Self {
50            endpoint: endpoint.into(),
51            auth_token: None,
52            extra_headers: HashMap::new(),
53            client_info: ClientInfo {
54                name: "promptfleet-mcp-client".to_string(),
55                version: env!("CARGO_PKG_VERSION").to_string(),
56                description: Some("PromptFleet Streamable HTTP MCP client".to_string()),
57            },
58            initialized: Mutex::new(false),
59            protocol_version: Mutex::new(None),
60            session_id: Mutex::new(None),
61            next_id: Mutex::new(0),
62        }
63    }
64
65    fn with_auth_token(mut self, token: impl Into<String>) -> Self {
66        self.auth_token = Some(token.into());
67        self
68    }
69
70    fn with_headers(mut self, headers: HashMap<String, String>) -> Self {
71        self.extra_headers = headers;
72        self
73    }
74
75    fn with_client_info(mut self, client_info: ClientInfo) -> Self {
76        self.client_info = client_info;
77        self
78    }
79
80    async fn initialize_if_needed(&self) -> Result<(), ProtocolError> {
81        let already_initialized = *self
82            .initialized
83            .lock()
84            .map_err(|_| ProtocolError::internal_error("streamable client init mutex poisoned"))?;
85        if already_initialized {
86            return Ok(());
87        }
88
89        let init_request = InitializeRequest {
90            protocol_version: MCP_PROTOCOL_VERSION.to_string(),
91            capabilities: ClientCapabilities {
92                tools: Some(ToolCapabilities { supported: true }),
93            },
94            client_info: self.client_info.clone(),
95        };
96
97        let result = self
98            .send_jsonrpc_raw(
99                "initialize",
100                Some(
101                    serde_json::to_value(init_request)
102                        .map_err(|e| ProtocolError::Parsing(format!("init serialize: {e}")))?,
103                ),
104            )
105            .await?;
106
107        let negotiated_protocol_version = result
108            .get("protocolVersion")
109            .or_else(|| result.get("protocol_version"))
110            .and_then(|value| value.as_str())
111            .map(ToString::to_string);
112        if negotiated_protocol_version.is_none() {
113            return Err(ProtocolError::Parsing(
114                "invalid initialize result: missing protocolVersion".to_string(),
115            ));
116        }
117        *self.protocol_version.lock().map_err(|_| {
118            ProtocolError::internal_error("streamable client protocol-version mutex poisoned")
119        })? = negotiated_protocol_version;
120
121        self.send_notification_raw("notifications/initialized", None)
122            .await?;
123
124        let mut initialized = self
125            .initialized
126            .lock()
127            .map_err(|_| ProtocolError::internal_error("streamable client init mutex poisoned"))?;
128        *initialized = true;
129        Ok(())
130    }
131
132    async fn list_tools(&self) -> Result<Vec<Tool>, ProtocolError> {
133        let result = self
134            .send_jsonrpc("tools/list", Some(json!({})), true)
135            .await?;
136        let list_result: ListToolsResult = serde_json::from_value(result)
137            .map_err(|e| ProtocolError::Parsing(format!("invalid tools list format: {e}")))?;
138        Ok(list_result.tools)
139    }
140
141    async fn call_tool(
142        &self,
143        name: &str,
144        arguments: Option<serde_json::Value>,
145        meta: Option<serde_json::Value>,
146    ) -> Result<CallToolResult, ProtocolError> {
147        let result = self
148            .send_jsonrpc(
149                "tools/call",
150                Some(json!({
151                    "name": name,
152                    "arguments": arguments,
153                    "_meta": meta,
154                })),
155                true,
156            )
157            .await?;
158
159        serde_json::from_value(result)
160            .map_err(|e| ProtocolError::Parsing(format!("invalid tool call result format: {e}")))
161    }
162
163    async fn send_jsonrpc(
164        &self,
165        method: &str,
166        params: Option<serde_json::Value>,
167        require_initialized: bool,
168    ) -> Result<serde_json::Value, ProtocolError> {
169        if require_initialized {
170            self.initialize_if_needed().await?;
171        }
172        self.send_jsonrpc_raw(method, params).await
173    }
174
175    async fn send_jsonrpc_raw(
176        &self,
177        method: &str,
178        params: Option<serde_json::Value>,
179    ) -> Result<serde_json::Value, ProtocolError> {
180        let id = {
181            let mut next_id = self.next_id.lock().map_err(|_| {
182                ProtocolError::internal_error("streamable client request-id mutex poisoned")
183            })?;
184            *next_id += 1;
185            *next_id
186        };
187
188        let request = JsonRpcRequest {
189            jsonrpc: "2.0".to_string(),
190            id: Some(json!(id)),
191            method: method.to_string(),
192            params,
193        };
194
195        let body = serde_json::to_vec(&request)?;
196        let response = self.http_post(body).await?;
197
198        if let Some(session_id) = find_header(&response.headers, HEADER_MCP_SESSION_ID) {
199            let mut stored = self.session_id.lock().map_err(|_| {
200                ProtocolError::internal_error("streamable client session mutex poisoned")
201            })?;
202            *stored = Some(session_id.to_string());
203        }
204
205        let content_type = find_header(&response.headers, HEADER_CONTENT_TYPE)
206            .map(|value| value.to_ascii_lowercase())
207            .unwrap_or_else(|| CONTENT_TYPE_JSON.to_string());
208
209        let rpc_response = if content_type.contains(CONTENT_TYPE_JSON) {
210            serde_json::from_slice::<JsonRpcResponse>(&response.body).map_err(|e| {
211                ProtocolError::Parsing(format!("invalid JSON-RPC response body: {e}"))
212            })?
213        } else if content_type.contains(CONTENT_TYPE_EVENT_STREAM) {
214            parse_sse_jsonrpc_response(&response.body)?
215        } else {
216            return Err(ProtocolError::Parsing(format!(
217                "unsupported response content-type '{content_type}'"
218            )));
219        };
220
221        if let Some(error) = rpc_response.error {
222            return Err(protocol_error_from_jsonrpc(error));
223        }
224
225        rpc_response
226            .result
227            .ok_or_else(|| ProtocolError::Parsing("missing JSON-RPC result field".to_string()))
228    }
229
230    async fn send_notification_raw(
231        &self,
232        method: &str,
233        params: Option<serde_json::Value>,
234    ) -> Result<(), ProtocolError> {
235        let request = JsonRpcRequest {
236            jsonrpc: "2.0".to_string(),
237            id: None,
238            method: method.to_string(),
239            params,
240        };
241
242        let body = serde_json::to_vec(&request)
243            .map_err(|e| ProtocolError::Parsing(format!("request serialize: {e}")))?;
244        let _ = self.http_post(body).await?;
245        Ok(())
246    }
247
248    async fn http_post(&self, body: Vec<u8>) -> Result<HttpResponse, ProtocolError> {
249        let mut headers = HashMap::new();
250        for (key, value) in &self.extra_headers {
251            if !key.eq_ignore_ascii_case(HEADER_MCP_SESSION_ID) {
252                headers.insert(key.clone(), value.clone());
253            }
254        }
255        headers.insert(
256            HEADER_ACCEPT.to_string(),
257            format!("{CONTENT_TYPE_JSON}, {CONTENT_TYPE_EVENT_STREAM}"),
258        );
259        headers.insert(
260            HEADER_CONTENT_TYPE.to_string(),
261            CONTENT_TYPE_JSON.to_string(),
262        );
263        if let Some(protocol_version) = self
264            .protocol_version
265            .lock()
266            .map_err(|_| {
267                ProtocolError::internal_error("streamable client protocol-version mutex poisoned")
268            })?
269            .clone()
270        {
271            headers.insert("MCP-Protocol-Version".to_string(), protocol_version);
272        }
273
274        if let Some(token) = &self.auth_token {
275            headers.insert(HEADER_AUTHORIZATION.to_string(), format!("Bearer {token}"));
276        }
277
278        if let Some(session_id) = self
279            .session_id
280            .lock()
281            .map_err(|_| ProtocolError::internal_error("streamable client session mutex poisoned"))?
282            .clone()
283        {
284            headers.insert(HEADER_MCP_SESSION_ID.to_string(), session_id);
285        }
286
287        #[cfg(target_arch = "wasm32")]
288        {
289            use spin_sdk::http::{Method, Request as SpinRequest, Response as SpinResponse, send};
290
291            let mut builder = SpinRequest::builder();
292            builder.method(Method::Post);
293            builder.uri(&self.endpoint);
294            for (key, value) in &headers {
295                builder.header(key, value);
296            }
297            let request = builder.body(body).build();
298            let response: SpinResponse = send(request).await.map_err(|e| {
299                ProtocolError::Transport(TransportError::Network(format!(
300                    "Spin HTTP send failed: {e}"
301                )))
302            })?;
303
304            let response_headers = response
305                .headers()
306                .filter_map(|(name, value)| {
307                    value
308                        .as_str()
309                        .map(|value| (name.to_string(), value.to_string()))
310                })
311                .collect::<HashMap<_, _>>();
312            let status = *response.status();
313            let body = response.body().to_vec();
314
315            if !(200..300).contains(&status) {
316                return Err(ProtocolError::Transport(TransportError::Http {
317                    status,
318                    message: format!("streamable HTTP request failed with status {status}"),
319                    body: Some(body),
320                    headers: Some(response_headers),
321                }));
322            }
323
324            Ok(HttpResponse {
325                headers: response_headers,
326                body,
327            })
328        }
329
330        #[cfg(not(target_arch = "wasm32"))]
331        {
332            let client = reqwest::Client::new();
333            let mut request = client.post(&self.endpoint);
334            for (key, value) in &headers {
335                request = request.header(key, value);
336            }
337            let response = request.body(body).send().await.map_err(|e| {
338                ProtocolError::Transport(TransportError::Network(format!(
339                    "streamable HTTP request failed: {e}"
340                )))
341            })?;
342
343            let status = response.status().as_u16();
344            let response_headers = response
345                .headers()
346                .iter()
347                .filter_map(|(name, value)| {
348                    value
349                        .to_str()
350                        .ok()
351                        .map(|value| (name.to_string(), value.to_string()))
352                })
353                .collect::<HashMap<_, _>>();
354            let body = response.bytes().await.map_err(|e| {
355                ProtocolError::Transport(TransportError::Network(format!(
356                    "streamable HTTP response read failed: {e}"
357                )))
358            })?;
359            let body = body.to_vec();
360
361            if !(200..300).contains(&status) {
362                return Err(ProtocolError::Transport(TransportError::Http {
363                    status,
364                    message: format!("streamable HTTP request failed with status {status}"),
365                    body: Some(body),
366                    headers: Some(response_headers),
367                }));
368            }
369
370            Ok(HttpResponse {
371                headers: response_headers,
372                body,
373            })
374        }
375    }
376}
377
378struct HttpResponse {
379    headers: HashMap<String, String>,
380    body: Vec<u8>,
381}
382
383fn find_header<'a>(headers: &'a HashMap<String, String>, name: &str) -> Option<&'a str> {
384    headers
385        .iter()
386        .find(|(header_name, _)| header_name.eq_ignore_ascii_case(name))
387        .map(|(_, value)| value.as_str())
388}
389
390fn protocol_error_from_jsonrpc(error: JsonRpcError) -> ProtocolError {
391    let details = error
392        .data
393        .map(|value| format!(" data={value}"))
394        .unwrap_or_default();
395    ProtocolError::Validation(format!(
396        "JSON-RPC error {}: {}{}",
397        error.code, error.message, details
398    ))
399}
400
401fn parse_sse_jsonrpc_response(body: &[u8]) -> Result<JsonRpcResponse, ProtocolError> {
402    let text = std::str::from_utf8(body)
403        .map_err(|e| ProtocolError::Parsing(format!("invalid UTF-8 event-stream body: {e}")))?;
404    let mut data_lines = Vec::new();
405
406    for line in text.lines() {
407        if let Some(rest) = line.strip_prefix("data:") {
408            data_lines.push(rest.trim_start().to_string());
409            continue;
410        }
411
412        if line.trim().is_empty() && !data_lines.is_empty() {
413            let payload = data_lines.join("\n");
414            if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(&payload) {
415                return Ok(response);
416            }
417            data_lines.clear();
418        }
419    }
420
421    if !data_lines.is_empty() {
422        let payload = data_lines.join("\n");
423        if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(&payload) {
424            return Ok(response);
425        }
426    }
427
428    Err(ProtocolError::Parsing(format!(
429        "event-stream response did not contain an MCP JSON-RPC payload; legacy SSE-only endpoints are unsupported; body={text:?}"
430    )))
431}
432
433/// **MCP Client** - Connect to MCP servers
434pub struct McpClient {
435    auth_handler: Option<Box<dyn AuthHandler>>,
436    transport: Option<ClientTransport>,
437}
438
439impl McpClient {
440    /// Create new MCP client
441    pub fn new() -> Self {
442        Self {
443            auth_handler: None,
444            transport: None,
445        }
446    }
447
448    /// Configure authentication
449    pub fn with_auth_handler<H: AuthHandler + 'static>(mut self, handler: H) -> Self {
450        self.auth_handler = Some(Box::new(handler));
451        self
452    }
453
454    /// Connect to MCP server via Streamable HTTP.
455    pub fn with_streamable_http_server(mut self, endpoint: &str) -> Self {
456        self.transport = Some(ClientTransport::StreamableHttp(
457            StreamableHttpClientTransport::new(endpoint),
458        ));
459        self
460    }
461
462    /// Connect to MCP server via Streamable HTTP with bearer authentication.
463    pub fn with_streamable_http_server_auth(mut self, endpoint: &str, auth_token: &str) -> Self {
464        self.transport = Some(ClientTransport::StreamableHttp(
465            StreamableHttpClientTransport::new(endpoint).with_auth_token(auth_token),
466        ));
467        self
468    }
469
470    /// Add extra HTTP headers to the Streamable HTTP transport.
471    pub fn with_streamable_http_headers(mut self, headers: HashMap<String, String>) -> Self {
472        let transport = match self.transport.take() {
473            Some(ClientTransport::StreamableHttp(transport)) => {
474                ClientTransport::StreamableHttp(transport.with_headers(headers))
475            }
476            other => {
477                self.transport = other;
478                return self;
479            }
480        };
481        self.transport = Some(transport);
482        self
483    }
484
485    /// Override the Streamable HTTP client identity used during initialize.
486    pub fn with_streamable_http_client_info(mut self, client_info: ClientInfo) -> Self {
487        let transport = match self.transport.take() {
488            Some(ClientTransport::StreamableHttp(transport)) => {
489                ClientTransport::StreamableHttp(transport.with_client_info(client_info))
490            }
491            other => {
492                self.transport = other;
493                return self;
494            }
495        };
496        self.transport = Some(transport);
497        self
498    }
499
500    /// Connect to MCP server via SSE (feature: "sse-client")
501    #[cfg(feature = "sse-client")]
502    pub fn with_sse_server(mut self, endpoint: &str) -> Self {
503        self.transport = Some(ClientTransport::Sse {
504            transport: TransportFactory::mcp_sse(endpoint),
505        });
506        self
507    }
508
509    /// Connect to MCP server via SSE with authentication (feature: "sse-client")
510    #[cfg(feature = "sse-client")]
511    pub fn with_sse_server_auth(mut self, endpoint: &str, auth_token: &str) -> Self {
512        self.transport = Some(ClientTransport::Sse {
513            transport: TransportFactory::mcp_sse_auth(endpoint, auth_token),
514        });
515        self
516    }
517
518    /// List tools from server.
519    pub async fn list_tools_async(&self) -> Result<Vec<Tool>, ProtocolError> {
520        match self
521            .transport
522            .as_ref()
523            .ok_or_else(|| ProtocolError::internal_error("no MCP transport configured"))?
524        {
525            ClientTransport::StreamableHttp(transport) => transport.list_tools().await,
526
527            #[cfg(feature = "sse-client")]
528            ClientTransport::Sse { transport } => {
529                let result = send_sse_request(transport, "tools/list", json!({})).await?;
530                let list_result: ListToolsResult = serde_json::from_value(result).map_err(|e| {
531                    ProtocolError::Parsing(format!("invalid tools list format: {e}"))
532                })?;
533                Ok(list_result.tools)
534            }
535        }
536    }
537
538    /// Call tool on server.
539    pub async fn call_tool_async(
540        &self,
541        name: &str,
542        arguments: Option<serde_json::Value>,
543    ) -> Result<CallToolResult, ProtocolError> {
544        self.call_tool_with_meta_async(name, arguments, None).await
545    }
546
547    pub async fn call_tool_with_meta_async(
548        &self,
549        name: &str,
550        arguments: Option<serde_json::Value>,
551        meta: Option<serde_json::Value>,
552    ) -> Result<CallToolResult, ProtocolError> {
553        match self
554            .transport
555            .as_ref()
556            .ok_or_else(|| ProtocolError::internal_error("no MCP transport configured"))?
557        {
558            ClientTransport::StreamableHttp(transport) => {
559                transport.call_tool(name, arguments, meta).await
560            }
561
562            #[cfg(feature = "sse-client")]
563            ClientTransport::Sse { transport } => {
564                let result = send_sse_request(
565                    transport,
566                    "tools/call",
567                    json!({
568                        "name": name,
569                        "arguments": arguments,
570                        "_meta": meta,
571                    }),
572                )
573                .await?;
574
575                serde_json::from_value(result).map_err(|e| {
576                    ProtocolError::Parsing(format!("invalid tool call result format: {e}"))
577                })
578            }
579        }
580    }
581
582    /// Initialize the current Streamable HTTP transport explicitly.
583    pub async fn initialize_async(&self) -> Result<(), ProtocolError> {
584        match self
585            .transport
586            .as_ref()
587            .ok_or_else(|| ProtocolError::internal_error("no MCP transport configured"))?
588        {
589            ClientTransport::StreamableHttp(transport) => transport.initialize_if_needed().await,
590
591            #[cfg(feature = "sse-client")]
592            ClientTransport::Sse { .. } => Ok(()),
593        }
594    }
595
596    /// Check server health.
597    pub async fn health_check(&self) -> Result<(), ProtocolError> {
598        match self
599            .transport
600            .as_ref()
601            .ok_or_else(|| ProtocolError::internal_error("no MCP transport configured"))?
602        {
603            ClientTransport::StreamableHttp(transport) => transport.initialize_if_needed().await,
604
605            #[cfg(feature = "sse-client")]
606            ClientTransport::Sse { transport } => transport.health_check().await.map_err(|e| {
607                ProtocolError::internal_error(&format!("health check failed: {:?}", e))
608            }),
609        }
610    }
611}
612
613#[cfg(feature = "sse-client")]
614async fn send_sse_request(
615    transport: &SseTransport,
616    method: &str,
617    params: serde_json::Value,
618) -> Result<serde_json::Value, ProtocolError> {
619    let request = UniversalRequest {
620        method: method.to_string(),
621        uri: "/".to_string(),
622        headers: HashMap::new(),
623        body: json!({
624            "jsonrpc": "2.0",
625            "method": method,
626            "params": params,
627            "id": 1,
628        })
629        .to_string()
630        .into_bytes(),
631        protocol: "MCP".to_string(),
632        correlation_id: format!("mcp-client-{}", method.replace('/', "-")),
633    };
634
635    let response = transport
636        .send(request)
637        .await
638        .map_err(|e| ProtocolError::internal_error(&format!("transport error: {e:?}")))?;
639
640    let response_json: serde_json::Value = serde_json::from_slice(&response.body)
641        .map_err(|e| ProtocolError::Parsing(format!("invalid JSON response: {e}")))?;
642
643    response_json
644        .get("result")
645        .cloned()
646        .ok_or_else(|| ProtocolError::Parsing("missing 'result' field".to_string()))
647}
648
649#[cfg(feature = "sse-client")]
650#[async_trait::async_trait]
651impl ToolProvider for McpClient {
652    fn list_tools(&self) -> Result<Vec<Tool>, ProtocolError> {
653        Err(ProtocolError::internal_error(
654            "async tool listing not supported in sync context. Use list_tools_async().",
655        ))
656    }
657
658    async fn call_tool(
659        &self,
660        name: &str,
661        _arguments: Option<serde_json::Value>,
662    ) -> Result<CallToolResult, ProtocolError> {
663        Err(ProtocolError::internal_error(&format!(
664            "async tool calls not supported in sync context. Use call_tool_async() for tool '{name}'.",
665        )))
666    }
667}
668
669impl Default for McpClient {
670    fn default() -> Self {
671        Self::new()
672    }
673}
674
675/// **MCP Client Builder** - Convenient client configuration
676pub struct McpClientBuilder {
677    auth_handler: Option<Box<dyn AuthHandler>>,
678    streamable_http_endpoint: Option<String>,
679    streamable_http_auth_token: Option<String>,
680
681    #[cfg(feature = "sse-client")]
682    sse_endpoint: Option<String>,
683
684    #[cfg(feature = "sse-client")]
685    sse_auth_token: Option<String>,
686}
687
688impl McpClientBuilder {
689    /// Create new client builder
690    pub fn new() -> Self {
691        Self {
692            auth_handler: None,
693            streamable_http_endpoint: None,
694            streamable_http_auth_token: None,
695
696            #[cfg(feature = "sse-client")]
697            sse_endpoint: None,
698
699            #[cfg(feature = "sse-client")]
700            sse_auth_token: None,
701        }
702    }
703
704    /// Set authentication handler
705    pub fn with_auth_handler<H: AuthHandler + 'static>(mut self, handler: H) -> Self {
706        self.auth_handler = Some(Box::new(handler));
707        self
708    }
709
710    /// Configure Streamable HTTP transport.
711    pub fn with_streamable_http_server(mut self, endpoint: &str) -> Self {
712        self.streamable_http_endpoint = Some(endpoint.to_string());
713        self
714    }
715
716    /// Set bearer token for Streamable HTTP transport.
717    pub fn with_streamable_http_auth_token(mut self, token: &str) -> Self {
718        self.streamable_http_auth_token = Some(token.to_string());
719        self
720    }
721
722    /// Connect to MCP server via SSE (feature: "sse-client")
723    #[cfg(feature = "sse-client")]
724    pub fn with_sse_server(mut self, endpoint: &str) -> Self {
725        self.sse_endpoint = Some(endpoint.to_string());
726        self
727    }
728
729    /// Set authentication token for SSE server (feature: "sse-client")
730    #[cfg(feature = "sse-client")]
731    pub fn with_auth_token(mut self, token: &str) -> Self {
732        self.sse_auth_token = Some(token.to_string());
733        self
734    }
735
736    /// Build the MCP client
737    pub fn build(self) -> McpClient {
738        let mut client = McpClient::new();
739
740        if let Some(handler) = self.auth_handler {
741            client.auth_handler = Some(handler);
742        }
743
744        if let Some(endpoint) = self.streamable_http_endpoint {
745            client = if let Some(token) = self.streamable_http_auth_token {
746                client.with_streamable_http_server_auth(&endpoint, &token)
747            } else {
748                client.with_streamable_http_server(&endpoint)
749            };
750        }
751
752        #[cfg(feature = "sse-client")]
753        {
754            if let Some(endpoint) = self.sse_endpoint {
755                client = if let Some(token) = self.sse_auth_token {
756                    client.with_sse_server_auth(&endpoint, &token)
757                } else {
758                    client.with_sse_server(&endpoint)
759                };
760            }
761        }
762
763        client
764    }
765}
766
767#[cfg(all(test, not(target_arch = "wasm32")))]
768mod tests {
769    use super::*;
770    use axum::{
771        Json, Router,
772        body::Bytes,
773        extract::State,
774        http::{HeaderMap, HeaderValue, StatusCode},
775        response::IntoResponse,
776        routing::post,
777    };
778    use std::sync::{
779        Arc,
780        atomic::{AtomicUsize, Ordering},
781    };
782    use tokio::net::TcpListener;
783
784    #[derive(Clone)]
785    struct TestState {
786        session_seen: Arc<AtomicUsize>,
787        initialized_seen: Arc<AtomicUsize>,
788    }
789
790    async fn json_handler(
791        State(state): State<TestState>,
792        headers: HeaderMap,
793        body: Bytes,
794    ) -> impl IntoResponse {
795        let request: serde_json::Value = serde_json::from_slice(&body).expect("json body");
796        let method = request["method"].as_str().expect("method");
797        match method {
798            "initialize" => {
799                let mut response_headers = HeaderMap::new();
800                response_headers.insert(
801                    HEADER_MCP_SESSION_ID,
802                    HeaderValue::from_static("session-123"),
803                );
804                (
805                    response_headers,
806                    Json(json!({
807                        "jsonrpc": "2.0",
808                        "id": request["id"].clone(),
809                        "result": {
810                            "protocolVersion": MCP_PROTOCOL_VERSION,
811                            "capabilities": { "tools": { "supported": true } },
812                            "serverInfo": {
813                                "name": "test-server",
814                                "version": "0.1.0"
815                            }
816                        }
817                    })),
818                )
819                    .into_response()
820            }
821            "notifications/initialized" => {
822                assert!(
823                    request.get("id").is_none(),
824                    "initialized notification must not carry an id"
825                );
826                state.initialized_seen.fetch_add(1, Ordering::SeqCst);
827                StatusCode::ACCEPTED.into_response()
828            }
829            "tools/list" => {
830                assert_eq!(
831                    state.initialized_seen.load(Ordering::SeqCst),
832                    1,
833                    "tools/list should only be called after notifications/initialized"
834                );
835                if headers
836                    .get(HEADER_MCP_SESSION_ID)
837                    .and_then(|value| value.to_str().ok())
838                    == Some("session-123")
839                {
840                    state.session_seen.fetch_add(1, Ordering::SeqCst);
841                }
842                Json(json!({
843                    "jsonrpc": "2.0",
844                    "id": request["id"].clone(),
845                    "result": {
846                        "tools": [{
847                            "name": "search_agents",
848                            "description": "Search directory",
849                            "inputSchema": { "type": "object", "properties": {} }
850                        }]
851                    }
852                }))
853                .into_response()
854            }
855            "tools/call" => {
856                let body = format!(
857                    "event: message\ndata: {}\n\n",
858                    json!({
859                        "jsonrpc": "2.0",
860                        "id": request["id"].clone(),
861                        "result": {
862                            "content": [{ "type": "text", "text": "{\"ok\":true}" }],
863                            "isError": false
864                        }
865                    })
866                );
867                ([(HEADER_CONTENT_TYPE, CONTENT_TYPE_EVENT_STREAM)], body).into_response()
868            }
869            _ => StatusCode::NOT_FOUND.into_response(),
870        }
871    }
872
873    async fn error_handler(body: Bytes) -> impl IntoResponse {
874        let request: serde_json::Value = serde_json::from_slice(&body).expect("json body");
875        Json(json!({
876            "jsonrpc": "2.0",
877            "id": request["id"].clone(),
878            "error": {
879                "code": -32602,
880                "message": "bad input"
881            }
882        }))
883    }
884
885    async fn legacy_sse_handler(_body: Bytes) -> impl IntoResponse {
886        (
887            [(HEADER_CONTENT_TYPE, CONTENT_TYPE_EVENT_STREAM)],
888            "event: endpoint\ndata: /messages?session=abc\n\n".to_string(),
889        )
890    }
891
892    async fn start_server(app: Router) -> (String, tokio::task::JoinHandle<()>) {
893        let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
894        let addr = listener.local_addr().expect("local addr");
895        let handle = tokio::spawn(async move {
896            axum::serve(listener, app).await.expect("server");
897        });
898        (format!("http://{addr}/mcp"), handle)
899    }
900
901    #[tokio::test]
902    async fn streamable_http_initializes_and_replays_session_header() {
903        let state = TestState {
904            session_seen: Arc::new(AtomicUsize::new(0)),
905            initialized_seen: Arc::new(AtomicUsize::new(0)),
906        };
907        let session_seen = state.session_seen.clone();
908        let initialized_seen = state.initialized_seen.clone();
909        let app = Router::new()
910            .route("/mcp", post(json_handler))
911            .with_state(state);
912        let (url, handle) = start_server(app).await;
913
914        let client = McpClient::new().with_streamable_http_server(&url);
915        let tools = client.list_tools_async().await.expect("list tools");
916
917        assert_eq!(tools.len(), 1);
918        assert_eq!(tools[0].name, "search_agents");
919        assert_eq!(session_seen.load(Ordering::SeqCst), 1);
920        assert_eq!(initialized_seen.load(Ordering::SeqCst), 1);
921
922        handle.abort();
923    }
924
925    #[tokio::test]
926    async fn streamable_http_parses_event_stream_tool_results() {
927        let app = Router::new()
928            .route("/mcp", post(json_handler))
929            .with_state(TestState {
930                session_seen: Arc::new(AtomicUsize::new(0)),
931                initialized_seen: Arc::new(AtomicUsize::new(0)),
932            });
933        let (url, handle) = start_server(app).await;
934
935        let client = McpClient::new().with_streamable_http_server(&url);
936        let result = client
937            .call_tool_async("search_agents", Some(json!({"q": "planner"})))
938            .await
939            .expect("tool call");
940
941        assert_eq!(result.is_error, Some(false));
942        assert_eq!(result.content.len(), 1);
943
944        handle.abort();
945    }
946
947    #[tokio::test]
948    async fn streamable_http_surfaces_jsonrpc_errors() {
949        let app = Router::new().route("/mcp", post(error_handler));
950        let (url, handle) = start_server(app).await;
951
952        let client = McpClient::new().with_streamable_http_server(&url);
953        let error = client.list_tools_async().await.expect_err("should fail");
954
955        assert!(error.to_string().contains("JSON-RPC error -32602"));
956
957        handle.abort();
958    }
959
960    #[tokio::test]
961    async fn streamable_http_rejects_legacy_sse_only_responses() {
962        let app = Router::new().route("/mcp", post(legacy_sse_handler));
963        let (url, handle) = start_server(app).await;
964
965        let client = McpClient::new().with_streamable_http_server(&url);
966        let error = client.list_tools_async().await.expect_err("should fail");
967
968        assert!(
969            error
970                .to_string()
971                .contains("legacy SSE-only endpoints are unsupported")
972        );
973
974        handle.abort();
975    }
976}