Skip to main content

forge_client/
lib.rs

1#![warn(missing_docs)]
2
3//! # forge-client
4//!
5//! MCP client connections to downstream servers for the Forgemax Code Mode Gateway.
6//!
7//! Provides [`McpClient`] for connecting to individual MCP servers over stdio
8//! or HTTP transports, and [`RouterDispatcher`] for routing tool calls to the
9//! correct downstream server.
10
11pub mod circuit_breaker;
12pub mod router;
13pub mod timeout;
14
15use std::borrow::Cow;
16use std::collections::HashMap;
17
18use anyhow::{Context, Result};
19use forge_sandbox::{ResourceDispatcher, ToolDispatcher};
20use rmcp::model::{CallToolRequestParams, CallToolResult, Content, RawContent};
21use rmcp::service::RunningService;
22use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
23use rmcp::transport::{ConfigureCommandExt, StreamableHttpClientTransport, TokioChildProcess};
24use rmcp::{RoleClient, ServiceExt};
25use serde_json::Value;
26use tokio::process::Command;
27
28pub use circuit_breaker::{
29    CircuitBreakerConfig, CircuitBreakerDispatcher, CircuitBreakerResourceDispatcher,
30};
31pub use router::{RouterDispatcher, RouterResourceDispatcher};
32pub use timeout::{TimeoutDispatcher, TimeoutResourceDispatcher};
33
34/// Configuration for connecting to a downstream MCP server.
35#[derive(Debug, Clone)]
36#[non_exhaustive]
37pub enum TransportConfig {
38    /// Connect via stdio to a child process.
39    Stdio {
40        /// Command to execute.
41        command: String,
42        /// Arguments to the command.
43        args: Vec<String>,
44    },
45    /// Connect via HTTP (Streamable HTTP / SSE).
46    Http {
47        /// URL of the MCP server endpoint.
48        url: String,
49        /// Optional HTTP headers (e.g., Authorization).
50        headers: HashMap<String, String>,
51    },
52}
53
54/// A client connection to a single downstream MCP server.
55///
56/// Wraps an rmcp client session and implements [`ToolDispatcher`] for routing
57/// tool calls from the sandbox.
58pub struct McpClient {
59    name: String,
60    inner: ClientInner,
61}
62
63enum ClientInner {
64    Stdio(RunningService<RoleClient, ()>),
65    Http(RunningService<RoleClient, ()>),
66}
67
68impl ClientInner {
69    fn peer(&self) -> &rmcp::Peer<RoleClient> {
70        match self {
71            ClientInner::Stdio(s) => s,
72            ClientInner::Http(s) => s,
73        }
74    }
75}
76
77/// Information about a tool discovered from a downstream server.
78#[derive(Debug, Clone)]
79pub struct ToolInfo {
80    /// Tool name.
81    pub name: String,
82    /// Tool description.
83    pub description: Option<String>,
84    /// JSON Schema for the tool's input parameters.
85    pub input_schema: Value,
86}
87
88/// Information about a resource discovered from a downstream server.
89#[derive(Debug, Clone)]
90pub struct ResourceInfo {
91    /// Resource URI.
92    pub uri: String,
93    /// Human-readable name.
94    pub name: String,
95    /// Description.
96    pub description: Option<String>,
97    /// MIME type.
98    pub mime_type: Option<String>,
99}
100
101impl McpClient {
102    /// Connect to a downstream MCP server over stdio (child process).
103    ///
104    /// Spawns the given command as a child process and communicates via stdin/stdout.
105    pub async fn connect_stdio(
106        name: impl Into<String>,
107        command: &str,
108        args: &[&str],
109    ) -> Result<Self> {
110        let name = name.into();
111        let args_owned: Vec<String> = args.iter().map(|s| s.to_string()).collect();
112
113        tracing::info!(
114            server = %name,
115            command = %command,
116            args = ?args_owned,
117            "connecting to downstream MCP server (stdio)"
118        );
119
120        let transport = TokioChildProcess::new(Command::new(command).configure(|cmd| {
121            for arg in &args_owned {
122                cmd.arg(arg);
123            }
124        }))
125        .with_context(|| {
126            format!(
127                "failed to spawn stdio transport for server '{}' (command: {})",
128                name, command
129            )
130        })?;
131
132        let service: RunningService<RoleClient, ()> = ()
133            .serve(transport)
134            .await
135            .with_context(|| format!("MCP handshake failed for server '{}'", name))?;
136
137        tracing::info!(server = %name, "connected to downstream MCP server (stdio)");
138
139        Ok(Self {
140            name,
141            inner: ClientInner::Stdio(service),
142        })
143    }
144
145    /// Connect to a downstream MCP server over HTTP (Streamable HTTP / SSE).
146    pub async fn connect_http(
147        name: impl Into<String>,
148        url: &str,
149        headers: Option<HashMap<String, String>>,
150    ) -> Result<Self> {
151        let name = name.into();
152
153        if url.starts_with("http://") {
154            tracing::warn!(
155                server = %name,
156                url = %url,
157                "connecting over plain HTTP — consider using HTTPS for production"
158            );
159        }
160
161        tracing::info!(
162            server = %name,
163            url = %url,
164            "connecting to downstream MCP server (HTTP)"
165        );
166
167        let mut config = StreamableHttpClientTransportConfig::with_uri(url);
168
169        // Fail-closed: reject credentials on plain HTTP
170        if let Some(ref hdrs) = headers {
171            check_http_credential_safety(url, hdrs)?;
172        }
173
174        // Defense-in-depth belt: also strip sensitive headers on plain HTTP
175        let headers = headers.map(|mut h| {
176            sanitize_headers_for_transport(url, &mut h);
177            h
178        });
179
180        if let Some(hdrs) = &headers {
181            for (key, value) in hdrs {
182                if key.to_lowercase() == "authorization" {
183                    tracing::debug!(server = %name, header = %key, "setting auth header (redacted)");
184                } else {
185                    tracing::debug!(server = %name, header = %key, value = %value, "setting header");
186                }
187            }
188
189            let mut header_map = HashMap::new();
190            for (key, value) in hdrs {
191                let header_name = http::HeaderName::from_bytes(key.as_bytes())
192                    .with_context(|| format!("invalid header name: {key}"))?;
193                let header_value = http::HeaderValue::from_str(value)
194                    .with_context(|| format!("invalid header value for {key}"))?;
195                header_map.insert(header_name, header_value);
196            }
197            config = config.custom_headers(header_map);
198        }
199
200        let transport = StreamableHttpClientTransport::from_config(config);
201        let service: RunningService<RoleClient, ()> = ()
202            .serve(transport)
203            .await
204            .with_context(|| format!("MCP handshake failed for server '{}' (HTTP)", name))?;
205
206        tracing::info!(server = %name, "connected to downstream MCP server (HTTP)");
207
208        Ok(Self {
209            name,
210            inner: ClientInner::Http(service),
211        })
212    }
213
214    /// Connect using a [`TransportConfig`].
215    pub async fn connect(name: impl Into<String>, config: &TransportConfig) -> Result<Self> {
216        let name = name.into();
217        match config {
218            TransportConfig::Stdio { command, args } => {
219                let arg_refs: Vec<&str> = args.iter().map(|s| s.as_str()).collect();
220                Self::connect_stdio(name, command, &arg_refs).await
221            }
222            TransportConfig::Http { url, headers } => {
223                let hdrs = if headers.is_empty() {
224                    None
225                } else {
226                    Some(headers.clone())
227                };
228                Self::connect_http(name, url, hdrs).await
229            }
230        }
231    }
232
233    /// List all tools available on this server.
234    pub async fn list_tools(&self) -> Result<Vec<ToolInfo>> {
235        let tools = self
236            .inner
237            .peer()
238            .list_all_tools()
239            .await
240            .with_context(|| format!("failed to list tools for server '{}'", self.name))?;
241
242        Ok(tools
243            .into_iter()
244            .map(|t| ToolInfo {
245                name: t.name.to_string(),
246                description: t.description.map(|d: Cow<'_, str>| d.to_string()),
247                input_schema: serde_json::to_value(&*t.input_schema)
248                    .unwrap_or(Value::Object(Default::default())),
249            })
250            .collect())
251    }
252
253    /// List all resources available on this server.
254    ///
255    /// Returns an empty Vec if the server does not support resources
256    /// (graceful degradation — not all MCP servers implement resources/list).
257    pub async fn list_resources(&self) -> Result<Vec<ResourceInfo>> {
258        let result = self.inner.peer().list_resources(None).await;
259
260        match result {
261            Ok(list) => Ok(list
262                .resources
263                .into_iter()
264                .map(|r| ResourceInfo {
265                    uri: r.uri.clone(),
266                    name: r.name.clone(),
267                    description: r.description.clone(),
268                    mime_type: r.mime_type.clone(),
269                })
270                .collect()),
271            Err(e) => {
272                let err_str = e.to_string();
273                // Graceful degradation: treat "method not found" as "no resources"
274                if err_str.contains("Method not found")
275                    || err_str.contains("method not found")
276                    || err_str.contains("-32601")
277                {
278                    tracing::debug!(
279                        server = %self.name,
280                        "server does not support resources/list, returning empty"
281                    );
282                    Ok(vec![])
283                } else {
284                    Err(anyhow::anyhow!(
285                        "failed to list resources for server '{}': {}",
286                        self.name,
287                        e
288                    ))
289                }
290            }
291        }
292    }
293
294    /// Read a specific resource by URI.
295    pub async fn read_resource(&self, uri: &str) -> Result<Value> {
296        let result = self
297            .inner
298            .peer()
299            .read_resource(rmcp::model::ReadResourceRequestParams {
300                uri: uri.to_string(),
301                meta: None,
302            })
303            .await
304            .with_context(|| {
305                format!(
306                    "resource read failed: server='{}', uri='{}'",
307                    self.name, uri
308                )
309            })?;
310
311        // Convert resource contents to JSON
312        let contents = result.contents;
313        if contents.is_empty() {
314            return Ok(Value::Null);
315        }
316
317        if contents.len() == 1 {
318            resource_content_to_value(&contents[0])
319        } else {
320            let values: Vec<Value> = contents
321                .iter()
322                .filter_map(|c| resource_content_to_value(c).ok())
323                .collect();
324            Ok(Value::Array(values))
325        }
326    }
327
328    /// Get the server name.
329    pub fn name(&self) -> &str {
330        &self.name
331    }
332
333    /// Gracefully disconnect from the server.
334    pub async fn disconnect(self) -> Result<()> {
335        tracing::info!(server = %self.name, "disconnecting from downstream MCP server");
336        match self.inner {
337            ClientInner::Stdio(s) => {
338                let _ = s.cancel().await;
339            }
340            ClientInner::Http(s) => {
341                let _ = s.cancel().await;
342            }
343        }
344        Ok(())
345    }
346}
347
348#[async_trait::async_trait]
349impl ToolDispatcher for McpClient {
350    async fn call_tool(
351        &self,
352        _server: &str,
353        tool: &str,
354        args: Value,
355    ) -> Result<Value, forge_error::DispatchError> {
356        let arguments = args.as_object().cloned().or_else(|| {
357            if args.is_null() {
358                Some(serde_json::Map::new())
359            } else {
360                None
361            }
362        });
363
364        let result: CallToolResult = self
365            .inner
366            .peer()
367            .call_tool(CallToolRequestParams {
368                meta: None,
369                name: Cow::Owned(tool.to_string()),
370                arguments,
371                task: None,
372            })
373            .await
374            .map_err(|e| forge_error::DispatchError::Upstream {
375                server: self.name.clone(),
376                message: format!("tool call failed: tool='{}': {}", tool, e),
377            })?;
378
379        // Tool-level errors (isError: true) mean the server is healthy but
380        // the tool rejected the request (wrong params, missing state, etc.).
381        // These must NOT trip the circuit breaker.
382        if result.is_error == Some(true) && result.structured_content.is_none() {
383            let error_text = result
384                .content
385                .iter()
386                .filter_map(|c| match &c.raw {
387                    RawContent::Text(t) => Some(t.text.as_str()),
388                    _ => None,
389                })
390                .collect::<Vec<_>>()
391                .join("\n");
392            return Err(forge_error::DispatchError::ToolError {
393                server: self.name.clone(),
394                tool: tool.to_string(),
395                message: format!("tool returned error: {}", error_text),
396            });
397        }
398
399        call_tool_result_to_value(result).map_err(|e| forge_error::DispatchError::Upstream {
400            server: self.name.clone(),
401            message: e.to_string(),
402        })
403    }
404}
405
406#[async_trait::async_trait]
407impl ResourceDispatcher for McpClient {
408    async fn read_resource(
409        &self,
410        _server: &str,
411        uri: &str,
412    ) -> Result<Value, forge_error::DispatchError> {
413        self.read_resource(uri)
414            .await
415            .map_err(|e| forge_error::DispatchError::Upstream {
416                server: self.name.clone(),
417                message: format!("resource read failed: uri='{}': {}", uri, e),
418            })
419    }
420}
421
422/// Convert a resource content item to a JSON Value.
423fn resource_content_to_value(content: &rmcp::model::ResourceContents) -> Result<Value> {
424    match content {
425        rmcp::model::ResourceContents::TextResourceContents { text, .. } => {
426            // Try to parse as JSON first, fall back to string
427            serde_json::from_str(text).or_else(|_| Ok(Value::String(text.clone())))
428        }
429        rmcp::model::ResourceContents::BlobResourceContents {
430            blob, mime_type, ..
431        } => Ok(serde_json::json!({
432            "_type": "blob",
433            "_encoding": "base64",
434            "data": blob,
435            "mime_type": mime_type.as_deref().unwrap_or("application/octet-stream"),
436        })),
437    }
438}
439
440/// Convert an MCP CallToolResult to a JSON Value.
441fn call_tool_result_to_value(result: CallToolResult) -> Result<Value> {
442    if let Some(structured) = result.structured_content {
443        return Ok(structured);
444    }
445
446    if result.is_error == Some(true) {
447        let error_text = result
448            .content
449            .iter()
450            .filter_map(|c| match &c.raw {
451                RawContent::Text(t) => Some(t.text.as_str()),
452                _ => None,
453            })
454            .collect::<Vec<_>>()
455            .join("\n");
456        return Err(anyhow::anyhow!("tool returned error: {}", error_text));
457    }
458
459    if result.content.len() == 1 {
460        content_to_value(&result.content[0])
461    } else if result.content.is_empty() {
462        Ok(Value::Null)
463    } else {
464        let values: Vec<Value> = result
465            .content
466            .iter()
467            .filter_map(|c| content_to_value(c).ok())
468            .collect();
469        Ok(Value::Array(values))
470    }
471}
472
473/// Maximum size in bytes for binary content (images, audio) before truncation.
474const MAX_BINARY_CONTENT_SIZE: usize = 1_048_576; // 1 MB
475
476/// Maximum size in bytes for text content before truncation.
477/// Prevents OOM from enormous text responses from compromised downstream servers.
478const MAX_TEXT_CONTENT_SIZE: usize = 10_485_760; // 10 MB
479
480/// Convert a single Content item to a JSON Value.
481///
482/// Binary content (images, audio) larger than [`MAX_BINARY_CONTENT_SIZE`] is
483/// replaced with truncation metadata to prevent OOM on large base64 payloads.
484fn content_to_value(content: &Content) -> Result<Value> {
485    match &content.raw {
486        RawContent::Text(t) => {
487            if t.text.len() > MAX_TEXT_CONTENT_SIZE {
488                Ok(serde_json::json!({
489                    "type": "text",
490                    "truncated": true,
491                    "original_size": t.text.len(),
492                    "preview": &t.text[..1024.min(t.text.len())],
493                }))
494            } else {
495                serde_json::from_str(&t.text).or_else(|_| Ok(Value::String(t.text.clone())))
496            }
497        }
498        RawContent::Image(img) => {
499            if img.data.len() > MAX_BINARY_CONTENT_SIZE {
500                Ok(serde_json::json!({
501                    "type": "image",
502                    "truncated": true,
503                    "original_size": img.data.len(),
504                    "mime_type": img.mime_type,
505                }))
506            } else {
507                Ok(serde_json::json!({
508                    "type": "image",
509                    "data": img.data,
510                    "mime_type": img.mime_type,
511                }))
512            }
513        }
514        RawContent::Resource(r) => Ok(serde_json::json!({
515            "type": "resource",
516            "resource": serde_json::to_value(&r.resource).unwrap_or(Value::Null),
517        })),
518        RawContent::Audio(a) => {
519            if a.data.len() > MAX_BINARY_CONTENT_SIZE {
520                Ok(serde_json::json!({
521                    "type": "audio",
522                    "truncated": true,
523                    "original_size": a.data.len(),
524                    "mime_type": a.mime_type,
525                }))
526            } else {
527                Ok(serde_json::json!({
528                    "type": "audio",
529                    "data": a.data,
530                    "mime_type": a.mime_type,
531                }))
532            }
533        }
534        _ => Ok(serde_json::json!({"type": "unknown"})),
535    }
536}
537
538/// Sensitive header name substrings (lowercase). Any header whose lowercased name
539/// contains one of these is stripped on plain HTTP connections.
540const SENSITIVE_HEADER_PATTERNS: &[&str] = &[
541    "authorization",
542    "cookie",
543    "token",
544    "secret",
545    "key",
546    "credential",
547    "password",
548    "auth",
549];
550
551/// Returns true if the header name matches a sensitive pattern.
552fn is_sensitive_header(name: &str) -> bool {
553    let lower = name.to_lowercase();
554    SENSITIVE_HEADER_PATTERNS
555        .iter()
556        .any(|pattern| lower.contains(pattern))
557}
558
559/// Check that credentials are not being sent over plain HTTP.
560///
561/// Returns an error if any sensitive headers are present on an `http://` connection.
562/// This is fail-closed: operators must fix their config to use HTTPS before credentials
563/// will be sent.
564fn check_http_credential_safety(
565    url: &str,
566    headers: &HashMap<String, String>,
567) -> Result<(), anyhow::Error> {
568    if url.starts_with("http://") {
569        let sensitive: Vec<&String> = headers.keys().filter(|k| is_sensitive_header(k)).collect();
570        if !sensitive.is_empty() {
571            return Err(anyhow::anyhow!(
572                "refusing to send credentials over plain HTTP (headers: {}). \
573                 Use HTTPS or remove sensitive headers.",
574                sensitive
575                    .iter()
576                    .map(|s| s.as_str())
577                    .collect::<Vec<_>>()
578                    .join(", ")
579            ));
580        }
581    }
582    Ok(())
583}
584
585/// Strip sensitive headers from HTTP connections over plain HTTP.
586///
587/// Defense-in-depth belt behind [`check_http_credential_safety`].
588/// Strips any header whose name contains "auth", "token", "secret", "key",
589/// "cookie", "credential", or "password" (case-insensitive) to prevent
590/// accidental credential leakage over unencrypted transports.
591fn sanitize_headers_for_transport(url: &str, headers: &mut HashMap<String, String>) {
592    if url.starts_with("http://") {
593        let removed: Vec<String> = headers
594            .keys()
595            .filter(|k| is_sensitive_header(k))
596            .cloned()
597            .collect();
598        for key in &removed {
599            headers.remove(key);
600        }
601        if !removed.is_empty() {
602            tracing::warn!(
603                url = %url,
604                removed_headers = ?removed,
605                "stripped sensitive headers from plain HTTP connection — use HTTPS to send credentials"
606            );
607        }
608    }
609}
610
611#[cfg(test)]
612mod tests {
613    use super::*;
614    use rmcp::model::{Content, RawContent};
615
616    #[test]
617    fn content_to_value_text_string() {
618        let content = Content::text("hello");
619        let val = content_to_value(&content).unwrap();
620        assert_eq!(val, Value::String("hello".into()));
621    }
622
623    #[test]
624    fn content_to_value_text_json() {
625        let content = Content::text(r#"{"k":"v"}"#);
626        let val = content_to_value(&content).unwrap();
627        assert_eq!(val, serde_json::json!({"k": "v"}));
628    }
629
630    #[test]
631    fn content_to_value_small_image_preserved() {
632        let small_data = "a".repeat(1024); // 1KB
633        let content = Content::image(small_data.clone(), "image/png");
634        let val = content_to_value(&content).unwrap();
635        assert_eq!(val["type"], "image");
636        assert_eq!(val["data"], small_data);
637        assert!(val.get("truncated").is_none());
638    }
639
640    #[test]
641    fn content_to_value_oversized_image_truncated() {
642        let large_data = "a".repeat(2 * 1024 * 1024); // 2MB
643        let content = Content::image(large_data, "image/png");
644        let val = content_to_value(&content).unwrap();
645        assert_eq!(val["type"], "image");
646        assert_eq!(val["truncated"], true);
647        assert!(val.get("data").is_none());
648        assert!(val["original_size"].as_u64().unwrap() > MAX_BINARY_CONTENT_SIZE as u64);
649    }
650
651    #[test]
652    fn content_to_value_oversized_audio_truncated() {
653        let large_data = "a".repeat(2 * 1024 * 1024); // 2MB
654        let content = Content {
655            raw: RawContent::Audio(rmcp::model::RawAudioContent {
656                data: large_data,
657                mime_type: "audio/wav".into(),
658            }),
659            annotations: None,
660        };
661        let val = content_to_value(&content).unwrap();
662        assert_eq!(val["type"], "audio");
663        assert_eq!(val["truncated"], true);
664        assert!(val.get("data").is_none());
665    }
666
667    #[test]
668    fn content_to_value_oversized_text_truncated() {
669        let large_text = "x".repeat(11 * 1024 * 1024); // 11MB
670        let content = Content::text(large_text);
671        let val = content_to_value(&content).unwrap();
672        assert_eq!(val["type"], "text");
673        assert_eq!(val["truncated"], true);
674        assert!(val["original_size"].as_u64().unwrap() > MAX_TEXT_CONTENT_SIZE as u64);
675        assert!(val["preview"].as_str().unwrap().len() <= 1024);
676    }
677
678    #[test]
679    fn content_to_value_normal_text_not_truncated() {
680        let normal_text = "x".repeat(1024); // 1KB — well under limit
681        let content = Content::text(normal_text.clone());
682        let val = content_to_value(&content).unwrap();
683        assert_eq!(val, Value::String(normal_text));
684    }
685
686    #[test]
687    fn sanitize_headers_strips_auth_on_http() {
688        let mut headers = HashMap::new();
689        headers.insert("Authorization".into(), "Bearer secret".into());
690        headers.insert("Content-Type".into(), "application/json".into());
691        sanitize_headers_for_transport("http://example.com/mcp", &mut headers);
692        assert!(!headers.contains_key("Authorization"));
693        assert!(headers.contains_key("Content-Type"));
694    }
695
696    #[test]
697    fn sanitize_headers_strips_api_key_on_http() {
698        let mut headers = HashMap::new();
699        headers.insert("X-Api-Key".into(), "sk-123".into());
700        headers.insert("Content-Type".into(), "application/json".into());
701        sanitize_headers_for_transport("http://example.com/mcp", &mut headers);
702        assert!(!headers.contains_key("X-Api-Key"));
703        assert!(headers.contains_key("Content-Type"));
704    }
705
706    #[test]
707    fn sanitize_headers_strips_cookie_on_http() {
708        let mut headers = HashMap::new();
709        headers.insert("Cookie".into(), "session=abc123".into());
710        sanitize_headers_for_transport("http://example.com/mcp", &mut headers);
711        assert!(!headers.contains_key("Cookie"));
712    }
713
714    #[test]
715    fn sanitize_headers_strips_custom_token_on_http() {
716        let mut headers = HashMap::new();
717        headers.insert("X-Auth-Token".into(), "tok_secret".into());
718        headers.insert("X-Secret-Key".into(), "s3cr3t".into());
719        headers.insert("X-Custom-Credential".into(), "cred".into());
720        headers.insert("X-Password".into(), "pass".into());
721        headers.insert("Accept".into(), "application/json".into());
722        sanitize_headers_for_transport("http://example.com/mcp", &mut headers);
723        assert!(!headers.contains_key("X-Auth-Token"));
724        assert!(!headers.contains_key("X-Secret-Key"));
725        assert!(!headers.contains_key("X-Custom-Credential"));
726        assert!(!headers.contains_key("X-Password"));
727        assert!(headers.contains_key("Accept"));
728    }
729
730    #[test]
731    fn sanitize_headers_preserves_all_on_https() {
732        let mut headers = HashMap::new();
733        headers.insert("Authorization".into(), "Bearer secret".into());
734        headers.insert("X-Api-Key".into(), "sk-123".into());
735        headers.insert("Cookie".into(), "session=abc".into());
736        sanitize_headers_for_transport("https://example.com/mcp", &mut headers);
737        assert!(headers.contains_key("Authorization"));
738        assert!(headers.contains_key("X-Api-Key"));
739        assert!(headers.contains_key("Cookie"));
740    }
741
742    // --- HTTP-SEC-01: rejects credentials on plain HTTP ---
743    #[test]
744    fn http_sec_01_rejects_creds_on_http() {
745        let mut headers = HashMap::new();
746        headers.insert("Authorization".into(), "Bearer secret".into());
747        let err = check_http_credential_safety("http://example.com/mcp", &headers);
748        assert!(err.is_err(), "should reject creds on HTTP");
749        let msg = err.unwrap_err().to_string();
750        assert!(
751            msg.contains("plain HTTP"),
752            "error should mention plain HTTP: {msg}"
753        );
754    }
755
756    // --- HTTP-SEC-02: allows HTTP without credentials ---
757    #[test]
758    fn http_sec_02_allows_http_without_creds() {
759        let mut headers = HashMap::new();
760        headers.insert("Content-Type".into(), "application/json".into());
761        assert!(check_http_credential_safety("http://example.com/mcp", &headers).is_ok());
762        // Empty headers also OK
763        assert!(check_http_credential_safety("http://example.com/mcp", &HashMap::new()).is_ok());
764    }
765
766    // --- HTTP-SEC-03: allows HTTPS with credentials ---
767    #[test]
768    fn http_sec_03_allows_https_with_creds() {
769        let mut headers = HashMap::new();
770        headers.insert("Authorization".into(), "Bearer secret".into());
771        headers.insert("X-Api-Key".into(), "sk-123".into());
772        assert!(check_http_credential_safety("https://example.com/mcp", &headers).is_ok());
773    }
774
775    #[test]
776    fn is_sensitive_header_matches() {
777        assert!(is_sensitive_header("Authorization"));
778        assert!(is_sensitive_header("x-api-key"));
779        assert!(is_sensitive_header("Cookie"));
780        assert!(is_sensitive_header("X-Auth-Token"));
781        assert!(is_sensitive_header("X-Secret-Key"));
782        assert!(is_sensitive_header("X-Custom-Credential"));
783        assert!(is_sensitive_header("X-Password"));
784        assert!(!is_sensitive_header("Content-Type"));
785        assert!(!is_sensitive_header("Accept"));
786        assert!(!is_sensitive_header("User-Agent"));
787    }
788
789    // --- isError classification tests ---
790
791    #[test]
792    fn call_tool_result_is_error_true_returns_err() {
793        let result = CallToolResult {
794            content: vec![Content::text("Invalid params: missing field 'base_url'")],
795            is_error: Some(true),
796            structured_content: None,
797            meta: None,
798        };
799        let err = call_tool_result_to_value(result);
800        assert!(err.is_err());
801        let msg = err.unwrap_err().to_string();
802        assert!(
803            msg.contains("Invalid params"),
804            "expected error text, got: {msg}"
805        );
806    }
807
808    #[test]
809    fn call_tool_result_success_returns_ok() {
810        let result = CallToolResult {
811            content: vec![Content::text(r#"{"status": "ok"}"#)],
812            is_error: None,
813            structured_content: None,
814            meta: None,
815        };
816        let val = call_tool_result_to_value(result).unwrap();
817        assert_eq!(val["status"], "ok");
818    }
819
820    #[test]
821    fn call_tool_result_structured_content_takes_priority_over_is_error() {
822        let structured = serde_json::json!({"data": "important"});
823        let result = CallToolResult {
824            content: vec![Content::text("error text")],
825            is_error: Some(true),
826            structured_content: Some(structured.clone()),
827            meta: None,
828        };
829        let val = call_tool_result_to_value(result).unwrap();
830        assert_eq!(val, structured);
831    }
832
833    /// Compile-time guard: TransportConfig is #[non_exhaustive].
834    #[test]
835    #[allow(unreachable_patterns)]
836    fn ne_transport_config_is_non_exhaustive() {
837        let config = TransportConfig::Stdio {
838            command: "test".into(),
839            args: vec![],
840        };
841        match config {
842            TransportConfig::Stdio { .. } | TransportConfig::Http { .. } => {}
843            _ => {}
844        }
845    }
846}