Skip to main content

forge_client/
lib.rs

1#![warn(missing_docs)]
2
3//! # forge-client
4//!
5//! MCP client connections to downstream servers for the Forgemax Code Mode Gateway.
6//!
7//! Provides [`McpClient`] for connecting to individual MCP servers over stdio
8//! or HTTP transports, and [`RouterDispatcher`] for routing tool calls to the
9//! correct downstream server.
10
11pub mod circuit_breaker;
12pub mod reconnect;
13pub mod router;
14pub mod timeout;
15
16use std::borrow::Cow;
17use std::collections::HashMap;
18
19use anyhow::{Context, Result};
20use forge_sandbox::{ResourceDispatcher, ToolDispatcher};
21use rmcp::model::{CallToolRequestParams, CallToolResult, Content, RawContent};
22use rmcp::service::RunningService;
23use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
24use rmcp::transport::{ConfigureCommandExt, StreamableHttpClientTransport, TokioChildProcess};
25use rmcp::{RoleClient, ServiceExt};
26use serde_json::Value;
27use tokio::process::Command;
28
29pub use circuit_breaker::{
30    CircuitBreakerConfig, CircuitBreakerDispatcher, CircuitBreakerResourceDispatcher,
31};
32pub use reconnect::ReconnectingClient;
33pub use router::{RouterDispatcher, RouterResourceDispatcher};
34pub use timeout::{TimeoutDispatcher, TimeoutResourceDispatcher};
35
36/// Configuration for connecting to a downstream MCP server.
37#[derive(Debug, Clone)]
38#[non_exhaustive]
39pub enum TransportConfig {
40    /// Connect via stdio to a child process.
41    Stdio {
42        /// Command to execute.
43        command: String,
44        /// Arguments to the command.
45        args: Vec<String>,
46    },
47    /// Connect via HTTP (Streamable HTTP / SSE).
48    Http {
49        /// URL of the MCP server endpoint.
50        url: String,
51        /// Optional HTTP headers (e.g., Authorization).
52        headers: HashMap<String, String>,
53    },
54}
55
56/// A client connection to a single downstream MCP server.
57///
58/// Wraps an rmcp client session and implements [`ToolDispatcher`] for routing
59/// tool calls from the sandbox.
60pub struct McpClient {
61    name: String,
62    inner: ClientInner,
63}
64
65enum ClientInner {
66    Stdio(RunningService<RoleClient, ()>),
67    Http(RunningService<RoleClient, ()>),
68}
69
70impl ClientInner {
71    fn peer(&self) -> &rmcp::Peer<RoleClient> {
72        match self {
73            ClientInner::Stdio(s) => s,
74            ClientInner::Http(s) => s,
75        }
76    }
77}
78
79/// Information about a tool discovered from a downstream server.
80#[derive(Debug, Clone)]
81pub struct ToolInfo {
82    /// Tool name.
83    pub name: String,
84    /// Tool description.
85    pub description: Option<String>,
86    /// JSON Schema for the tool's input parameters.
87    pub input_schema: Value,
88}
89
90/// Information about a resource discovered from a downstream server.
91#[derive(Debug, Clone)]
92pub struct ResourceInfo {
93    /// Resource URI.
94    pub uri: String,
95    /// Human-readable name.
96    pub name: String,
97    /// Description.
98    pub description: Option<String>,
99    /// MIME type.
100    pub mime_type: Option<String>,
101}
102
103impl McpClient {
104    /// Connect to a downstream MCP server over stdio (child process).
105    ///
106    /// Spawns the given command as a child process and communicates via stdin/stdout.
107    pub async fn connect_stdio(
108        name: impl Into<String>,
109        command: &str,
110        args: &[&str],
111    ) -> Result<Self> {
112        let name = name.into();
113        let args_owned: Vec<String> = args.iter().map(|s| s.to_string()).collect();
114
115        tracing::info!(
116            server = %name,
117            command = %command,
118            args = ?args_owned,
119            "connecting to downstream MCP server (stdio)"
120        );
121
122        let transport = TokioChildProcess::new(Command::new(command).configure(|cmd| {
123            for arg in &args_owned {
124                cmd.arg(arg);
125            }
126        }))
127        .with_context(|| {
128            format!(
129                "failed to spawn stdio transport for server '{}' (command: {})",
130                name, command
131            )
132        })?;
133
134        let service: RunningService<RoleClient, ()> = ()
135            .serve(transport)
136            .await
137            .with_context(|| format!("MCP handshake failed for server '{}'", name))?;
138
139        tracing::info!(server = %name, "connected to downstream MCP server (stdio)");
140
141        Ok(Self {
142            name,
143            inner: ClientInner::Stdio(service),
144        })
145    }
146
147    /// Connect to a downstream MCP server over HTTP (Streamable HTTP / SSE).
148    pub async fn connect_http(
149        name: impl Into<String>,
150        url: &str,
151        headers: Option<HashMap<String, String>>,
152    ) -> Result<Self> {
153        let name = name.into();
154
155        if url.starts_with("http://") {
156            tracing::warn!(
157                server = %name,
158                url = %url,
159                "connecting over plain HTTP — consider using HTTPS for production"
160            );
161        }
162
163        tracing::info!(
164            server = %name,
165            url = %url,
166            "connecting to downstream MCP server (HTTP)"
167        );
168
169        let mut config = StreamableHttpClientTransportConfig::with_uri(url);
170
171        // Fail-closed: reject credentials on plain HTTP
172        if let Some(ref hdrs) = headers {
173            check_http_credential_safety(url, hdrs)?;
174        }
175
176        // Defense-in-depth belt: also strip sensitive headers on plain HTTP
177        let headers = headers.map(|mut h| {
178            sanitize_headers_for_transport(url, &mut h);
179            h
180        });
181
182        if let Some(hdrs) = &headers {
183            for (key, value) in hdrs {
184                if key.to_lowercase() == "authorization" {
185                    tracing::debug!(server = %name, header = %key, "setting auth header (redacted)");
186                } else {
187                    tracing::debug!(server = %name, header = %key, value = %value, "setting header");
188                }
189            }
190
191            let mut header_map = HashMap::new();
192            for (key, value) in hdrs {
193                let header_name = http::HeaderName::from_bytes(key.as_bytes())
194                    .with_context(|| format!("invalid header name: {key}"))?;
195                let header_value = http::HeaderValue::from_str(value)
196                    .with_context(|| format!("invalid header value for {key}"))?;
197                header_map.insert(header_name, header_value);
198            }
199            config = config.custom_headers(header_map);
200        }
201
202        let transport = StreamableHttpClientTransport::from_config(config);
203        let service: RunningService<RoleClient, ()> = ()
204            .serve(transport)
205            .await
206            .with_context(|| format!("MCP handshake failed for server '{}' (HTTP)", name))?;
207
208        tracing::info!(server = %name, "connected to downstream MCP server (HTTP)");
209
210        Ok(Self {
211            name,
212            inner: ClientInner::Http(service),
213        })
214    }
215
216    /// Connect using a [`TransportConfig`].
217    pub async fn connect(name: impl Into<String>, config: &TransportConfig) -> Result<Self> {
218        let name = name.into();
219        match config {
220            TransportConfig::Stdio { command, args } => {
221                let arg_refs: Vec<&str> = args.iter().map(|s| s.as_str()).collect();
222                Self::connect_stdio(name, command, &arg_refs).await
223            }
224            TransportConfig::Http { url, headers } => {
225                let hdrs = if headers.is_empty() {
226                    None
227                } else {
228                    Some(headers.clone())
229                };
230                Self::connect_http(name, url, hdrs).await
231            }
232        }
233    }
234
235    /// List all tools available on this server.
236    pub async fn list_tools(&self) -> Result<Vec<ToolInfo>> {
237        let tools = self
238            .inner
239            .peer()
240            .list_all_tools()
241            .await
242            .with_context(|| format!("failed to list tools for server '{}'", self.name))?;
243
244        Ok(tools
245            .into_iter()
246            .map(|t| ToolInfo {
247                name: t.name.to_string(),
248                description: t.description.map(|d: Cow<'_, str>| d.to_string()),
249                input_schema: serde_json::to_value(&*t.input_schema)
250                    .unwrap_or(Value::Object(Default::default())),
251            })
252            .collect())
253    }
254
255    /// List all resources available on this server.
256    ///
257    /// Returns an empty Vec if the server does not support resources
258    /// (graceful degradation — not all MCP servers implement resources/list).
259    pub async fn list_resources(&self) -> Result<Vec<ResourceInfo>> {
260        let result = self.inner.peer().list_resources(None).await;
261
262        match result {
263            Ok(list) => Ok(list
264                .resources
265                .into_iter()
266                .map(|r| ResourceInfo {
267                    uri: r.uri.clone(),
268                    name: r.name.clone(),
269                    description: r.description.clone(),
270                    mime_type: r.mime_type.clone(),
271                })
272                .collect()),
273            Err(e) => {
274                let err_str = e.to_string();
275                // Graceful degradation: treat "method not found" as "no resources"
276                if err_str.contains("Method not found")
277                    || err_str.contains("method not found")
278                    || err_str.contains("-32601")
279                {
280                    tracing::debug!(
281                        server = %self.name,
282                        "server does not support resources/list, returning empty"
283                    );
284                    Ok(vec![])
285                } else {
286                    Err(anyhow::anyhow!(
287                        "failed to list resources for server '{}': {}",
288                        self.name,
289                        e
290                    ))
291                }
292            }
293        }
294    }
295
296    /// Read a specific resource by URI.
297    pub async fn read_resource(&self, uri: &str) -> Result<Value> {
298        let result = self
299            .inner
300            .peer()
301            .read_resource(rmcp::model::ReadResourceRequestParams::new(uri))
302            .await
303            .with_context(|| {
304                format!(
305                    "resource read failed: server='{}', uri='{}'",
306                    self.name, uri
307                )
308            })?;
309
310        // Convert resource contents to JSON
311        let contents = result.contents;
312        if contents.is_empty() {
313            return Ok(Value::Null);
314        }
315
316        if contents.len() == 1 {
317            resource_content_to_value(&contents[0])
318        } else {
319            let values: Vec<Value> = contents
320                .iter()
321                .filter_map(|c| resource_content_to_value(c).ok())
322                .collect();
323            Ok(Value::Array(values))
324        }
325    }
326
327    /// Get the server name.
328    pub fn name(&self) -> &str {
329        &self.name
330    }
331
332    /// Gracefully disconnect from the server.
333    pub async fn disconnect(self) -> Result<()> {
334        tracing::info!(server = %self.name, "disconnecting from downstream MCP server");
335        match self.inner {
336            ClientInner::Stdio(s) => {
337                let _ = s.cancel().await;
338            }
339            ClientInner::Http(s) => {
340                let _ = s.cancel().await;
341            }
342        }
343        Ok(())
344    }
345}
346
347#[async_trait::async_trait]
348impl ToolDispatcher for McpClient {
349    async fn call_tool(
350        &self,
351        _server: &str,
352        tool: &str,
353        args: Value,
354    ) -> Result<Value, forge_error::DispatchError> {
355        let arguments = args.as_object().cloned().or_else(|| {
356            if args.is_null() {
357                Some(serde_json::Map::new())
358            } else {
359                None
360            }
361        });
362
363        let result: CallToolResult = self
364            .inner
365            .peer()
366            .call_tool(
367                CallToolRequestParams::new(tool.to_string())
368                    .with_arguments(arguments.unwrap_or_default()),
369            )
370            .await
371            .map_err(|e| {
372                let msg = format!("tool call failed: tool='{}': {}", tool, e);
373                let err_str = e.to_string();
374                if is_transport_dead(&err_str) {
375                    forge_error::DispatchError::TransportDead {
376                        server: self.name.clone(),
377                        reason: msg,
378                    }
379                } else {
380                    forge_error::DispatchError::Upstream {
381                        server: self.name.clone(),
382                        message: msg,
383                    }
384                }
385            })?;
386
387        // Tool-level errors (isError: true) mean the server is healthy but
388        // the tool rejected the request (wrong params, missing state, etc.).
389        // These must NOT trip the circuit breaker.
390        if result.is_error == Some(true) && result.structured_content.is_none() {
391            let error_text = result
392                .content
393                .iter()
394                .filter_map(|c| match &c.raw {
395                    RawContent::Text(t) => Some(t.text.as_str()),
396                    _ => None,
397                })
398                .collect::<Vec<_>>()
399                .join("\n");
400            return Err(forge_error::DispatchError::ToolError {
401                server: self.name.clone(),
402                tool: tool.to_string(),
403                message: format!("tool returned error: {}", error_text),
404            });
405        }
406
407        call_tool_result_to_value(result).map_err(|e| forge_error::DispatchError::Upstream {
408            server: self.name.clone(),
409            message: e.to_string(),
410        })
411    }
412}
413
414#[async_trait::async_trait]
415impl ResourceDispatcher for McpClient {
416    async fn read_resource(
417        &self,
418        _server: &str,
419        uri: &str,
420    ) -> Result<Value, forge_error::DispatchError> {
421        self.read_resource(uri).await.map_err(|e| {
422            let msg = format!("resource read failed: uri='{}': {}", uri, e);
423            let err_str = e.to_string();
424            if is_transport_dead(&err_str) {
425                forge_error::DispatchError::TransportDead {
426                    server: self.name.clone(),
427                    reason: msg,
428                }
429            } else {
430                forge_error::DispatchError::Upstream {
431                    server: self.name.clone(),
432                    message: msg,
433                }
434            }
435        })
436    }
437}
438
439/// Returns true if the error string indicates a permanently dead transport.
440///
441/// Detects patterns from rmcp's internal channel overflow, broken pipes,
442/// and closed transports that indicate the MCP client session is unrecoverable.
443fn is_transport_dead(err_str: &str) -> bool {
444    err_str.contains("TransportClosed")
445        || err_str.contains("transport closed")
446        || err_str.contains("channel closed")
447        || err_str.contains("broken pipe")
448        || err_str.contains("Broken pipe")
449        || err_str.contains("BrokenPipe")
450}
451
452/// Convert a resource content item to a JSON Value.
453fn resource_content_to_value(content: &rmcp::model::ResourceContents) -> Result<Value> {
454    match content {
455        rmcp::model::ResourceContents::TextResourceContents { text, .. } => {
456            // Try to parse as JSON first, fall back to string
457            serde_json::from_str(text).or_else(|_| Ok(Value::String(text.clone())))
458        }
459        rmcp::model::ResourceContents::BlobResourceContents {
460            blob, mime_type, ..
461        } => Ok(serde_json::json!({
462            "_type": "blob",
463            "_encoding": "base64",
464            "data": blob,
465            "mime_type": mime_type.as_deref().unwrap_or("application/octet-stream"),
466        })),
467    }
468}
469
470/// Convert an MCP CallToolResult to a JSON Value.
471fn call_tool_result_to_value(result: CallToolResult) -> Result<Value> {
472    if let Some(structured) = result.structured_content {
473        return Ok(structured);
474    }
475
476    if result.is_error == Some(true) {
477        let error_text = result
478            .content
479            .iter()
480            .filter_map(|c| match &c.raw {
481                RawContent::Text(t) => Some(t.text.as_str()),
482                _ => None,
483            })
484            .collect::<Vec<_>>()
485            .join("\n");
486        return Err(anyhow::anyhow!("tool returned error: {}", error_text));
487    }
488
489    if result.content.len() == 1 {
490        content_to_value(&result.content[0])
491    } else if result.content.is_empty() {
492        Ok(Value::Null)
493    } else {
494        let values: Vec<Value> = result
495            .content
496            .iter()
497            .filter_map(|c| content_to_value(c).ok())
498            .collect();
499        Ok(Value::Array(values))
500    }
501}
502
503/// Maximum size in bytes for binary content (images, audio) before truncation.
504const MAX_BINARY_CONTENT_SIZE: usize = 1_048_576; // 1 MB
505
506/// Maximum size in bytes for text content before truncation.
507/// Prevents OOM from enormous text responses from compromised downstream servers.
508const MAX_TEXT_CONTENT_SIZE: usize = 10_485_760; // 10 MB
509
510/// Convert a single Content item to a JSON Value.
511///
512/// Binary content (images, audio) larger than [`MAX_BINARY_CONTENT_SIZE`] is
513/// replaced with truncation metadata to prevent OOM on large base64 payloads.
514fn content_to_value(content: &Content) -> Result<Value> {
515    match &content.raw {
516        RawContent::Text(t) => {
517            if t.text.len() > MAX_TEXT_CONTENT_SIZE {
518                Ok(serde_json::json!({
519                    "type": "text",
520                    "truncated": true,
521                    "original_size": t.text.len(),
522                    "preview": &t.text[..1024.min(t.text.len())],
523                }))
524            } else {
525                serde_json::from_str(&t.text).or_else(|_| Ok(Value::String(t.text.clone())))
526            }
527        }
528        RawContent::Image(img) => {
529            if img.data.len() > MAX_BINARY_CONTENT_SIZE {
530                Ok(serde_json::json!({
531                    "type": "image",
532                    "truncated": true,
533                    "original_size": img.data.len(),
534                    "mime_type": img.mime_type,
535                }))
536            } else {
537                Ok(serde_json::json!({
538                    "type": "image",
539                    "data": img.data,
540                    "mime_type": img.mime_type,
541                }))
542            }
543        }
544        RawContent::Resource(r) => Ok(serde_json::json!({
545            "type": "resource",
546            "resource": serde_json::to_value(&r.resource).unwrap_or(Value::Null),
547        })),
548        RawContent::Audio(a) => {
549            if a.data.len() > MAX_BINARY_CONTENT_SIZE {
550                Ok(serde_json::json!({
551                    "type": "audio",
552                    "truncated": true,
553                    "original_size": a.data.len(),
554                    "mime_type": a.mime_type,
555                }))
556            } else {
557                Ok(serde_json::json!({
558                    "type": "audio",
559                    "data": a.data,
560                    "mime_type": a.mime_type,
561                }))
562            }
563        }
564        _ => Ok(serde_json::json!({"type": "unknown"})),
565    }
566}
567
568/// Sensitive header name substrings (lowercase). Any header whose lowercased name
569/// contains one of these is stripped on plain HTTP connections.
570const SENSITIVE_HEADER_PATTERNS: &[&str] = &[
571    "authorization",
572    "cookie",
573    "token",
574    "secret",
575    "key",
576    "credential",
577    "password",
578    "auth",
579];
580
581/// Returns true if the header name matches a sensitive pattern.
582fn is_sensitive_header(name: &str) -> bool {
583    let lower = name.to_lowercase();
584    SENSITIVE_HEADER_PATTERNS
585        .iter()
586        .any(|pattern| lower.contains(pattern))
587}
588
589/// Check that credentials are not being sent over plain HTTP.
590///
591/// Returns an error if any sensitive headers are present on an `http://` connection.
592/// This is fail-closed: operators must fix their config to use HTTPS before credentials
593/// will be sent.
594fn check_http_credential_safety(
595    url: &str,
596    headers: &HashMap<String, String>,
597) -> Result<(), anyhow::Error> {
598    if url.starts_with("http://") {
599        let sensitive: Vec<&String> = headers.keys().filter(|k| is_sensitive_header(k)).collect();
600        if !sensitive.is_empty() {
601            return Err(anyhow::anyhow!(
602                "refusing to send credentials over plain HTTP (headers: {}). \
603                 Use HTTPS or remove sensitive headers.",
604                sensitive
605                    .iter()
606                    .map(|s| s.as_str())
607                    .collect::<Vec<_>>()
608                    .join(", ")
609            ));
610        }
611    }
612    Ok(())
613}
614
615/// Strip sensitive headers from HTTP connections over plain HTTP.
616///
617/// Defense-in-depth belt behind [`check_http_credential_safety`].
618/// Strips any header whose name contains "auth", "token", "secret", "key",
619/// "cookie", "credential", or "password" (case-insensitive) to prevent
620/// accidental credential leakage over unencrypted transports.
621fn sanitize_headers_for_transport(url: &str, headers: &mut HashMap<String, String>) {
622    if url.starts_with("http://") {
623        let removed: Vec<String> = headers
624            .keys()
625            .filter(|k| is_sensitive_header(k))
626            .cloned()
627            .collect();
628        for key in &removed {
629            headers.remove(key);
630        }
631        if !removed.is_empty() {
632            tracing::warn!(
633                url = %url,
634                removed_headers = ?removed,
635                "stripped sensitive headers from plain HTTP connection — use HTTPS to send credentials"
636            );
637        }
638    }
639}
640
641#[cfg(test)]
642mod tests {
643    use super::*;
644    use rmcp::model::{Content, RawContent};
645
646    #[test]
647    fn content_to_value_text_string() {
648        let content = Content::text("hello");
649        let val = content_to_value(&content).unwrap();
650        assert_eq!(val, Value::String("hello".into()));
651    }
652
653    #[test]
654    fn content_to_value_text_json() {
655        let content = Content::text(r#"{"k":"v"}"#);
656        let val = content_to_value(&content).unwrap();
657        assert_eq!(val, serde_json::json!({"k": "v"}));
658    }
659
660    #[test]
661    fn content_to_value_small_image_preserved() {
662        let small_data = "a".repeat(1024); // 1KB
663        let content = Content::image(small_data.clone(), "image/png");
664        let val = content_to_value(&content).unwrap();
665        assert_eq!(val["type"], "image");
666        assert_eq!(val["data"], small_data);
667        assert!(val.get("truncated").is_none());
668    }
669
670    #[test]
671    fn content_to_value_oversized_image_truncated() {
672        let large_data = "a".repeat(2 * 1024 * 1024); // 2MB
673        let content = Content::image(large_data, "image/png");
674        let val = content_to_value(&content).unwrap();
675        assert_eq!(val["type"], "image");
676        assert_eq!(val["truncated"], true);
677        assert!(val.get("data").is_none());
678        assert!(val["original_size"].as_u64().unwrap() > MAX_BINARY_CONTENT_SIZE as u64);
679    }
680
681    #[test]
682    fn content_to_value_oversized_audio_truncated() {
683        let large_data = "a".repeat(2 * 1024 * 1024); // 2MB
684        let content = Content {
685            raw: RawContent::Audio(rmcp::model::RawAudioContent {
686                data: large_data,
687                mime_type: "audio/wav".into(),
688            }),
689            annotations: None,
690        };
691        let val = content_to_value(&content).unwrap();
692        assert_eq!(val["type"], "audio");
693        assert_eq!(val["truncated"], true);
694        assert!(val.get("data").is_none());
695    }
696
697    #[test]
698    fn content_to_value_oversized_text_truncated() {
699        let large_text = "x".repeat(11 * 1024 * 1024); // 11MB
700        let content = Content::text(large_text);
701        let val = content_to_value(&content).unwrap();
702        assert_eq!(val["type"], "text");
703        assert_eq!(val["truncated"], true);
704        assert!(val["original_size"].as_u64().unwrap() > MAX_TEXT_CONTENT_SIZE as u64);
705        assert!(val["preview"].as_str().unwrap().len() <= 1024);
706    }
707
708    #[test]
709    fn content_to_value_normal_text_not_truncated() {
710        let normal_text = "x".repeat(1024); // 1KB — well under limit
711        let content = Content::text(normal_text.clone());
712        let val = content_to_value(&content).unwrap();
713        assert_eq!(val, Value::String(normal_text));
714    }
715
716    #[test]
717    fn sanitize_headers_strips_auth_on_http() {
718        let mut headers = HashMap::new();
719        headers.insert("Authorization".into(), "Bearer secret".into());
720        headers.insert("Content-Type".into(), "application/json".into());
721        sanitize_headers_for_transport("http://example.com/mcp", &mut headers);
722        assert!(!headers.contains_key("Authorization"));
723        assert!(headers.contains_key("Content-Type"));
724    }
725
726    #[test]
727    fn sanitize_headers_strips_api_key_on_http() {
728        let mut headers = HashMap::new();
729        headers.insert("X-Api-Key".into(), "sk-123".into());
730        headers.insert("Content-Type".into(), "application/json".into());
731        sanitize_headers_for_transport("http://example.com/mcp", &mut headers);
732        assert!(!headers.contains_key("X-Api-Key"));
733        assert!(headers.contains_key("Content-Type"));
734    }
735
736    #[test]
737    fn sanitize_headers_strips_cookie_on_http() {
738        let mut headers = HashMap::new();
739        headers.insert("Cookie".into(), "session=abc123".into());
740        sanitize_headers_for_transport("http://example.com/mcp", &mut headers);
741        assert!(!headers.contains_key("Cookie"));
742    }
743
744    #[test]
745    fn sanitize_headers_strips_custom_token_on_http() {
746        let mut headers = HashMap::new();
747        headers.insert("X-Auth-Token".into(), "tok_secret".into());
748        headers.insert("X-Secret-Key".into(), "s3cr3t".into());
749        headers.insert("X-Custom-Credential".into(), "cred".into());
750        headers.insert("X-Password".into(), "pass".into());
751        headers.insert("Accept".into(), "application/json".into());
752        sanitize_headers_for_transport("http://example.com/mcp", &mut headers);
753        assert!(!headers.contains_key("X-Auth-Token"));
754        assert!(!headers.contains_key("X-Secret-Key"));
755        assert!(!headers.contains_key("X-Custom-Credential"));
756        assert!(!headers.contains_key("X-Password"));
757        assert!(headers.contains_key("Accept"));
758    }
759
760    #[test]
761    fn sanitize_headers_preserves_all_on_https() {
762        let mut headers = HashMap::new();
763        headers.insert("Authorization".into(), "Bearer secret".into());
764        headers.insert("X-Api-Key".into(), "sk-123".into());
765        headers.insert("Cookie".into(), "session=abc".into());
766        sanitize_headers_for_transport("https://example.com/mcp", &mut headers);
767        assert!(headers.contains_key("Authorization"));
768        assert!(headers.contains_key("X-Api-Key"));
769        assert!(headers.contains_key("Cookie"));
770    }
771
772    // --- HTTP-SEC-01: rejects credentials on plain HTTP ---
773    #[test]
774    fn http_sec_01_rejects_creds_on_http() {
775        let mut headers = HashMap::new();
776        headers.insert("Authorization".into(), "Bearer secret".into());
777        let err = check_http_credential_safety("http://example.com/mcp", &headers);
778        assert!(err.is_err(), "should reject creds on HTTP");
779        let msg = err.unwrap_err().to_string();
780        assert!(
781            msg.contains("plain HTTP"),
782            "error should mention plain HTTP: {msg}"
783        );
784    }
785
786    // --- HTTP-SEC-02: allows HTTP without credentials ---
787    #[test]
788    fn http_sec_02_allows_http_without_creds() {
789        let mut headers = HashMap::new();
790        headers.insert("Content-Type".into(), "application/json".into());
791        assert!(check_http_credential_safety("http://example.com/mcp", &headers).is_ok());
792        // Empty headers also OK
793        assert!(check_http_credential_safety("http://example.com/mcp", &HashMap::new()).is_ok());
794    }
795
796    // --- HTTP-SEC-03: allows HTTPS with credentials ---
797    #[test]
798    fn http_sec_03_allows_https_with_creds() {
799        let mut headers = HashMap::new();
800        headers.insert("Authorization".into(), "Bearer secret".into());
801        headers.insert("X-Api-Key".into(), "sk-123".into());
802        assert!(check_http_credential_safety("https://example.com/mcp", &headers).is_ok());
803    }
804
805    #[test]
806    fn is_sensitive_header_matches() {
807        assert!(is_sensitive_header("Authorization"));
808        assert!(is_sensitive_header("x-api-key"));
809        assert!(is_sensitive_header("Cookie"));
810        assert!(is_sensitive_header("X-Auth-Token"));
811        assert!(is_sensitive_header("X-Secret-Key"));
812        assert!(is_sensitive_header("X-Custom-Credential"));
813        assert!(is_sensitive_header("X-Password"));
814        assert!(!is_sensitive_header("Content-Type"));
815        assert!(!is_sensitive_header("Accept"));
816        assert!(!is_sensitive_header("User-Agent"));
817    }
818
819    // --- isError classification tests ---
820
821    #[test]
822    fn call_tool_result_is_error_true_returns_err() {
823        let result = CallToolResult::error(vec![Content::text(
824            "Invalid params: missing field 'base_url'",
825        )]);
826        let err = call_tool_result_to_value(result);
827        assert!(err.is_err());
828        let msg = err.unwrap_err().to_string();
829        assert!(
830            msg.contains("Invalid params"),
831            "expected error text, got: {msg}"
832        );
833    }
834
835    #[test]
836    fn call_tool_result_success_returns_ok() {
837        let result = CallToolResult::success(vec![Content::text(r#"{"status": "ok"}"#)]);
838        let val = call_tool_result_to_value(result).unwrap();
839        assert_eq!(val["status"], "ok");
840    }
841
842    #[test]
843    fn call_tool_result_structured_content_takes_priority_over_is_error() {
844        let structured = serde_json::json!({"data": "important"});
845        let mut result = CallToolResult::error(vec![Content::text("error text")]);
846        result.structured_content = Some(structured.clone());
847        let val = call_tool_result_to_value(result).unwrap();
848        assert_eq!(val, structured);
849    }
850
851    // --- Transport death detection tests ---
852
853    #[test]
854    fn transport_dead_detects_transport_closed() {
855        assert!(is_transport_dead("TransportClosed: channel full"));
856        assert!(is_transport_dead("error: transport closed unexpectedly"));
857        assert!(is_transport_dead("channel closed by peer"));
858        assert!(is_transport_dead("broken pipe while writing"));
859        assert!(is_transport_dead("Broken pipe (os error 32)"));
860        assert!(is_transport_dead("BrokenPipe"));
861    }
862
863    #[test]
864    fn transport_dead_rejects_normal_errors() {
865        assert!(!is_transport_dead("tool not found: echo"));
866        assert!(!is_transport_dead("timeout after 5000ms"));
867        assert!(!is_transport_dead("Invalid params: missing field"));
868        assert!(!is_transport_dead("connection refused"));
869    }
870
871    /// Compile-time guard: TransportConfig is #[non_exhaustive].
872    #[test]
873    #[allow(unreachable_patterns)]
874    fn ne_transport_config_is_non_exhaustive() {
875        let config = TransportConfig::Stdio {
876            command: "test".into(),
877            args: vec![],
878        };
879        match config {
880            TransportConfig::Stdio { .. } | TransportConfig::Http { .. } => {}
881            _ => {}
882        }
883    }
884}