1#![warn(missing_docs)]
2
3pub mod circuit_breaker;
12pub mod router;
13pub mod timeout;
14
15use std::borrow::Cow;
16use std::collections::HashMap;
17
18use anyhow::{Context, Result};
19use forge_sandbox::{ResourceDispatcher, ToolDispatcher};
20use rmcp::model::{CallToolRequestParams, CallToolResult, Content, RawContent};
21use rmcp::service::RunningService;
22use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
23use rmcp::transport::{ConfigureCommandExt, StreamableHttpClientTransport, TokioChildProcess};
24use rmcp::{RoleClient, ServiceExt};
25use serde_json::Value;
26use tokio::process::Command;
27
28pub use circuit_breaker::{
29 CircuitBreakerConfig, CircuitBreakerDispatcher, CircuitBreakerResourceDispatcher,
30};
31pub use router::{RouterDispatcher, RouterResourceDispatcher};
32pub use timeout::{TimeoutDispatcher, TimeoutResourceDispatcher};
33
34#[derive(Debug, Clone)]
36#[non_exhaustive]
37pub enum TransportConfig {
38 Stdio {
40 command: String,
42 args: Vec<String>,
44 },
45 Http {
47 url: String,
49 headers: HashMap<String, String>,
51 },
52}
53
54pub struct McpClient {
59 name: String,
60 inner: ClientInner,
61}
62
63enum ClientInner {
64 Stdio(RunningService<RoleClient, ()>),
65 Http(RunningService<RoleClient, ()>),
66}
67
68impl ClientInner {
69 fn peer(&self) -> &rmcp::Peer<RoleClient> {
70 match self {
71 ClientInner::Stdio(s) => s,
72 ClientInner::Http(s) => s,
73 }
74 }
75}
76
77#[derive(Debug, Clone)]
79pub struct ToolInfo {
80 pub name: String,
82 pub description: Option<String>,
84 pub input_schema: Value,
86}
87
88#[derive(Debug, Clone)]
90pub struct ResourceInfo {
91 pub uri: String,
93 pub name: String,
95 pub description: Option<String>,
97 pub mime_type: Option<String>,
99}
100
101impl McpClient {
102 pub async fn connect_stdio(
106 name: impl Into<String>,
107 command: &str,
108 args: &[&str],
109 ) -> Result<Self> {
110 let name = name.into();
111 let args_owned: Vec<String> = args.iter().map(|s| s.to_string()).collect();
112
113 tracing::info!(
114 server = %name,
115 command = %command,
116 args = ?args_owned,
117 "connecting to downstream MCP server (stdio)"
118 );
119
120 let transport = TokioChildProcess::new(Command::new(command).configure(|cmd| {
121 for arg in &args_owned {
122 cmd.arg(arg);
123 }
124 }))
125 .with_context(|| {
126 format!(
127 "failed to spawn stdio transport for server '{}' (command: {})",
128 name, command
129 )
130 })?;
131
132 let service: RunningService<RoleClient, ()> = ()
133 .serve(transport)
134 .await
135 .with_context(|| format!("MCP handshake failed for server '{}'", name))?;
136
137 tracing::info!(server = %name, "connected to downstream MCP server (stdio)");
138
139 Ok(Self {
140 name,
141 inner: ClientInner::Stdio(service),
142 })
143 }
144
145 pub async fn connect_http(
147 name: impl Into<String>,
148 url: &str,
149 headers: Option<HashMap<String, String>>,
150 ) -> Result<Self> {
151 let name = name.into();
152
153 if url.starts_with("http://") {
154 tracing::warn!(
155 server = %name,
156 url = %url,
157 "connecting over plain HTTP — consider using HTTPS for production"
158 );
159 }
160
161 tracing::info!(
162 server = %name,
163 url = %url,
164 "connecting to downstream MCP server (HTTP)"
165 );
166
167 let mut config = StreamableHttpClientTransportConfig::with_uri(url);
168
169 if let Some(ref hdrs) = headers {
171 check_http_credential_safety(url, hdrs)?;
172 }
173
174 let headers = headers.map(|mut h| {
176 sanitize_headers_for_transport(url, &mut h);
177 h
178 });
179
180 if let Some(hdrs) = &headers {
181 for (key, value) in hdrs {
182 if key.to_lowercase() == "authorization" {
183 tracing::debug!(server = %name, header = %key, "setting auth header (redacted)");
184 } else {
185 tracing::debug!(server = %name, header = %key, value = %value, "setting header");
186 }
187 }
188
189 let mut header_map = HashMap::new();
190 for (key, value) in hdrs {
191 let header_name = http::HeaderName::from_bytes(key.as_bytes())
192 .with_context(|| format!("invalid header name: {key}"))?;
193 let header_value = http::HeaderValue::from_str(value)
194 .with_context(|| format!("invalid header value for {key}"))?;
195 header_map.insert(header_name, header_value);
196 }
197 config = config.custom_headers(header_map);
198 }
199
200 let transport = StreamableHttpClientTransport::from_config(config);
201 let service: RunningService<RoleClient, ()> = ()
202 .serve(transport)
203 .await
204 .with_context(|| format!("MCP handshake failed for server '{}' (HTTP)", name))?;
205
206 tracing::info!(server = %name, "connected to downstream MCP server (HTTP)");
207
208 Ok(Self {
209 name,
210 inner: ClientInner::Http(service),
211 })
212 }
213
214 pub async fn connect(name: impl Into<String>, config: &TransportConfig) -> Result<Self> {
216 let name = name.into();
217 match config {
218 TransportConfig::Stdio { command, args } => {
219 let arg_refs: Vec<&str> = args.iter().map(|s| s.as_str()).collect();
220 Self::connect_stdio(name, command, &arg_refs).await
221 }
222 TransportConfig::Http { url, headers } => {
223 let hdrs = if headers.is_empty() {
224 None
225 } else {
226 Some(headers.clone())
227 };
228 Self::connect_http(name, url, hdrs).await
229 }
230 }
231 }
232
233 pub async fn list_tools(&self) -> Result<Vec<ToolInfo>> {
235 let tools = self
236 .inner
237 .peer()
238 .list_all_tools()
239 .await
240 .with_context(|| format!("failed to list tools for server '{}'", self.name))?;
241
242 Ok(tools
243 .into_iter()
244 .map(|t| ToolInfo {
245 name: t.name.to_string(),
246 description: t.description.map(|d: Cow<'_, str>| d.to_string()),
247 input_schema: serde_json::to_value(&*t.input_schema)
248 .unwrap_or(Value::Object(Default::default())),
249 })
250 .collect())
251 }
252
253 pub async fn list_resources(&self) -> Result<Vec<ResourceInfo>> {
258 let result = self.inner.peer().list_resources(None).await;
259
260 match result {
261 Ok(list) => Ok(list
262 .resources
263 .into_iter()
264 .map(|r| ResourceInfo {
265 uri: r.uri.clone(),
266 name: r.name.clone(),
267 description: r.description.clone(),
268 mime_type: r.mime_type.clone(),
269 })
270 .collect()),
271 Err(e) => {
272 let err_str = e.to_string();
273 if err_str.contains("Method not found")
275 || err_str.contains("method not found")
276 || err_str.contains("-32601")
277 {
278 tracing::debug!(
279 server = %self.name,
280 "server does not support resources/list, returning empty"
281 );
282 Ok(vec![])
283 } else {
284 Err(anyhow::anyhow!(
285 "failed to list resources for server '{}': {}",
286 self.name,
287 e
288 ))
289 }
290 }
291 }
292 }
293
294 pub async fn read_resource(&self, uri: &str) -> Result<Value> {
296 let result = self
297 .inner
298 .peer()
299 .read_resource(rmcp::model::ReadResourceRequestParams {
300 uri: uri.to_string(),
301 meta: None,
302 })
303 .await
304 .with_context(|| {
305 format!(
306 "resource read failed: server='{}', uri='{}'",
307 self.name, uri
308 )
309 })?;
310
311 let contents = result.contents;
313 if contents.is_empty() {
314 return Ok(Value::Null);
315 }
316
317 if contents.len() == 1 {
318 resource_content_to_value(&contents[0])
319 } else {
320 let values: Vec<Value> = contents
321 .iter()
322 .filter_map(|c| resource_content_to_value(c).ok())
323 .collect();
324 Ok(Value::Array(values))
325 }
326 }
327
328 pub fn name(&self) -> &str {
330 &self.name
331 }
332
333 pub async fn disconnect(self) -> Result<()> {
335 tracing::info!(server = %self.name, "disconnecting from downstream MCP server");
336 match self.inner {
337 ClientInner::Stdio(s) => {
338 let _ = s.cancel().await;
339 }
340 ClientInner::Http(s) => {
341 let _ = s.cancel().await;
342 }
343 }
344 Ok(())
345 }
346}
347
348#[async_trait::async_trait]
349impl ToolDispatcher for McpClient {
350 async fn call_tool(
351 &self,
352 _server: &str,
353 tool: &str,
354 args: Value,
355 ) -> Result<Value, forge_error::DispatchError> {
356 let arguments = args.as_object().cloned().or_else(|| {
357 if args.is_null() {
358 Some(serde_json::Map::new())
359 } else {
360 None
361 }
362 });
363
364 let result: CallToolResult = self
365 .inner
366 .peer()
367 .call_tool(CallToolRequestParams {
368 meta: None,
369 name: Cow::Owned(tool.to_string()),
370 arguments,
371 task: None,
372 })
373 .await
374 .map_err(|e| forge_error::DispatchError::Upstream {
375 server: self.name.clone(),
376 message: format!("tool call failed: tool='{}': {}", tool, e),
377 })?;
378
379 call_tool_result_to_value(result).map_err(|e| forge_error::DispatchError::Upstream {
380 server: self.name.clone(),
381 message: e.to_string(),
382 })
383 }
384}
385
386#[async_trait::async_trait]
387impl ResourceDispatcher for McpClient {
388 async fn read_resource(
389 &self,
390 _server: &str,
391 uri: &str,
392 ) -> Result<Value, forge_error::DispatchError> {
393 self.read_resource(uri)
394 .await
395 .map_err(|e| forge_error::DispatchError::Upstream {
396 server: self.name.clone(),
397 message: format!("resource read failed: uri='{}': {}", uri, e),
398 })
399 }
400}
401
402fn resource_content_to_value(content: &rmcp::model::ResourceContents) -> Result<Value> {
404 match content {
405 rmcp::model::ResourceContents::TextResourceContents { text, .. } => {
406 serde_json::from_str(text).or_else(|_| Ok(Value::String(text.clone())))
408 }
409 rmcp::model::ResourceContents::BlobResourceContents {
410 blob, mime_type, ..
411 } => Ok(serde_json::json!({
412 "_type": "blob",
413 "_encoding": "base64",
414 "data": blob,
415 "mime_type": mime_type.as_deref().unwrap_or("application/octet-stream"),
416 })),
417 }
418}
419
420fn call_tool_result_to_value(result: CallToolResult) -> Result<Value> {
422 if let Some(structured) = result.structured_content {
423 return Ok(structured);
424 }
425
426 if result.is_error == Some(true) {
427 let error_text = result
428 .content
429 .iter()
430 .filter_map(|c| match &c.raw {
431 RawContent::Text(t) => Some(t.text.as_str()),
432 _ => None,
433 })
434 .collect::<Vec<_>>()
435 .join("\n");
436 return Err(anyhow::anyhow!("tool returned error: {}", error_text));
437 }
438
439 if result.content.len() == 1 {
440 content_to_value(&result.content[0])
441 } else if result.content.is_empty() {
442 Ok(Value::Null)
443 } else {
444 let values: Vec<Value> = result
445 .content
446 .iter()
447 .filter_map(|c| content_to_value(c).ok())
448 .collect();
449 Ok(Value::Array(values))
450 }
451}
452
453const MAX_BINARY_CONTENT_SIZE: usize = 1_048_576; const MAX_TEXT_CONTENT_SIZE: usize = 10_485_760; fn content_to_value(content: &Content) -> Result<Value> {
465 match &content.raw {
466 RawContent::Text(t) => {
467 if t.text.len() > MAX_TEXT_CONTENT_SIZE {
468 Ok(serde_json::json!({
469 "type": "text",
470 "truncated": true,
471 "original_size": t.text.len(),
472 "preview": &t.text[..1024.min(t.text.len())],
473 }))
474 } else {
475 serde_json::from_str(&t.text).or_else(|_| Ok(Value::String(t.text.clone())))
476 }
477 }
478 RawContent::Image(img) => {
479 if img.data.len() > MAX_BINARY_CONTENT_SIZE {
480 Ok(serde_json::json!({
481 "type": "image",
482 "truncated": true,
483 "original_size": img.data.len(),
484 "mime_type": img.mime_type,
485 }))
486 } else {
487 Ok(serde_json::json!({
488 "type": "image",
489 "data": img.data,
490 "mime_type": img.mime_type,
491 }))
492 }
493 }
494 RawContent::Resource(r) => Ok(serde_json::json!({
495 "type": "resource",
496 "resource": serde_json::to_value(&r.resource).unwrap_or(Value::Null),
497 })),
498 RawContent::Audio(a) => {
499 if a.data.len() > MAX_BINARY_CONTENT_SIZE {
500 Ok(serde_json::json!({
501 "type": "audio",
502 "truncated": true,
503 "original_size": a.data.len(),
504 "mime_type": a.mime_type,
505 }))
506 } else {
507 Ok(serde_json::json!({
508 "type": "audio",
509 "data": a.data,
510 "mime_type": a.mime_type,
511 }))
512 }
513 }
514 _ => Ok(serde_json::json!({"type": "unknown"})),
515 }
516}
517
518const SENSITIVE_HEADER_PATTERNS: &[&str] = &[
521 "authorization",
522 "cookie",
523 "token",
524 "secret",
525 "key",
526 "credential",
527 "password",
528 "auth",
529];
530
531fn is_sensitive_header(name: &str) -> bool {
533 let lower = name.to_lowercase();
534 SENSITIVE_HEADER_PATTERNS
535 .iter()
536 .any(|pattern| lower.contains(pattern))
537}
538
539fn check_http_credential_safety(
545 url: &str,
546 headers: &HashMap<String, String>,
547) -> Result<(), anyhow::Error> {
548 if url.starts_with("http://") {
549 let sensitive: Vec<&String> = headers.keys().filter(|k| is_sensitive_header(k)).collect();
550 if !sensitive.is_empty() {
551 return Err(anyhow::anyhow!(
552 "refusing to send credentials over plain HTTP (headers: {}). \
553 Use HTTPS or remove sensitive headers.",
554 sensitive
555 .iter()
556 .map(|s| s.as_str())
557 .collect::<Vec<_>>()
558 .join(", ")
559 ));
560 }
561 }
562 Ok(())
563}
564
565fn sanitize_headers_for_transport(url: &str, headers: &mut HashMap<String, String>) {
572 if url.starts_with("http://") {
573 let removed: Vec<String> = headers
574 .keys()
575 .filter(|k| is_sensitive_header(k))
576 .cloned()
577 .collect();
578 for key in &removed {
579 headers.remove(key);
580 }
581 if !removed.is_empty() {
582 tracing::warn!(
583 url = %url,
584 removed_headers = ?removed,
585 "stripped sensitive headers from plain HTTP connection — use HTTPS to send credentials"
586 );
587 }
588 }
589}
590
591#[cfg(test)]
592mod tests {
593 use super::*;
594 use rmcp::model::{Content, RawContent};
595
596 #[test]
597 fn content_to_value_text_string() {
598 let content = Content::text("hello");
599 let val = content_to_value(&content).unwrap();
600 assert_eq!(val, Value::String("hello".into()));
601 }
602
603 #[test]
604 fn content_to_value_text_json() {
605 let content = Content::text(r#"{"k":"v"}"#);
606 let val = content_to_value(&content).unwrap();
607 assert_eq!(val, serde_json::json!({"k": "v"}));
608 }
609
610 #[test]
611 fn content_to_value_small_image_preserved() {
612 let small_data = "a".repeat(1024); let content = Content::image(small_data.clone(), "image/png");
614 let val = content_to_value(&content).unwrap();
615 assert_eq!(val["type"], "image");
616 assert_eq!(val["data"], small_data);
617 assert!(val.get("truncated").is_none());
618 }
619
620 #[test]
621 fn content_to_value_oversized_image_truncated() {
622 let large_data = "a".repeat(2 * 1024 * 1024); let content = Content::image(large_data, "image/png");
624 let val = content_to_value(&content).unwrap();
625 assert_eq!(val["type"], "image");
626 assert_eq!(val["truncated"], true);
627 assert!(val.get("data").is_none());
628 assert!(val["original_size"].as_u64().unwrap() > MAX_BINARY_CONTENT_SIZE as u64);
629 }
630
631 #[test]
632 fn content_to_value_oversized_audio_truncated() {
633 let large_data = "a".repeat(2 * 1024 * 1024); let content = Content {
635 raw: RawContent::Audio(rmcp::model::RawAudioContent {
636 data: large_data,
637 mime_type: "audio/wav".into(),
638 }),
639 annotations: None,
640 };
641 let val = content_to_value(&content).unwrap();
642 assert_eq!(val["type"], "audio");
643 assert_eq!(val["truncated"], true);
644 assert!(val.get("data").is_none());
645 }
646
647 #[test]
648 fn content_to_value_oversized_text_truncated() {
649 let large_text = "x".repeat(11 * 1024 * 1024); let content = Content::text(large_text);
651 let val = content_to_value(&content).unwrap();
652 assert_eq!(val["type"], "text");
653 assert_eq!(val["truncated"], true);
654 assert!(val["original_size"].as_u64().unwrap() > MAX_TEXT_CONTENT_SIZE as u64);
655 assert!(val["preview"].as_str().unwrap().len() <= 1024);
656 }
657
658 #[test]
659 fn content_to_value_normal_text_not_truncated() {
660 let normal_text = "x".repeat(1024); let content = Content::text(normal_text.clone());
662 let val = content_to_value(&content).unwrap();
663 assert_eq!(val, Value::String(normal_text));
664 }
665
666 #[test]
667 fn sanitize_headers_strips_auth_on_http() {
668 let mut headers = HashMap::new();
669 headers.insert("Authorization".into(), "Bearer secret".into());
670 headers.insert("Content-Type".into(), "application/json".into());
671 sanitize_headers_for_transport("http://example.com/mcp", &mut headers);
672 assert!(!headers.contains_key("Authorization"));
673 assert!(headers.contains_key("Content-Type"));
674 }
675
676 #[test]
677 fn sanitize_headers_strips_api_key_on_http() {
678 let mut headers = HashMap::new();
679 headers.insert("X-Api-Key".into(), "sk-123".into());
680 headers.insert("Content-Type".into(), "application/json".into());
681 sanitize_headers_for_transport("http://example.com/mcp", &mut headers);
682 assert!(!headers.contains_key("X-Api-Key"));
683 assert!(headers.contains_key("Content-Type"));
684 }
685
686 #[test]
687 fn sanitize_headers_strips_cookie_on_http() {
688 let mut headers = HashMap::new();
689 headers.insert("Cookie".into(), "session=abc123".into());
690 sanitize_headers_for_transport("http://example.com/mcp", &mut headers);
691 assert!(!headers.contains_key("Cookie"));
692 }
693
694 #[test]
695 fn sanitize_headers_strips_custom_token_on_http() {
696 let mut headers = HashMap::new();
697 headers.insert("X-Auth-Token".into(), "tok_secret".into());
698 headers.insert("X-Secret-Key".into(), "s3cr3t".into());
699 headers.insert("X-Custom-Credential".into(), "cred".into());
700 headers.insert("X-Password".into(), "pass".into());
701 headers.insert("Accept".into(), "application/json".into());
702 sanitize_headers_for_transport("http://example.com/mcp", &mut headers);
703 assert!(!headers.contains_key("X-Auth-Token"));
704 assert!(!headers.contains_key("X-Secret-Key"));
705 assert!(!headers.contains_key("X-Custom-Credential"));
706 assert!(!headers.contains_key("X-Password"));
707 assert!(headers.contains_key("Accept"));
708 }
709
710 #[test]
711 fn sanitize_headers_preserves_all_on_https() {
712 let mut headers = HashMap::new();
713 headers.insert("Authorization".into(), "Bearer secret".into());
714 headers.insert("X-Api-Key".into(), "sk-123".into());
715 headers.insert("Cookie".into(), "session=abc".into());
716 sanitize_headers_for_transport("https://example.com/mcp", &mut headers);
717 assert!(headers.contains_key("Authorization"));
718 assert!(headers.contains_key("X-Api-Key"));
719 assert!(headers.contains_key("Cookie"));
720 }
721
722 #[test]
724 fn http_sec_01_rejects_creds_on_http() {
725 let mut headers = HashMap::new();
726 headers.insert("Authorization".into(), "Bearer secret".into());
727 let err = check_http_credential_safety("http://example.com/mcp", &headers);
728 assert!(err.is_err(), "should reject creds on HTTP");
729 let msg = err.unwrap_err().to_string();
730 assert!(
731 msg.contains("plain HTTP"),
732 "error should mention plain HTTP: {msg}"
733 );
734 }
735
736 #[test]
738 fn http_sec_02_allows_http_without_creds() {
739 let mut headers = HashMap::new();
740 headers.insert("Content-Type".into(), "application/json".into());
741 assert!(check_http_credential_safety("http://example.com/mcp", &headers).is_ok());
742 assert!(check_http_credential_safety("http://example.com/mcp", &HashMap::new()).is_ok());
744 }
745
746 #[test]
748 fn http_sec_03_allows_https_with_creds() {
749 let mut headers = HashMap::new();
750 headers.insert("Authorization".into(), "Bearer secret".into());
751 headers.insert("X-Api-Key".into(), "sk-123".into());
752 assert!(check_http_credential_safety("https://example.com/mcp", &headers).is_ok());
753 }
754
755 #[test]
756 fn is_sensitive_header_matches() {
757 assert!(is_sensitive_header("Authorization"));
758 assert!(is_sensitive_header("x-api-key"));
759 assert!(is_sensitive_header("Cookie"));
760 assert!(is_sensitive_header("X-Auth-Token"));
761 assert!(is_sensitive_header("X-Secret-Key"));
762 assert!(is_sensitive_header("X-Custom-Credential"));
763 assert!(is_sensitive_header("X-Password"));
764 assert!(!is_sensitive_header("Content-Type"));
765 assert!(!is_sensitive_header("Accept"));
766 assert!(!is_sensitive_header("User-Agent"));
767 }
768
769 #[test]
771 #[allow(unreachable_patterns)]
772 fn ne_transport_config_is_non_exhaustive() {
773 let config = TransportConfig::Stdio {
774 command: "test".into(),
775 args: vec![],
776 };
777 match config {
778 TransportConfig::Stdio { .. } | TransportConfig::Http { .. } => {}
779 _ => {}
780 }
781 }
782}