claude_agent/mcp/
mod.rs

1//! MCP (Model Context Protocol) server integration.
2
3pub mod client;
4pub mod manager;
5pub mod resources;
6pub mod toolset;
7
8pub use client::McpClient;
9pub use manager::McpManager;
10pub use resources::{ResourceManager, ResourceQuery};
11pub use toolset::{McpToolset, McpToolsetRegistry, ToolLoadConfig};
12
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::time::Duration;
16
17const MCP_TOOL_PREFIX: &str = "mcp__";
18
19/// Default timeout for MCP connections
20pub const MCP_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
21/// Default timeout for MCP tool calls
22pub const MCP_CALL_TIMEOUT: Duration = Duration::from_secs(60);
23/// Default timeout for MCP resource reads
24pub const MCP_RESOURCE_TIMEOUT: Duration = Duration::from_secs(30);
25
26/// MCP server configuration
27#[derive(Clone, Debug, Serialize, Deserialize)]
28#[serde(tag = "type", rename_all = "lowercase")]
29pub enum McpServerConfig {
30    /// stdio transport - communicates with server via stdin/stdout
31    Stdio {
32        command: String,
33        #[serde(default)]
34        args: Vec<String>,
35        #[serde(default)]
36        env: HashMap<String, String>,
37    },
38    /// Server-Sent Events transport (requires rmcp SSE support)
39    Sse {
40        url: String,
41        #[serde(default)]
42        headers: HashMap<String, String>,
43    },
44}
45
46/// Reconnection policy with exponential backoff and jitter
47#[derive(Clone, Debug)]
48pub struct ReconnectPolicy {
49    pub max_retries: u32,
50    pub base_delay_ms: u64,
51    pub max_delay_ms: u64,
52    pub jitter_factor: f64,
53}
54
55impl Default for ReconnectPolicy {
56    fn default() -> Self {
57        Self {
58            max_retries: 3,
59            base_delay_ms: 1000,
60            max_delay_ms: 30000,
61            jitter_factor: 0.3,
62        }
63    }
64}
65
66impl ReconnectPolicy {
67    pub fn delay_for_attempt(&self, attempt: u32) -> std::time::Duration {
68        let base = self.base_delay_ms * 2u64.pow(attempt.min(10));
69        let jitter = (base as f64 * self.jitter_factor * rand_factor()) as u64;
70        std::time::Duration::from_millis((base + jitter).min(self.max_delay_ms))
71    }
72}
73
74fn rand_factor() -> f64 {
75    use std::collections::hash_map::DefaultHasher;
76    use std::hash::{Hash, Hasher};
77    use std::time::SystemTime;
78
79    let mut hasher = DefaultHasher::new();
80    SystemTime::now().hash(&mut hasher);
81    std::thread::current().id().hash(&mut hasher);
82    std::process::id().hash(&mut hasher);
83
84    let hash = hasher.finish();
85    (hash % 10000) as f64 / 10000.0
86}
87
88/// Parse MCP qualified name (mcp__server_tool) into (server, tool)
89pub fn parse_mcp_name(name: &str) -> Option<(&str, &str)> {
90    name.strip_prefix(MCP_TOOL_PREFIX)?.split_once('_')
91}
92
93/// Create MCP qualified name from server and tool names
94pub fn make_mcp_name(server: &str, tool: &str) -> String {
95    format!("{}{server}_{tool}", MCP_TOOL_PREFIX)
96}
97
98/// Check if a name matches MCP naming pattern
99pub fn is_mcp_name(name: &str) -> bool {
100    name.starts_with(MCP_TOOL_PREFIX)
101}
102
103/// MCP connection status
104#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
105pub enum McpConnectionStatus {
106    /// Attempting to connect
107    #[default]
108    Connecting,
109    /// Successfully connected
110    Connected,
111    /// Disconnected
112    Disconnected,
113    /// Connection failed
114    Failed,
115    /// Server requires authentication
116    NeedsAuth,
117}
118
119/// MCP server information returned during initialization
120#[derive(Clone, Debug, Default, Serialize, Deserialize)]
121#[serde(rename_all = "camelCase")]
122pub struct McpServerInfo {
123    /// Server name
124    pub name: String,
125    /// Server version
126    pub version: String,
127    /// Protocol version (e.g., "2025-06-18")
128    #[serde(default)]
129    pub protocol_version: String,
130}
131
132/// MCP tool definition from server
133#[derive(Clone, Debug, Serialize, Deserialize)]
134#[serde(rename_all = "camelCase")]
135pub struct McpToolDefinition {
136    /// Tool name
137    pub name: String,
138    /// Tool description
139    #[serde(default)]
140    pub description: String,
141    /// Input schema (JSON Schema)
142    #[serde(default)]
143    pub input_schema: serde_json::Value,
144}
145
146/// MCP resource definition
147#[derive(Clone, Debug, Serialize, Deserialize)]
148#[serde(rename_all = "camelCase")]
149pub struct McpResourceDefinition {
150    /// Resource URI
151    pub uri: String,
152    /// Resource name
153    pub name: String,
154    /// Resource description
155    #[serde(default)]
156    pub description: Option<String>,
157    /// MIME type of the resource
158    #[serde(default)]
159    pub mime_type: Option<String>,
160}
161
162/// MCP server state tracking
163#[derive(Clone, Debug)]
164pub struct McpServerState {
165    /// Server name (unique identifier)
166    pub name: String,
167    /// Server configuration
168    pub config: McpServerConfig,
169    /// Connection status
170    pub status: McpConnectionStatus,
171    /// Server info (available after connection)
172    pub server_info: Option<McpServerInfo>,
173    /// Available tools
174    pub tools: Vec<McpToolDefinition>,
175    /// Available resources
176    pub resources: Vec<McpResourceDefinition>,
177}
178
179impl McpServerState {
180    /// Create a new server state with the given name and config
181    pub fn new(name: impl Into<String>, config: McpServerConfig) -> Self {
182        Self {
183            name: name.into(),
184            config,
185            status: McpConnectionStatus::Connecting,
186            server_info: None,
187            tools: Vec::new(),
188            resources: Vec::new(),
189        }
190    }
191
192    /// Check if the server is connected
193    pub fn is_connected(&self) -> bool {
194        self.status == McpConnectionStatus::Connected
195    }
196}
197
198/// MCP error types
199#[derive(Debug, thiserror::Error)]
200pub enum McpError {
201    /// Connection to server failed
202    #[error("Connection failed: {message}")]
203    ConnectionFailed {
204        /// Error message
205        message: String,
206    },
207
208    /// Protocol error (invalid messages, etc.)
209    #[error("Protocol error: {message}")]
210    Protocol {
211        /// Error message
212        message: String,
213    },
214
215    /// JSON-RPC error from server
216    #[error("JSON-RPC error {code}: {message}")]
217    JsonRpc {
218        /// Error code
219        code: i32,
220        /// Error message
221        message: String,
222    },
223
224    /// Tool execution error
225    #[error("Tool error: {message}")]
226    ToolError {
227        /// Error message
228        message: String,
229    },
230
231    /// Version mismatch between client and server
232    #[error("Version mismatch: server supports {supported:?}, client requested {requested}")]
233    VersionMismatch {
234        /// Versions supported by server
235        supported: Vec<String>,
236        /// Version requested by client
237        requested: String,
238    },
239
240    /// Server not found
241    #[error("Server not found: {name}")]
242    ServerNotFound {
243        /// Server name
244        name: String,
245    },
246
247    /// Tool not found
248    #[error("Tool not found: {name}")]
249    ToolNotFound {
250        /// Tool name
251        name: String,
252    },
253
254    /// Resource not found
255    #[error("Resource not found: {uri}")]
256    ResourceNotFound {
257        /// Resource URI
258        uri: String,
259    },
260
261    /// IO error
262    #[error("IO error: {0}")]
263    Io(#[from] std::io::Error),
264
265    /// JSON serialization error
266    #[error("JSON error: {0}")]
267    Json(#[from] serde_json::Error),
268}
269
270/// Result type for MCP operations
271pub type McpResult<T> = std::result::Result<T, McpError>;
272
273/// MCP tool call result
274#[derive(Clone, Debug, Serialize, Deserialize)]
275pub struct McpToolResult {
276    /// Result content
277    pub content: Vec<McpContent>,
278    /// Whether the call resulted in an error
279    #[serde(default)]
280    pub is_error: bool,
281}
282
283/// MCP content types (returned from tool calls and resources)
284#[derive(Clone, Debug, Serialize, Deserialize)]
285#[serde(tag = "type", rename_all = "lowercase")]
286pub enum McpContent {
287    /// Text content
288    Text {
289        /// Text value
290        text: String,
291    },
292    /// Image content (base64 encoded)
293    Image {
294        /// Base64-encoded image data
295        data: String,
296        /// MIME type (e.g., "image/png")
297        mime_type: String,
298    },
299    /// Resource reference
300    Resource {
301        /// Resource URI
302        uri: String,
303        /// Resource text content (if available)
304        #[serde(default)]
305        text: Option<String>,
306        /// Resource blob content (if available, base64)
307        #[serde(default)]
308        blob: Option<String>,
309        /// MIME type
310        #[serde(default)]
311        mime_type: Option<String>,
312    },
313}
314
315impl McpContent {
316    /// Get text content if this is a text content type
317    pub fn as_text(&self) -> Option<&str> {
318        match self {
319            McpContent::Text { text } => Some(text),
320            _ => None,
321        }
322    }
323}
324
325impl McpToolResult {
326    /// Convert to a string representation
327    pub fn to_string_content(&self) -> String {
328        self.content
329            .iter()
330            .filter_map(|c| c.as_text())
331            .collect::<Vec<_>>()
332            .join("\n")
333    }
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339
340    #[test]
341    fn test_parse_mcp_name() {
342        assert_eq!(parse_mcp_name("mcp__server_tool"), Some(("server", "tool")));
343        assert_eq!(
344            parse_mcp_name("mcp__fs_read_file"),
345            Some(("fs", "read_file"))
346        );
347        assert_eq!(parse_mcp_name("Read"), None);
348        assert_eq!(parse_mcp_name("mcp_invalid"), None);
349    }
350
351    #[test]
352    fn test_make_mcp_name() {
353        assert_eq!(make_mcp_name("server", "tool"), "mcp__server_tool");
354        assert_eq!(make_mcp_name("fs", "read_file"), "mcp__fs_read_file");
355    }
356
357    #[test]
358    fn test_is_mcp_name() {
359        assert!(is_mcp_name("mcp__server_tool"));
360        assert!(!is_mcp_name("Read"));
361        assert!(!is_mcp_name("mcp_invalid"));
362    }
363
364    #[test]
365    fn test_reconnect_policy_delay() {
366        let policy = ReconnectPolicy::default();
367        let d0 = policy.delay_for_attempt(0);
368        let d1 = policy.delay_for_attempt(1);
369        assert!(d1 > d0);
370        assert!(d0.as_millis() >= 1000);
371        assert!(d0.as_millis() <= 1300);
372    }
373
374    #[test]
375    fn test_mcp_server_config_serde() {
376        let config = McpServerConfig::Stdio {
377            command: "npx".to_string(),
378            args: vec!["server".to_string()],
379            env: HashMap::new(),
380        };
381
382        let json = serde_json::to_string(&config).unwrap();
383        assert!(json.contains("stdio"));
384        assert!(json.contains("npx"));
385    }
386
387    #[test]
388    fn test_mcp_server_state_new() {
389        let state = McpServerState::new(
390            "test",
391            McpServerConfig::Stdio {
392                command: "test".to_string(),
393                args: vec![],
394                env: HashMap::new(),
395            },
396        );
397
398        assert_eq!(state.name, "test");
399        assert_eq!(state.status, McpConnectionStatus::Connecting);
400        assert!(!state.is_connected());
401    }
402
403    #[test]
404    fn test_mcp_content_as_text() {
405        let content = McpContent::Text {
406            text: "hello".to_string(),
407        };
408        assert_eq!(content.as_text(), Some("hello"));
409
410        let image = McpContent::Image {
411            data: "base64".to_string(),
412            mime_type: "image/png".to_string(),
413        };
414        assert_eq!(image.as_text(), None);
415    }
416}