use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, ChildStdin, ChildStdout, Command};
use tokio::sync::Mutex;
use traitclaw_core::{Error, Result};
use crate::protocol::{JsonRpcRequest, JsonRpcResponse, ToolsListResponse};
use crate::tool::McpTool;
pub struct McpServer {
stdin: Arc<Mutex<ChildStdin>>,
stdout: Arc<Mutex<BufReader<ChildStdout>>>,
next_id: AtomicU64,
tools: Vec<Arc<McpTool>>,
#[allow(dead_code)]
child: Child,
}
impl McpServer {
pub async fn stdio(program: &str, args: &[&str]) -> Result<Self> {
let mut child = Command::new(program)
.args(args)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::null())
.spawn()
.map_err(|e| Error::Runtime(format!("Failed to spawn MCP server: {e}")))?;
let stdin = child
.stdin
.take()
.ok_or_else(|| Error::Runtime("Failed to capture stdin".into()))?;
let stdout = child
.stdout
.take()
.ok_or_else(|| Error::Runtime("Failed to capture stdout".into()))?;
let stdin = Arc::new(Mutex::new(stdin));
let stdout = Arc::new(Mutex::new(BufReader::new(stdout)));
let mut server = Self {
stdin,
stdout,
next_id: AtomicU64::new(1),
tools: Vec::new(),
child,
};
server.initialize().await?;
server.discover_tools().await?;
Ok(server)
}
#[must_use]
pub fn tools(&self) -> &[Arc<McpTool>] {
&self.tools
}
#[must_use]
pub fn erased_tools(&self) -> Vec<Arc<dyn traitclaw_core::traits::tool::ErasedTool>> {
self.tools
.iter()
.map(|t| Arc::clone(t) as Arc<dyn traitclaw_core::traits::tool::ErasedTool>)
.collect()
}
async fn initialize(&mut self) -> Result<()> {
let req = JsonRpcRequest::new(
self.next_id(),
"initialize",
Some(serde_json::json!({
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {
"name": "traitclaw",
"version": env!("CARGO_PKG_VERSION")
}
})),
);
let _resp = self.send_request(req).await?;
let notif = serde_json::json!({
"jsonrpc": "2.0",
"method": "notifications/initialized"
});
let mut stdin = self.stdin.lock().await;
let line = serde_json::to_string(¬if)
.map_err(|e| Error::Runtime(format!("JSON error: {e}")))?;
stdin
.write_all(line.as_bytes())
.await
.map_err(|e| Error::Runtime(format!("Write error: {e}")))?;
stdin
.write_all(b"\n")
.await
.map_err(|e| Error::Runtime(format!("Write error: {e}")))?;
stdin
.flush()
.await
.map_err(|e| Error::Runtime(format!("Flush error: {e}")))?;
Ok(())
}
async fn discover_tools(&mut self) -> Result<()> {
let req = JsonRpcRequest::new(self.next_id(), "tools/list", None);
let resp = self.send_request(req).await?;
let result = resp
.result
.ok_or_else(|| Error::Runtime("No result in tools/list response".into()))?;
let tools_resp: ToolsListResponse = serde_json::from_value(result)
.map_err(|e| Error::Runtime(format!("Parse tools/list error: {e}")))?;
self.tools = tools_resp
.tools
.into_iter()
.map(|def| {
Arc::new(McpTool::new(
def.name,
def.description,
def.input_schema,
Arc::clone(&self.stdin),
Arc::clone(&self.stdout),
self.next_id.load(Ordering::Relaxed),
))
})
.collect();
Ok(())
}
async fn send_request(&self, req: JsonRpcRequest) -> Result<JsonRpcResponse> {
let line =
serde_json::to_string(&req).map_err(|e| Error::Runtime(format!("JSON error: {e}")))?;
let mut stdin = self.stdin.lock().await;
stdin
.write_all(line.as_bytes())
.await
.map_err(|e| Error::Runtime(format!("Write error: {e}")))?;
stdin
.write_all(b"\n")
.await
.map_err(|e| Error::Runtime(format!("Write error: {e}")))?;
stdin
.flush()
.await
.map_err(|e| Error::Runtime(format!("Flush error: {e}")))?;
drop(stdin);
let mut stdout = self.stdout.lock().await;
let mut line = String::new();
stdout
.read_line(&mut line)
.await
.map_err(|e| Error::Runtime(format!("Read error: {e}")))?;
let resp: JsonRpcResponse = serde_json::from_str(&line)
.map_err(|e| Error::Runtime(format!("Parse response error: {e}")))?;
if let Some(err) = resp.error {
return Err(Error::Runtime(format!("MCP error: {}", err.message)));
}
Ok(resp)
}
fn next_id(&self) -> u64 {
self.next_id.fetch_add(1, Ordering::Relaxed)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_json_rpc_request_serialization() {
let req = JsonRpcRequest::new(1, "tools/list", None);
let json = serde_json::to_value(&req).unwrap();
assert_eq!(json["jsonrpc"], "2.0");
assert_eq!(json["id"], 1);
assert_eq!(json["method"], "tools/list");
assert!(json.get("params").is_none());
}
#[test]
fn test_json_rpc_request_with_params() {
let req = JsonRpcRequest::new(2, "tools/call", Some(serde_json::json!({"name": "test"})));
let json = serde_json::to_value(&req).unwrap();
assert_eq!(json["params"]["name"], "test");
}
}