Skip to main content

devboy_mcp/
proxy.rs

1//! MCP proxy — connects to upstream MCP servers and proxies tool calls.
2//!
3//! Supports two transports:
4//! - **SSE**: Legacy MCP transport with SSE stream for responses.
5//! - **Streamable HTTP**: POST-based transport with `mcp-session-id` header.
6
7use std::sync::Arc;
8use std::sync::atomic::{AtomicI64, Ordering};
9use std::time::Duration;
10
11use futures::StreamExt;
12use reqwest::header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderValue};
13use reqwest_eventsource::{Event, EventSource};
14use secrecy::{ExposeSecret, SecretString};
15use serde_json::Value;
16use tokio::sync::{Mutex, RwLock, oneshot};
17
18use crate::protocol::{
19    JSONRPC_VERSION, JsonRpcRequest, JsonRpcResponse, RequestId, ToolCallResult, ToolDefinition,
20};
21
22/// Transport mode for upstream MCP server.
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum ProxyTransport {
25    /// SSE-based transport (legacy MCP): GET for SSE stream, POST for requests.
26    Sse,
27    /// Streamable HTTP transport: POST with JSON response and mcp-session-id header.
28    StreamableHttp,
29}
30
31impl ProxyTransport {
32    pub fn parse(s: &str) -> Self {
33        match s {
34            "streamable-http" | "streamable_http" | "http" => Self::StreamableHttp,
35            _ => Self::Sse,
36        }
37    }
38}
39
40/// Pending SSE response receivers indexed by request ID.
41type PendingResponses = Arc<Mutex<Vec<(i64, oneshot::Sender<JsonRpcResponse>)>>>;
42
43/// Single upstream MCP server connection.
44pub struct McpProxyClient {
45    name: String,
46    tool_prefix: String,
47    post_url: String,
48    http_client: reqwest::Client,
49    upstream_tools: Vec<ToolDefinition>,
50    next_id: AtomicI64,
51    transport: ProxyTransport,
52    /// Session ID for Streamable HTTP transport.
53    session_id: RwLock<Option<String>>,
54    /// Channel to receive SSE responses routed by request id (SSE transport only).
55    pending: PendingResponses,
56}
57
58impl McpProxyClient {
59    /// Connect to an upstream MCP server, perform initialize handshake, fetch tools.
60    pub async fn connect(
61        name: &str,
62        url: &str,
63        tool_prefix: Option<&str>,
64        token: Option<&SecretString>,
65        auth_type: &str,
66        transport: ProxyTransport,
67    ) -> devboy_core::Result<Self> {
68        let mut headers = HeaderMap::new();
69        if let Some(token) = token {
70            match auth_type {
71                "bearer" => {
72                    let val = HeaderValue::from_str(&format!("Bearer {}", token.expose_secret()))
73                        .map_err(|e| {
74                        devboy_core::Error::Config(format!("Invalid token: {}", e))
75                    })?;
76                    headers.insert(AUTHORIZATION, val);
77                }
78                "api_key" => {
79                    let val = HeaderValue::from_str(token.expose_secret())
80                        .map_err(|e| devboy_core::Error::Config(format!("Invalid token: {}", e)))?;
81                    headers.insert("X-API-Key", val);
82                }
83                _ => {}
84            }
85        }
86
87        let http_client = reqwest::Client::builder()
88            .default_headers(headers.clone())
89            .timeout(Duration::from_secs(60))
90            .pool_max_idle_per_host(0)
91            .build()
92            .map_err(|e| devboy_core::Error::Http(format!("Failed to build HTTP client: {}", e)))?;
93
94        let prefix = tool_prefix.unwrap_or(name).to_string();
95
96        match transport {
97            ProxyTransport::Sse => {
98                Self::connect_sse(name, url, &prefix, headers, http_client).await
99            }
100            ProxyTransport::StreamableHttp => {
101                Self::connect_streamable_http(name, url, &prefix, http_client).await
102            }
103        }
104    }
105
106    /// Connect via SSE transport.
107    async fn connect_sse(
108        name: &str,
109        url: &str,
110        prefix: &str,
111        headers: HeaderMap,
112        http_client: reqwest::Client,
113    ) -> devboy_core::Result<Self> {
114        let sse_url = url.to_string();
115        let mut es = EventSource::new(
116            reqwest::Client::builder()
117                .default_headers(headers)
118                .build()
119                .unwrap()
120                .get(&sse_url),
121        )
122        .map_err(|e| {
123            devboy_core::Error::Http(format!("Failed to connect SSE to {}: {}", sse_url, e))
124        })?;
125
126        let post_url = Self::wait_for_endpoint(&mut es, url).await?;
127
128        let pending: PendingResponses = Arc::new(Mutex::new(Vec::new()));
129
130        // Spawn SSE listener
131        let pending_clone = pending.clone();
132        tokio::spawn(async move {
133            while let Some(event) = es.next().await {
134                match event {
135                    Ok(Event::Message(msg)) => {
136                        if msg.event == "message"
137                            && let Ok(resp) = serde_json::from_str::<JsonRpcResponse>(&msg.data)
138                        {
139                            let id_num = match &resp.id {
140                                RequestId::Number(n) => *n,
141                                _ => continue,
142                            };
143                            let mut pending = pending_clone.lock().await;
144                            if let Some(idx) = pending.iter().position(|(id, _)| *id == id_num) {
145                                let (_, sender) = pending.remove(idx);
146                                let _ = sender.send(resp);
147                            }
148                        }
149                    }
150                    Ok(Event::Open) => {
151                        tracing::debug!("SSE stream open");
152                    }
153                    Err(e) => {
154                        tracing::warn!("SSE error: {}", e);
155                        break;
156                    }
157                }
158            }
159        });
160
161        let client = Self {
162            name: name.to_string(),
163            tool_prefix: prefix.to_string(),
164            post_url,
165            http_client,
166            upstream_tools: Vec::new(),
167            next_id: AtomicI64::new(1),
168            transport: ProxyTransport::Sse,
169            session_id: RwLock::new(None),
170            pending,
171        };
172
173        client.initialize().await?;
174
175        Ok(client)
176    }
177
178    /// Connect via Streamable HTTP transport.
179    async fn connect_streamable_http(
180        name: &str,
181        url: &str,
182        prefix: &str,
183        http_client: reqwest::Client,
184    ) -> devboy_core::Result<Self> {
185        let client = Self {
186            name: name.to_string(),
187            tool_prefix: prefix.to_string(),
188            post_url: url.to_string(),
189            http_client,
190            upstream_tools: Vec::new(),
191            next_id: AtomicI64::new(1),
192            transport: ProxyTransport::StreamableHttp,
193            session_id: RwLock::new(None),
194            pending: Arc::new(Mutex::new(Vec::new())),
195        };
196
197        client.initialize().await?;
198
199        Ok(client)
200    }
201
202    /// Wait for the SSE `endpoint` event that tells us where to POST requests.
203    async fn wait_for_endpoint(
204        es: &mut EventSource,
205        base_url: &str,
206    ) -> devboy_core::Result<String> {
207        let timeout = tokio::time::timeout(Duration::from_secs(10), async {
208            while let Some(event) = es.next().await {
209                match event {
210                    Ok(Event::Message(msg)) if msg.event == "endpoint" => {
211                        let endpoint = msg.data.trim().to_string();
212                        // If relative URL, resolve against base
213                        if endpoint.starts_with('/')
214                            && let Ok(base) = reqwest::Url::parse(base_url)
215                            && let Ok(resolved) = base.join(&endpoint)
216                        {
217                            return Ok(resolved.to_string());
218                        }
219                        return Ok(endpoint);
220                    }
221                    Ok(Event::Open) => continue,
222                    Ok(_) => continue,
223                    Err(e) => {
224                        return Err(devboy_core::Error::Http(format!("SSE error: {}", e)));
225                    }
226                }
227            }
228            Err(devboy_core::Error::Http(
229                "SSE stream ended before endpoint event".to_string(),
230            ))
231        });
232
233        timeout.await.map_err(|_| {
234            devboy_core::Error::Http("Timeout waiting for SSE endpoint event".to_string())
235        })?
236    }
237
238    fn next_request_id(&self) -> i64 {
239        self.next_id.fetch_add(1, Ordering::SeqCst)
240    }
241
242    /// Send a JSON-RPC request and wait for response.
243    async fn request(
244        &self,
245        method: &str,
246        params: Option<Value>,
247    ) -> devboy_core::Result<JsonRpcResponse> {
248        match self.transport {
249            ProxyTransport::Sse => self.request_sse(method, params).await,
250            ProxyTransport::StreamableHttp => self.request_http(method, params).await,
251        }
252    }
253
254    /// Send request via SSE transport (POST request, response via SSE stream).
255    async fn request_sse(
256        &self,
257        method: &str,
258        params: Option<Value>,
259    ) -> devboy_core::Result<JsonRpcResponse> {
260        let id = self.next_request_id();
261        let req = JsonRpcRequest {
262            jsonrpc: JSONRPC_VERSION.to_string(),
263            id: RequestId::Number(id),
264            method: method.to_string(),
265            params,
266        };
267
268        let (tx, rx) = oneshot::channel();
269        {
270            let mut pending = self.pending.lock().await;
271            pending.push((id, tx));
272        }
273
274        self.http_client
275            .post(&self.post_url)
276            .json(&req)
277            .send()
278            .await
279            .map_err(|e| devboy_core::Error::Http(format!("POST failed: {}", e)))?;
280
281        let resp = tokio::time::timeout(Duration::from_secs(30), rx)
282            .await
283            .map_err(|_| devboy_core::Error::Http("Timeout waiting for response".to_string()))?
284            .map_err(|_| devboy_core::Error::Http("Response channel closed".to_string()))?;
285
286        Ok(resp)
287    }
288
289    /// Send request via Streamable HTTP transport (POST, JSON response, session header).
290    async fn request_http(
291        &self,
292        method: &str,
293        params: Option<Value>,
294    ) -> devboy_core::Result<JsonRpcResponse> {
295        let id = self.next_request_id();
296        let req = JsonRpcRequest {
297            jsonrpc: JSONRPC_VERSION.to_string(),
298            id: RequestId::Number(id),
299            method: method.to_string(),
300            params,
301        };
302
303        let mut request = self
304            .http_client
305            .post(&self.post_url)
306            .header(CONTENT_TYPE, "application/json")
307            .header(ACCEPT, "application/json, text/event-stream");
308
309        // Add session ID for all requests after initialize
310        if method != "initialize" {
311            let session = self.session_id.read().await;
312            if let Some(sid) = session.as_ref() {
313                request = request.header("mcp-session-id", sid);
314            }
315        }
316
317        let response = request.json(&req).send().await.map_err(|e| {
318            tracing::error!(
319                "POST to {} failed: {} (is_timeout={}, is_connect={}, is_request={})",
320                self.post_url,
321                e,
322                e.is_timeout(),
323                e.is_connect(),
324                e.is_request(),
325            );
326            devboy_core::Error::Http(format!("POST failed: {}", e))
327        })?;
328
329        // Extract session ID from response headers (set during initialize)
330        if method == "initialize"
331            && let Some(sid) = response.headers().get("mcp-session-id")
332            && let Ok(sid_str) = sid.to_str()
333        {
334            let mut session = self.session_id.write().await;
335            *session = Some(sid_str.to_string());
336            tracing::debug!("Proxy '{}': got session ID", self.name);
337        }
338
339        let status = response.status();
340        if !status.is_success() {
341            let body = response.text().await.unwrap_or_default();
342            return Err(devboy_core::Error::Http(format!(
343                "HTTP {}: {}",
344                status, body
345            )));
346        }
347
348        // Check Content-Type: server may respond with JSON or SSE stream
349        let content_type = response
350            .headers()
351            .get(CONTENT_TYPE)
352            .and_then(|v| v.to_str().ok())
353            .unwrap_or("")
354            .to_string();
355
356        let resp = if content_type.contains("text/event-stream") {
357            // Parse SSE stream to extract JSON-RPC response
358            tracing::debug!("Response is SSE stream, parsing events...");
359            self.parse_sse_response(response, id).await?
360        } else {
361            // Direct JSON response — read via bytes_stream to handle
362            // servers that keep the connection open after sending the body
363            // (e.g. streamable-http holding for notifications).
364            // On read error (broken pipe, timeout), parse whatever was read.
365            tracing::debug!("Response is JSON (content-type: {})", content_type);
366            self.read_json_response(response).await?
367        };
368
369        // Verify response ID matches request ID
370        let expected_id = RequestId::Number(id);
371        if resp.id != expected_id {
372            return Err(devboy_core::Error::Http(format!(
373                "Mismatched JSON-RPC id: expected {:?}, got {:?}",
374                expected_id, resp.id
375            )));
376        }
377
378        Ok(resp)
379    }
380
381    /// Read a JSON response body using streaming, gracefully handling
382    /// connection drops (broken pipe, timeout) mid-transfer.
383    ///
384    /// Streamable-HTTP servers may keep the connection open after sending the
385    /// JSON body (e.g. to push notifications). `response.json()` waits for the
386    /// connection to close, which causes "error decoding response body" when the
387    /// server or an intermediate proxy eventually drops it.
388    async fn read_json_response(
389        &self,
390        response: reqwest::Response,
391    ) -> devboy_core::Result<JsonRpcResponse> {
392        Self::parse_json_stream(response.bytes_stream()).await
393    }
394
395    /// Drain a chunked body stream and parse it as `JsonRpcResponse`.
396    ///
397    /// The loop attempts an incremental parse after each chunk and **returns
398    /// as soon as one complete JSON-RPC response has been deserialized**,
399    /// regardless of whether the upstream has signalled EOF. This is the
400    /// behaviour that actually fixes the streamable-HTTP hang from #244:
401    /// servers that keep the connection open after the response body for
402    /// notifications no longer block the caller until the proxy/CDN drops
403    /// the idle connection. Trailing bytes after the response object are
404    /// ignored.
405    ///
406    /// If the stream ends (clean EOF or error) before a complete response
407    /// is parsed, the loop falls back to a final parse on the accumulated
408    /// bytes. Stream-error context is preserved in the resulting message so
409    /// callers can distinguish a truncated body from a malformed one.
410    ///
411    /// Generic over chunk type and stream-error type so unit tests can drive
412    /// this path with `futures::stream::iter` without constructing a real
413    /// `reqwest::Response`. The error is stringified late, only when
414    /// composing the final failure message.
415    async fn parse_json_stream<S, B, E>(mut stream: S) -> devboy_core::Result<JsonRpcResponse>
416    where
417        S: futures::Stream<Item = std::result::Result<B, E>> + Unpin,
418        B: AsRef<[u8]>,
419        E: std::fmt::Display,
420    {
421        let mut body = Vec::new();
422        let mut stream_error: Option<String> = None;
423
424        while let Some(chunk_result) = stream.next().await {
425            match chunk_result {
426                Ok(chunk) => {
427                    body.extend_from_slice(chunk.as_ref());
428                    // Try parsing the first JSON value out of the accumulated
429                    // bytes. `Deserializer::from_slice` lets us deserialize a
430                    // single value and ignore trailing bytes — useful when
431                    // the upstream then sends notifications on the same
432                    // stream.
433                    let mut de = serde_json::Deserializer::from_slice(&body);
434                    match <JsonRpcResponse as serde::Deserialize>::deserialize(&mut de) {
435                        Ok(resp) => {
436                            tracing::debug!(
437                                "Parsed JSON-RPC response after {} bytes (stream still open)",
438                                body.len()
439                            );
440                            return Ok(resp);
441                        }
442                        Err(e) if e.is_eof() => {
443                            // Body so far is a valid JSON prefix but
444                            // incomplete — keep reading.
445                        }
446                        Err(_) => {
447                            // Not a parse-from-prefix error. It might be a
448                            // chunk boundary inside a string literal; let
449                            // the loop continue and the post-loop final
450                            // parse surface a clean error if the body is
451                            // genuinely malformed.
452                        }
453                    }
454                }
455                Err(e) => {
456                    let msg = e.to_string();
457                    tracing::debug!(
458                        "Stream ended with error ({} bytes read): {}",
459                        body.len(),
460                        msg
461                    );
462                    stream_error = Some(msg);
463                    break;
464                }
465            }
466        }
467
468        if body.is_empty() {
469            return Err(devboy_core::Error::Http(match stream_error {
470                Some(e) => format!("Empty response body from upstream (stream error: {e})"),
471                None => "Empty response body from upstream".to_string(),
472            }));
473        }
474
475        tracing::debug!("Final parse over {} accumulated bytes", body.len());
476
477        serde_json::from_slice::<JsonRpcResponse>(&body).map_err(|json_err| {
478            let preview = String::from_utf8_lossy(&body[..body.len().min(200)]);
479            let base = format!(
480                "Failed to parse JSON ({} bytes, starts with: {}): {}",
481                body.len(),
482                preview,
483                json_err
484            );
485            devboy_core::Error::Http(match stream_error {
486                Some(stream_err) => {
487                    format!("{base} (stream ended with error: {stream_err})")
488                }
489                None => base,
490            })
491        })
492    }
493
494    /// Parse an SSE event stream response to extract the JSON-RPC response.
495    ///
496    /// Streamable HTTP spec allows servers to respond with SSE for long-running
497    /// operations like tool calls. Reads the stream line-by-line instead of
498    /// buffering the entire body (which would hang on open SSE connections).
499    async fn parse_sse_response(
500        &self,
501        response: reqwest::Response,
502        expected_id: i64,
503    ) -> devboy_core::Result<JsonRpcResponse> {
504        use futures::TryStreamExt;
505        use tokio::io::AsyncBufReadExt;
506
507        let stream = response.bytes_stream().map_err(std::io::Error::other);
508        let reader = tokio_util::io::StreamReader::new(stream);
509        let mut lines = tokio::io::BufReader::new(reader).lines();
510
511        let mut current_data = String::new();
512
513        tracing::debug!("Starting SSE line reader...");
514
515        tokio::time::timeout(Duration::from_secs(60), async {
516            while let Ok(Some(line)) = lines.next_line().await {
517                let line = line.trim().to_string();
518                let debug_len = line
519                    .char_indices()
520                    .nth(100)
521                    .map(|(i, _)| i)
522                    .unwrap_or(line.len());
523                tracing::debug!("SSE line: {}", &line[..debug_len]);
524
525                if line.is_empty() {
526                    // End of SSE event — try to parse collected data
527                    if !current_data.is_empty()
528                        && let Ok(resp) = serde_json::from_str::<JsonRpcResponse>(&current_data)
529                    {
530                        let id_matches = match &resp.id {
531                            RequestId::Number(n) => *n == expected_id,
532                            _ => false,
533                        };
534                        if id_matches {
535                            return Ok(resp);
536                        }
537                        current_data.clear();
538                    } else if !current_data.is_empty() {
539                        current_data.clear();
540                    }
541                    continue;
542                }
543
544                if let Some(data) = line.strip_prefix("data:") {
545                    let data = data.trim();
546                    if !data.is_empty() {
547                        current_data.push_str(data);
548                    }
549                }
550                // Skip event:, id:, retry: lines
551            }
552
553            // Try last accumulated data
554            if !current_data.is_empty()
555                && let Ok(resp) = serde_json::from_str::<JsonRpcResponse>(&current_data)
556            {
557                return Ok(resp);
558            }
559
560            Err(devboy_core::Error::Http(
561                "No matching JSON-RPC response found in SSE stream".to_string(),
562            ))
563        })
564        .await
565        .map_err(|_| devboy_core::Error::Http("Timeout reading SSE response".to_string()))?
566    }
567
568    /// Send initialize handshake.
569    async fn initialize(&self) -> devboy_core::Result<()> {
570        let params = serde_json::json!({
571            "protocolVersion": "2025-11-25",
572            "capabilities": {},
573            "clientInfo": {
574                "name": "devboy-mcp-proxy",
575                "version": env!("CARGO_PKG_VERSION")
576            }
577        });
578
579        let resp = self.request("initialize", Some(params)).await?;
580        if let Some(err) = resp.error {
581            return Err(devboy_core::Error::Http(format!(
582                "Initialize failed: {}",
583                err.message
584            )));
585        }
586
587        tracing::info!("Proxy '{}' initialized", self.name);
588        Ok(())
589    }
590
591    /// Fetch tools/list from upstream. Call this before using `prefixed_tools()`.
592    pub async fn fetch_tools(&mut self) -> devboy_core::Result<()> {
593        let resp = self.request("tools/list", None).await?;
594
595        if let Some(result) = resp.result {
596            #[derive(serde::Deserialize)]
597            struct ToolsList {
598                tools: Vec<ToolDefinition>,
599            }
600
601            if let Ok(list) = serde_json::from_value::<ToolsList>(result) {
602                self.upstream_tools = list.tools;
603                tracing::info!(
604                    "Proxy '{}': fetched {} tools",
605                    self.name,
606                    self.upstream_tools.len()
607                );
608            }
609        }
610
611        Ok(())
612    }
613
614    /// Return upstream tools with prefixed names.
615    pub fn prefixed_tools(&self) -> Vec<ToolDefinition> {
616        self.upstream_tools
617            .iter()
618            .map(|t| ToolDefinition {
619                name: format!("{}__{}", self.tool_prefix, t.name),
620                description: format!("[{}] {}", self.name, t.description),
621                input_schema: t.input_schema.clone(),
622                category: None, // Proxy tools don't have a category (always available)
623            })
624            .collect()
625    }
626
627    /// Return the raw upstream tool catalogue (without prefixing).
628    /// Used by the signature matcher to compare upstream definitions against the local
629    /// tool registry. Empty until `fetch_tools()` has been called.
630    pub fn raw_upstream_tools(&self) -> &[ToolDefinition] {
631        &self.upstream_tools
632    }
633
634    /// Execute a tool call, stripping the prefix before forwarding.
635    pub async fn call_tool(
636        &self,
637        original_name: &str,
638        arguments: Option<Value>,
639    ) -> devboy_core::Result<ToolCallResult> {
640        let params = serde_json::json!({
641            "name": original_name,
642            "arguments": arguments.unwrap_or(Value::Object(Default::default()))
643        });
644
645        let resp = self.request("tools/call", Some(params)).await?;
646
647        if let Some(err) = resp.error {
648            return Ok(ToolCallResult::error(err.message));
649        }
650
651        match resp.result {
652            Some(result) => serde_json::from_value(result).map_err(|e| {
653                devboy_core::Error::InvalidData(format!("Invalid tool result: {}", e))
654            }),
655            None => Ok(ToolCallResult::error(
656                "Empty response from upstream".to_string(),
657            )),
658        }
659    }
660
661    /// Get the tool prefix for this client.
662    pub fn prefix(&self) -> &str {
663        &self.tool_prefix
664    }
665}
666
667/// Manages multiple upstream MCP proxy connections.
668pub struct ProxyManager {
669    clients: Vec<McpProxyClient>,
670}
671
672impl Default for ProxyManager {
673    fn default() -> Self {
674        Self::new()
675    }
676}
677
678impl ProxyManager {
679    pub fn new() -> Self {
680        Self {
681            clients: Vec::new(),
682        }
683    }
684
685    pub fn add_client(&mut self, client: McpProxyClient) {
686        self.clients.push(client);
687    }
688
689    pub fn is_empty(&self) -> bool {
690        self.clients.is_empty()
691    }
692
693    /// Fetch tool lists from all upstream servers.
694    pub async fn fetch_all_tools(&mut self) -> devboy_core::Result<()> {
695        for client in &mut self.clients {
696            client.fetch_tools().await?;
697        }
698        Ok(())
699    }
700
701    /// Get all proxied tools (with prefixes) from all upstreams.
702    pub fn all_tools(&self) -> Vec<ToolDefinition> {
703        self.clients
704            .iter()
705            .flat_map(|c| c.prefixed_tools())
706            .collect()
707    }
708
709    /// Check whether a tool name belongs to a proxied upstream.
710    pub fn has_tool(&self, tool_name: &str) -> bool {
711        self.clients
712            .iter()
713            .any(|c| tool_name.starts_with(&format!("{}__", c.prefix())))
714    }
715
716    /// Try to route a tool call to the matching upstream.
717    /// Returns None if no upstream matches the tool name prefix.
718    pub async fn try_call(
719        &self,
720        tool_name: &str,
721        arguments: Option<Value>,
722    ) -> Option<ToolCallResult> {
723        for client in &self.clients {
724            let prefix = format!("{}__", client.prefix());
725            if let Some(original_name) = tool_name.strip_prefix(&prefix) {
726                let result = client.call_tool(original_name, arguments).await;
727                return Some(match result {
728                    Ok(r) => r,
729                    Err(e) => ToolCallResult::error(format!("Proxy error: {}", e)),
730                });
731            }
732        }
733        None
734    }
735
736    /// Call a specific upstream by prefix using the unprefixed tool name.
737    /// Used by the routing engine when it has already decided the remote executor is the
738    /// right target for a matched tool (and therefore doesn't need to rely on the
739    /// prefixed alias).
740    pub async fn call_by_prefix(
741        &self,
742        prefix: &str,
743        unprefixed_tool_name: &str,
744        arguments: Option<Value>,
745    ) -> Option<ToolCallResult> {
746        for client in &self.clients {
747            if client.prefix() == prefix {
748                let result = client.call_tool(unprefixed_tool_name, arguments).await;
749                return Some(match result {
750                    Ok(r) => r,
751                    Err(e) => ToolCallResult::error(format!("Proxy error: {}", e)),
752                });
753            }
754        }
755        None
756    }
757
758    /// Return every upstream's raw (unprefixed) tool catalogue tagged by prefix.
759    /// Consumers use this to feed the signature matcher.
760    pub fn raw_upstream_catalogue(&self) -> Vec<(String, &[ToolDefinition])> {
761        self.clients
762            .iter()
763            .map(|c| (c.prefix().to_string(), c.raw_upstream_tools()))
764            .collect()
765    }
766}
767
768// =============================================================================
769// Tests
770// =============================================================================
771
772#[cfg(test)]
773#[allow(clippy::err_expect)]
774mod tests {
775    use super::*;
776    use crate::protocol::ToolResultContent;
777    use httpmock::prelude::*;
778
779    fn token_secret(s: &str) -> SecretString {
780        SecretString::from(s.to_string())
781    }
782
783    // =========================================================================
784    // ProxyTransport
785    // =========================================================================
786
787    #[test]
788    fn test_proxy_transport_parse() {
789        assert_eq!(
790            ProxyTransport::parse("streamable-http"),
791            ProxyTransport::StreamableHttp
792        );
793        assert_eq!(
794            ProxyTransport::parse("streamable_http"),
795            ProxyTransport::StreamableHttp
796        );
797        assert_eq!(
798            ProxyTransport::parse("http"),
799            ProxyTransport::StreamableHttp
800        );
801        assert_eq!(ProxyTransport::parse("sse"), ProxyTransport::Sse);
802        assert_eq!(ProxyTransport::parse(""), ProxyTransport::Sse);
803        assert_eq!(ProxyTransport::parse("unknown"), ProxyTransport::Sse);
804    }
805
806    #[test]
807    fn test_proxy_transport_debug_clone_eq() {
808        let t = ProxyTransport::Sse;
809        let t2 = t;
810        assert_eq!(t, t2);
811        assert_eq!(format!("{:?}", t), "Sse");
812        assert_eq!(
813            format!("{:?}", ProxyTransport::StreamableHttp),
814            "StreamableHttp"
815        );
816    }
817
818    // =========================================================================
819    // Helper: mock upstream that implements Streamable HTTP MCP protocol
820    // =========================================================================
821
822    /// Create a MockServer that responds to initialize and tools/list.
823    fn setup_mock_upstream(server: &MockServer, tools: Vec<serde_json::Value>) {
824        // Initialize endpoint — returns session ID in header
825        server.mock(|when, then| {
826            when.method(POST)
827                .path("/mcp")
828                .body_includes(r#""method":"initialize""#);
829            then.status(200)
830                .header("mcp-session-id", "test-session-123")
831                .json_body(serde_json::json!({
832                    "jsonrpc": "2.0",
833                    "id": 1,
834                    "result": {
835                        "protocolVersion": "2025-11-25",
836                        "capabilities": { "tools": {} },
837                        "serverInfo": { "name": "mock-server", "version": "1.0.0" }
838                    }
839                }));
840        });
841
842        // tools/list endpoint
843        server.mock(|when, then| {
844            when.method(POST)
845                .path("/mcp")
846                .body_includes(r#""method":"tools/list""#);
847            then.status(200).json_body(serde_json::json!({
848                "jsonrpc": "2.0",
849                "id": 2,
850                "result": { "tools": tools }
851            }));
852        });
853    }
854
855    fn sample_tools() -> Vec<serde_json::Value> {
856        vec![
857            serde_json::json!({
858                "name": "get_issues",
859                "description": "Get issues from tracker",
860                "inputSchema": { "type": "object", "properties": {} }
861            }),
862            serde_json::json!({
863                "name": "get_merge_requests",
864                "description": "Get merge requests",
865                "inputSchema": { "type": "object", "properties": {} }
866            }),
867        ]
868    }
869
870    // =========================================================================
871    // McpProxyClient — Streamable HTTP connect
872    // =========================================================================
873
874    #[tokio::test]
875    async fn test_connect_streamable_http() {
876        let server = MockServer::start();
877        setup_mock_upstream(&server, sample_tools());
878
879        let url = format!("{}/mcp", server.base_url());
880        let token = token_secret("my-token");
881        let client = McpProxyClient::connect(
882            "test-server",
883            &url,
884            None,
885            Some(&token),
886            "bearer",
887            ProxyTransport::StreamableHttp,
888        )
889        .await
890        .unwrap();
891
892        assert_eq!(client.prefix(), "test-server");
893        assert!(client.upstream_tools.is_empty()); // Not fetched yet
894    }
895
896    #[tokio::test]
897    async fn test_connect_with_custom_prefix() {
898        let server = MockServer::start();
899        setup_mock_upstream(&server, sample_tools());
900
901        let url = format!("{}/mcp", server.base_url());
902        let client = McpProxyClient::connect(
903            "test-server",
904            &url,
905            Some("custom"),
906            None,
907            "none",
908            ProxyTransport::StreamableHttp,
909        )
910        .await
911        .unwrap();
912
913        assert_eq!(client.prefix(), "custom");
914    }
915
916    #[tokio::test]
917    async fn test_connect_initialize_failure() {
918        let server = MockServer::start();
919
920        server.mock(|when, then| {
921            when.method(POST)
922                .path("/mcp")
923                .body_includes(r#""method":"initialize""#);
924            then.status(200).json_body(serde_json::json!({
925                "jsonrpc": "2.0",
926                "id": 1,
927                "error": { "code": -32600, "message": "Bad request" }
928            }));
929        });
930
931        let url = format!("{}/mcp", server.base_url());
932        let result = McpProxyClient::connect(
933            "test-server",
934            &url,
935            None,
936            None,
937            "none",
938            ProxyTransport::StreamableHttp,
939        )
940        .await;
941
942        let err = result.err().expect("should be error");
943        assert!(err.to_string().contains("Initialize failed"));
944    }
945
946    #[tokio::test]
947    async fn test_connect_http_error() {
948        let server = MockServer::start();
949
950        server.mock(|when, then| {
951            when.method(POST).path("/mcp");
952            then.status(500).body("Internal Server Error");
953        });
954
955        let url = format!("{}/mcp", server.base_url());
956        let result = McpProxyClient::connect(
957            "test-server",
958            &url,
959            None,
960            None,
961            "none",
962            ProxyTransport::StreamableHttp,
963        )
964        .await;
965
966        let err = result.err().expect("should be error");
967        assert!(err.to_string().contains("500"));
968    }
969
970    // =========================================================================
971    // McpProxyClient — fetch_tools
972    // =========================================================================
973
974    #[tokio::test]
975    async fn test_fetch_tools() {
976        let server = MockServer::start();
977        setup_mock_upstream(&server, sample_tools());
978
979        let url = format!("{}/mcp", server.base_url());
980        let mut client = McpProxyClient::connect(
981            "test-server",
982            &url,
983            None,
984            None,
985            "none",
986            ProxyTransport::StreamableHttp,
987        )
988        .await
989        .unwrap();
990
991        assert!(client.upstream_tools.is_empty());
992
993        client.fetch_tools().await.unwrap();
994
995        assert_eq!(client.upstream_tools.len(), 2);
996        assert_eq!(client.upstream_tools[0].name, "get_issues");
997        assert_eq!(client.upstream_tools[1].name, "get_merge_requests");
998    }
999
1000    // =========================================================================
1001    // McpProxyClient — prefixed_tools
1002    // =========================================================================
1003
1004    #[tokio::test]
1005    async fn test_prefixed_tools() {
1006        let server = MockServer::start();
1007        setup_mock_upstream(&server, sample_tools());
1008
1009        let url = format!("{}/mcp", server.base_url());
1010        let mut client = McpProxyClient::connect(
1011            "my-server",
1012            &url,
1013            Some("cloud"),
1014            None,
1015            "none",
1016            ProxyTransport::StreamableHttp,
1017        )
1018        .await
1019        .unwrap();
1020
1021        client.fetch_tools().await.unwrap();
1022
1023        let prefixed = client.prefixed_tools();
1024        assert_eq!(prefixed.len(), 2);
1025        assert_eq!(prefixed[0].name, "cloud__get_issues");
1026        assert_eq!(prefixed[1].name, "cloud__get_merge_requests");
1027        assert!(prefixed[0].description.starts_with("[my-server]"));
1028    }
1029
1030    #[tokio::test]
1031    async fn test_prefixed_tools_empty_when_not_fetched() {
1032        let server = MockServer::start();
1033        setup_mock_upstream(&server, sample_tools());
1034
1035        let url = format!("{}/mcp", server.base_url());
1036        let client = McpProxyClient::connect(
1037            "test-server",
1038            &url,
1039            None,
1040            None,
1041            "none",
1042            ProxyTransport::StreamableHttp,
1043        )
1044        .await
1045        .unwrap();
1046
1047        let prefixed = client.prefixed_tools();
1048        assert!(prefixed.is_empty());
1049    }
1050
1051    // =========================================================================
1052    // McpProxyClient — call_tool
1053    // =========================================================================
1054
1055    #[tokio::test]
1056    async fn test_call_tool_success() {
1057        let server = MockServer::start();
1058        setup_mock_upstream(&server, sample_tools());
1059
1060        // tools/call endpoint
1061        server.mock(|when, then| {
1062            when.method(POST)
1063                .path("/mcp")
1064                .body_includes(r#""method":"tools/call""#);
1065            then.status(200).json_body(serde_json::json!({
1066                "jsonrpc": "2.0",
1067                "id": 2,
1068                "result": {
1069                    "content": [{ "type": "text", "text": "issue data here" }]
1070                }
1071            }));
1072        });
1073
1074        let url = format!("{}/mcp", server.base_url());
1075        let client = McpProxyClient::connect(
1076            "test-server",
1077            &url,
1078            None,
1079            None,
1080            "none",
1081            ProxyTransport::StreamableHttp,
1082        )
1083        .await
1084        .unwrap();
1085
1086        let result = client
1087            .call_tool("get_issues", Some(serde_json::json!({"state": "open"})))
1088            .await
1089            .unwrap();
1090
1091        assert!(result.is_error.is_none());
1092        assert_eq!(result.content.len(), 1);
1093        match &result.content[0] {
1094            ToolResultContent::Text { text } => assert_eq!(text, "issue data here"),
1095        }
1096    }
1097
1098    #[tokio::test]
1099    async fn test_call_tool_with_upstream_error() {
1100        let server = MockServer::start();
1101        setup_mock_upstream(&server, sample_tools());
1102
1103        server.mock(|when, then| {
1104            when.method(POST)
1105                .path("/mcp")
1106                .body_includes(r#""method":"tools/call""#);
1107            then.status(200).json_body(serde_json::json!({
1108                "jsonrpc": "2.0",
1109                "id": 2,
1110                "error": { "code": -32000, "message": "Tool execution failed" }
1111            }));
1112        });
1113
1114        let url = format!("{}/mcp", server.base_url());
1115        let client = McpProxyClient::connect(
1116            "test-server",
1117            &url,
1118            None,
1119            None,
1120            "none",
1121            ProxyTransport::StreamableHttp,
1122        )
1123        .await
1124        .unwrap();
1125
1126        let result = client.call_tool("get_issues", None).await.unwrap();
1127
1128        assert_eq!(result.is_error, Some(true));
1129        match &result.content[0] {
1130            ToolResultContent::Text { text } => assert!(text.contains("Tool execution failed")),
1131        }
1132    }
1133
1134    #[tokio::test]
1135    async fn test_call_tool_empty_response() {
1136        let server = MockServer::start();
1137        setup_mock_upstream(&server, sample_tools());
1138
1139        server.mock(|when, then| {
1140            when.method(POST)
1141                .path("/mcp")
1142                .body_includes(r#""method":"tools/call""#);
1143            then.status(200).json_body(serde_json::json!({
1144                "jsonrpc": "2.0",
1145                "id": 2
1146            }));
1147        });
1148
1149        let url = format!("{}/mcp", server.base_url());
1150        let client = McpProxyClient::connect(
1151            "test-server",
1152            &url,
1153            None,
1154            None,
1155            "none",
1156            ProxyTransport::StreamableHttp,
1157        )
1158        .await
1159        .unwrap();
1160
1161        let result = client.call_tool("get_issues", None).await.unwrap();
1162
1163        assert_eq!(result.is_error, Some(true));
1164        match &result.content[0] {
1165            ToolResultContent::Text { text } => assert!(text.contains("Empty response")),
1166        }
1167    }
1168
1169    // =========================================================================
1170    // McpProxyClient — session ID management
1171    // =========================================================================
1172
1173    #[tokio::test]
1174    async fn test_session_id_sent_on_subsequent_requests() {
1175        let server = MockServer::start();
1176
1177        // Initialize — returns session ID
1178        server.mock(|when, then| {
1179            when.method(POST)
1180                .path("/mcp")
1181                .body_includes(r#""method":"initialize""#);
1182            then.status(200)
1183                .header("mcp-session-id", "sess-abc")
1184                .json_body(serde_json::json!({
1185                    "jsonrpc": "2.0",
1186                    "id": 1,
1187                    "result": {
1188                        "protocolVersion": "2025-11-25",
1189                        "capabilities": {},
1190                        "serverInfo": { "name": "mock", "version": "1.0" }
1191                    }
1192                }));
1193        });
1194
1195        // tools/call — expect session ID header
1196        let tool_call_mock = server.mock(|when, then| {
1197            when.method(POST)
1198                .path("/mcp")
1199                .header("mcp-session-id", "sess-abc")
1200                .body_includes(r#""method":"tools/call""#);
1201            then.status(200).json_body(serde_json::json!({
1202                "jsonrpc": "2.0",
1203                "id": 2,
1204                "result": {
1205                    "content": [{ "type": "text", "text": "ok" }]
1206                }
1207            }));
1208        });
1209
1210        let url = format!("{}/mcp", server.base_url());
1211        let client = McpProxyClient::connect(
1212            "test-server",
1213            &url,
1214            None,
1215            None,
1216            "none",
1217            ProxyTransport::StreamableHttp,
1218        )
1219        .await
1220        .unwrap();
1221
1222        client.call_tool("test_tool", None).await.unwrap();
1223
1224        // Verify the session header was actually sent
1225        tool_call_mock.assert();
1226    }
1227
1228    // =========================================================================
1229    // McpProxyClient — auth types
1230    // =========================================================================
1231
1232    #[tokio::test]
1233    async fn test_bearer_auth_header() {
1234        let server = MockServer::start();
1235
1236        let init_mock = server.mock(|when, then| {
1237            when.method(POST)
1238                .path("/mcp")
1239                .header("Authorization", "Bearer secret-token")
1240                .body_includes(r#""method":"initialize""#);
1241            then.status(200).json_body(serde_json::json!({
1242                "jsonrpc": "2.0",
1243                "id": 1,
1244                "result": {
1245                    "protocolVersion": "2025-11-25",
1246                    "capabilities": {},
1247                    "serverInfo": { "name": "mock", "version": "1.0" }
1248                }
1249            }));
1250        });
1251
1252        let url = format!("{}/mcp", server.base_url());
1253        let token = token_secret("secret-token");
1254        McpProxyClient::connect(
1255            "test-server",
1256            &url,
1257            None,
1258            Some(&token),
1259            "bearer",
1260            ProxyTransport::StreamableHttp,
1261        )
1262        .await
1263        .unwrap();
1264
1265        init_mock.assert();
1266    }
1267
1268    #[tokio::test]
1269    async fn test_api_key_auth_header() {
1270        let server = MockServer::start();
1271
1272        let init_mock = server.mock(|when, then| {
1273            when.method(POST)
1274                .path("/mcp")
1275                .header("X-API-Key", "my-api-key")
1276                .body_includes(r#""method":"initialize""#);
1277            then.status(200).json_body(serde_json::json!({
1278                "jsonrpc": "2.0",
1279                "id": 1,
1280                "result": {
1281                    "protocolVersion": "2025-11-25",
1282                    "capabilities": {},
1283                    "serverInfo": { "name": "mock", "version": "1.0" }
1284                }
1285            }));
1286        });
1287
1288        let url = format!("{}/mcp", server.base_url());
1289        let token = token_secret("my-api-key");
1290        McpProxyClient::connect(
1291            "test-server",
1292            &url,
1293            None,
1294            Some(&token),
1295            "api_key",
1296            ProxyTransport::StreamableHttp,
1297        )
1298        .await
1299        .unwrap();
1300
1301        init_mock.assert();
1302    }
1303
1304    // =========================================================================
1305    // ProxyManager
1306    // =========================================================================
1307
1308    #[test]
1309    fn test_proxy_manager_new_is_empty() {
1310        let mgr = ProxyManager::new();
1311        assert!(mgr.is_empty());
1312        assert!(mgr.all_tools().is_empty());
1313    }
1314
1315    #[tokio::test]
1316    async fn test_proxy_manager_all_tools() {
1317        let server = MockServer::start();
1318        setup_mock_upstream(&server, sample_tools());
1319
1320        let url = format!("{}/mcp", server.base_url());
1321        let mut client = McpProxyClient::connect(
1322            "upstream",
1323            &url,
1324            Some("up"),
1325            None,
1326            "none",
1327            ProxyTransport::StreamableHttp,
1328        )
1329        .await
1330        .unwrap();
1331
1332        client.fetch_tools().await.unwrap();
1333
1334        let mut mgr = ProxyManager::new();
1335        mgr.add_client(client);
1336
1337        assert!(!mgr.is_empty());
1338
1339        let tools = mgr.all_tools();
1340        assert_eq!(tools.len(), 2);
1341        assert_eq!(tools[0].name, "up__get_issues");
1342        assert_eq!(tools[1].name, "up__get_merge_requests");
1343    }
1344
1345    #[tokio::test]
1346    async fn test_proxy_manager_try_call_routes_correctly() {
1347        let server = MockServer::start();
1348        setup_mock_upstream(&server, sample_tools());
1349
1350        server.mock(|when, then| {
1351            when.method(POST)
1352                .path("/mcp")
1353                .body_includes(r#""method":"tools/call""#);
1354            then.status(200).json_body(serde_json::json!({
1355                "jsonrpc": "2.0",
1356                "id": 2,
1357                "result": {
1358                    "content": [{ "type": "text", "text": "routed ok" }]
1359                }
1360            }));
1361        });
1362
1363        let url = format!("{}/mcp", server.base_url());
1364        let client = McpProxyClient::connect(
1365            "upstream",
1366            &url,
1367            Some("up"),
1368            None,
1369            "none",
1370            ProxyTransport::StreamableHttp,
1371        )
1372        .await
1373        .unwrap();
1374
1375        let mut mgr = ProxyManager::new();
1376        mgr.add_client(client);
1377
1378        let result = mgr
1379            .try_call("up__get_issues", Some(serde_json::json!({})))
1380            .await;
1381
1382        assert!(result.is_some());
1383        let result = result.unwrap();
1384        assert!(result.is_error.is_none());
1385        match &result.content[0] {
1386            ToolResultContent::Text { text } => assert_eq!(text, "routed ok"),
1387        }
1388    }
1389
1390    #[tokio::test]
1391    async fn test_proxy_manager_try_call_no_match() {
1392        let server = MockServer::start();
1393        setup_mock_upstream(&server, sample_tools());
1394
1395        let url = format!("{}/mcp", server.base_url());
1396        let client = McpProxyClient::connect(
1397            "upstream",
1398            &url,
1399            Some("up"),
1400            None,
1401            "none",
1402            ProxyTransport::StreamableHttp,
1403        )
1404        .await
1405        .unwrap();
1406
1407        let mut mgr = ProxyManager::new();
1408        mgr.add_client(client);
1409
1410        let result = mgr
1411            .try_call("unknown__get_issues", Some(serde_json::json!({})))
1412            .await;
1413
1414        assert!(result.is_none());
1415    }
1416
1417    #[tokio::test]
1418    async fn test_proxy_manager_try_call_without_prefix_no_match() {
1419        let mgr = ProxyManager::new();
1420        let result = mgr.try_call("get_issues", None).await;
1421        assert!(result.is_none());
1422    }
1423
1424    #[tokio::test]
1425    async fn test_proxy_manager_fetch_all_tools() {
1426        let server = MockServer::start();
1427        setup_mock_upstream(&server, sample_tools());
1428
1429        let url = format!("{}/mcp", server.base_url());
1430        let client = McpProxyClient::connect(
1431            "upstream",
1432            &url,
1433            Some("up"),
1434            None,
1435            "none",
1436            ProxyTransport::StreamableHttp,
1437        )
1438        .await
1439        .unwrap();
1440
1441        let mut mgr = ProxyManager::new();
1442        mgr.add_client(client);
1443
1444        assert!(mgr.all_tools().is_empty());
1445
1446        mgr.fetch_all_tools().await.unwrap();
1447
1448        assert_eq!(mgr.all_tools().len(), 2);
1449    }
1450
1451    // =========================================================================
1452    // McpProxyClient — invalid token (non-ASCII)
1453    // =========================================================================
1454
1455    #[tokio::test]
1456    async fn test_connect_invalid_bearer_token() {
1457        let token = token_secret("token-with-\x01-control-chars");
1458        let result = McpProxyClient::connect(
1459            "test-server",
1460            "http://localhost:1/mcp",
1461            None,
1462            Some(&token),
1463            "bearer",
1464            ProxyTransport::StreamableHttp,
1465        )
1466        .await;
1467
1468        let err = result.err().expect("should be error");
1469        assert!(err.to_string().contains("Invalid token"));
1470    }
1471
1472    #[tokio::test]
1473    async fn test_connect_invalid_api_key_token() {
1474        let token = token_secret("key-with-\x01-control");
1475        let result = McpProxyClient::connect(
1476            "test-server",
1477            "http://localhost:1/mcp",
1478            None,
1479            Some(&token),
1480            "api_key",
1481            ProxyTransport::StreamableHttp,
1482        )
1483        .await;
1484
1485        let err = result.err().expect("should be error");
1486        assert!(err.to_string().contains("Invalid token"));
1487    }
1488
1489    // =========================================================================
1490    // McpProxyClient — SSE transport via mock
1491    // =========================================================================
1492
1493    /// Helper: set up SSE mock endpoint that returns endpoint event + initialize response.
1494    fn setup_sse_mock(server: &MockServer) {
1495        // SSE stream: returns endpoint event, then initialize response
1496        server.mock(|when, then| {
1497            when.method(GET).path("/sse");
1498            then.status(200)
1499                .header("content-type", "text/event-stream")
1500                .header("cache-control", "no-cache")
1501                .body(
1502                    "event: endpoint\ndata: /messages\n\n\
1503                     event: message\ndata: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"protocolVersion\":\"2025-11-25\",\"capabilities\":{},\"serverInfo\":{\"name\":\"mock-sse\",\"version\":\"1.0\"}}}\n\n"
1504                );
1505        });
1506
1507        // POST endpoint for requests (initialize is sent here)
1508        server.mock(|when, then| {
1509            when.method(POST).path("/messages");
1510            then.status(200);
1511        });
1512    }
1513
1514    #[tokio::test]
1515    async fn test_connect_sse_transport() {
1516        let server = MockServer::start();
1517        setup_sse_mock(&server);
1518
1519        let url = format!("{}/sse", server.base_url());
1520        let result = McpProxyClient::connect(
1521            "sse-server",
1522            &url,
1523            Some("sse"),
1524            None,
1525            "none",
1526            ProxyTransport::Sse,
1527        )
1528        .await;
1529
1530        assert!(result.is_ok(), "SSE connect failed: {:?}", result.err());
1531        let client = result.unwrap();
1532        assert_eq!(client.prefix(), "sse");
1533        assert_eq!(client.transport, ProxyTransport::Sse);
1534    }
1535
1536    #[tokio::test]
1537    async fn test_connect_sse_with_bearer_auth() {
1538        let server = MockServer::start();
1539
1540        // SSE stream with auth check
1541        server.mock(|when, then| {
1542            when.method(GET)
1543                .path("/sse")
1544                .header("Authorization", "Bearer sse-token");
1545            then.status(200)
1546                .header("content-type", "text/event-stream")
1547                .header("cache-control", "no-cache")
1548                .body(
1549                    "event: endpoint\ndata: /messages\n\n\
1550                     event: message\ndata: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"protocolVersion\":\"2025-11-25\",\"capabilities\":{},\"serverInfo\":{\"name\":\"mock\",\"version\":\"1.0\"}}}\n\n"
1551                );
1552        });
1553
1554        server.mock(|when, then| {
1555            when.method(POST).path("/messages");
1556            then.status(200);
1557        });
1558
1559        let url = format!("{}/sse", server.base_url());
1560        let token = token_secret("sse-token");
1561        let result = McpProxyClient::connect(
1562            "sse-server",
1563            &url,
1564            None,
1565            Some(&token),
1566            "bearer",
1567            ProxyTransport::Sse,
1568        )
1569        .await;
1570
1571        assert!(
1572            result.is_ok(),
1573            "SSE connect with auth failed: {:?}",
1574            result.err()
1575        );
1576    }
1577
1578    #[tokio::test]
1579    async fn test_sse_request_dispatch_path() {
1580        // Verify that an SSE-transport client dispatches via request_sse.
1581        // We test that the request method correctly routes to SSE path
1582        // by checking the client transport type after connect.
1583        let server = MockServer::start();
1584        setup_sse_mock(&server);
1585
1586        let url = format!("{}/sse", server.base_url());
1587        let client = McpProxyClient::connect(
1588            "sse-server",
1589            &url,
1590            Some("sse"),
1591            None,
1592            "none",
1593            ProxyTransport::Sse,
1594        )
1595        .await
1596        .unwrap();
1597
1598        // Verify the client is configured for SSE transport
1599        assert_eq!(client.transport, ProxyTransport::Sse);
1600        // The post_url should be resolved from the endpoint event
1601        assert!(client.post_url.contains("/messages"));
1602    }
1603
1604    // =========================================================================
1605    // McpProxyClient — fetch_tools with error response
1606    // =========================================================================
1607
1608    #[tokio::test]
1609    async fn test_fetch_tools_with_error_response() {
1610        let server = MockServer::start();
1611
1612        // Initialize endpoint
1613        server.mock(|when, then| {
1614            when.method(POST)
1615                .path("/mcp")
1616                .body_includes(r#""method":"initialize""#);
1617            then.status(200).json_body(serde_json::json!({
1618                "jsonrpc": "2.0",
1619                "id": 1,
1620                "result": {
1621                    "protocolVersion": "2025-11-25",
1622                    "capabilities": {},
1623                    "serverInfo": { "name": "mock", "version": "1.0" }
1624                }
1625            }));
1626        });
1627
1628        // tools/list returns error
1629        server.mock(|when, then| {
1630            when.method(POST)
1631                .path("/mcp")
1632                .body_includes(r#""method":"tools/list""#);
1633            then.status(200).json_body(serde_json::json!({
1634                "jsonrpc": "2.0",
1635                "id": 2,
1636                "error": { "code": -32601, "message": "Method not found" }
1637            }));
1638        });
1639
1640        let url = format!("{}/mcp", server.base_url());
1641        let mut client = McpProxyClient::connect(
1642            "test-server",
1643            &url,
1644            None,
1645            None,
1646            "none",
1647            ProxyTransport::StreamableHttp,
1648        )
1649        .await
1650        .unwrap();
1651
1652        // fetch_tools should succeed but not populate tools (error response has no result)
1653        client.fetch_tools().await.unwrap();
1654        assert!(client.upstream_tools.is_empty());
1655    }
1656
1657    #[tokio::test]
1658    async fn test_fetch_tools_with_empty_result() {
1659        let server = MockServer::start();
1660
1661        server.mock(|when, then| {
1662            when.method(POST)
1663                .path("/mcp")
1664                .body_includes(r#""method":"initialize""#);
1665            then.status(200).json_body(serde_json::json!({
1666                "jsonrpc": "2.0",
1667                "id": 1,
1668                "result": {
1669                    "protocolVersion": "2025-11-25",
1670                    "capabilities": {},
1671                    "serverInfo": { "name": "mock", "version": "1.0" }
1672                }
1673            }));
1674        });
1675
1676        // tools/list returns result without tools field
1677        server.mock(|when, then| {
1678            when.method(POST)
1679                .path("/mcp")
1680                .body_includes(r#""method":"tools/list""#);
1681            then.status(200).json_body(serde_json::json!({
1682                "jsonrpc": "2.0",
1683                "id": 2,
1684                "result": { "something_else": true }
1685            }));
1686        });
1687
1688        let url = format!("{}/mcp", server.base_url());
1689        let mut client = McpProxyClient::connect(
1690            "test-server",
1691            &url,
1692            None,
1693            None,
1694            "none",
1695            ProxyTransport::StreamableHttp,
1696        )
1697        .await
1698        .unwrap();
1699
1700        // fetch_tools should succeed, but tools remain empty (deserialization fails silently)
1701        client.fetch_tools().await.unwrap();
1702        assert!(client.upstream_tools.is_empty());
1703    }
1704
1705    // =========================================================================
1706    // McpProxyClient — call_tool with None arguments
1707    // =========================================================================
1708
1709    #[tokio::test]
1710    async fn test_call_tool_with_none_arguments_uses_empty_object() {
1711        let server = MockServer::start();
1712        setup_mock_upstream(&server, sample_tools());
1713
1714        // Verify that arguments defaults to {}
1715        let tool_mock = server.mock(|when, then| {
1716            when.method(POST)
1717                .path("/mcp")
1718                .body_includes(r#""arguments":{}"#)
1719                .body_includes(r#""method":"tools/call""#);
1720            then.status(200).json_body(serde_json::json!({
1721                "jsonrpc": "2.0",
1722                "id": 2,
1723                "result": {
1724                    "content": [{ "type": "text", "text": "no args ok" }]
1725                }
1726            }));
1727        });
1728
1729        let url = format!("{}/mcp", server.base_url());
1730        let client = McpProxyClient::connect(
1731            "test-server",
1732            &url,
1733            None,
1734            None,
1735            "none",
1736            ProxyTransport::StreamableHttp,
1737        )
1738        .await
1739        .unwrap();
1740
1741        let result = client.call_tool("get_issues", None).await.unwrap();
1742        assert!(result.is_error.is_none());
1743        tool_mock.assert();
1744    }
1745
1746    // =========================================================================
1747    // ProxyManager — try_call transport error
1748    // =========================================================================
1749
1750    #[tokio::test]
1751    async fn test_proxy_manager_try_call_transport_error() {
1752        let server = MockServer::start();
1753        setup_mock_upstream(&server, sample_tools());
1754
1755        let url = format!("{}/mcp", server.base_url());
1756        let client = McpProxyClient::connect(
1757            "upstream",
1758            &url,
1759            Some("up"),
1760            None,
1761            "none",
1762            ProxyTransport::StreamableHttp,
1763        )
1764        .await
1765        .unwrap();
1766
1767        let mut mgr = ProxyManager::new();
1768        mgr.add_client(client);
1769
1770        // Drop the mock server so the next call fails with a transport error
1771        drop(server);
1772
1773        let result = mgr
1774            .try_call("up__get_issues", Some(serde_json::json!({})))
1775            .await;
1776
1777        assert!(result.is_some());
1778        let result = result.unwrap();
1779        assert_eq!(result.is_error, Some(true));
1780        match &result.content[0] {
1781            ToolResultContent::Text { text } => assert!(text.contains("Proxy error")),
1782        }
1783    }
1784
1785    // =========================================================================
1786    // ProxyManager — default trait
1787    // =========================================================================
1788
1789    #[test]
1790    fn test_proxy_manager_default() {
1791        let mgr = ProxyManager::default();
1792        assert!(mgr.is_empty());
1793        assert!(mgr.all_tools().is_empty());
1794    }
1795
1796    // =========================================================================
1797    // ProxyManager — multiple clients routing
1798    // =========================================================================
1799
1800    #[tokio::test]
1801    async fn test_proxy_manager_multiple_clients() {
1802        let server1 = MockServer::start();
1803        let server2 = MockServer::start();
1804
1805        setup_mock_upstream(
1806            &server1,
1807            vec![serde_json::json!({
1808                "name": "tool_a",
1809                "description": "Tool A",
1810                "inputSchema": { "type": "object" }
1811            })],
1812        );
1813        setup_mock_upstream(
1814            &server2,
1815            vec![serde_json::json!({
1816                "name": "tool_b",
1817                "description": "Tool B",
1818                "inputSchema": { "type": "object" }
1819            })],
1820        );
1821
1822        let url1 = format!("{}/mcp", server1.base_url());
1823        let url2 = format!("{}/mcp", server2.base_url());
1824
1825        let client1 = McpProxyClient::connect(
1826            "server1",
1827            &url1,
1828            Some("s1"),
1829            None,
1830            "none",
1831            ProxyTransport::StreamableHttp,
1832        )
1833        .await
1834        .unwrap();
1835
1836        let client2 = McpProxyClient::connect(
1837            "server2",
1838            &url2,
1839            Some("s2"),
1840            None,
1841            "none",
1842            ProxyTransport::StreamableHttp,
1843        )
1844        .await
1845        .unwrap();
1846
1847        let mut mgr = ProxyManager::new();
1848        mgr.add_client(client1);
1849        mgr.add_client(client2);
1850
1851        mgr.fetch_all_tools().await.unwrap();
1852
1853        let tools = mgr.all_tools();
1854        assert_eq!(tools.len(), 2);
1855        assert!(tools.iter().any(|t| t.name == "s1__tool_a"));
1856        assert!(tools.iter().any(|t| t.name == "s2__tool_b"));
1857    }
1858
1859    // =========================================================================
1860    // Response ID validation
1861    // =========================================================================
1862
1863    #[tokio::test]
1864    async fn test_mismatched_response_id_returns_error() {
1865        let server = MockServer::start();
1866
1867        // Initialize returns correct id
1868        server.mock(|when, then| {
1869            when.method(POST)
1870                .path("/mcp")
1871                .body_includes(r#""method":"initialize""#);
1872            then.status(200)
1873                .header("mcp-session-id", "sess-1")
1874                .json_body(serde_json::json!({
1875                    "jsonrpc": "2.0",
1876                    "id": 1,
1877                    "result": {
1878                        "protocolVersion": "2025-11-25",
1879                        "capabilities": {},
1880                        "serverInfo": { "name": "mock", "version": "1.0" }
1881                    }
1882                }));
1883        });
1884
1885        // tools/call returns mismatched id
1886        server.mock(|when, then| {
1887            when.method(POST)
1888                .path("/mcp")
1889                .body_includes(r#""method":"tools/call""#);
1890            then.status(200).json_body(serde_json::json!({
1891                "jsonrpc": "2.0",
1892                "id": 999,
1893                "result": {
1894                    "content": [{ "type": "text", "text": "wrong id" }]
1895                }
1896            }));
1897        });
1898
1899        let url = format!("{}/mcp", server.base_url());
1900        let client = McpProxyClient::connect(
1901            "test-server",
1902            &url,
1903            None,
1904            None,
1905            "none",
1906            ProxyTransport::StreamableHttp,
1907        )
1908        .await
1909        .unwrap();
1910
1911        let result = client.call_tool("some_tool", None).await;
1912        let err = result.expect_err("should be error");
1913        assert!(err.to_string().contains("Mismatched JSON-RPC id"));
1914    }
1915
1916    // =========================================================================
1917    // McpProxyClient — read_json_response edge cases
1918    // =========================================================================
1919
1920    #[tokio::test]
1921    async fn test_tools_list_with_empty_body_returns_error() {
1922        let server = MockServer::start();
1923
1924        server.mock(|when, then| {
1925            when.method(POST)
1926                .path("/mcp")
1927                .body_includes(r#""method":"initialize""#);
1928            then.status(200)
1929                .header("mcp-session-id", "sess-empty")
1930                .json_body(serde_json::json!({
1931                    "jsonrpc": "2.0",
1932                    "id": 1,
1933                    "result": {
1934                        "protocolVersion": "2025-11-25",
1935                        "capabilities": {},
1936                        "serverInfo": { "name": "mock", "version": "1.0" }
1937                    }
1938                }));
1939        });
1940
1941        // tools/list returns 200 with empty body
1942        server.mock(|when, then| {
1943            when.method(POST)
1944                .path("/mcp")
1945                .body_includes(r#""method":"tools/list""#);
1946            then.status(200).body("");
1947        });
1948
1949        let url = format!("{}/mcp", server.base_url());
1950        let mut client = McpProxyClient::connect(
1951            "test-server",
1952            &url,
1953            None,
1954            None,
1955            "none",
1956            ProxyTransport::StreamableHttp,
1957        )
1958        .await
1959        .unwrap();
1960
1961        let result = client.fetch_tools().await;
1962        let err = result.expect_err("empty body should fail");
1963        assert!(
1964            err.to_string().contains("Empty response body"),
1965            "expected empty body error, got: {err}"
1966        );
1967    }
1968
1969    #[tokio::test]
1970    async fn test_tools_list_with_invalid_json_returns_parse_error() {
1971        let server = MockServer::start();
1972
1973        server.mock(|when, then| {
1974            when.method(POST)
1975                .path("/mcp")
1976                .body_includes(r#""method":"initialize""#);
1977            then.status(200)
1978                .header("mcp-session-id", "sess-badjson")
1979                .json_body(serde_json::json!({
1980                    "jsonrpc": "2.0",
1981                    "id": 1,
1982                    "result": {
1983                        "protocolVersion": "2025-11-25",
1984                        "capabilities": {},
1985                        "serverInfo": { "name": "mock", "version": "1.0" }
1986                    }
1987                }));
1988        });
1989
1990        // tools/list returns 200 with invalid JSON
1991        server.mock(|when, then| {
1992            when.method(POST)
1993                .path("/mcp")
1994                .body_includes(r#""method":"tools/list""#);
1995            then.status(200)
1996                .header("content-type", "application/json")
1997                .body("this is not json");
1998        });
1999
2000        let url = format!("{}/mcp", server.base_url());
2001        let mut client = McpProxyClient::connect(
2002            "test-server",
2003            &url,
2004            None,
2005            None,
2006            "none",
2007            ProxyTransport::StreamableHttp,
2008        )
2009        .await
2010        .unwrap();
2011
2012        let result = client.fetch_tools().await;
2013        let err = result.expect_err("invalid JSON should fail");
2014        assert!(
2015            err.to_string().contains("Failed to parse JSON"),
2016            "expected parse error, got: {err}"
2017        );
2018        assert!(
2019            err.to_string().contains("this is not json"),
2020            "error should include body preview"
2021        );
2022    }
2023
2024    #[tokio::test]
2025    async fn test_tools_list_with_large_valid_response() {
2026        let server = MockServer::start();
2027
2028        server.mock(|when, then| {
2029            when.method(POST)
2030                .path("/mcp")
2031                .body_includes(r#""method":"initialize""#);
2032            then.status(200)
2033                .header("mcp-session-id", "sess-large")
2034                .json_body(serde_json::json!({
2035                    "jsonrpc": "2.0",
2036                    "id": 1,
2037                    "result": {
2038                        "protocolVersion": "2025-11-25",
2039                        "capabilities": {},
2040                        "serverInfo": { "name": "mock", "version": "1.0" }
2041                    }
2042                }));
2043        });
2044
2045        // Build a tools/list response with 50 tools to exercise streaming
2046        let tools: Vec<serde_json::Value> = (0..50)
2047            .map(|i| {
2048                serde_json::json!({
2049                    "name": format!("tool_{i}"),
2050                    "description": format!("Tool number {i} with a longer description to make the response body larger"),
2051                    "inputSchema": { "type": "object", "properties": {} }
2052                })
2053            })
2054            .collect();
2055
2056        server.mock(|when, then| {
2057            when.method(POST)
2058                .path("/mcp")
2059                .body_includes(r#""method":"tools/list""#);
2060            then.status(200).json_body(serde_json::json!({
2061                "jsonrpc": "2.0",
2062                "id": 2,
2063                "result": { "tools": tools }
2064            }));
2065        });
2066
2067        let url = format!("{}/mcp", server.base_url());
2068        let mut client = McpProxyClient::connect(
2069            "test-server",
2070            &url,
2071            None,
2072            None,
2073            "none",
2074            ProxyTransport::StreamableHttp,
2075        )
2076        .await
2077        .unwrap();
2078
2079        client.fetch_tools().await.unwrap();
2080        assert_eq!(client.upstream_tools.len(), 50);
2081    }
2082
2083    // =========================================================================
2084    // McpProxyClient::parse_json_stream — direct unit coverage
2085    //
2086    // The httpmock-based tests above can't drive the parse-on-stream-error
2087    // branch: httpmock always closes connections cleanly so `bytes_stream()`
2088    // never yields `Err(_)`. These tests feed a synthetic stream straight into
2089    // `parse_json_stream` to cover the production scenario (body delivered,
2090    // then stream errors before clean EOF) and the truncated-body case where
2091    // the original stream error must be preserved in the parse failure
2092    // message.
2093    // =========================================================================
2094
2095    #[tokio::test]
2096    async fn parse_json_stream_succeeds_when_stream_errors_after_complete_body() {
2097        use futures::stream;
2098
2099        let body: Vec<u8> = serde_json::to_vec(&serde_json::json!({
2100            "jsonrpc": "2.0",
2101            "id": 7,
2102            "result": { "tools": [] }
2103        }))
2104        .unwrap();
2105
2106        let chunks: Vec<std::result::Result<Vec<u8>, String>> = vec![
2107            Ok(body),
2108            Err("simulated broken pipe after body".to_string()),
2109        ];
2110        let s = stream::iter(chunks);
2111
2112        let resp = McpProxyClient::parse_json_stream(s)
2113            .await
2114            .expect("complete body before stream error must still parse");
2115        assert!(matches!(resp.id, RequestId::Number(7)));
2116    }
2117
2118    #[tokio::test]
2119    async fn parse_json_stream_partial_body_preserves_stream_error_in_message() {
2120        use futures::stream;
2121
2122        // First chunk is a syntactically truncated JSON body — `from_slice`
2123        // will fail.
2124        let truncated = b"{\"jsonrpc\":\"2.0\",\"id\":1,\"resu".to_vec();
2125        let chunks: Vec<std::result::Result<Vec<u8>, String>> =
2126            vec![Ok(truncated), Err("connection reset by peer".to_string())];
2127        let s = stream::iter(chunks);
2128
2129        let err = McpProxyClient::parse_json_stream(s)
2130            .await
2131            .expect_err("truncated body must fail to parse");
2132        let msg = err.to_string();
2133        assert!(
2134            msg.contains("Failed to parse JSON"),
2135            "expected parse error preface, got: {msg}"
2136        );
2137        assert!(
2138            msg.contains("connection reset by peer"),
2139            "stream error must be preserved in message, got: {msg}"
2140        );
2141    }
2142
2143    #[tokio::test]
2144    async fn parse_json_stream_empty_body_with_stream_error_reports_both() {
2145        use futures::stream;
2146
2147        let chunks: Vec<std::result::Result<Vec<u8>, String>> =
2148            vec![Err("immediate disconnect".to_string())];
2149        let s = stream::iter(chunks);
2150
2151        let err = McpProxyClient::parse_json_stream(s)
2152            .await
2153            .expect_err("empty body must error");
2154        let msg = err.to_string();
2155        assert!(
2156            msg.contains("Empty response body"),
2157            "expected empty-body marker, got: {msg}"
2158        );
2159        assert!(
2160            msg.contains("immediate disconnect"),
2161            "stream error must be preserved, got: {msg}"
2162        );
2163    }
2164
2165    #[tokio::test]
2166    async fn parse_json_stream_returns_early_when_stream_stays_open() {
2167        use futures::stream;
2168
2169        // Production scenario: server sends a complete JSON-RPC response
2170        // then keeps the connection open for notifications. The parser
2171        // must return as soon as a complete response is decoded — without
2172        // waiting for the upstream (or a downstream proxy) to ever close
2173        // the connection. We simulate "stream stays open with extra
2174        // notification chunks after the response" by appending bytes that
2175        // would NOT compose a valid `JsonRpcResponse`. Without an
2176        // incremental early return, the post-loop final parse would fail
2177        // because of the trailing data.
2178        let body: Vec<u8> = serde_json::to_vec(&serde_json::json!({
2179            "jsonrpc": "2.0",
2180            "id": 99,
2181            "result": { "tools": [] }
2182        }))
2183        .unwrap();
2184        let trailing: Vec<u8> =
2185            b"\n{\"jsonrpc\":\"2.0\",\"method\":\"notifications/progress\"}\n".to_vec();
2186
2187        let chunks: Vec<std::result::Result<Vec<u8>, String>> = vec![Ok(body), Ok(trailing)];
2188        let s = stream::iter(chunks);
2189
2190        let resp = McpProxyClient::parse_json_stream(s)
2191            .await
2192            .expect("complete response should parse before EOF, ignoring trailing notifications");
2193        assert!(matches!(resp.id, RequestId::Number(99)));
2194    }
2195}