1#![warn(missing_docs)]
2
3pub mod circuit_breaker;
12pub mod router;
13pub mod timeout;
14
15use std::borrow::Cow;
16use std::collections::HashMap;
17
18use anyhow::{Context, Result};
19use forge_sandbox::{ResourceDispatcher, ToolDispatcher};
20use rmcp::model::{CallToolRequestParams, CallToolResult, Content, RawContent};
21use rmcp::service::RunningService;
22use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
23use rmcp::transport::{ConfigureCommandExt, StreamableHttpClientTransport, TokioChildProcess};
24use rmcp::{RoleClient, ServiceExt};
25use serde_json::Value;
26use tokio::process::Command;
27
28pub use circuit_breaker::{
29 CircuitBreakerConfig, CircuitBreakerDispatcher, CircuitBreakerResourceDispatcher,
30};
31pub use router::{RouterDispatcher, RouterResourceDispatcher};
32pub use timeout::{TimeoutDispatcher, TimeoutResourceDispatcher};
33
34#[derive(Debug, Clone)]
36#[non_exhaustive]
37pub enum TransportConfig {
38 Stdio {
40 command: String,
42 args: Vec<String>,
44 },
45 Http {
47 url: String,
49 headers: HashMap<String, String>,
51 },
52}
53
54pub struct McpClient {
59 name: String,
60 inner: ClientInner,
61}
62
63enum ClientInner {
64 Stdio(RunningService<RoleClient, ()>),
65 Http(RunningService<RoleClient, ()>),
66}
67
68impl ClientInner {
69 fn peer(&self) -> &rmcp::Peer<RoleClient> {
70 match self {
71 ClientInner::Stdio(s) => s,
72 ClientInner::Http(s) => s,
73 }
74 }
75}
76
77#[derive(Debug, Clone)]
79pub struct ToolInfo {
80 pub name: String,
82 pub description: Option<String>,
84 pub input_schema: Value,
86}
87
88#[derive(Debug, Clone)]
90pub struct ResourceInfo {
91 pub uri: String,
93 pub name: String,
95 pub description: Option<String>,
97 pub mime_type: Option<String>,
99}
100
101impl McpClient {
102 pub async fn connect_stdio(
106 name: impl Into<String>,
107 command: &str,
108 args: &[&str],
109 ) -> Result<Self> {
110 let name = name.into();
111 let args_owned: Vec<String> = args.iter().map(|s| s.to_string()).collect();
112
113 tracing::info!(
114 server = %name,
115 command = %command,
116 args = ?args_owned,
117 "connecting to downstream MCP server (stdio)"
118 );
119
120 let transport = TokioChildProcess::new(Command::new(command).configure(|cmd| {
121 for arg in &args_owned {
122 cmd.arg(arg);
123 }
124 }))
125 .with_context(|| {
126 format!(
127 "failed to spawn stdio transport for server '{}' (command: {})",
128 name, command
129 )
130 })?;
131
132 let service: RunningService<RoleClient, ()> = ()
133 .serve(transport)
134 .await
135 .with_context(|| format!("MCP handshake failed for server '{}'", name))?;
136
137 tracing::info!(server = %name, "connected to downstream MCP server (stdio)");
138
139 Ok(Self {
140 name,
141 inner: ClientInner::Stdio(service),
142 })
143 }
144
145 pub async fn connect_http(
147 name: impl Into<String>,
148 url: &str,
149 headers: Option<HashMap<String, String>>,
150 ) -> Result<Self> {
151 let name = name.into();
152
153 if url.starts_with("http://") {
154 tracing::warn!(
155 server = %name,
156 url = %url,
157 "connecting over plain HTTP — consider using HTTPS for production"
158 );
159 }
160
161 tracing::info!(
162 server = %name,
163 url = %url,
164 "connecting to downstream MCP server (HTTP)"
165 );
166
167 let mut config = StreamableHttpClientTransportConfig::with_uri(url);
168
169 if let Some(ref hdrs) = headers {
171 check_http_credential_safety(url, hdrs)?;
172 }
173
174 let headers = headers.map(|mut h| {
176 sanitize_headers_for_transport(url, &mut h);
177 h
178 });
179
180 if let Some(hdrs) = &headers {
181 for (key, value) in hdrs {
182 if key.to_lowercase() == "authorization" {
183 tracing::debug!(server = %name, header = %key, "setting auth header (redacted)");
184 } else {
185 tracing::debug!(server = %name, header = %key, value = %value, "setting header");
186 }
187 }
188
189 let mut header_map = HashMap::new();
190 for (key, value) in hdrs {
191 let header_name = http::HeaderName::from_bytes(key.as_bytes())
192 .with_context(|| format!("invalid header name: {key}"))?;
193 let header_value = http::HeaderValue::from_str(value)
194 .with_context(|| format!("invalid header value for {key}"))?;
195 header_map.insert(header_name, header_value);
196 }
197 config = config.custom_headers(header_map);
198 }
199
200 let transport = StreamableHttpClientTransport::from_config(config);
201 let service: RunningService<RoleClient, ()> = ()
202 .serve(transport)
203 .await
204 .with_context(|| format!("MCP handshake failed for server '{}' (HTTP)", name))?;
205
206 tracing::info!(server = %name, "connected to downstream MCP server (HTTP)");
207
208 Ok(Self {
209 name,
210 inner: ClientInner::Http(service),
211 })
212 }
213
214 pub async fn connect(name: impl Into<String>, config: &TransportConfig) -> Result<Self> {
216 let name = name.into();
217 match config {
218 TransportConfig::Stdio { command, args } => {
219 let arg_refs: Vec<&str> = args.iter().map(|s| s.as_str()).collect();
220 Self::connect_stdio(name, command, &arg_refs).await
221 }
222 TransportConfig::Http { url, headers } => {
223 let hdrs = if headers.is_empty() {
224 None
225 } else {
226 Some(headers.clone())
227 };
228 Self::connect_http(name, url, hdrs).await
229 }
230 }
231 }
232
233 pub async fn list_tools(&self) -> Result<Vec<ToolInfo>> {
235 let tools = self
236 .inner
237 .peer()
238 .list_all_tools()
239 .await
240 .with_context(|| format!("failed to list tools for server '{}'", self.name))?;
241
242 Ok(tools
243 .into_iter()
244 .map(|t| ToolInfo {
245 name: t.name.to_string(),
246 description: t.description.map(|d: Cow<'_, str>| d.to_string()),
247 input_schema: serde_json::to_value(&*t.input_schema)
248 .unwrap_or(Value::Object(Default::default())),
249 })
250 .collect())
251 }
252
253 pub async fn list_resources(&self) -> Result<Vec<ResourceInfo>> {
258 let result = self.inner.peer().list_resources(None).await;
259
260 match result {
261 Ok(list) => Ok(list
262 .resources
263 .into_iter()
264 .map(|r| ResourceInfo {
265 uri: r.uri.clone(),
266 name: r.name.clone(),
267 description: r.description.clone(),
268 mime_type: r.mime_type.clone(),
269 })
270 .collect()),
271 Err(e) => {
272 let err_str = e.to_string();
273 if err_str.contains("Method not found")
275 || err_str.contains("method not found")
276 || err_str.contains("-32601")
277 {
278 tracing::debug!(
279 server = %self.name,
280 "server does not support resources/list, returning empty"
281 );
282 Ok(vec![])
283 } else {
284 Err(anyhow::anyhow!(
285 "failed to list resources for server '{}': {}",
286 self.name,
287 e
288 ))
289 }
290 }
291 }
292 }
293
294 pub async fn read_resource(&self, uri: &str) -> Result<Value> {
296 let result = self
297 .inner
298 .peer()
299 .read_resource(rmcp::model::ReadResourceRequestParams {
300 uri: uri.to_string(),
301 meta: None,
302 })
303 .await
304 .with_context(|| {
305 format!(
306 "resource read failed: server='{}', uri='{}'",
307 self.name, uri
308 )
309 })?;
310
311 let contents = result.contents;
313 if contents.is_empty() {
314 return Ok(Value::Null);
315 }
316
317 if contents.len() == 1 {
318 resource_content_to_value(&contents[0])
319 } else {
320 let values: Vec<Value> = contents
321 .iter()
322 .filter_map(|c| resource_content_to_value(c).ok())
323 .collect();
324 Ok(Value::Array(values))
325 }
326 }
327
328 pub fn name(&self) -> &str {
330 &self.name
331 }
332
333 pub async fn disconnect(self) -> Result<()> {
335 tracing::info!(server = %self.name, "disconnecting from downstream MCP server");
336 match self.inner {
337 ClientInner::Stdio(s) => {
338 let _ = s.cancel().await;
339 }
340 ClientInner::Http(s) => {
341 let _ = s.cancel().await;
342 }
343 }
344 Ok(())
345 }
346}
347
348#[async_trait::async_trait]
349impl ToolDispatcher for McpClient {
350 async fn call_tool(
351 &self,
352 _server: &str,
353 tool: &str,
354 args: Value,
355 ) -> Result<Value, forge_error::DispatchError> {
356 let arguments = args.as_object().cloned().or_else(|| {
357 if args.is_null() {
358 Some(serde_json::Map::new())
359 } else {
360 None
361 }
362 });
363
364 let result: CallToolResult = self
365 .inner
366 .peer()
367 .call_tool(CallToolRequestParams {
368 meta: None,
369 name: Cow::Owned(tool.to_string()),
370 arguments,
371 task: None,
372 })
373 .await
374 .map_err(|e| forge_error::DispatchError::Upstream {
375 server: self.name.clone(),
376 message: format!("tool call failed: tool='{}': {}", tool, e),
377 })?;
378
379 if result.is_error == Some(true) && result.structured_content.is_none() {
383 let error_text = result
384 .content
385 .iter()
386 .filter_map(|c| match &c.raw {
387 RawContent::Text(t) => Some(t.text.as_str()),
388 _ => None,
389 })
390 .collect::<Vec<_>>()
391 .join("\n");
392 return Err(forge_error::DispatchError::ToolError {
393 server: self.name.clone(),
394 tool: tool.to_string(),
395 message: format!("tool returned error: {}", error_text),
396 });
397 }
398
399 call_tool_result_to_value(result).map_err(|e| forge_error::DispatchError::Upstream {
400 server: self.name.clone(),
401 message: e.to_string(),
402 })
403 }
404}
405
406#[async_trait::async_trait]
407impl ResourceDispatcher for McpClient {
408 async fn read_resource(
409 &self,
410 _server: &str,
411 uri: &str,
412 ) -> Result<Value, forge_error::DispatchError> {
413 self.read_resource(uri)
414 .await
415 .map_err(|e| forge_error::DispatchError::Upstream {
416 server: self.name.clone(),
417 message: format!("resource read failed: uri='{}': {}", uri, e),
418 })
419 }
420}
421
422fn resource_content_to_value(content: &rmcp::model::ResourceContents) -> Result<Value> {
424 match content {
425 rmcp::model::ResourceContents::TextResourceContents { text, .. } => {
426 serde_json::from_str(text).or_else(|_| Ok(Value::String(text.clone())))
428 }
429 rmcp::model::ResourceContents::BlobResourceContents {
430 blob, mime_type, ..
431 } => Ok(serde_json::json!({
432 "_type": "blob",
433 "_encoding": "base64",
434 "data": blob,
435 "mime_type": mime_type.as_deref().unwrap_or("application/octet-stream"),
436 })),
437 }
438}
439
440fn call_tool_result_to_value(result: CallToolResult) -> Result<Value> {
442 if let Some(structured) = result.structured_content {
443 return Ok(structured);
444 }
445
446 if result.is_error == Some(true) {
447 let error_text = result
448 .content
449 .iter()
450 .filter_map(|c| match &c.raw {
451 RawContent::Text(t) => Some(t.text.as_str()),
452 _ => None,
453 })
454 .collect::<Vec<_>>()
455 .join("\n");
456 return Err(anyhow::anyhow!("tool returned error: {}", error_text));
457 }
458
459 if result.content.len() == 1 {
460 content_to_value(&result.content[0])
461 } else if result.content.is_empty() {
462 Ok(Value::Null)
463 } else {
464 let values: Vec<Value> = result
465 .content
466 .iter()
467 .filter_map(|c| content_to_value(c).ok())
468 .collect();
469 Ok(Value::Array(values))
470 }
471}
472
473const MAX_BINARY_CONTENT_SIZE: usize = 1_048_576; const MAX_TEXT_CONTENT_SIZE: usize = 10_485_760; fn content_to_value(content: &Content) -> Result<Value> {
485 match &content.raw {
486 RawContent::Text(t) => {
487 if t.text.len() > MAX_TEXT_CONTENT_SIZE {
488 Ok(serde_json::json!({
489 "type": "text",
490 "truncated": true,
491 "original_size": t.text.len(),
492 "preview": &t.text[..1024.min(t.text.len())],
493 }))
494 } else {
495 serde_json::from_str(&t.text).or_else(|_| Ok(Value::String(t.text.clone())))
496 }
497 }
498 RawContent::Image(img) => {
499 if img.data.len() > MAX_BINARY_CONTENT_SIZE {
500 Ok(serde_json::json!({
501 "type": "image",
502 "truncated": true,
503 "original_size": img.data.len(),
504 "mime_type": img.mime_type,
505 }))
506 } else {
507 Ok(serde_json::json!({
508 "type": "image",
509 "data": img.data,
510 "mime_type": img.mime_type,
511 }))
512 }
513 }
514 RawContent::Resource(r) => Ok(serde_json::json!({
515 "type": "resource",
516 "resource": serde_json::to_value(&r.resource).unwrap_or(Value::Null),
517 })),
518 RawContent::Audio(a) => {
519 if a.data.len() > MAX_BINARY_CONTENT_SIZE {
520 Ok(serde_json::json!({
521 "type": "audio",
522 "truncated": true,
523 "original_size": a.data.len(),
524 "mime_type": a.mime_type,
525 }))
526 } else {
527 Ok(serde_json::json!({
528 "type": "audio",
529 "data": a.data,
530 "mime_type": a.mime_type,
531 }))
532 }
533 }
534 _ => Ok(serde_json::json!({"type": "unknown"})),
535 }
536}
537
538const SENSITIVE_HEADER_PATTERNS: &[&str] = &[
541 "authorization",
542 "cookie",
543 "token",
544 "secret",
545 "key",
546 "credential",
547 "password",
548 "auth",
549];
550
551fn is_sensitive_header(name: &str) -> bool {
553 let lower = name.to_lowercase();
554 SENSITIVE_HEADER_PATTERNS
555 .iter()
556 .any(|pattern| lower.contains(pattern))
557}
558
559fn check_http_credential_safety(
565 url: &str,
566 headers: &HashMap<String, String>,
567) -> Result<(), anyhow::Error> {
568 if url.starts_with("http://") {
569 let sensitive: Vec<&String> = headers.keys().filter(|k| is_sensitive_header(k)).collect();
570 if !sensitive.is_empty() {
571 return Err(anyhow::anyhow!(
572 "refusing to send credentials over plain HTTP (headers: {}). \
573 Use HTTPS or remove sensitive headers.",
574 sensitive
575 .iter()
576 .map(|s| s.as_str())
577 .collect::<Vec<_>>()
578 .join(", ")
579 ));
580 }
581 }
582 Ok(())
583}
584
585fn sanitize_headers_for_transport(url: &str, headers: &mut HashMap<String, String>) {
592 if url.starts_with("http://") {
593 let removed: Vec<String> = headers
594 .keys()
595 .filter(|k| is_sensitive_header(k))
596 .cloned()
597 .collect();
598 for key in &removed {
599 headers.remove(key);
600 }
601 if !removed.is_empty() {
602 tracing::warn!(
603 url = %url,
604 removed_headers = ?removed,
605 "stripped sensitive headers from plain HTTP connection — use HTTPS to send credentials"
606 );
607 }
608 }
609}
610
611#[cfg(test)]
612mod tests {
613 use super::*;
614 use rmcp::model::{Content, RawContent};
615
616 #[test]
617 fn content_to_value_text_string() {
618 let content = Content::text("hello");
619 let val = content_to_value(&content).unwrap();
620 assert_eq!(val, Value::String("hello".into()));
621 }
622
623 #[test]
624 fn content_to_value_text_json() {
625 let content = Content::text(r#"{"k":"v"}"#);
626 let val = content_to_value(&content).unwrap();
627 assert_eq!(val, serde_json::json!({"k": "v"}));
628 }
629
630 #[test]
631 fn content_to_value_small_image_preserved() {
632 let small_data = "a".repeat(1024); let content = Content::image(small_data.clone(), "image/png");
634 let val = content_to_value(&content).unwrap();
635 assert_eq!(val["type"], "image");
636 assert_eq!(val["data"], small_data);
637 assert!(val.get("truncated").is_none());
638 }
639
640 #[test]
641 fn content_to_value_oversized_image_truncated() {
642 let large_data = "a".repeat(2 * 1024 * 1024); let content = Content::image(large_data, "image/png");
644 let val = content_to_value(&content).unwrap();
645 assert_eq!(val["type"], "image");
646 assert_eq!(val["truncated"], true);
647 assert!(val.get("data").is_none());
648 assert!(val["original_size"].as_u64().unwrap() > MAX_BINARY_CONTENT_SIZE as u64);
649 }
650
651 #[test]
652 fn content_to_value_oversized_audio_truncated() {
653 let large_data = "a".repeat(2 * 1024 * 1024); let content = Content {
655 raw: RawContent::Audio(rmcp::model::RawAudioContent {
656 data: large_data,
657 mime_type: "audio/wav".into(),
658 }),
659 annotations: None,
660 };
661 let val = content_to_value(&content).unwrap();
662 assert_eq!(val["type"], "audio");
663 assert_eq!(val["truncated"], true);
664 assert!(val.get("data").is_none());
665 }
666
667 #[test]
668 fn content_to_value_oversized_text_truncated() {
669 let large_text = "x".repeat(11 * 1024 * 1024); let content = Content::text(large_text);
671 let val = content_to_value(&content).unwrap();
672 assert_eq!(val["type"], "text");
673 assert_eq!(val["truncated"], true);
674 assert!(val["original_size"].as_u64().unwrap() > MAX_TEXT_CONTENT_SIZE as u64);
675 assert!(val["preview"].as_str().unwrap().len() <= 1024);
676 }
677
678 #[test]
679 fn content_to_value_normal_text_not_truncated() {
680 let normal_text = "x".repeat(1024); let content = Content::text(normal_text.clone());
682 let val = content_to_value(&content).unwrap();
683 assert_eq!(val, Value::String(normal_text));
684 }
685
686 #[test]
687 fn sanitize_headers_strips_auth_on_http() {
688 let mut headers = HashMap::new();
689 headers.insert("Authorization".into(), "Bearer secret".into());
690 headers.insert("Content-Type".into(), "application/json".into());
691 sanitize_headers_for_transport("http://example.com/mcp", &mut headers);
692 assert!(!headers.contains_key("Authorization"));
693 assert!(headers.contains_key("Content-Type"));
694 }
695
696 #[test]
697 fn sanitize_headers_strips_api_key_on_http() {
698 let mut headers = HashMap::new();
699 headers.insert("X-Api-Key".into(), "sk-123".into());
700 headers.insert("Content-Type".into(), "application/json".into());
701 sanitize_headers_for_transport("http://example.com/mcp", &mut headers);
702 assert!(!headers.contains_key("X-Api-Key"));
703 assert!(headers.contains_key("Content-Type"));
704 }
705
706 #[test]
707 fn sanitize_headers_strips_cookie_on_http() {
708 let mut headers = HashMap::new();
709 headers.insert("Cookie".into(), "session=abc123".into());
710 sanitize_headers_for_transport("http://example.com/mcp", &mut headers);
711 assert!(!headers.contains_key("Cookie"));
712 }
713
714 #[test]
715 fn sanitize_headers_strips_custom_token_on_http() {
716 let mut headers = HashMap::new();
717 headers.insert("X-Auth-Token".into(), "tok_secret".into());
718 headers.insert("X-Secret-Key".into(), "s3cr3t".into());
719 headers.insert("X-Custom-Credential".into(), "cred".into());
720 headers.insert("X-Password".into(), "pass".into());
721 headers.insert("Accept".into(), "application/json".into());
722 sanitize_headers_for_transport("http://example.com/mcp", &mut headers);
723 assert!(!headers.contains_key("X-Auth-Token"));
724 assert!(!headers.contains_key("X-Secret-Key"));
725 assert!(!headers.contains_key("X-Custom-Credential"));
726 assert!(!headers.contains_key("X-Password"));
727 assert!(headers.contains_key("Accept"));
728 }
729
730 #[test]
731 fn sanitize_headers_preserves_all_on_https() {
732 let mut headers = HashMap::new();
733 headers.insert("Authorization".into(), "Bearer secret".into());
734 headers.insert("X-Api-Key".into(), "sk-123".into());
735 headers.insert("Cookie".into(), "session=abc".into());
736 sanitize_headers_for_transport("https://example.com/mcp", &mut headers);
737 assert!(headers.contains_key("Authorization"));
738 assert!(headers.contains_key("X-Api-Key"));
739 assert!(headers.contains_key("Cookie"));
740 }
741
742 #[test]
744 fn http_sec_01_rejects_creds_on_http() {
745 let mut headers = HashMap::new();
746 headers.insert("Authorization".into(), "Bearer secret".into());
747 let err = check_http_credential_safety("http://example.com/mcp", &headers);
748 assert!(err.is_err(), "should reject creds on HTTP");
749 let msg = err.unwrap_err().to_string();
750 assert!(
751 msg.contains("plain HTTP"),
752 "error should mention plain HTTP: {msg}"
753 );
754 }
755
756 #[test]
758 fn http_sec_02_allows_http_without_creds() {
759 let mut headers = HashMap::new();
760 headers.insert("Content-Type".into(), "application/json".into());
761 assert!(check_http_credential_safety("http://example.com/mcp", &headers).is_ok());
762 assert!(check_http_credential_safety("http://example.com/mcp", &HashMap::new()).is_ok());
764 }
765
766 #[test]
768 fn http_sec_03_allows_https_with_creds() {
769 let mut headers = HashMap::new();
770 headers.insert("Authorization".into(), "Bearer secret".into());
771 headers.insert("X-Api-Key".into(), "sk-123".into());
772 assert!(check_http_credential_safety("https://example.com/mcp", &headers).is_ok());
773 }
774
775 #[test]
776 fn is_sensitive_header_matches() {
777 assert!(is_sensitive_header("Authorization"));
778 assert!(is_sensitive_header("x-api-key"));
779 assert!(is_sensitive_header("Cookie"));
780 assert!(is_sensitive_header("X-Auth-Token"));
781 assert!(is_sensitive_header("X-Secret-Key"));
782 assert!(is_sensitive_header("X-Custom-Credential"));
783 assert!(is_sensitive_header("X-Password"));
784 assert!(!is_sensitive_header("Content-Type"));
785 assert!(!is_sensitive_header("Accept"));
786 assert!(!is_sensitive_header("User-Agent"));
787 }
788
789 #[test]
792 fn call_tool_result_is_error_true_returns_err() {
793 let result = CallToolResult {
794 content: vec![Content::text("Invalid params: missing field 'base_url'")],
795 is_error: Some(true),
796 structured_content: None,
797 meta: None,
798 };
799 let err = call_tool_result_to_value(result);
800 assert!(err.is_err());
801 let msg = err.unwrap_err().to_string();
802 assert!(
803 msg.contains("Invalid params"),
804 "expected error text, got: {msg}"
805 );
806 }
807
808 #[test]
809 fn call_tool_result_success_returns_ok() {
810 let result = CallToolResult {
811 content: vec![Content::text(r#"{"status": "ok"}"#)],
812 is_error: None,
813 structured_content: None,
814 meta: None,
815 };
816 let val = call_tool_result_to_value(result).unwrap();
817 assert_eq!(val["status"], "ok");
818 }
819
820 #[test]
821 fn call_tool_result_structured_content_takes_priority_over_is_error() {
822 let structured = serde_json::json!({"data": "important"});
823 let result = CallToolResult {
824 content: vec![Content::text("error text")],
825 is_error: Some(true),
826 structured_content: Some(structured.clone()),
827 meta: None,
828 };
829 let val = call_tool_result_to_value(result).unwrap();
830 assert_eq!(val, structured);
831 }
832
833 #[test]
835 #[allow(unreachable_patterns)]
836 fn ne_transport_config_is_non_exhaustive() {
837 let config = TransportConfig::Stdio {
838 command: "test".into(),
839 args: vec![],
840 };
841 match config {
842 TransportConfig::Stdio { .. } | TransportConfig::Http { .. } => {}
843 _ => {}
844 }
845 }
846}