Skip to main content

devboy_mcp/
proxy.rs

1//! MCP proxy — connects to upstream MCP servers and proxies tool calls.
2//!
3//! Supports two transports:
4//! - **SSE**: Legacy MCP transport with SSE stream for responses.
5//! - **Streamable HTTP**: POST-based transport with `mcp-session-id` header.
6
7use std::sync::Arc;
8use std::sync::atomic::{AtomicI64, Ordering};
9use std::time::Duration;
10
11use futures::StreamExt;
12use reqwest::header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderValue};
13use reqwest_eventsource::{Event, EventSource};
14use secrecy::{ExposeSecret, SecretString};
15use serde_json::Value;
16use tokio::sync::{Mutex, RwLock, oneshot};
17
18use crate::protocol::{
19    JSONRPC_VERSION, JsonRpcRequest, JsonRpcResponse, RequestId, ToolCallResult, ToolDefinition,
20};
21
22/// Transport mode for upstream MCP server.
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum ProxyTransport {
25    /// SSE-based transport (legacy MCP): GET for SSE stream, POST for requests.
26    Sse,
27    /// Streamable HTTP transport: POST with JSON response and mcp-session-id header.
28    StreamableHttp,
29}
30
31impl ProxyTransport {
32    pub fn parse(s: &str) -> Self {
33        match s {
34            "streamable-http" | "streamable_http" | "http" => Self::StreamableHttp,
35            _ => Self::Sse,
36        }
37    }
38}
39
40/// Pending SSE response receivers indexed by request ID.
41type PendingResponses = Arc<Mutex<Vec<(i64, oneshot::Sender<JsonRpcResponse>)>>>;
42
43/// Single upstream MCP server connection.
44pub struct McpProxyClient {
45    name: String,
46    tool_prefix: String,
47    post_url: String,
48    http_client: reqwest::Client,
49    upstream_tools: Vec<ToolDefinition>,
50    next_id: AtomicI64,
51    transport: ProxyTransport,
52    /// Session ID for Streamable HTTP transport.
53    session_id: RwLock<Option<String>>,
54    /// Channel to receive SSE responses routed by request id (SSE transport only).
55    pending: PendingResponses,
56}
57
58impl McpProxyClient {
59    /// Connect to an upstream MCP server, perform initialize handshake, fetch tools.
60    pub async fn connect(
61        name: &str,
62        url: &str,
63        tool_prefix: Option<&str>,
64        token: Option<&SecretString>,
65        auth_type: &str,
66        transport: ProxyTransport,
67    ) -> devboy_core::Result<Self> {
68        let mut headers = HeaderMap::new();
69        if let Some(token) = token {
70            match auth_type {
71                "bearer" => {
72                    let val = HeaderValue::from_str(&format!("Bearer {}", token.expose_secret()))
73                        .map_err(|e| {
74                        devboy_core::Error::Config(format!("Invalid token: {}", e))
75                    })?;
76                    headers.insert(AUTHORIZATION, val);
77                }
78                "api_key" => {
79                    let val = HeaderValue::from_str(token.expose_secret())
80                        .map_err(|e| devboy_core::Error::Config(format!("Invalid token: {}", e)))?;
81                    headers.insert("X-API-Key", val);
82                }
83                _ => {}
84            }
85        }
86
87        let http_client = reqwest::Client::builder()
88            .default_headers(headers.clone())
89            .timeout(Duration::from_secs(60))
90            .pool_max_idle_per_host(0)
91            .build()
92            .map_err(|e| devboy_core::Error::Http(format!("Failed to build HTTP client: {}", e)))?;
93
94        let prefix = tool_prefix.unwrap_or(name).to_string();
95
96        match transport {
97            ProxyTransport::Sse => {
98                Self::connect_sse(name, url, &prefix, headers, http_client).await
99            }
100            ProxyTransport::StreamableHttp => {
101                Self::connect_streamable_http(name, url, &prefix, http_client).await
102            }
103        }
104    }
105
106    /// Connect via SSE transport.
107    async fn connect_sse(
108        name: &str,
109        url: &str,
110        prefix: &str,
111        headers: HeaderMap,
112        http_client: reqwest::Client,
113    ) -> devboy_core::Result<Self> {
114        let sse_url = url.to_string();
115        let mut es = EventSource::new(
116            reqwest::Client::builder()
117                .default_headers(headers)
118                .build()
119                .unwrap()
120                .get(&sse_url),
121        )
122        .map_err(|e| {
123            devboy_core::Error::Http(format!("Failed to connect SSE to {}: {}", sse_url, e))
124        })?;
125
126        let post_url = Self::wait_for_endpoint(&mut es, url).await?;
127
128        let pending: PendingResponses = Arc::new(Mutex::new(Vec::new()));
129
130        // Spawn SSE listener
131        let pending_clone = pending.clone();
132        tokio::spawn(async move {
133            while let Some(event) = es.next().await {
134                match event {
135                    Ok(Event::Message(msg)) => {
136                        if msg.event == "message"
137                            && let Ok(resp) = serde_json::from_str::<JsonRpcResponse>(&msg.data)
138                        {
139                            let id_num = match &resp.id {
140                                RequestId::Number(n) => *n,
141                                _ => continue,
142                            };
143                            let mut pending = pending_clone.lock().await;
144                            if let Some(idx) = pending.iter().position(|(id, _)| *id == id_num) {
145                                let (_, sender) = pending.remove(idx);
146                                let _ = sender.send(resp);
147                            }
148                        }
149                    }
150                    Ok(Event::Open) => {
151                        tracing::debug!("SSE stream open");
152                    }
153                    Err(e) => {
154                        tracing::warn!("SSE error: {}", e);
155                        break;
156                    }
157                }
158            }
159        });
160
161        let client = Self {
162            name: name.to_string(),
163            tool_prefix: prefix.to_string(),
164            post_url,
165            http_client,
166            upstream_tools: Vec::new(),
167            next_id: AtomicI64::new(1),
168            transport: ProxyTransport::Sse,
169            session_id: RwLock::new(None),
170            pending,
171        };
172
173        client.initialize().await?;
174
175        Ok(client)
176    }
177
178    /// Connect via Streamable HTTP transport.
179    async fn connect_streamable_http(
180        name: &str,
181        url: &str,
182        prefix: &str,
183        http_client: reqwest::Client,
184    ) -> devboy_core::Result<Self> {
185        let client = Self {
186            name: name.to_string(),
187            tool_prefix: prefix.to_string(),
188            post_url: url.to_string(),
189            http_client,
190            upstream_tools: Vec::new(),
191            next_id: AtomicI64::new(1),
192            transport: ProxyTransport::StreamableHttp,
193            session_id: RwLock::new(None),
194            pending: Arc::new(Mutex::new(Vec::new())),
195        };
196
197        client.initialize().await?;
198
199        Ok(client)
200    }
201
202    /// Wait for the SSE `endpoint` event that tells us where to POST requests.
203    async fn wait_for_endpoint(
204        es: &mut EventSource,
205        base_url: &str,
206    ) -> devboy_core::Result<String> {
207        let timeout = tokio::time::timeout(Duration::from_secs(10), async {
208            while let Some(event) = es.next().await {
209                match event {
210                    Ok(Event::Message(msg)) if msg.event == "endpoint" => {
211                        let endpoint = msg.data.trim().to_string();
212                        // If relative URL, resolve against base
213                        if endpoint.starts_with('/')
214                            && let Ok(base) = reqwest::Url::parse(base_url)
215                            && let Ok(resolved) = base.join(&endpoint)
216                        {
217                            return Ok(resolved.to_string());
218                        }
219                        return Ok(endpoint);
220                    }
221                    Ok(Event::Open) => continue,
222                    Ok(_) => continue,
223                    Err(e) => {
224                        return Err(devboy_core::Error::Http(format!("SSE error: {}", e)));
225                    }
226                }
227            }
228            Err(devboy_core::Error::Http(
229                "SSE stream ended before endpoint event".to_string(),
230            ))
231        });
232
233        timeout.await.map_err(|_| {
234            devboy_core::Error::Http("Timeout waiting for SSE endpoint event".to_string())
235        })?
236    }
237
238    fn next_request_id(&self) -> i64 {
239        self.next_id.fetch_add(1, Ordering::SeqCst)
240    }
241
242    /// Send a JSON-RPC request and wait for response.
243    async fn request(
244        &self,
245        method: &str,
246        params: Option<Value>,
247    ) -> devboy_core::Result<JsonRpcResponse> {
248        match self.transport {
249            ProxyTransport::Sse => self.request_sse(method, params).await,
250            ProxyTransport::StreamableHttp => self.request_http(method, params).await,
251        }
252    }
253
254    /// Send request via SSE transport (POST request, response via SSE stream).
255    async fn request_sse(
256        &self,
257        method: &str,
258        params: Option<Value>,
259    ) -> devboy_core::Result<JsonRpcResponse> {
260        let id = self.next_request_id();
261        let req = JsonRpcRequest {
262            jsonrpc: JSONRPC_VERSION.to_string(),
263            id: RequestId::Number(id),
264            method: method.to_string(),
265            params,
266        };
267
268        let (tx, rx) = oneshot::channel();
269        {
270            let mut pending = self.pending.lock().await;
271            pending.push((id, tx));
272        }
273
274        self.http_client
275            .post(&self.post_url)
276            .json(&req)
277            .send()
278            .await
279            .map_err(|e| devboy_core::Error::Http(format!("POST failed: {}", e)))?;
280
281        let resp = tokio::time::timeout(Duration::from_secs(30), rx)
282            .await
283            .map_err(|_| devboy_core::Error::Http("Timeout waiting for response".to_string()))?
284            .map_err(|_| devboy_core::Error::Http("Response channel closed".to_string()))?;
285
286        Ok(resp)
287    }
288
289    /// Send request via Streamable HTTP transport (POST, JSON response, session header).
290    async fn request_http(
291        &self,
292        method: &str,
293        params: Option<Value>,
294    ) -> devboy_core::Result<JsonRpcResponse> {
295        let id = self.next_request_id();
296        let req = JsonRpcRequest {
297            jsonrpc: JSONRPC_VERSION.to_string(),
298            id: RequestId::Number(id),
299            method: method.to_string(),
300            params,
301        };
302
303        let mut request = self
304            .http_client
305            .post(&self.post_url)
306            .header(CONTENT_TYPE, "application/json")
307            .header(ACCEPT, "application/json, text/event-stream");
308
309        // Add session ID for all requests after initialize
310        if method != "initialize" {
311            let session = self.session_id.read().await;
312            if let Some(sid) = session.as_ref() {
313                request = request.header("mcp-session-id", sid);
314            }
315        }
316
317        let response = request.json(&req).send().await.map_err(|e| {
318            tracing::error!(
319                "POST to {} failed: {} (is_timeout={}, is_connect={}, is_request={})",
320                self.post_url,
321                e,
322                e.is_timeout(),
323                e.is_connect(),
324                e.is_request(),
325            );
326            devboy_core::Error::Http(format!("POST failed: {}", e))
327        })?;
328
329        // Extract session ID from response headers (set during initialize)
330        if method == "initialize"
331            && let Some(sid) = response.headers().get("mcp-session-id")
332            && let Ok(sid_str) = sid.to_str()
333        {
334            let mut session = self.session_id.write().await;
335            *session = Some(sid_str.to_string());
336            tracing::debug!("Proxy '{}': got session ID", self.name);
337        }
338
339        let status = response.status();
340        if !status.is_success() {
341            let body = response.text().await.unwrap_or_default();
342            return Err(devboy_core::Error::Http(format!(
343                "HTTP {}: {}",
344                status, body
345            )));
346        }
347
348        // Check Content-Type: server may respond with JSON or SSE stream
349        let content_type = response
350            .headers()
351            .get(CONTENT_TYPE)
352            .and_then(|v| v.to_str().ok())
353            .unwrap_or("")
354            .to_string();
355
356        let resp = if content_type.contains("text/event-stream") {
357            // Parse SSE stream to extract JSON-RPC response
358            tracing::debug!("Response is SSE stream, parsing events...");
359            self.parse_sse_response(response, id).await?
360        } else {
361            // Direct JSON response
362            tracing::debug!("Response is JSON (content-type: {})", content_type);
363            response
364                .json::<JsonRpcResponse>()
365                .await
366                .map_err(|e| devboy_core::Error::Http(format!("Failed to parse response: {}", e)))?
367        };
368
369        // Verify response ID matches request ID
370        let expected_id = RequestId::Number(id);
371        if resp.id != expected_id {
372            return Err(devboy_core::Error::Http(format!(
373                "Mismatched JSON-RPC id: expected {:?}, got {:?}",
374                expected_id, resp.id
375            )));
376        }
377
378        Ok(resp)
379    }
380
381    /// Parse an SSE event stream response to extract the JSON-RPC response.
382    ///
383    /// Streamable HTTP spec allows servers to respond with SSE for long-running
384    /// operations like tool calls. Reads the stream line-by-line instead of
385    /// buffering the entire body (which would hang on open SSE connections).
386    async fn parse_sse_response(
387        &self,
388        response: reqwest::Response,
389        expected_id: i64,
390    ) -> devboy_core::Result<JsonRpcResponse> {
391        use futures::TryStreamExt;
392        use tokio::io::AsyncBufReadExt;
393
394        let stream = response.bytes_stream().map_err(std::io::Error::other);
395        let reader = tokio_util::io::StreamReader::new(stream);
396        let mut lines = tokio::io::BufReader::new(reader).lines();
397
398        let mut current_data = String::new();
399
400        tracing::debug!("Starting SSE line reader...");
401
402        tokio::time::timeout(Duration::from_secs(60), async {
403            while let Ok(Some(line)) = lines.next_line().await {
404                let line = line.trim().to_string();
405                let debug_len = line
406                    .char_indices()
407                    .nth(100)
408                    .map(|(i, _)| i)
409                    .unwrap_or(line.len());
410                tracing::debug!("SSE line: {}", &line[..debug_len]);
411
412                if line.is_empty() {
413                    // End of SSE event — try to parse collected data
414                    if !current_data.is_empty()
415                        && let Ok(resp) = serde_json::from_str::<JsonRpcResponse>(&current_data)
416                    {
417                        let id_matches = match &resp.id {
418                            RequestId::Number(n) => *n == expected_id,
419                            _ => false,
420                        };
421                        if id_matches {
422                            return Ok(resp);
423                        }
424                        current_data.clear();
425                    } else if !current_data.is_empty() {
426                        current_data.clear();
427                    }
428                    continue;
429                }
430
431                if let Some(data) = line.strip_prefix("data:") {
432                    let data = data.trim();
433                    if !data.is_empty() {
434                        current_data.push_str(data);
435                    }
436                }
437                // Skip event:, id:, retry: lines
438            }
439
440            // Try last accumulated data
441            if !current_data.is_empty()
442                && let Ok(resp) = serde_json::from_str::<JsonRpcResponse>(&current_data)
443            {
444                return Ok(resp);
445            }
446
447            Err(devboy_core::Error::Http(
448                "No matching JSON-RPC response found in SSE stream".to_string(),
449            ))
450        })
451        .await
452        .map_err(|_| devboy_core::Error::Http("Timeout reading SSE response".to_string()))?
453    }
454
455    /// Send initialize handshake.
456    async fn initialize(&self) -> devboy_core::Result<()> {
457        let params = serde_json::json!({
458            "protocolVersion": "2025-11-25",
459            "capabilities": {},
460            "clientInfo": {
461                "name": "devboy-mcp-proxy",
462                "version": env!("CARGO_PKG_VERSION")
463            }
464        });
465
466        let resp = self.request("initialize", Some(params)).await?;
467        if let Some(err) = resp.error {
468            return Err(devboy_core::Error::Http(format!(
469                "Initialize failed: {}",
470                err.message
471            )));
472        }
473
474        tracing::info!("Proxy '{}' initialized", self.name);
475        Ok(())
476    }
477
478    /// Fetch tools/list from upstream. Call this before using `prefixed_tools()`.
479    pub async fn fetch_tools(&mut self) -> devboy_core::Result<()> {
480        let resp = self.request("tools/list", None).await?;
481
482        if let Some(result) = resp.result {
483            #[derive(serde::Deserialize)]
484            struct ToolsList {
485                tools: Vec<ToolDefinition>,
486            }
487
488            if let Ok(list) = serde_json::from_value::<ToolsList>(result) {
489                self.upstream_tools = list.tools;
490                tracing::info!(
491                    "Proxy '{}': fetched {} tools",
492                    self.name,
493                    self.upstream_tools.len()
494                );
495            }
496        }
497
498        Ok(())
499    }
500
501    /// Return upstream tools with prefixed names.
502    pub fn prefixed_tools(&self) -> Vec<ToolDefinition> {
503        self.upstream_tools
504            .iter()
505            .map(|t| ToolDefinition {
506                name: format!("{}__{}", self.tool_prefix, t.name),
507                description: format!("[{}] {}", self.name, t.description),
508                input_schema: t.input_schema.clone(),
509                category: None, // Proxy tools don't have a category (always available)
510            })
511            .collect()
512    }
513
514    /// Return the raw upstream tool catalogue (without prefixing).
515    /// Used by the signature matcher to compare upstream definitions against the local
516    /// tool registry. Empty until `fetch_tools()` has been called.
517    pub fn raw_upstream_tools(&self) -> &[ToolDefinition] {
518        &self.upstream_tools
519    }
520
521    /// Execute a tool call, stripping the prefix before forwarding.
522    pub async fn call_tool(
523        &self,
524        original_name: &str,
525        arguments: Option<Value>,
526    ) -> devboy_core::Result<ToolCallResult> {
527        let params = serde_json::json!({
528            "name": original_name,
529            "arguments": arguments.unwrap_or(Value::Object(Default::default()))
530        });
531
532        let resp = self.request("tools/call", Some(params)).await?;
533
534        if let Some(err) = resp.error {
535            return Ok(ToolCallResult::error(err.message));
536        }
537
538        match resp.result {
539            Some(result) => serde_json::from_value(result).map_err(|e| {
540                devboy_core::Error::InvalidData(format!("Invalid tool result: {}", e))
541            }),
542            None => Ok(ToolCallResult::error(
543                "Empty response from upstream".to_string(),
544            )),
545        }
546    }
547
548    /// Get the tool prefix for this client.
549    pub fn prefix(&self) -> &str {
550        &self.tool_prefix
551    }
552}
553
554/// Manages multiple upstream MCP proxy connections.
555pub struct ProxyManager {
556    clients: Vec<McpProxyClient>,
557}
558
559impl Default for ProxyManager {
560    fn default() -> Self {
561        Self::new()
562    }
563}
564
565impl ProxyManager {
566    pub fn new() -> Self {
567        Self {
568            clients: Vec::new(),
569        }
570    }
571
572    pub fn add_client(&mut self, client: McpProxyClient) {
573        self.clients.push(client);
574    }
575
576    pub fn is_empty(&self) -> bool {
577        self.clients.is_empty()
578    }
579
580    /// Fetch tool lists from all upstream servers.
581    pub async fn fetch_all_tools(&mut self) -> devboy_core::Result<()> {
582        for client in &mut self.clients {
583            client.fetch_tools().await?;
584        }
585        Ok(())
586    }
587
588    /// Get all proxied tools (with prefixes) from all upstreams.
589    pub fn all_tools(&self) -> Vec<ToolDefinition> {
590        self.clients
591            .iter()
592            .flat_map(|c| c.prefixed_tools())
593            .collect()
594    }
595
596    /// Check whether a tool name belongs to a proxied upstream.
597    pub fn has_tool(&self, tool_name: &str) -> bool {
598        self.clients
599            .iter()
600            .any(|c| tool_name.starts_with(&format!("{}__", c.prefix())))
601    }
602
603    /// Try to route a tool call to the matching upstream.
604    /// Returns None if no upstream matches the tool name prefix.
605    pub async fn try_call(
606        &self,
607        tool_name: &str,
608        arguments: Option<Value>,
609    ) -> Option<ToolCallResult> {
610        for client in &self.clients {
611            let prefix = format!("{}__", client.prefix());
612            if let Some(original_name) = tool_name.strip_prefix(&prefix) {
613                let result = client.call_tool(original_name, arguments).await;
614                return Some(match result {
615                    Ok(r) => r,
616                    Err(e) => ToolCallResult::error(format!("Proxy error: {}", e)),
617                });
618            }
619        }
620        None
621    }
622
623    /// Call a specific upstream by prefix using the unprefixed tool name.
624    /// Used by the routing engine when it has already decided the remote executor is the
625    /// right target for a matched tool (and therefore doesn't need to rely on the
626    /// prefixed alias).
627    pub async fn call_by_prefix(
628        &self,
629        prefix: &str,
630        unprefixed_tool_name: &str,
631        arguments: Option<Value>,
632    ) -> Option<ToolCallResult> {
633        for client in &self.clients {
634            if client.prefix() == prefix {
635                let result = client.call_tool(unprefixed_tool_name, arguments).await;
636                return Some(match result {
637                    Ok(r) => r,
638                    Err(e) => ToolCallResult::error(format!("Proxy error: {}", e)),
639                });
640            }
641        }
642        None
643    }
644
645    /// Return every upstream's raw (unprefixed) tool catalogue tagged by prefix.
646    /// Consumers use this to feed the signature matcher.
647    pub fn raw_upstream_catalogue(&self) -> Vec<(String, &[ToolDefinition])> {
648        self.clients
649            .iter()
650            .map(|c| (c.prefix().to_string(), c.raw_upstream_tools()))
651            .collect()
652    }
653}
654
655// =============================================================================
656// Tests
657// =============================================================================
658
659#[cfg(test)]
660#[allow(clippy::err_expect)]
661mod tests {
662    use super::*;
663    use crate::protocol::ToolResultContent;
664    use httpmock::prelude::*;
665
666    fn token_secret(s: &str) -> SecretString {
667        SecretString::from(s.to_string())
668    }
669
670    // =========================================================================
671    // ProxyTransport
672    // =========================================================================
673
674    #[test]
675    fn test_proxy_transport_parse() {
676        assert_eq!(
677            ProxyTransport::parse("streamable-http"),
678            ProxyTransport::StreamableHttp
679        );
680        assert_eq!(
681            ProxyTransport::parse("streamable_http"),
682            ProxyTransport::StreamableHttp
683        );
684        assert_eq!(
685            ProxyTransport::parse("http"),
686            ProxyTransport::StreamableHttp
687        );
688        assert_eq!(ProxyTransport::parse("sse"), ProxyTransport::Sse);
689        assert_eq!(ProxyTransport::parse(""), ProxyTransport::Sse);
690        assert_eq!(ProxyTransport::parse("unknown"), ProxyTransport::Sse);
691    }
692
693    #[test]
694    fn test_proxy_transport_debug_clone_eq() {
695        let t = ProxyTransport::Sse;
696        let t2 = t;
697        assert_eq!(t, t2);
698        assert_eq!(format!("{:?}", t), "Sse");
699        assert_eq!(
700            format!("{:?}", ProxyTransport::StreamableHttp),
701            "StreamableHttp"
702        );
703    }
704
705    // =========================================================================
706    // Helper: mock upstream that implements Streamable HTTP MCP protocol
707    // =========================================================================
708
709    /// Create a MockServer that responds to initialize and tools/list.
710    fn setup_mock_upstream(server: &MockServer, tools: Vec<serde_json::Value>) {
711        // Initialize endpoint — returns session ID in header
712        server.mock(|when, then| {
713            when.method(POST)
714                .path("/mcp")
715                .body_includes(r#""method":"initialize""#);
716            then.status(200)
717                .header("mcp-session-id", "test-session-123")
718                .json_body(serde_json::json!({
719                    "jsonrpc": "2.0",
720                    "id": 1,
721                    "result": {
722                        "protocolVersion": "2025-11-25",
723                        "capabilities": { "tools": {} },
724                        "serverInfo": { "name": "mock-server", "version": "1.0.0" }
725                    }
726                }));
727        });
728
729        // tools/list endpoint
730        server.mock(|when, then| {
731            when.method(POST)
732                .path("/mcp")
733                .body_includes(r#""method":"tools/list""#);
734            then.status(200).json_body(serde_json::json!({
735                "jsonrpc": "2.0",
736                "id": 2,
737                "result": { "tools": tools }
738            }));
739        });
740    }
741
742    fn sample_tools() -> Vec<serde_json::Value> {
743        vec![
744            serde_json::json!({
745                "name": "get_issues",
746                "description": "Get issues from tracker",
747                "inputSchema": { "type": "object", "properties": {} }
748            }),
749            serde_json::json!({
750                "name": "get_merge_requests",
751                "description": "Get merge requests",
752                "inputSchema": { "type": "object", "properties": {} }
753            }),
754        ]
755    }
756
757    // =========================================================================
758    // McpProxyClient — Streamable HTTP connect
759    // =========================================================================
760
761    #[tokio::test]
762    async fn test_connect_streamable_http() {
763        let server = MockServer::start();
764        setup_mock_upstream(&server, sample_tools());
765
766        let url = format!("{}/mcp", server.base_url());
767        let token = token_secret("my-token");
768        let client = McpProxyClient::connect(
769            "test-server",
770            &url,
771            None,
772            Some(&token),
773            "bearer",
774            ProxyTransport::StreamableHttp,
775        )
776        .await
777        .unwrap();
778
779        assert_eq!(client.prefix(), "test-server");
780        assert!(client.upstream_tools.is_empty()); // Not fetched yet
781    }
782
783    #[tokio::test]
784    async fn test_connect_with_custom_prefix() {
785        let server = MockServer::start();
786        setup_mock_upstream(&server, sample_tools());
787
788        let url = format!("{}/mcp", server.base_url());
789        let client = McpProxyClient::connect(
790            "test-server",
791            &url,
792            Some("custom"),
793            None,
794            "none",
795            ProxyTransport::StreamableHttp,
796        )
797        .await
798        .unwrap();
799
800        assert_eq!(client.prefix(), "custom");
801    }
802
803    #[tokio::test]
804    async fn test_connect_initialize_failure() {
805        let server = MockServer::start();
806
807        server.mock(|when, then| {
808            when.method(POST)
809                .path("/mcp")
810                .body_includes(r#""method":"initialize""#);
811            then.status(200).json_body(serde_json::json!({
812                "jsonrpc": "2.0",
813                "id": 1,
814                "error": { "code": -32600, "message": "Bad request" }
815            }));
816        });
817
818        let url = format!("{}/mcp", server.base_url());
819        let result = McpProxyClient::connect(
820            "test-server",
821            &url,
822            None,
823            None,
824            "none",
825            ProxyTransport::StreamableHttp,
826        )
827        .await;
828
829        let err = result.err().expect("should be error");
830        assert!(err.to_string().contains("Initialize failed"));
831    }
832
833    #[tokio::test]
834    async fn test_connect_http_error() {
835        let server = MockServer::start();
836
837        server.mock(|when, then| {
838            when.method(POST).path("/mcp");
839            then.status(500).body("Internal Server Error");
840        });
841
842        let url = format!("{}/mcp", server.base_url());
843        let result = McpProxyClient::connect(
844            "test-server",
845            &url,
846            None,
847            None,
848            "none",
849            ProxyTransport::StreamableHttp,
850        )
851        .await;
852
853        let err = result.err().expect("should be error");
854        assert!(err.to_string().contains("500"));
855    }
856
857    // =========================================================================
858    // McpProxyClient — fetch_tools
859    // =========================================================================
860
861    #[tokio::test]
862    async fn test_fetch_tools() {
863        let server = MockServer::start();
864        setup_mock_upstream(&server, sample_tools());
865
866        let url = format!("{}/mcp", server.base_url());
867        let mut client = McpProxyClient::connect(
868            "test-server",
869            &url,
870            None,
871            None,
872            "none",
873            ProxyTransport::StreamableHttp,
874        )
875        .await
876        .unwrap();
877
878        assert!(client.upstream_tools.is_empty());
879
880        client.fetch_tools().await.unwrap();
881
882        assert_eq!(client.upstream_tools.len(), 2);
883        assert_eq!(client.upstream_tools[0].name, "get_issues");
884        assert_eq!(client.upstream_tools[1].name, "get_merge_requests");
885    }
886
887    // =========================================================================
888    // McpProxyClient — prefixed_tools
889    // =========================================================================
890
891    #[tokio::test]
892    async fn test_prefixed_tools() {
893        let server = MockServer::start();
894        setup_mock_upstream(&server, sample_tools());
895
896        let url = format!("{}/mcp", server.base_url());
897        let mut client = McpProxyClient::connect(
898            "my-server",
899            &url,
900            Some("cloud"),
901            None,
902            "none",
903            ProxyTransport::StreamableHttp,
904        )
905        .await
906        .unwrap();
907
908        client.fetch_tools().await.unwrap();
909
910        let prefixed = client.prefixed_tools();
911        assert_eq!(prefixed.len(), 2);
912        assert_eq!(prefixed[0].name, "cloud__get_issues");
913        assert_eq!(prefixed[1].name, "cloud__get_merge_requests");
914        assert!(prefixed[0].description.starts_with("[my-server]"));
915    }
916
917    #[tokio::test]
918    async fn test_prefixed_tools_empty_when_not_fetched() {
919        let server = MockServer::start();
920        setup_mock_upstream(&server, sample_tools());
921
922        let url = format!("{}/mcp", server.base_url());
923        let client = McpProxyClient::connect(
924            "test-server",
925            &url,
926            None,
927            None,
928            "none",
929            ProxyTransport::StreamableHttp,
930        )
931        .await
932        .unwrap();
933
934        let prefixed = client.prefixed_tools();
935        assert!(prefixed.is_empty());
936    }
937
938    // =========================================================================
939    // McpProxyClient — call_tool
940    // =========================================================================
941
942    #[tokio::test]
943    async fn test_call_tool_success() {
944        let server = MockServer::start();
945        setup_mock_upstream(&server, sample_tools());
946
947        // tools/call endpoint
948        server.mock(|when, then| {
949            when.method(POST)
950                .path("/mcp")
951                .body_includes(r#""method":"tools/call""#);
952            then.status(200).json_body(serde_json::json!({
953                "jsonrpc": "2.0",
954                "id": 2,
955                "result": {
956                    "content": [{ "type": "text", "text": "issue data here" }]
957                }
958            }));
959        });
960
961        let url = format!("{}/mcp", server.base_url());
962        let client = McpProxyClient::connect(
963            "test-server",
964            &url,
965            None,
966            None,
967            "none",
968            ProxyTransport::StreamableHttp,
969        )
970        .await
971        .unwrap();
972
973        let result = client
974            .call_tool("get_issues", Some(serde_json::json!({"state": "open"})))
975            .await
976            .unwrap();
977
978        assert!(result.is_error.is_none());
979        assert_eq!(result.content.len(), 1);
980        match &result.content[0] {
981            ToolResultContent::Text { text } => assert_eq!(text, "issue data here"),
982        }
983    }
984
985    #[tokio::test]
986    async fn test_call_tool_with_upstream_error() {
987        let server = MockServer::start();
988        setup_mock_upstream(&server, sample_tools());
989
990        server.mock(|when, then| {
991            when.method(POST)
992                .path("/mcp")
993                .body_includes(r#""method":"tools/call""#);
994            then.status(200).json_body(serde_json::json!({
995                "jsonrpc": "2.0",
996                "id": 2,
997                "error": { "code": -32000, "message": "Tool execution failed" }
998            }));
999        });
1000
1001        let url = format!("{}/mcp", server.base_url());
1002        let client = McpProxyClient::connect(
1003            "test-server",
1004            &url,
1005            None,
1006            None,
1007            "none",
1008            ProxyTransport::StreamableHttp,
1009        )
1010        .await
1011        .unwrap();
1012
1013        let result = client.call_tool("get_issues", None).await.unwrap();
1014
1015        assert_eq!(result.is_error, Some(true));
1016        match &result.content[0] {
1017            ToolResultContent::Text { text } => assert!(text.contains("Tool execution failed")),
1018        }
1019    }
1020
1021    #[tokio::test]
1022    async fn test_call_tool_empty_response() {
1023        let server = MockServer::start();
1024        setup_mock_upstream(&server, sample_tools());
1025
1026        server.mock(|when, then| {
1027            when.method(POST)
1028                .path("/mcp")
1029                .body_includes(r#""method":"tools/call""#);
1030            then.status(200).json_body(serde_json::json!({
1031                "jsonrpc": "2.0",
1032                "id": 2
1033            }));
1034        });
1035
1036        let url = format!("{}/mcp", server.base_url());
1037        let client = McpProxyClient::connect(
1038            "test-server",
1039            &url,
1040            None,
1041            None,
1042            "none",
1043            ProxyTransport::StreamableHttp,
1044        )
1045        .await
1046        .unwrap();
1047
1048        let result = client.call_tool("get_issues", None).await.unwrap();
1049
1050        assert_eq!(result.is_error, Some(true));
1051        match &result.content[0] {
1052            ToolResultContent::Text { text } => assert!(text.contains("Empty response")),
1053        }
1054    }
1055
1056    // =========================================================================
1057    // McpProxyClient — session ID management
1058    // =========================================================================
1059
1060    #[tokio::test]
1061    async fn test_session_id_sent_on_subsequent_requests() {
1062        let server = MockServer::start();
1063
1064        // Initialize — returns session ID
1065        server.mock(|when, then| {
1066            when.method(POST)
1067                .path("/mcp")
1068                .body_includes(r#""method":"initialize""#);
1069            then.status(200)
1070                .header("mcp-session-id", "sess-abc")
1071                .json_body(serde_json::json!({
1072                    "jsonrpc": "2.0",
1073                    "id": 1,
1074                    "result": {
1075                        "protocolVersion": "2025-11-25",
1076                        "capabilities": {},
1077                        "serverInfo": { "name": "mock", "version": "1.0" }
1078                    }
1079                }));
1080        });
1081
1082        // tools/call — expect session ID header
1083        let tool_call_mock = server.mock(|when, then| {
1084            when.method(POST)
1085                .path("/mcp")
1086                .header("mcp-session-id", "sess-abc")
1087                .body_includes(r#""method":"tools/call""#);
1088            then.status(200).json_body(serde_json::json!({
1089                "jsonrpc": "2.0",
1090                "id": 2,
1091                "result": {
1092                    "content": [{ "type": "text", "text": "ok" }]
1093                }
1094            }));
1095        });
1096
1097        let url = format!("{}/mcp", server.base_url());
1098        let client = McpProxyClient::connect(
1099            "test-server",
1100            &url,
1101            None,
1102            None,
1103            "none",
1104            ProxyTransport::StreamableHttp,
1105        )
1106        .await
1107        .unwrap();
1108
1109        client.call_tool("test_tool", None).await.unwrap();
1110
1111        // Verify the session header was actually sent
1112        tool_call_mock.assert();
1113    }
1114
1115    // =========================================================================
1116    // McpProxyClient — auth types
1117    // =========================================================================
1118
1119    #[tokio::test]
1120    async fn test_bearer_auth_header() {
1121        let server = MockServer::start();
1122
1123        let init_mock = server.mock(|when, then| {
1124            when.method(POST)
1125                .path("/mcp")
1126                .header("Authorization", "Bearer secret-token")
1127                .body_includes(r#""method":"initialize""#);
1128            then.status(200).json_body(serde_json::json!({
1129                "jsonrpc": "2.0",
1130                "id": 1,
1131                "result": {
1132                    "protocolVersion": "2025-11-25",
1133                    "capabilities": {},
1134                    "serverInfo": { "name": "mock", "version": "1.0" }
1135                }
1136            }));
1137        });
1138
1139        let url = format!("{}/mcp", server.base_url());
1140        let token = token_secret("secret-token");
1141        McpProxyClient::connect(
1142            "test-server",
1143            &url,
1144            None,
1145            Some(&token),
1146            "bearer",
1147            ProxyTransport::StreamableHttp,
1148        )
1149        .await
1150        .unwrap();
1151
1152        init_mock.assert();
1153    }
1154
1155    #[tokio::test]
1156    async fn test_api_key_auth_header() {
1157        let server = MockServer::start();
1158
1159        let init_mock = server.mock(|when, then| {
1160            when.method(POST)
1161                .path("/mcp")
1162                .header("X-API-Key", "my-api-key")
1163                .body_includes(r#""method":"initialize""#);
1164            then.status(200).json_body(serde_json::json!({
1165                "jsonrpc": "2.0",
1166                "id": 1,
1167                "result": {
1168                    "protocolVersion": "2025-11-25",
1169                    "capabilities": {},
1170                    "serverInfo": { "name": "mock", "version": "1.0" }
1171                }
1172            }));
1173        });
1174
1175        let url = format!("{}/mcp", server.base_url());
1176        let token = token_secret("my-api-key");
1177        McpProxyClient::connect(
1178            "test-server",
1179            &url,
1180            None,
1181            Some(&token),
1182            "api_key",
1183            ProxyTransport::StreamableHttp,
1184        )
1185        .await
1186        .unwrap();
1187
1188        init_mock.assert();
1189    }
1190
1191    // =========================================================================
1192    // ProxyManager
1193    // =========================================================================
1194
1195    #[test]
1196    fn test_proxy_manager_new_is_empty() {
1197        let mgr = ProxyManager::new();
1198        assert!(mgr.is_empty());
1199        assert!(mgr.all_tools().is_empty());
1200    }
1201
1202    #[tokio::test]
1203    async fn test_proxy_manager_all_tools() {
1204        let server = MockServer::start();
1205        setup_mock_upstream(&server, sample_tools());
1206
1207        let url = format!("{}/mcp", server.base_url());
1208        let mut client = McpProxyClient::connect(
1209            "upstream",
1210            &url,
1211            Some("up"),
1212            None,
1213            "none",
1214            ProxyTransport::StreamableHttp,
1215        )
1216        .await
1217        .unwrap();
1218
1219        client.fetch_tools().await.unwrap();
1220
1221        let mut mgr = ProxyManager::new();
1222        mgr.add_client(client);
1223
1224        assert!(!mgr.is_empty());
1225
1226        let tools = mgr.all_tools();
1227        assert_eq!(tools.len(), 2);
1228        assert_eq!(tools[0].name, "up__get_issues");
1229        assert_eq!(tools[1].name, "up__get_merge_requests");
1230    }
1231
1232    #[tokio::test]
1233    async fn test_proxy_manager_try_call_routes_correctly() {
1234        let server = MockServer::start();
1235        setup_mock_upstream(&server, sample_tools());
1236
1237        server.mock(|when, then| {
1238            when.method(POST)
1239                .path("/mcp")
1240                .body_includes(r#""method":"tools/call""#);
1241            then.status(200).json_body(serde_json::json!({
1242                "jsonrpc": "2.0",
1243                "id": 2,
1244                "result": {
1245                    "content": [{ "type": "text", "text": "routed ok" }]
1246                }
1247            }));
1248        });
1249
1250        let url = format!("{}/mcp", server.base_url());
1251        let client = McpProxyClient::connect(
1252            "upstream",
1253            &url,
1254            Some("up"),
1255            None,
1256            "none",
1257            ProxyTransport::StreamableHttp,
1258        )
1259        .await
1260        .unwrap();
1261
1262        let mut mgr = ProxyManager::new();
1263        mgr.add_client(client);
1264
1265        let result = mgr
1266            .try_call("up__get_issues", Some(serde_json::json!({})))
1267            .await;
1268
1269        assert!(result.is_some());
1270        let result = result.unwrap();
1271        assert!(result.is_error.is_none());
1272        match &result.content[0] {
1273            ToolResultContent::Text { text } => assert_eq!(text, "routed ok"),
1274        }
1275    }
1276
1277    #[tokio::test]
1278    async fn test_proxy_manager_try_call_no_match() {
1279        let server = MockServer::start();
1280        setup_mock_upstream(&server, sample_tools());
1281
1282        let url = format!("{}/mcp", server.base_url());
1283        let client = McpProxyClient::connect(
1284            "upstream",
1285            &url,
1286            Some("up"),
1287            None,
1288            "none",
1289            ProxyTransport::StreamableHttp,
1290        )
1291        .await
1292        .unwrap();
1293
1294        let mut mgr = ProxyManager::new();
1295        mgr.add_client(client);
1296
1297        let result = mgr
1298            .try_call("unknown__get_issues", Some(serde_json::json!({})))
1299            .await;
1300
1301        assert!(result.is_none());
1302    }
1303
1304    #[tokio::test]
1305    async fn test_proxy_manager_try_call_without_prefix_no_match() {
1306        let mgr = ProxyManager::new();
1307        let result = mgr.try_call("get_issues", None).await;
1308        assert!(result.is_none());
1309    }
1310
1311    #[tokio::test]
1312    async fn test_proxy_manager_fetch_all_tools() {
1313        let server = MockServer::start();
1314        setup_mock_upstream(&server, sample_tools());
1315
1316        let url = format!("{}/mcp", server.base_url());
1317        let client = McpProxyClient::connect(
1318            "upstream",
1319            &url,
1320            Some("up"),
1321            None,
1322            "none",
1323            ProxyTransport::StreamableHttp,
1324        )
1325        .await
1326        .unwrap();
1327
1328        let mut mgr = ProxyManager::new();
1329        mgr.add_client(client);
1330
1331        assert!(mgr.all_tools().is_empty());
1332
1333        mgr.fetch_all_tools().await.unwrap();
1334
1335        assert_eq!(mgr.all_tools().len(), 2);
1336    }
1337
1338    // =========================================================================
1339    // McpProxyClient — invalid token (non-ASCII)
1340    // =========================================================================
1341
1342    #[tokio::test]
1343    async fn test_connect_invalid_bearer_token() {
1344        let token = token_secret("token-with-\x01-control-chars");
1345        let result = McpProxyClient::connect(
1346            "test-server",
1347            "http://localhost:1/mcp",
1348            None,
1349            Some(&token),
1350            "bearer",
1351            ProxyTransport::StreamableHttp,
1352        )
1353        .await;
1354
1355        let err = result.err().expect("should be error");
1356        assert!(err.to_string().contains("Invalid token"));
1357    }
1358
1359    #[tokio::test]
1360    async fn test_connect_invalid_api_key_token() {
1361        let token = token_secret("key-with-\x01-control");
1362        let result = McpProxyClient::connect(
1363            "test-server",
1364            "http://localhost:1/mcp",
1365            None,
1366            Some(&token),
1367            "api_key",
1368            ProxyTransport::StreamableHttp,
1369        )
1370        .await;
1371
1372        let err = result.err().expect("should be error");
1373        assert!(err.to_string().contains("Invalid token"));
1374    }
1375
1376    // =========================================================================
1377    // McpProxyClient — SSE transport via mock
1378    // =========================================================================
1379
1380    /// Helper: set up SSE mock endpoint that returns endpoint event + initialize response.
1381    fn setup_sse_mock(server: &MockServer) {
1382        // SSE stream: returns endpoint event, then initialize response
1383        server.mock(|when, then| {
1384            when.method(GET).path("/sse");
1385            then.status(200)
1386                .header("content-type", "text/event-stream")
1387                .header("cache-control", "no-cache")
1388                .body(
1389                    "event: endpoint\ndata: /messages\n\n\
1390                     event: message\ndata: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"protocolVersion\":\"2025-11-25\",\"capabilities\":{},\"serverInfo\":{\"name\":\"mock-sse\",\"version\":\"1.0\"}}}\n\n"
1391                );
1392        });
1393
1394        // POST endpoint for requests (initialize is sent here)
1395        server.mock(|when, then| {
1396            when.method(POST).path("/messages");
1397            then.status(200);
1398        });
1399    }
1400
1401    #[tokio::test]
1402    async fn test_connect_sse_transport() {
1403        let server = MockServer::start();
1404        setup_sse_mock(&server);
1405
1406        let url = format!("{}/sse", server.base_url());
1407        let result = McpProxyClient::connect(
1408            "sse-server",
1409            &url,
1410            Some("sse"),
1411            None,
1412            "none",
1413            ProxyTransport::Sse,
1414        )
1415        .await;
1416
1417        assert!(result.is_ok(), "SSE connect failed: {:?}", result.err());
1418        let client = result.unwrap();
1419        assert_eq!(client.prefix(), "sse");
1420        assert_eq!(client.transport, ProxyTransport::Sse);
1421    }
1422
1423    #[tokio::test]
1424    async fn test_connect_sse_with_bearer_auth() {
1425        let server = MockServer::start();
1426
1427        // SSE stream with auth check
1428        server.mock(|when, then| {
1429            when.method(GET)
1430                .path("/sse")
1431                .header("Authorization", "Bearer sse-token");
1432            then.status(200)
1433                .header("content-type", "text/event-stream")
1434                .header("cache-control", "no-cache")
1435                .body(
1436                    "event: endpoint\ndata: /messages\n\n\
1437                     event: message\ndata: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"protocolVersion\":\"2025-11-25\",\"capabilities\":{},\"serverInfo\":{\"name\":\"mock\",\"version\":\"1.0\"}}}\n\n"
1438                );
1439        });
1440
1441        server.mock(|when, then| {
1442            when.method(POST).path("/messages");
1443            then.status(200);
1444        });
1445
1446        let url = format!("{}/sse", server.base_url());
1447        let token = token_secret("sse-token");
1448        let result = McpProxyClient::connect(
1449            "sse-server",
1450            &url,
1451            None,
1452            Some(&token),
1453            "bearer",
1454            ProxyTransport::Sse,
1455        )
1456        .await;
1457
1458        assert!(
1459            result.is_ok(),
1460            "SSE connect with auth failed: {:?}",
1461            result.err()
1462        );
1463    }
1464
1465    #[tokio::test]
1466    async fn test_sse_request_dispatch_path() {
1467        // Verify that an SSE-transport client dispatches via request_sse.
1468        // We test that the request method correctly routes to SSE path
1469        // by checking the client transport type after connect.
1470        let server = MockServer::start();
1471        setup_sse_mock(&server);
1472
1473        let url = format!("{}/sse", server.base_url());
1474        let client = McpProxyClient::connect(
1475            "sse-server",
1476            &url,
1477            Some("sse"),
1478            None,
1479            "none",
1480            ProxyTransport::Sse,
1481        )
1482        .await
1483        .unwrap();
1484
1485        // Verify the client is configured for SSE transport
1486        assert_eq!(client.transport, ProxyTransport::Sse);
1487        // The post_url should be resolved from the endpoint event
1488        assert!(client.post_url.contains("/messages"));
1489    }
1490
1491    // =========================================================================
1492    // McpProxyClient — fetch_tools with error response
1493    // =========================================================================
1494
1495    #[tokio::test]
1496    async fn test_fetch_tools_with_error_response() {
1497        let server = MockServer::start();
1498
1499        // Initialize endpoint
1500        server.mock(|when, then| {
1501            when.method(POST)
1502                .path("/mcp")
1503                .body_includes(r#""method":"initialize""#);
1504            then.status(200).json_body(serde_json::json!({
1505                "jsonrpc": "2.0",
1506                "id": 1,
1507                "result": {
1508                    "protocolVersion": "2025-11-25",
1509                    "capabilities": {},
1510                    "serverInfo": { "name": "mock", "version": "1.0" }
1511                }
1512            }));
1513        });
1514
1515        // tools/list returns error
1516        server.mock(|when, then| {
1517            when.method(POST)
1518                .path("/mcp")
1519                .body_includes(r#""method":"tools/list""#);
1520            then.status(200).json_body(serde_json::json!({
1521                "jsonrpc": "2.0",
1522                "id": 2,
1523                "error": { "code": -32601, "message": "Method not found" }
1524            }));
1525        });
1526
1527        let url = format!("{}/mcp", server.base_url());
1528        let mut client = McpProxyClient::connect(
1529            "test-server",
1530            &url,
1531            None,
1532            None,
1533            "none",
1534            ProxyTransport::StreamableHttp,
1535        )
1536        .await
1537        .unwrap();
1538
1539        // fetch_tools should succeed but not populate tools (error response has no result)
1540        client.fetch_tools().await.unwrap();
1541        assert!(client.upstream_tools.is_empty());
1542    }
1543
1544    #[tokio::test]
1545    async fn test_fetch_tools_with_empty_result() {
1546        let server = MockServer::start();
1547
1548        server.mock(|when, then| {
1549            when.method(POST)
1550                .path("/mcp")
1551                .body_includes(r#""method":"initialize""#);
1552            then.status(200).json_body(serde_json::json!({
1553                "jsonrpc": "2.0",
1554                "id": 1,
1555                "result": {
1556                    "protocolVersion": "2025-11-25",
1557                    "capabilities": {},
1558                    "serverInfo": { "name": "mock", "version": "1.0" }
1559                }
1560            }));
1561        });
1562
1563        // tools/list returns result without tools field
1564        server.mock(|when, then| {
1565            when.method(POST)
1566                .path("/mcp")
1567                .body_includes(r#""method":"tools/list""#);
1568            then.status(200).json_body(serde_json::json!({
1569                "jsonrpc": "2.0",
1570                "id": 2,
1571                "result": { "something_else": true }
1572            }));
1573        });
1574
1575        let url = format!("{}/mcp", server.base_url());
1576        let mut client = McpProxyClient::connect(
1577            "test-server",
1578            &url,
1579            None,
1580            None,
1581            "none",
1582            ProxyTransport::StreamableHttp,
1583        )
1584        .await
1585        .unwrap();
1586
1587        // fetch_tools should succeed, but tools remain empty (deserialization fails silently)
1588        client.fetch_tools().await.unwrap();
1589        assert!(client.upstream_tools.is_empty());
1590    }
1591
1592    // =========================================================================
1593    // McpProxyClient — call_tool with None arguments
1594    // =========================================================================
1595
1596    #[tokio::test]
1597    async fn test_call_tool_with_none_arguments_uses_empty_object() {
1598        let server = MockServer::start();
1599        setup_mock_upstream(&server, sample_tools());
1600
1601        // Verify that arguments defaults to {}
1602        let tool_mock = server.mock(|when, then| {
1603            when.method(POST)
1604                .path("/mcp")
1605                .body_includes(r#""arguments":{}"#)
1606                .body_includes(r#""method":"tools/call""#);
1607            then.status(200).json_body(serde_json::json!({
1608                "jsonrpc": "2.0",
1609                "id": 2,
1610                "result": {
1611                    "content": [{ "type": "text", "text": "no args ok" }]
1612                }
1613            }));
1614        });
1615
1616        let url = format!("{}/mcp", server.base_url());
1617        let client = McpProxyClient::connect(
1618            "test-server",
1619            &url,
1620            None,
1621            None,
1622            "none",
1623            ProxyTransport::StreamableHttp,
1624        )
1625        .await
1626        .unwrap();
1627
1628        let result = client.call_tool("get_issues", None).await.unwrap();
1629        assert!(result.is_error.is_none());
1630        tool_mock.assert();
1631    }
1632
1633    // =========================================================================
1634    // ProxyManager — try_call transport error
1635    // =========================================================================
1636
1637    #[tokio::test]
1638    async fn test_proxy_manager_try_call_transport_error() {
1639        let server = MockServer::start();
1640        setup_mock_upstream(&server, sample_tools());
1641
1642        let url = format!("{}/mcp", server.base_url());
1643        let client = McpProxyClient::connect(
1644            "upstream",
1645            &url,
1646            Some("up"),
1647            None,
1648            "none",
1649            ProxyTransport::StreamableHttp,
1650        )
1651        .await
1652        .unwrap();
1653
1654        let mut mgr = ProxyManager::new();
1655        mgr.add_client(client);
1656
1657        // Drop the mock server so the next call fails with a transport error
1658        drop(server);
1659
1660        let result = mgr
1661            .try_call("up__get_issues", Some(serde_json::json!({})))
1662            .await;
1663
1664        assert!(result.is_some());
1665        let result = result.unwrap();
1666        assert_eq!(result.is_error, Some(true));
1667        match &result.content[0] {
1668            ToolResultContent::Text { text } => assert!(text.contains("Proxy error")),
1669        }
1670    }
1671
1672    // =========================================================================
1673    // ProxyManager — default trait
1674    // =========================================================================
1675
1676    #[test]
1677    fn test_proxy_manager_default() {
1678        let mgr = ProxyManager::default();
1679        assert!(mgr.is_empty());
1680        assert!(mgr.all_tools().is_empty());
1681    }
1682
1683    // =========================================================================
1684    // ProxyManager — multiple clients routing
1685    // =========================================================================
1686
1687    #[tokio::test]
1688    async fn test_proxy_manager_multiple_clients() {
1689        let server1 = MockServer::start();
1690        let server2 = MockServer::start();
1691
1692        setup_mock_upstream(
1693            &server1,
1694            vec![serde_json::json!({
1695                "name": "tool_a",
1696                "description": "Tool A",
1697                "inputSchema": { "type": "object" }
1698            })],
1699        );
1700        setup_mock_upstream(
1701            &server2,
1702            vec![serde_json::json!({
1703                "name": "tool_b",
1704                "description": "Tool B",
1705                "inputSchema": { "type": "object" }
1706            })],
1707        );
1708
1709        let url1 = format!("{}/mcp", server1.base_url());
1710        let url2 = format!("{}/mcp", server2.base_url());
1711
1712        let client1 = McpProxyClient::connect(
1713            "server1",
1714            &url1,
1715            Some("s1"),
1716            None,
1717            "none",
1718            ProxyTransport::StreamableHttp,
1719        )
1720        .await
1721        .unwrap();
1722
1723        let client2 = McpProxyClient::connect(
1724            "server2",
1725            &url2,
1726            Some("s2"),
1727            None,
1728            "none",
1729            ProxyTransport::StreamableHttp,
1730        )
1731        .await
1732        .unwrap();
1733
1734        let mut mgr = ProxyManager::new();
1735        mgr.add_client(client1);
1736        mgr.add_client(client2);
1737
1738        mgr.fetch_all_tools().await.unwrap();
1739
1740        let tools = mgr.all_tools();
1741        assert_eq!(tools.len(), 2);
1742        assert!(tools.iter().any(|t| t.name == "s1__tool_a"));
1743        assert!(tools.iter().any(|t| t.name == "s2__tool_b"));
1744    }
1745
1746    // =========================================================================
1747    // Response ID validation
1748    // =========================================================================
1749
1750    #[tokio::test]
1751    async fn test_mismatched_response_id_returns_error() {
1752        let server = MockServer::start();
1753
1754        // Initialize returns correct id
1755        server.mock(|when, then| {
1756            when.method(POST)
1757                .path("/mcp")
1758                .body_includes(r#""method":"initialize""#);
1759            then.status(200)
1760                .header("mcp-session-id", "sess-1")
1761                .json_body(serde_json::json!({
1762                    "jsonrpc": "2.0",
1763                    "id": 1,
1764                    "result": {
1765                        "protocolVersion": "2025-11-25",
1766                        "capabilities": {},
1767                        "serverInfo": { "name": "mock", "version": "1.0" }
1768                    }
1769                }));
1770        });
1771
1772        // tools/call returns mismatched id
1773        server.mock(|when, then| {
1774            when.method(POST)
1775                .path("/mcp")
1776                .body_includes(r#""method":"tools/call""#);
1777            then.status(200).json_body(serde_json::json!({
1778                "jsonrpc": "2.0",
1779                "id": 999,
1780                "result": {
1781                    "content": [{ "type": "text", "text": "wrong id" }]
1782                }
1783            }));
1784        });
1785
1786        let url = format!("{}/mcp", server.base_url());
1787        let client = McpProxyClient::connect(
1788            "test-server",
1789            &url,
1790            None,
1791            None,
1792            "none",
1793            ProxyTransport::StreamableHttp,
1794        )
1795        .await
1796        .unwrap();
1797
1798        let result = client.call_tool("some_tool", None).await;
1799        let err = result.expect_err("should be error");
1800        assert!(err.to_string().contains("Mismatched JSON-RPC id"));
1801    }
1802}