1pub mod jsonrpc;
8pub mod transport;
9
10use cersei_types::*;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14use tokio::sync::Mutex;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct McpServerConfig {
20 pub name: String,
21 pub command: Option<String>,
22 #[serde(default)]
23 pub args: Vec<String>,
24 #[serde(default)]
25 pub env: HashMap<String, String>,
26 pub url: Option<String>,
27 #[serde(rename = "type", default = "default_type")]
28 pub server_type: String,
29}
30
31fn default_type() -> String {
32 "stdio".to_string()
33}
34
35impl McpServerConfig {
36 pub fn stdio(name: impl Into<String>, command: impl Into<String>, args: &[&str]) -> Self {
37 Self {
38 name: name.into(),
39 command: Some(command.into()),
40 args: args.iter().map(|s| s.to_string()).collect(),
41 env: HashMap::new(),
42 url: None,
43 server_type: "stdio".to_string(),
44 }
45 }
46
47 pub fn sse(name: impl Into<String>, url: impl Into<String>) -> Self {
48 Self {
49 name: name.into(),
50 command: None,
51 args: Vec::new(),
52 env: HashMap::new(),
53 url: Some(url.into()),
54 server_type: "sse".to_string(),
55 }
56 }
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
62#[serde(rename_all = "camelCase")]
63pub struct McpToolDef {
64 pub name: String,
65 #[serde(skip_serializing_if = "Option::is_none")]
66 pub description: Option<String>,
67 pub input_schema: serde_json::Value,
68}
69
70impl From<&McpToolDef> for ToolDefinition {
71 fn from(t: &McpToolDef) -> Self {
72 ToolDefinition {
73 name: t.name.clone(),
74 description: t.description.clone().unwrap_or_default(),
75 input_schema: t.input_schema.clone(),
76 }
77 }
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct McpResource {
82 pub uri: String,
83 pub name: String,
84 #[serde(skip_serializing_if = "Option::is_none")]
85 pub description: Option<String>,
86 #[serde(skip_serializing_if = "Option::is_none", rename = "mimeType")]
87 pub mime_type: Option<String>,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
91#[serde(tag = "type", rename_all = "lowercase")]
92pub enum McpContent {
93 Text {
94 text: String,
95 },
96 Image {
97 data: String,
98 #[serde(rename = "mimeType")]
99 mime_type: String,
100 },
101 Resource {
102 resource: McpResource,
103 },
104}
105
106#[derive(Debug, Clone, PartialEq)]
109pub enum McpServerStatus {
110 Connecting,
111 Connected,
112 Error(String),
113 Disconnected,
114}
115
116pub struct McpClient {
120 pub config: McpServerConfig,
121 pub status: McpServerStatus,
122 pub tools: Vec<McpToolDef>,
123 pub resources: Vec<McpResource>,
124 transport: Option<transport::StdioTransport>,
125}
126
127impl McpClient {
128 pub async fn connect(config: McpServerConfig) -> Result<Self> {
130 let config_expanded = expand_server_config(&config);
131
132 if config_expanded.server_type == "stdio" {
133 let command = config_expanded
134 .command
135 .as_deref()
136 .ok_or_else(|| CerseiError::Mcp("stdio server requires 'command'".into()))?;
137
138 let mut transport = transport::StdioTransport::spawn(
139 command,
140 &config_expanded.args,
141 &config_expanded.env,
142 )
143 .await?;
144
145 let init_params = serde_json::json!({
147 "protocolVersion": "2024-11-05",
148 "capabilities": {
149 "roots": { "listChanged": true }
150 },
151 "clientInfo": {
152 "name": "cersei",
153 "version": env!("CARGO_PKG_VERSION")
154 }
155 });
156
157 let init_result = transport.request("initialize", Some(init_params)).await?;
158 tracing::debug!("MCP initialize result: {:?}", init_result);
159
160 transport.notify("notifications/initialized", None).await?;
162
163 let tools_result = transport.request("tools/list", None).await?;
165 let tools: Vec<McpToolDef> = tools_result
166 .get("tools")
167 .and_then(|t| serde_json::from_value(t.clone()).ok())
168 .unwrap_or_default();
169
170 let resources = match transport.request("resources/list", None).await {
172 Ok(res) => res
173 .get("resources")
174 .and_then(|r| serde_json::from_value(r.clone()).ok())
175 .unwrap_or_default(),
176 Err(_) => Vec::new(), };
178
179 tracing::info!(
180 server = %config.name,
181 tools = tools.len(),
182 resources = resources.len(),
183 "MCP server connected"
184 );
185
186 Ok(Self {
187 config,
188 status: McpServerStatus::Connected,
189 tools,
190 resources,
191 transport: Some(transport),
192 })
193 } else {
194 Err(CerseiError::Mcp(format!(
196 "SSE transport not yet implemented for server '{}'",
197 config.name
198 )))
199 }
200 }
201
202 pub async fn call_tool(
204 &mut self,
205 tool_name: &str,
206 arguments: Option<serde_json::Value>,
207 ) -> Result<String> {
208 let transport = self
209 .transport
210 .as_mut()
211 .ok_or_else(|| CerseiError::Mcp("Not connected".into()))?;
212
213 let params = serde_json::json!({
214 "name": tool_name,
215 "arguments": arguments.unwrap_or(serde_json::Value::Object(Default::default())),
216 });
217
218 let result = transport.request("tools/call", Some(params)).await?;
219
220 let content: Vec<McpContent> = result
222 .get("content")
223 .and_then(|c| serde_json::from_value(c.clone()).ok())
224 .unwrap_or_default();
225
226 let is_error = result
227 .get("isError")
228 .and_then(|v| v.as_bool())
229 .unwrap_or(false);
230
231 let text: String = content
232 .iter()
233 .filter_map(|c| match c {
234 McpContent::Text { text } => Some(text.as_str()),
235 _ => None,
236 })
237 .collect::<Vec<_>>()
238 .join("\n");
239
240 if is_error {
241 Err(CerseiError::Mcp(text))
242 } else {
243 Ok(text)
244 }
245 }
246
247 pub async fn read_resource(&mut self, uri: &str) -> Result<String> {
249 let transport = self
250 .transport
251 .as_mut()
252 .ok_or_else(|| CerseiError::Mcp("Not connected".into()))?;
253
254 let params = serde_json::json!({ "uri": uri });
255 let result = transport.request("resources/read", Some(params)).await?;
256
257 let contents = result
258 .get("contents")
259 .and_then(|c| c.as_array())
260 .map(|arr| {
261 arr.iter()
262 .filter_map(|item| item.get("text").and_then(|t| t.as_str()))
263 .collect::<Vec<_>>()
264 .join("\n")
265 })
266 .unwrap_or_default();
267
268 Ok(contents)
269 }
270
271 pub fn tool_definitions(&self) -> Vec<ToolDefinition> {
273 self.tools.iter().map(ToolDefinition::from).collect()
274 }
275}
276
277pub struct McpManager {
281 clients: Arc<Mutex<HashMap<String, McpClient>>>,
282}
283
284impl McpManager {
285 pub async fn connect(configs: &[McpServerConfig]) -> Result<Self> {
287 let mut clients = HashMap::new();
288
289 for config in configs {
290 match McpClient::connect(config.clone()).await {
291 Ok(client) => {
292 clients.insert(config.name.clone(), client);
293 }
294 Err(e) => {
295 tracing::warn!(server = %config.name, error = %e, "Failed to connect MCP server");
296 }
297 }
298 }
299
300 Ok(Self {
301 clients: Arc::new(Mutex::new(clients)),
302 })
303 }
304
305 pub async fn tool_definitions(&self) -> Vec<ToolDefinition> {
307 let clients = self.clients.lock().await;
308 clients
309 .values()
310 .flat_map(|c| c.tool_definitions())
311 .collect()
312 }
313
314 pub async fn call_tool(
316 &self,
317 tool_name: &str,
318 arguments: Option<serde_json::Value>,
319 ) -> Result<String> {
320 let mut clients = self.clients.lock().await;
321
322 for client in clients.values_mut() {
323 if client.tools.iter().any(|t| t.name == tool_name) {
324 return client.call_tool(tool_name, arguments).await;
325 }
326 }
327
328 Err(CerseiError::Mcp(format!(
329 "No MCP server has tool '{}'",
330 tool_name
331 )))
332 }
333
334 pub async fn list_resources(&self) -> Vec<McpResource> {
336 let clients = self.clients.lock().await;
337 clients.values().flat_map(|c| c.resources.clone()).collect()
338 }
339
340 pub async fn read_resource(&self, uri: &str) -> Result<String> {
342 let mut clients = self.clients.lock().await;
343
344 for client in clients.values_mut() {
345 if client.resources.iter().any(|r| r.uri == uri) {
346 return client.read_resource(uri).await;
347 }
348 }
349
350 Err(CerseiError::Mcp(format!(
351 "No MCP server has resource '{}'",
352 uri
353 )))
354 }
355
356 pub async fn server_statuses(&self) -> HashMap<String, McpServerStatus> {
358 let clients = self.clients.lock().await;
359 clients
360 .iter()
361 .map(|(name, client)| (name.clone(), client.status.clone()))
362 .collect()
363 }
364
365 pub async fn configs(&self) -> Vec<McpServerConfig> {
367 let clients = self.clients.lock().await;
368 clients.values().map(|c| c.config.clone()).collect()
369 }
370}
371
372pub fn expand_env_vars(input: &str) -> String {
376 let mut result = input.to_string();
377 let mut search_from = 0;
378 loop {
379 match result[search_from..].find("${") {
380 None => break,
381 Some(rel_start) => {
382 let start = search_from + rel_start;
383 match result[start..].find('}') {
384 None => break,
385 Some(rel_end) => {
386 let end = start + rel_end;
387 let inner = &result[start + 2..end];
388 let (var_name, default_value) = if let Some(pos) = inner.find(":-") {
389 (&inner[..pos], Some(&inner[pos + 2..]))
390 } else {
391 (inner, None)
392 };
393
394 let replacement = match std::env::var(var_name) {
395 Ok(val) => val,
396 Err(_) => match default_value {
397 Some(def) => def.to_string(),
398 None => {
399 search_from = end + 1;
400 continue;
401 }
402 },
403 };
404
405 result =
406 format!("{}{}{}", &result[..start], replacement, &result[end + 1..]);
407 search_from = start + replacement.len();
408 }
409 }
410 }
411 }
412 }
413 result
414}
415
416pub fn expand_server_config(config: &McpServerConfig) -> McpServerConfig {
418 McpServerConfig {
419 name: config.name.clone(),
420 command: config.command.as_deref().map(expand_env_vars),
421 args: config.args.iter().map(|a| expand_env_vars(a)).collect(),
422 env: config
423 .env
424 .iter()
425 .map(|(k, v)| (k.clone(), expand_env_vars(v)))
426 .collect(),
427 url: config.url.as_deref().map(expand_env_vars),
428 server_type: config.server_type.clone(),
429 }
430}
431
432#[cfg(test)]
435mod tests {
436 use super::*;
437
438 #[test]
439 fn test_expand_env_vars_simple() {
440 std::env::set_var("CERSEI_TEST_VAR", "hello");
441 assert_eq!(expand_env_vars("${CERSEI_TEST_VAR}"), "hello");
442 std::env::remove_var("CERSEI_TEST_VAR");
443 }
444
445 #[test]
446 fn test_expand_env_vars_default() {
447 assert_eq!(expand_env_vars("${NONEXISTENT_VAR:-fallback}"), "fallback");
448 }
449
450 #[test]
451 fn test_expand_env_vars_missing_no_default() {
452 let result = expand_env_vars("${CERSEI_MISSING_XYZ}");
453 assert_eq!(result, "${CERSEI_MISSING_XYZ}"); }
455
456 #[test]
457 fn test_expand_env_vars_multiple() {
458 std::env::set_var("CERSEI_A", "one");
459 std::env::set_var("CERSEI_B", "two");
460 assert_eq!(expand_env_vars("${CERSEI_A}-${CERSEI_B}"), "one-two");
461 std::env::remove_var("CERSEI_A");
462 std::env::remove_var("CERSEI_B");
463 }
464
465 #[test]
466 fn test_stdio_config() {
467 let config = McpServerConfig::stdio("test", "node", &["server.js"]);
468 assert_eq!(config.server_type, "stdio");
469 assert_eq!(config.command.as_deref(), Some("node"));
470 assert_eq!(config.args, vec!["server.js"]);
471 }
472
473 #[test]
474 fn test_sse_config() {
475 let config = McpServerConfig::sse("remote", "https://mcp.example.com");
476 assert_eq!(config.server_type, "sse");
477 assert_eq!(config.url.as_deref(), Some("https://mcp.example.com"));
478 }
479
480 #[test]
481 fn test_tool_def_conversion() {
482 let mcp_tool = McpToolDef {
483 name: "search".into(),
484 description: Some("Search docs".into()),
485 input_schema: serde_json::json!({"type": "object"}),
486 };
487 let tool_def: ToolDefinition = ToolDefinition::from(&mcp_tool);
488 assert_eq!(tool_def.name, "search");
489 assert_eq!(tool_def.description, "Search docs");
490 }
491
492 #[test]
493 fn test_expand_server_config() {
494 std::env::set_var("CERSEI_MCP_CMD", "/usr/bin/node");
495 let config = McpServerConfig {
496 name: "test".into(),
497 command: Some("${CERSEI_MCP_CMD}".into()),
498 args: vec!["${CERSEI_MCP_CMD}".into()],
499 env: HashMap::from([("KEY".into(), "${CERSEI_MCP_CMD}".into())]),
500 url: None,
501 server_type: "stdio".into(),
502 };
503 let expanded = expand_server_config(&config);
504 assert_eq!(expanded.command.as_deref(), Some("/usr/bin/node"));
505 assert_eq!(expanded.args[0], "/usr/bin/node");
506 assert_eq!(expanded.env["KEY"], "/usr/bin/node");
507 std::env::remove_var("CERSEI_MCP_CMD");
508 }
509}