Skip to main content

cersei_mcp/
lib.rs

1//! cersei-mcp: Model Context Protocol (MCP) client.
2//!
3//! Full JSON-RPC 2.0 implementation with stdio transport for connecting
4//! to MCP servers. Discovers tools and resources, makes them available
5//! as standard Cersei tool definitions.
6
7pub mod jsonrpc;
8pub mod transport;
9
10use cersei_types::*;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14use tokio::sync::Mutex;
15
16// ─── MCP server config ──────────────────────────────────────────────────────
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct McpServerConfig {
20    pub name: String,
21    pub command: Option<String>,
22    #[serde(default)]
23    pub args: Vec<String>,
24    #[serde(default)]
25    pub env: HashMap<String, String>,
26    pub url: Option<String>,
27    #[serde(rename = "type", default = "default_type")]
28    pub server_type: String,
29}
30
31fn default_type() -> String {
32    "stdio".to_string()
33}
34
35impl McpServerConfig {
36    pub fn stdio(name: impl Into<String>, command: impl Into<String>, args: &[&str]) -> Self {
37        Self {
38            name: name.into(),
39            command: Some(command.into()),
40            args: args.iter().map(|s| s.to_string()).collect(),
41            env: HashMap::new(),
42            url: None,
43            server_type: "stdio".to_string(),
44        }
45    }
46
47    pub fn sse(name: impl Into<String>, url: impl Into<String>) -> Self {
48        Self {
49            name: name.into(),
50            command: None,
51            args: Vec::new(),
52            env: HashMap::new(),
53            url: Some(url.into()),
54            server_type: "sse".to_string(),
55        }
56    }
57}
58
59// ─── MCP protocol types ─────────────────────────────────────────────────────
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
62#[serde(rename_all = "camelCase")]
63pub struct McpToolDef {
64    pub name: String,
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub description: Option<String>,
67    pub input_schema: serde_json::Value,
68}
69
70impl From<&McpToolDef> for ToolDefinition {
71    fn from(t: &McpToolDef) -> Self {
72        ToolDefinition {
73            name: t.name.clone(),
74            description: t.description.clone().unwrap_or_default(),
75            input_schema: t.input_schema.clone(),
76        }
77    }
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct McpResource {
82    pub uri: String,
83    pub name: String,
84    #[serde(skip_serializing_if = "Option::is_none")]
85    pub description: Option<String>,
86    #[serde(skip_serializing_if = "Option::is_none", rename = "mimeType")]
87    pub mime_type: Option<String>,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
91#[serde(tag = "type", rename_all = "lowercase")]
92pub enum McpContent {
93    Text {
94        text: String,
95    },
96    Image {
97        data: String,
98        #[serde(rename = "mimeType")]
99        mime_type: String,
100    },
101    Resource {
102        resource: McpResource,
103    },
104}
105
106// ─── Server status ───────────────────────────────────────────────────────────
107
108#[derive(Debug, Clone, PartialEq)]
109pub enum McpServerStatus {
110    Connecting,
111    Connected,
112    Error(String),
113    Disconnected,
114}
115
116// ─── MCP client (per-server) ─────────────────────────────────────────────────
117
118/// A client connected to a single MCP server.
119pub struct McpClient {
120    pub config: McpServerConfig,
121    pub status: McpServerStatus,
122    pub tools: Vec<McpToolDef>,
123    pub resources: Vec<McpResource>,
124    transport: Option<transport::StdioTransport>,
125}
126
127impl McpClient {
128    /// Connect to an MCP server and perform the handshake.
129    pub async fn connect(config: McpServerConfig) -> Result<Self> {
130        let config_expanded = expand_server_config(&config);
131
132        if config_expanded.server_type == "stdio" {
133            let command = config_expanded
134                .command
135                .as_deref()
136                .ok_or_else(|| CerseiError::Mcp("stdio server requires 'command'".into()))?;
137
138            let mut transport = transport::StdioTransport::spawn(
139                command,
140                &config_expanded.args,
141                &config_expanded.env,
142            )
143            .await?;
144
145            // Initialize handshake
146            let init_params = serde_json::json!({
147                "protocolVersion": "2024-11-05",
148                "capabilities": {
149                    "roots": { "listChanged": true }
150                },
151                "clientInfo": {
152                    "name": "cersei",
153                    "version": env!("CARGO_PKG_VERSION")
154                }
155            });
156
157            let init_result = transport.request("initialize", Some(init_params)).await?;
158            tracing::debug!("MCP initialize result: {:?}", init_result);
159
160            // Send initialized notification
161            transport.notify("notifications/initialized", None).await?;
162
163            // Discover tools
164            let tools_result = transport.request("tools/list", None).await?;
165            let tools: Vec<McpToolDef> = tools_result
166                .get("tools")
167                .and_then(|t| serde_json::from_value(t.clone()).ok())
168                .unwrap_or_default();
169
170            // Discover resources
171            let resources = match transport.request("resources/list", None).await {
172                Ok(res) => res
173                    .get("resources")
174                    .and_then(|r| serde_json::from_value(r.clone()).ok())
175                    .unwrap_or_default(),
176                Err(_) => Vec::new(), // resources are optional
177            };
178
179            tracing::info!(
180                server = %config.name,
181                tools = tools.len(),
182                resources = resources.len(),
183                "MCP server connected"
184            );
185
186            Ok(Self {
187                config,
188                status: McpServerStatus::Connected,
189                tools,
190                resources,
191                transport: Some(transport),
192            })
193        } else {
194            // SSE transport placeholder
195            Err(CerseiError::Mcp(format!(
196                "SSE transport not yet implemented for server '{}'",
197                config.name
198            )))
199        }
200    }
201
202    /// Call a tool on this MCP server.
203    pub async fn call_tool(
204        &mut self,
205        tool_name: &str,
206        arguments: Option<serde_json::Value>,
207    ) -> Result<String> {
208        let transport = self
209            .transport
210            .as_mut()
211            .ok_or_else(|| CerseiError::Mcp("Not connected".into()))?;
212
213        let params = serde_json::json!({
214            "name": tool_name,
215            "arguments": arguments.unwrap_or(serde_json::Value::Object(Default::default())),
216        });
217
218        let result = transport.request("tools/call", Some(params)).await?;
219
220        // Parse content array
221        let content: Vec<McpContent> = result
222            .get("content")
223            .and_then(|c| serde_json::from_value(c.clone()).ok())
224            .unwrap_or_default();
225
226        let is_error = result
227            .get("isError")
228            .and_then(|v| v.as_bool())
229            .unwrap_or(false);
230
231        let text: String = content
232            .iter()
233            .filter_map(|c| match c {
234                McpContent::Text { text } => Some(text.as_str()),
235                _ => None,
236            })
237            .collect::<Vec<_>>()
238            .join("\n");
239
240        if is_error {
241            Err(CerseiError::Mcp(text))
242        } else {
243            Ok(text)
244        }
245    }
246
247    /// Read a resource from this MCP server.
248    pub async fn read_resource(&mut self, uri: &str) -> Result<String> {
249        let transport = self
250            .transport
251            .as_mut()
252            .ok_or_else(|| CerseiError::Mcp("Not connected".into()))?;
253
254        let params = serde_json::json!({ "uri": uri });
255        let result = transport.request("resources/read", Some(params)).await?;
256
257        let contents = result
258            .get("contents")
259            .and_then(|c| c.as_array())
260            .map(|arr| {
261                arr.iter()
262                    .filter_map(|item| item.get("text").and_then(|t| t.as_str()))
263                    .collect::<Vec<_>>()
264                    .join("\n")
265            })
266            .unwrap_or_default();
267
268        Ok(contents)
269    }
270
271    /// Get tool definitions for the provider.
272    pub fn tool_definitions(&self) -> Vec<ToolDefinition> {
273        self.tools.iter().map(ToolDefinition::from).collect()
274    }
275}
276
277// ─── MCP manager (multi-server) ──────────────────────────────────────────────
278
279/// Manages connections to multiple MCP servers.
280pub struct McpManager {
281    clients: Arc<Mutex<HashMap<String, McpClient>>>,
282}
283
284impl McpManager {
285    /// Connect to all configured MCP servers.
286    pub async fn connect(configs: &[McpServerConfig]) -> Result<Self> {
287        let mut clients = HashMap::new();
288
289        for config in configs {
290            match McpClient::connect(config.clone()).await {
291                Ok(client) => {
292                    clients.insert(config.name.clone(), client);
293                }
294                Err(e) => {
295                    tracing::warn!(server = %config.name, error = %e, "Failed to connect MCP server");
296                }
297            }
298        }
299
300        Ok(Self {
301            clients: Arc::new(Mutex::new(clients)),
302        })
303    }
304
305    /// Get all discovered tool definitions across all servers.
306    pub async fn tool_definitions(&self) -> Vec<ToolDefinition> {
307        let clients = self.clients.lock().await;
308        clients
309            .values()
310            .flat_map(|c| c.tool_definitions())
311            .collect()
312    }
313
314    /// Call a tool by name (routes to the correct server).
315    pub async fn call_tool(
316        &self,
317        tool_name: &str,
318        arguments: Option<serde_json::Value>,
319    ) -> Result<String> {
320        let mut clients = self.clients.lock().await;
321
322        for client in clients.values_mut() {
323            if client.tools.iter().any(|t| t.name == tool_name) {
324                return client.call_tool(tool_name, arguments).await;
325            }
326        }
327
328        Err(CerseiError::Mcp(format!(
329            "No MCP server has tool '{}'",
330            tool_name
331        )))
332    }
333
334    /// List all resources across all servers.
335    pub async fn list_resources(&self) -> Vec<McpResource> {
336        let clients = self.clients.lock().await;
337        clients.values().flat_map(|c| c.resources.clone()).collect()
338    }
339
340    /// Read a resource by URI (routes to the correct server).
341    pub async fn read_resource(&self, uri: &str) -> Result<String> {
342        let mut clients = self.clients.lock().await;
343
344        for client in clients.values_mut() {
345            if client.resources.iter().any(|r| r.uri == uri) {
346                return client.read_resource(uri).await;
347            }
348        }
349
350        Err(CerseiError::Mcp(format!(
351            "No MCP server has resource '{}'",
352            uri
353        )))
354    }
355
356    /// Get the status of all connected servers.
357    pub async fn server_statuses(&self) -> HashMap<String, McpServerStatus> {
358        let clients = self.clients.lock().await;
359        clients
360            .iter()
361            .map(|(name, client)| (name.clone(), client.status.clone()))
362            .collect()
363    }
364
365    /// Get server configs.
366    pub async fn configs(&self) -> Vec<McpServerConfig> {
367        let clients = self.clients.lock().await;
368        clients.values().map(|c| c.config.clone()).collect()
369    }
370}
371
372// ─── Env var expansion ───────────────────────────────────────────────────────
373
374/// Expand `${VAR}` and `${VAR:-default}` patterns.
375pub fn expand_env_vars(input: &str) -> String {
376    let mut result = input.to_string();
377    let mut search_from = 0;
378    loop {
379        match result[search_from..].find("${") {
380            None => break,
381            Some(rel_start) => {
382                let start = search_from + rel_start;
383                match result[start..].find('}') {
384                    None => break,
385                    Some(rel_end) => {
386                        let end = start + rel_end;
387                        let inner = &result[start + 2..end];
388                        let (var_name, default_value) = if let Some(pos) = inner.find(":-") {
389                            (&inner[..pos], Some(&inner[pos + 2..]))
390                        } else {
391                            (inner, None)
392                        };
393
394                        let replacement = match std::env::var(var_name) {
395                            Ok(val) => val,
396                            Err(_) => match default_value {
397                                Some(def) => def.to_string(),
398                                None => {
399                                    search_from = end + 1;
400                                    continue;
401                                }
402                            },
403                        };
404
405                        result =
406                            format!("{}{}{}", &result[..start], replacement, &result[end + 1..]);
407                        search_from = start + replacement.len();
408                    }
409                }
410            }
411        }
412    }
413    result
414}
415
416/// Expand env vars in all string fields of a server config.
417pub fn expand_server_config(config: &McpServerConfig) -> McpServerConfig {
418    McpServerConfig {
419        name: config.name.clone(),
420        command: config.command.as_deref().map(expand_env_vars),
421        args: config.args.iter().map(|a| expand_env_vars(a)).collect(),
422        env: config
423            .env
424            .iter()
425            .map(|(k, v)| (k.clone(), expand_env_vars(v)))
426            .collect(),
427        url: config.url.as_deref().map(expand_env_vars),
428        server_type: config.server_type.clone(),
429    }
430}
431
432// ─── Tests ───────────────────────────────────────────────────────────────────
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437
438    #[test]
439    fn test_expand_env_vars_simple() {
440        std::env::set_var("CERSEI_TEST_VAR", "hello");
441        assert_eq!(expand_env_vars("${CERSEI_TEST_VAR}"), "hello");
442        std::env::remove_var("CERSEI_TEST_VAR");
443    }
444
445    #[test]
446    fn test_expand_env_vars_default() {
447        assert_eq!(expand_env_vars("${NONEXISTENT_VAR:-fallback}"), "fallback");
448    }
449
450    #[test]
451    fn test_expand_env_vars_missing_no_default() {
452        let result = expand_env_vars("${CERSEI_MISSING_XYZ}");
453        assert_eq!(result, "${CERSEI_MISSING_XYZ}"); // left as-is
454    }
455
456    #[test]
457    fn test_expand_env_vars_multiple() {
458        std::env::set_var("CERSEI_A", "one");
459        std::env::set_var("CERSEI_B", "two");
460        assert_eq!(expand_env_vars("${CERSEI_A}-${CERSEI_B}"), "one-two");
461        std::env::remove_var("CERSEI_A");
462        std::env::remove_var("CERSEI_B");
463    }
464
465    #[test]
466    fn test_stdio_config() {
467        let config = McpServerConfig::stdio("test", "node", &["server.js"]);
468        assert_eq!(config.server_type, "stdio");
469        assert_eq!(config.command.as_deref(), Some("node"));
470        assert_eq!(config.args, vec!["server.js"]);
471    }
472
473    #[test]
474    fn test_sse_config() {
475        let config = McpServerConfig::sse("remote", "https://mcp.example.com");
476        assert_eq!(config.server_type, "sse");
477        assert_eq!(config.url.as_deref(), Some("https://mcp.example.com"));
478    }
479
480    #[test]
481    fn test_tool_def_conversion() {
482        let mcp_tool = McpToolDef {
483            name: "search".into(),
484            description: Some("Search docs".into()),
485            input_schema: serde_json::json!({"type": "object"}),
486        };
487        let tool_def: ToolDefinition = ToolDefinition::from(&mcp_tool);
488        assert_eq!(tool_def.name, "search");
489        assert_eq!(tool_def.description, "Search docs");
490    }
491
492    #[test]
493    fn test_expand_server_config() {
494        std::env::set_var("CERSEI_MCP_CMD", "/usr/bin/node");
495        let config = McpServerConfig {
496            name: "test".into(),
497            command: Some("${CERSEI_MCP_CMD}".into()),
498            args: vec!["${CERSEI_MCP_CMD}".into()],
499            env: HashMap::from([("KEY".into(), "${CERSEI_MCP_CMD}".into())]),
500            url: None,
501            server_type: "stdio".into(),
502        };
503        let expanded = expand_server_config(&config);
504        assert_eq!(expanded.command.as_deref(), Some("/usr/bin/node"));
505        assert_eq!(expanded.args[0], "/usr/bin/node");
506        assert_eq!(expanded.env["KEY"], "/usr/bin/node");
507        std::env::remove_var("CERSEI_MCP_CMD");
508    }
509}