1use serde::{Deserialize, Serialize};
34use serde_json::{json, Value};
35use std::fmt;
36use std::io::{BufRead, BufReader, Write};
37use std::process::{Child, ChildStdin, Command, Stdio};
38use std::sync::mpsc::{self, Receiver, RecvTimeoutError};
39use std::sync::{Arc, Mutex};
40use std::thread::{self, JoinHandle};
41use std::time::Duration;
42
43pub const PROTOCOL_VERSION: &str = "2024-11-05";
45
46pub const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
48
49pub const MAX_LINE_BYTES: usize = 4 * 1024 * 1024;
60
61#[derive(Debug)]
67pub enum McpError {
68 Io(String),
70 Protocol(String),
72 Timeout,
74 Closed,
76}
77
78impl fmt::Display for McpError {
79 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80 match self {
81 McpError::Io(m) => write!(f, "mcp io: {}", m),
82 McpError::Protocol(m) => write!(f, "mcp protocol: {}", m),
83 McpError::Timeout => write!(f, "mcp timeout"),
84 McpError::Closed => write!(f, "mcp channel closed"),
85 }
86 }
87}
88
89impl std::error::Error for McpError {}
90
91impl From<std::io::Error> for McpError {
92 fn from(e: std::io::Error) -> Self {
93 McpError::Io(e.to_string())
94 }
95}
96
97#[derive(Serialize)]
102struct JsonRpcRequest<'a> {
103 jsonrpc: &'static str,
104 id: u64,
105 method: &'a str,
106 #[serde(skip_serializing_if = "Option::is_none")]
107 params: Option<Value>,
108}
109
110#[derive(Serialize)]
111struct JsonRpcNotification<'a> {
112 jsonrpc: &'static str,
113 method: &'a str,
114 #[serde(skip_serializing_if = "Option::is_none")]
115 params: Option<Value>,
116}
117
118#[derive(Deserialize, Debug)]
119struct JsonRpcResponse {
120 #[allow(dead_code)]
121 jsonrpc: Option<String>,
122 id: Option<Value>,
123 #[serde(default)]
124 result: Option<Value>,
125 #[serde(default)]
126 error: Option<JsonRpcErrorPayload>,
127}
128
129#[derive(Deserialize, Debug)]
130struct JsonRpcErrorPayload {
131 code: i64,
132 message: String,
133 #[serde(default)]
134 #[allow(dead_code)]
135 data: Option<Value>,
136}
137
138#[derive(Debug, Clone, Deserialize, Serialize)]
140pub struct McpToolInfo {
141 pub name: String,
142 #[serde(default)]
143 pub description: String,
144 #[serde(rename = "inputSchema", default = "default_schema")]
145 pub input_schema: Value,
146}
147
148fn default_schema() -> Value {
149 json!({"type": "object"})
150}
151
152enum BoundedRead {
160 Eof,
162 Line,
164}
165
166#[derive(Debug)]
168enum BoundedReadError {
169 Overflow,
172 #[allow(dead_code)]
175 Io(std::io::Error),
176}
177
178fn read_bounded_line<R: BufRead>(
189 reader: &mut R,
190 buf: &mut Vec<u8>,
191 limit: usize,
192) -> Result<BoundedRead, BoundedReadError> {
193 loop {
194 let available = match reader.fill_buf() {
195 Ok(b) => {
196 if b.is_empty() {
197 return Ok(BoundedRead::Eof);
201 }
202 b
203 }
204 Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
205 Err(e) => return Err(BoundedReadError::Io(e)),
206 };
207 let (chunk, done) = match available.iter().position(|&b| b == b'\n') {
208 Some(i) => (&available[..=i], true),
209 None => (available, false),
210 };
211 if buf.len() + chunk.len() > limit.saturating_add(1) {
212 let take = limit.saturating_add(1).saturating_sub(buf.len());
215 buf.extend_from_slice(&chunk[..take]);
216 let consumed = take;
217 reader.consume(consumed);
218 return Err(BoundedReadError::Overflow);
219 }
220 buf.extend_from_slice(chunk);
221 let consumed = chunk.len();
222 reader.consume(consumed);
223 if done {
224 if buf.last() == Some(&b'\n') {
226 buf.pop();
227 }
228 if buf.last() == Some(&b'\r') {
229 buf.pop();
230 }
231 return Ok(BoundedRead::Line);
232 }
233 }
234}
235
236pub struct McpClient {
246 child: Option<Child>,
247 stdin: Option<ChildStdin>,
248 rx: Option<Receiver<ReaderMsg>>,
249 reader_thread: Option<JoinHandle<()>>,
250 next_id: u64,
251}
252
253enum ReaderMsg {
255 Line(String),
257 Eof,
259 Error(String),
264}
265
266impl McpClient {
267 pub fn start(command: &str, args: &[&str]) -> Result<Self, McpError> {
269 let mut child = Command::new(command)
270 .args(args)
271 .stdin(Stdio::piped())
272 .stdout(Stdio::piped())
273 .stderr(Stdio::piped())
274 .spawn()
275 .map_err(|e| McpError::Io(format!("spawn {}: {}", command, e)))?;
276
277 let stdin = child
278 .stdin
279 .take()
280 .ok_or_else(|| McpError::Io("missing child stdin".into()))?;
281 let stdout = child
282 .stdout
283 .take()
284 .ok_or_else(|| McpError::Io("missing child stdout".into()))?;
285
286 let (tx, rx) = mpsc::channel();
287 let reader_thread = thread::spawn(move || {
288 let mut reader = BufReader::new(stdout);
289 let mut buf: Vec<u8> = Vec::with_capacity(4096);
294 loop {
295 buf.clear();
296 match read_bounded_line(&mut reader, &mut buf, MAX_LINE_BYTES) {
297 Ok(BoundedRead::Eof) => {
298 let _ = tx.send(ReaderMsg::Eof);
299 return;
300 }
301 Ok(BoundedRead::Line) => {
302 match std::str::from_utf8(&buf) {
303 Ok(s) => {
304 if tx.send(ReaderMsg::Line(s.to_string())).is_err() {
305 return;
306 }
307 }
308 Err(_) => {
309 let _ = tx.send(ReaderMsg::Error(
310 "non-utf8 bytes on mcp stdout".into(),
311 ));
312 return;
313 }
314 }
315 }
316 Err(BoundedReadError::Overflow) => {
317 let _ = tx.send(ReaderMsg::Error(format!(
318 "mcp line exceeded {} bytes",
319 MAX_LINE_BYTES
320 )));
321 return;
322 }
323 Err(BoundedReadError::Io(_)) => {
324 let _ = tx.send(ReaderMsg::Eof);
325 return;
326 }
327 }
328 }
329 });
330
331 let mut this = McpClient {
332 child: Some(child),
333 stdin: Some(stdin),
334 rx: Some(rx),
335 reader_thread: Some(reader_thread),
336 next_id: 0,
337 };
338
339 let params = json!({
341 "protocolVersion": PROTOCOL_VERSION,
342 "capabilities": {},
343 "clientInfo": { "name": "agnt-mcp", "version": "0.3.1" }
344 });
345 let _ = this.request("initialize", Some(params))?;
346
347 let _ = this.notify("notifications/initialized", None);
349
350 Ok(this)
351 }
352
353 pub fn list_tools(&mut self) -> Result<Vec<McpToolInfo>, McpError> {
355 let result = self.request("tools/list", None)?;
356 let tools = result
357 .get("tools")
358 .and_then(|v| v.as_array())
359 .ok_or_else(|| McpError::Protocol("tools/list: missing tools array".into()))?;
360 let mut out = Vec::with_capacity(tools.len());
361 for t in tools {
362 let info: McpToolInfo = serde_json::from_value(t.clone())
363 .map_err(|e| McpError::Protocol(format!("tools/list entry: {}", e)))?;
364 out.push(info);
365 }
366 Ok(out)
367 }
368
369 pub fn call_tool(&mut self, name: &str, args: Value) -> Result<String, McpError> {
372 let span = tracing::info_span!("mcp.call", name = %name);
373 let _enter = span.enter();
374 let params = json!({ "name": name, "arguments": args });
375 let result = self.request("tools/call", Some(params))?;
376
377 if result
378 .get("isError")
379 .and_then(|v| v.as_bool())
380 .unwrap_or(false)
381 {
382 return Err(McpError::Protocol(format!(
383 "tools/call isError: {}",
384 result
385 )));
386 }
387
388 let content = result
389 .get("content")
390 .and_then(|v| v.as_array())
391 .ok_or_else(|| McpError::Protocol("tools/call: missing content".into()))?;
392
393 let mut buf = String::new();
394 for block in content {
395 if block.get("type").and_then(|v| v.as_str()) == Some("text") {
396 if let Some(text) = block.get("text").and_then(|v| v.as_str()) {
397 if !buf.is_empty() {
398 buf.push('\n');
399 }
400 buf.push_str(text);
401 }
402 }
403 }
404 Ok(buf)
405 }
406
407 pub fn shutdown(mut self) -> Result<(), McpError> {
410 self.shutdown_inner()
411 }
412
413 fn shutdown_inner(&mut self) -> Result<(), McpError> {
414 let _ = self.notify("shutdown", None);
416
417 drop(self.stdin.take());
419
420 if let Some(mut child) = self.child.take() {
421 match child.try_wait() {
423 Ok(Some(_)) => {}
424 _ => {
425 let _ = child.kill();
426 let _ = child.wait();
427 }
428 }
429 }
430 if let Some(handle) = self.reader_thread.take() {
432 let _ = handle.join();
433 }
434 self.rx.take();
435 Ok(())
436 }
437
438 fn alloc_id(&mut self) -> u64 {
443 self.next_id += 1;
444 self.next_id
445 }
446
447 fn request(&mut self, method: &str, params: Option<Value>) -> Result<Value, McpError> {
448 let id = self.alloc_id();
449 let req = JsonRpcRequest {
450 jsonrpc: "2.0",
451 id,
452 method,
453 params,
454 };
455 let mut line = serde_json::to_string(&req)
456 .map_err(|e| McpError::Protocol(format!("serialize request: {}", e)))?;
457 line.push('\n');
458
459 {
460 let stdin = self
461 .stdin
462 .as_mut()
463 .ok_or(McpError::Closed)?;
464 stdin
465 .write_all(line.as_bytes())
466 .map_err(|e| McpError::Io(format!("write: {}", e)))?;
467 stdin.flush().map_err(|e| McpError::Io(format!("flush: {}", e)))?;
468 }
469
470 self.await_response(id)
471 }
472
473 fn notify(&mut self, method: &str, params: Option<Value>) -> Result<(), McpError> {
474 let n = JsonRpcNotification {
475 jsonrpc: "2.0",
476 method,
477 params,
478 };
479 let mut line = serde_json::to_string(&n)
480 .map_err(|e| McpError::Protocol(format!("serialize notify: {}", e)))?;
481 line.push('\n');
482 let stdin = self.stdin.as_mut().ok_or(McpError::Closed)?;
483 stdin
484 .write_all(line.as_bytes())
485 .map_err(|e| McpError::Io(format!("write notify: {}", e)))?;
486 stdin
487 .flush()
488 .map_err(|e| McpError::Io(format!("flush notify: {}", e)))?;
489 Ok(())
490 }
491
492 fn await_response(&mut self, id: u64) -> Result<Value, McpError> {
493 let rx = self.rx.as_ref().ok_or(McpError::Closed)?;
494 loop {
495 match rx.recv_timeout(REQUEST_TIMEOUT) {
496 Ok(ReaderMsg::Line(line)) => {
497 let trimmed = line.trim();
498 if trimmed.is_empty() {
499 continue;
500 }
501 let resp: JsonRpcResponse = match serde_json::from_str(trimmed) {
502 Ok(r) => r,
503 Err(e) => {
504 return Err(McpError::Protocol(format!(
505 "parse response: {} (line: {})",
506 e, trimmed
507 )));
508 }
509 };
510 let resp_id = match &resp.id {
512 Some(Value::Number(n)) => n.as_u64(),
513 _ => None,
514 };
515 if resp_id != Some(id) {
516 continue;
517 }
518 if let Some(err) = resp.error {
519 return Err(McpError::Protocol(format!(
520 "jsonrpc error {}: {}",
521 err.code, err.message
522 )));
523 }
524 return Ok(resp.result.unwrap_or(Value::Null));
525 }
526 Ok(ReaderMsg::Eof) => return Err(McpError::Closed),
527 Ok(ReaderMsg::Error(msg)) => return Err(McpError::Protocol(msg)),
528 Err(RecvTimeoutError::Timeout) => return Err(McpError::Timeout),
529 Err(RecvTimeoutError::Disconnected) => return Err(McpError::Closed),
530 }
531 }
532 }
533}
534
535impl fmt::Debug for McpClient {
536 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
537 f.debug_struct("McpClient")
538 .field("next_id", &self.next_id)
539 .field("alive", &self.child.is_some())
540 .finish()
541 }
542}
543
544impl Drop for McpClient {
545 fn drop(&mut self) {
546 let _ = self.shutdown_inner();
547 }
548}
549
550pub struct McpTool {
559 client: Arc<Mutex<McpClient>>,
560 name: String,
561 description: String,
562 schema: Value,
563}
564
565impl McpTool {
566 pub fn new(client: Arc<Mutex<McpClient>>, info: McpToolInfo) -> Self {
567 Self {
568 client,
569 name: info.name,
570 description: info.description,
571 schema: info.input_schema,
572 }
573 }
574}
575
576impl agnt_core::tool::Tool for McpTool {
577 fn name(&self) -> &str {
578 &self.name
579 }
580
581 fn description(&self) -> &str {
582 &self.description
583 }
584
585 fn schema(&self) -> Value {
586 self.schema.clone()
587 }
588
589 fn call(&self, args: Value) -> Result<String, String> {
590 let span = tracing::info_span!("mcp.call", name = %self.name);
591 let _enter = span.enter();
592 let mut guard = self
593 .client
594 .lock()
595 .map_err(|e| format!("mcp client mutex poisoned: {}", e))?;
596 guard.call_tool(&self.name, args).map_err(|e| e.to_string())
597 }
598}
599
600#[cfg(test)]
605mod tests {
606 use super::*;
607
608 fn mock_server_script(responses: &[&str]) -> String {
615 let mut s = String::new();
623 for (i, r) in responses.iter().enumerate() {
624 let escaped = r.replace('\'', "'\\''");
625 s.push_str(&format!("read line; printf '%s\\n' '{}'\n", escaped));
626 if i == 0 {
627 s.push_str("read line\n");
628 }
629 }
630 s.push_str("sleep 0.2\n");
633 s
634 }
635
636 fn start_mock(responses: &[&str]) -> McpClient {
637 let script = mock_server_script(responses);
638 McpClient::start("/bin/sh", &["-c", &script]).expect("start mock")
639 }
640
641 #[test]
642 fn initialize_handshake_completes() {
643 let init = r#"{"jsonrpc":"2.0","id":1,"result":{"protocolVersion":"2024-11-05","capabilities":{}}}"#;
644 let client = start_mock(&[init]);
645 drop(client);
647 }
648
649 #[test]
650 fn list_tools_parses_server_response() {
651 let init = r#"{"jsonrpc":"2.0","id":1,"result":{"protocolVersion":"2024-11-05","capabilities":{}}}"#;
652 let list = r#"{"jsonrpc":"2.0","id":2,"result":{"tools":[{"name":"echo","description":"Echo text","inputSchema":{"type":"object","properties":{"text":{"type":"string"}}}}]}}"#;
653 let mut client = start_mock(&[init, list]);
654 let tools = client.list_tools().expect("list");
655 assert_eq!(tools.len(), 1);
656 assert_eq!(tools[0].name, "echo");
657 assert_eq!(tools[0].description, "Echo text");
658 assert_eq!(
659 tools[0].input_schema,
660 serde_json::json!({"type":"object","properties":{"text":{"type":"string"}}})
661 );
662 }
663
664 #[test]
665 fn call_tool_joins_text_content_blocks() {
666 let init = r#"{"jsonrpc":"2.0","id":1,"result":{"protocolVersion":"2024-11-05","capabilities":{}}}"#;
667 let call = r#"{"jsonrpc":"2.0","id":2,"result":{"content":[{"type":"text","text":"hello"},{"type":"text","text":"world"}],"isError":false}}"#;
668 let mut client = start_mock(&[init, call]);
669 let out = client
670 .call_tool("echo", serde_json::json!({"text":"hi"}))
671 .expect("call");
672 assert_eq!(out, "hello\nworld");
673 }
674
675 #[test]
676 fn call_tool_is_error_surfaces_protocol_error() {
677 let init = r#"{"jsonrpc":"2.0","id":1,"result":{"protocolVersion":"2024-11-05","capabilities":{}}}"#;
678 let call = r#"{"jsonrpc":"2.0","id":2,"result":{"content":[{"type":"text","text":"bad"}],"isError":true}}"#;
679 let mut client = start_mock(&[init, call]);
680 let err = client
681 .call_tool("echo", serde_json::json!({}))
682 .expect_err("should error");
683 assert!(matches!(err, McpError::Protocol(_)), "got {:?}", err);
684 }
685
686 #[test]
687 fn jsonrpc_error_response_maps_to_protocol_error() {
688 let init = r#"{"jsonrpc":"2.0","id":1,"result":{"protocolVersion":"2024-11-05","capabilities":{}}}"#;
689 let err_resp =
690 r#"{"jsonrpc":"2.0","id":2,"error":{"code":-32601,"message":"method not found"}}"#;
691 let mut client = start_mock(&[init, err_resp]);
692 let err = client.list_tools().expect_err("should error");
693 match err {
694 McpError::Protocol(m) => assert!(m.contains("method not found"), "got: {}", m),
695 other => panic!("expected Protocol, got {:?}", other),
696 }
697 }
698
699 #[test]
700 fn closed_pipe_yields_closed_error() {
701 let err = McpClient::start("/bin/sh", &["-c", "exit 0"]).expect_err("should fail");
704 assert!(
705 matches!(err, McpError::Closed | McpError::Io(_) | McpError::Protocol(_)),
706 "got {:?}",
707 err
708 );
709 }
710
711 #[test]
712 fn spawn_nonexistent_binary_is_io_error() {
713 let err = McpClient::start("/definitely/not/a/real/binary-xyz", &[])
714 .expect_err("should fail");
715 assert!(matches!(err, McpError::Io(_)), "got {:?}", err);
716 }
717
718 #[test]
719 fn mcp_tool_bridges_to_agnt_core_tool_trait() {
720 use agnt_core::tool::Tool;
721 let init = r#"{"jsonrpc":"2.0","id":1,"result":{"protocolVersion":"2024-11-05","capabilities":{}}}"#;
722 let call = r#"{"jsonrpc":"2.0","id":2,"result":{"content":[{"type":"text","text":"routed"}]}}"#;
723 let client = start_mock(&[init, call]);
724 let shared = Arc::new(Mutex::new(client));
725 let info = McpToolInfo {
726 name: "echo".into(),
727 description: "Echo text".into(),
728 input_schema: serde_json::json!({"type":"object"}),
729 };
730 let tool = McpTool::new(Arc::clone(&shared), info);
731 assert_eq!(tool.name(), "echo");
732 assert_eq!(tool.description(), "Echo text");
733 assert_eq!(tool.schema(), serde_json::json!({"type":"object"}));
734 let out = tool.call(serde_json::json!({})).expect("call");
735 assert_eq!(out, "routed");
736 }
737
738 #[test]
739 fn mcp_tool_info_deserializes_with_missing_description() {
740 let info: McpToolInfo = serde_json::from_value(serde_json::json!({
741 "name": "bare",
742 "inputSchema": {"type":"object"}
743 }))
744 .expect("deserialize");
745 assert_eq!(info.name, "bare");
746 assert_eq!(info.description, "");
747 }
748
749 #[test]
750 fn mcp_tool_info_deserializes_with_missing_schema() {
751 let info: McpToolInfo = serde_json::from_value(serde_json::json!({
752 "name": "bare"
753 }))
754 .expect("deserialize");
755 assert_eq!(info.input_schema, serde_json::json!({"type":"object"}));
756 }
757
758 #[test]
759 fn mcp_error_display_is_stable() {
760 assert_eq!(McpError::Timeout.to_string(), "mcp timeout");
761 assert_eq!(McpError::Closed.to_string(), "mcp channel closed");
762 assert!(McpError::Io("x".into()).to_string().contains("io"));
763 assert!(McpError::Protocol("x".into()).to_string().contains("protocol"));
764 }
765
766 #[test]
769 fn bounded_reader_accepts_short_line() {
770 let input: &[u8] = b"hello\n";
771 let mut r = std::io::BufReader::new(input);
772 let mut buf = Vec::new();
773 let outcome = read_bounded_line(&mut r, &mut buf, 1024).unwrap_or_else(|_| {
774 panic!("should accept short line")
775 });
776 assert!(matches!(outcome, BoundedRead::Line));
777 assert_eq!(buf, b"hello");
778 }
779
780 #[test]
781 fn bounded_reader_strips_crlf() {
782 let input: &[u8] = b"crlf\r\n";
783 let mut r = std::io::BufReader::new(input);
784 let mut buf = Vec::new();
785 read_bounded_line(&mut r, &mut buf, 1024).expect("ok");
786 assert_eq!(buf, b"crlf");
787 }
788
789 #[test]
790 fn bounded_reader_reports_eof_on_empty() {
791 let input: &[u8] = b"";
792 let mut r = std::io::BufReader::new(input);
793 let mut buf = Vec::new();
794 match read_bounded_line(&mut r, &mut buf, 1024).expect("ok") {
795 BoundedRead::Eof => {}
796 BoundedRead::Line => panic!("expected EOF"),
797 }
798 }
799
800 #[test]
801 fn bounded_reader_rejects_oversized_line() {
802 let big: Vec<u8> = vec![b'x'; 32 * 1024];
804 let mut r = std::io::BufReader::new(&big[..]);
805 let mut buf = Vec::new();
806 let err = read_bounded_line(&mut r, &mut buf, 8 * 1024);
807 assert!(matches!(err, Err(BoundedReadError::Overflow)));
808 }
809
810 #[test]
811 fn bounded_reader_rejects_line_just_over_limit() {
812 let mut big: Vec<u8> = vec![b'a'; 1025];
814 big.push(b'\n');
815 let mut r = std::io::BufReader::new(&big[..]);
816 let mut buf = Vec::new();
817 let err = read_bounded_line(&mut r, &mut buf, 1024);
818 assert!(matches!(err, Err(BoundedReadError::Overflow)));
819 }
820
821 #[test]
822 fn bounded_reader_handles_multi_line_stream() {
823 let input: &[u8] = b"one\ntwo\nthree\n";
824 let mut r = std::io::BufReader::new(input);
825 let mut buf = Vec::new();
826 read_bounded_line(&mut r, &mut buf, 1024).expect("one");
827 assert_eq!(buf, b"one");
828 buf.clear();
829 read_bounded_line(&mut r, &mut buf, 1024).expect("two");
830 assert_eq!(buf, b"two");
831 buf.clear();
832 read_bounded_line(&mut r, &mut buf, 1024).expect("three");
833 assert_eq!(buf, b"three");
834 buf.clear();
835 match read_bounded_line(&mut r, &mut buf, 1024).expect("eof") {
836 BoundedRead::Eof => {}
837 BoundedRead::Line => panic!("expected EOF after exhausting input"),
838 }
839 }
840}