1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4use std::path::PathBuf;
5use std::sync::Arc;
6use tokio::io::{AsyncWriteExt, BufReader};
7use tokio::process::Command as TokioCommand;
8
9use crate::event::RiskLevel;
10use crate::tools::{Tool, ToolCtx, ToolResult};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct McpServer {
16 pub name: String,
17 #[serde(default)]
18 pub transport: Transport,
19 #[serde(default)]
21 pub command: Option<String>,
22 #[serde(default)]
23 pub args: Vec<String>,
24 #[serde(default)]
26 pub url: Option<String>,
27 #[serde(default)]
29 pub env: std::collections::HashMap<String, String>,
30 #[serde(default)]
32 pub allow_tools: Vec<String>,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
36pub enum Transport {
37 #[serde(rename = "stdio")]
38 Stdio,
39 #[serde(rename = "sse")]
40 Sse,
41 #[serde(rename = "url")]
42 Url,
43}
44
45impl Default for Transport {
46 fn default() -> Self {
47 Transport::Stdio
48 }
49}
50
51#[derive(Debug, Serialize, Deserialize)]
54struct JsonRpcRequest {
55 jsonrpc: String,
56 id: u64,
57 method: String,
58 #[serde(default)]
59 params: Value,
60}
61
62#[derive(Debug, Deserialize)]
63struct ToolsListResult {
64 tools: Vec<McpToolDef>,
65}
66
67#[derive(Debug, Deserialize)]
68struct McpToolDef {
69 name: String,
70 #[serde(default)]
71 description: String,
72 #[serde(default)]
73 #[serde(rename = "inputSchema")]
74 input_schema: Value,
75}
76
77use tokio::sync::mpsc;
80
81struct McpToolWrapper {
83 tool_def: McpToolDef,
84 backend: McpBackend,
85}
86
87enum McpBackend {
88 Stdio {
90 request_tx: mpsc::Sender<McpRequest>,
91 },
92 Http {
94 url: String,
95 client: reqwest::Client,
96 },
97}
98
99struct McpRequest {
100 tool_name: String,
101 args: Value,
102 response_tx: tokio::sync::oneshot::Sender<anyhow::Result<ToolResult>>,
103}
104
105#[async_trait]
106impl Tool for McpToolWrapper {
107 fn name(&self) -> &str {
108 &self.tool_def.name
109 }
110 fn description(&self) -> &str {
111 &self.tool_def.description
112 }
113 fn schema(&self) -> Value {
114 self.tool_def.input_schema.clone()
115 }
116 fn risk(&self) -> RiskLevel {
117 RiskLevel::Exec
118 }
119
120 async fn call(&self, args: Value, _ctx: &ToolCtx) -> anyhow::Result<ToolResult> {
121 match &self.backend {
122 McpBackend::Stdio { request_tx } => {
123 let (tx, rx) = tokio::sync::oneshot::channel();
124 request_tx
125 .send(McpRequest {
126 tool_name: self.tool_def.name.clone(),
127 args,
128 response_tx: tx,
129 })
130 .await
131 .map_err(|_| anyhow::anyhow!("MCP server process has stopped"))?;
132
133 tokio::time::timeout(std::time::Duration::from_secs(30), rx)
134 .await
135 .map_err(|_| anyhow::anyhow!("MCP tool call timed out"))?
136 .map_err(|_| anyhow::anyhow!("MCP tool call channel closed"))?
137 }
138 McpBackend::Http { url, client } => {
139 static NEXT_ID: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(1);
140 let id = NEXT_ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
141 let body = serde_json::json!({
142 "jsonrpc": "2.0",
143 "id": id,
144 "method": "tools/call",
145 "params": {
146 "name": self.tool_def.name,
147 "arguments": args,
148 }
149 });
150 let resp = tokio::time::timeout(
151 std::time::Duration::from_secs(30),
152 client.post(url).json(&body).send(),
153 )
154 .await
155 .map_err(|_| anyhow::anyhow!("MCP HTTP call timed out"))??;
156 if !resp.status().is_success() {
157 let status = resp.status();
158 let body = resp.text().await.unwrap_or_default();
159 return Ok(ToolResult::error(format!(
160 "MCP HTTP error {}: {}",
161 status, body
162 )));
163 }
164 let value: Value = resp.json().await?;
165 if let Some(err) = value.get("error") {
166 return Ok(ToolResult::error(format!("MCP error: {}", err)));
167 }
168 if let Some(result) = value.get("result") {
169 Ok(ToolResult::text(result.to_string()))
170 } else {
171 Ok(ToolResult::text("(empty MCP response)"))
172 }
173 }
174 }
175 }
176}
177
178#[async_trait]
181pub trait McpClient: Send + Sync {
182 async fn connect(&self, server: &McpServer) -> anyhow::Result<Vec<Arc<dyn Tool>>>;
183 async fn disconnect(&self, server_name: &str) -> anyhow::Result<()>;
184 async fn list_servers(&self) -> Vec<McpServer>;
185}
186
187pub struct BasicMcpClient {
190 config_dir: PathBuf,
191}
192
193impl BasicMcpClient {
194 pub fn new(config_dir: PathBuf) -> Self {
195 Self { config_dir }
196 }
197
198 fn servers_file(&self) -> PathBuf {
199 self.config_dir.join("mcp_servers.json")
200 }
201
202 fn load_servers(&self) -> Vec<McpServer> {
203 let path = self.servers_file();
204 if !path.exists() {
205 return vec![];
206 }
207 std::fs::read_to_string(&path)
208 .ok()
209 .and_then(|s| serde_json::from_str(&s).ok())
210 .unwrap_or_default()
211 }
212
213 fn save_servers(&self, servers: &[McpServer]) -> anyhow::Result<()> {
214 std::fs::create_dir_all(&self.config_dir)?;
215 let json = serde_json::to_string_pretty(servers)?;
216 std::fs::write(self.servers_file(), json)?;
217 Ok(())
218 }
219
220 pub fn add_server(&self, server: McpServer) -> anyhow::Result<()> {
221 let mut servers = self.load_servers();
222 servers.retain(|s| s.name != server.name);
223 servers.push(server);
224 self.save_servers(&servers)
225 }
226
227 pub fn remove_server(&self, name: &str) -> anyhow::Result<()> {
228 let mut servers = self.load_servers();
229 servers.retain(|s| s.name != name);
230 self.save_servers(&servers)
231 }
232
233 pub fn get_server(&self, name: &str) -> Option<McpServer> {
234 self.load_servers().into_iter().find(|s| s.name == name)
235 }
236}
237
238#[async_trait]
239impl McpClient for BasicMcpClient {
240 async fn connect(&self, server: &McpServer) -> anyhow::Result<Vec<Arc<dyn Tool>>> {
241 match server.transport {
242 Transport::Stdio => self.connect_stdio(server).await,
243 Transport::Url | Transport::Sse => self.connect_http(server).await,
244 }
245 }
246
247 async fn disconnect(&self, _server_name: &str) -> anyhow::Result<()> {
248 Ok(())
251 }
252
253 async fn list_servers(&self) -> Vec<McpServer> {
254 self.load_servers()
255 }
256}
257
258impl BasicMcpClient {
259 async fn connect_stdio(&self, server: &McpServer) -> anyhow::Result<Vec<Arc<dyn Tool>>> {
261 let command = server
262 .command
263 .as_ref()
264 .ok_or_else(|| anyhow::anyhow!("stdio transport requires 'command'"))?;
265
266 let mut child = TokioCommand::new(command)
267 .args(&server.args)
268 .envs(&server.env)
269 .stdin(std::process::Stdio::piped())
270 .stdout(std::process::Stdio::piped())
271 .stderr(std::process::Stdio::piped())
272 .kill_on_drop(true)
273 .spawn()?;
274
275 let stdin = child.stdin.take().unwrap();
276 let stdout = child.stdout.take().unwrap();
277
278 let (mut writer, mut reader) = (tokio::io::BufWriter::new(stdin), BufReader::new(stdout));
279
280 let init_req = JsonRpcRequest {
282 jsonrpc: "2.0".into(),
283 id: 1,
284 method: "initialize".into(),
285 params: serde_json::json!({
286 "protocolVersion": "2024-11-05",
287 "capabilities": {},
288 "clientInfo": {
289 "name": "sparrow",
290 "version": "0.1.0"
291 }
292 }),
293 };
294
295 let req_json = serde_json::to_string(&init_req)? + "\n";
296 writer.write_all(req_json.as_bytes()).await?;
297 writer.flush().await?;
298
299 let _ = read_jsonrpc_response(&mut reader, 1).await?;
301
302 let notif = serde_json::json!({
304 "jsonrpc": "2.0",
305 "method": "notifications/initialized",
306 "params": {}
307 });
308 writer
309 .write_all((serde_json::to_string(¬if)? + "\n").as_bytes())
310 .await?;
311 writer.flush().await?;
312
313 let list_req = JsonRpcRequest {
315 jsonrpc: "2.0".into(),
316 id: 2,
317 method: "tools/list".into(),
318 params: Value::Null,
319 };
320
321 writer
322 .write_all((serde_json::to_string(&list_req)? + "\n").as_bytes())
323 .await?;
324 writer.flush().await?;
325
326 let tools_resp_value = read_jsonrpc_response(&mut reader, 2).await?;
327
328 let (request_tx, mut request_rx) = mpsc::channel::<McpRequest>(32);
330
331 tokio::spawn(async move {
332 let _child_guard = child;
335 let mut call_id: u64 = 3; while let Some(req) = request_rx.recv().await {
337 call_id += 1;
338 let call_req = serde_json::json!({
339 "jsonrpc": "2.0",
340 "id": call_id,
341 "method": "tools/call",
342 "params": {
343 "name": req.tool_name,
344 "arguments": req.args,
345 }
346 });
347 if writer
348 .write_all((serde_json::to_string(&call_req).unwrap() + "\n").as_bytes())
349 .await
350 .is_err()
351 || writer.flush().await.is_err()
352 {
353 let _ = req
354 .response_tx
355 .send(Err(anyhow::anyhow!("MCP stdin closed")));
356 break;
357 }
358
359 match read_jsonrpc_response(&mut reader, call_id).await {
361 Ok(value) => {
362 let result = if let Some(err) = value.get("error") {
363 Ok(ToolResult::error(format!("MCP error: {}", err)))
364 } else if let Some(val) = value.get("result") {
365 Ok(ToolResult::text(val.to_string()))
366 } else {
367 Ok(ToolResult::text("(empty MCP response)"))
368 };
369 let _ = req.response_tx.send(result);
370 }
371 Err(e) => {
372 let _ = req
373 .response_tx
374 .send(Err(anyhow::anyhow!("MCP read error: {}", e)));
375 break;
376 }
377 }
378 }
379 });
380
381 let server_name = server.name.clone();
383 let allow_list = server.allow_tools.clone();
384
385 let tools: Vec<Arc<dyn Tool>> = if let Some(result) = tools_resp_value.get("result") {
386 if let Ok(list) = serde_json::from_value::<ToolsListResult>(result.clone()) {
387 list.tools
388 .into_iter()
389 .filter(|t| allow_list.is_empty() || allow_list.contains(&t.name))
390 .map(|t| {
391 let _srv = server_name.clone();
392 Arc::new(McpToolWrapper {
393 tool_def: t,
394 backend: McpBackend::Stdio {
395 request_tx: request_tx.clone(),
396 },
397 }) as Arc<dyn Tool>
398 })
399 .collect()
400 } else {
401 vec![]
402 }
403 } else {
404 tracing::warn!("MCP server {} returned no tools/list result", server.name);
405 vec![]
406 };
407
408 Ok(tools)
409 }
410
411 async fn connect_http(&self, server: &McpServer) -> anyhow::Result<Vec<Arc<dyn Tool>>> {
413 let url = server
414 .url
415 .as_ref()
416 .ok_or_else(|| anyhow::anyhow!("url/sse transport requires 'url'"))?;
417
418 let client = reqwest::Client::new();
419
420 let _init_resp: Value = client
422 .post(url)
423 .json(&serde_json::json!({
424 "jsonrpc": "2.0",
425 "id": 1,
426 "method": "initialize",
427 "params": {
428 "protocolVersion": "2024-11-05",
429 "capabilities": {},
430 "clientInfo": { "name": "sparrow", "version": "0.1.0" }
431 }
432 }))
433 .send()
434 .await?
435 .json()
436 .await?;
437
438 let tools_resp: Value = client
440 .post(url)
441 .json(&serde_json::json!({
442 "jsonrpc": "2.0",
443 "id": 2,
444 "method": "tools/list",
445 "params": {}
446 }))
447 .send()
448 .await?
449 .json()
450 .await?;
451
452 let server_name = server.name.clone();
453 let allow_list = server.allow_tools.clone();
454
455 let tools: Vec<Arc<dyn Tool>> = if let Some(result) = tools_resp.get("result") {
456 if let Ok(list) = serde_json::from_value::<ToolsListResult>(result.clone()) {
457 list.tools
458 .into_iter()
459 .filter(|t| allow_list.is_empty() || allow_list.contains(&t.name))
460 .map(|t| {
461 let _srv = server_name.clone();
462 Arc::new(McpToolWrapper {
463 tool_def: t,
464 backend: McpBackend::Http {
465 url: url.clone(),
466 client: client.clone(),
467 },
468 }) as Arc<dyn Tool>
469 })
470 .collect()
471 } else {
472 vec![]
473 }
474 } else {
475 vec![]
476 };
477
478 Ok(tools)
479 }
480}
481
482async fn read_jsonrpc_response<R: tokio::io::AsyncBufRead + Unpin>(
486 reader: &mut R,
487 expected_id: u64,
488) -> anyhow::Result<Value> {
489 use tokio::io::AsyncBufReadExt;
490 let mut line = String::new();
491 for _ in 0..64 {
492 line.clear();
493 let n = reader.read_line(&mut line).await?;
494 if n == 0 {
495 anyhow::bail!("MCP server closed stdout");
496 }
497 let trimmed = line.trim();
498 if trimmed.is_empty() {
499 continue;
500 }
501 let value: Value = match serde_json::from_str(trimmed) {
502 Ok(v) => v,
503 Err(_) => {
504 tracing::debug!("MCP non-JSON stdout line: {}", trimmed);
506 continue;
507 }
508 };
509 match value.get("id").and_then(|v| v.as_u64()) {
511 Some(id) if id == expected_id => return Ok(value),
512 Some(_) => continue, None => continue, }
515 }
516 anyhow::bail!(
517 "MCP server did not respond to id={} within 64 frames",
518 expected_id
519 )
520}