use std::process::Stdio;
use std::sync::atomic::{AtomicI64, Ordering};
use std::sync::Arc;
use aonyx_core::{AonyxError, Result, SafetyClass, ToolCall, ToolHandler, ToolResult};
use async_trait::async_trait;
use serde_json::{json, Value};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, ChildStdin, ChildStdout, Command};
use tokio::sync::Mutex;
const MCP_PROTOCOL_VERSION: &str = "2024-11-05";
#[derive(Debug, Clone)]
pub struct McpToolDef {
pub name: String,
pub description: String,
pub input_schema: Value,
}
#[async_trait]
pub trait McpTransport: Send + Sync {
fn server_name(&self) -> &str;
async fn list_tools(&self) -> Result<Vec<McpToolDef>>;
async fn call_tool(&self, name: &str, args: Value) -> Result<Value>;
}
pub struct StdioMcpClient {
server_name: String,
_child: Child,
io: Mutex<ClientIo>,
next_id: AtomicI64,
}
struct ClientIo {
stdin: ChildStdin,
stdout: BufReader<ChildStdout>,
}
impl StdioMcpClient {
pub async fn connect(
server_name: impl Into<String>,
command: &str,
args: &[String],
) -> Result<Self> {
let mut child = Command::new(command)
.args(args)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::null())
.spawn()
.map_err(|e| AonyxError::Mcp(format!("spawn {command}: {e}")))?;
let stdin = child
.stdin
.take()
.ok_or_else(|| AonyxError::Mcp("child stdin unavailable".into()))?;
let stdout = child
.stdout
.take()
.ok_or_else(|| AonyxError::Mcp("child stdout unavailable".into()))?;
let client = Self {
server_name: server_name.into(),
_child: child,
io: Mutex::new(ClientIo {
stdin,
stdout: BufReader::new(stdout),
}),
next_id: AtomicI64::new(1),
};
client.handshake().await?;
Ok(client)
}
async fn handshake(&self) -> Result<()> {
let params = json!({
"protocolVersion": MCP_PROTOCOL_VERSION,
"capabilities": {},
"clientInfo": { "name": "aonyx-agent", "version": env!("CARGO_PKG_VERSION") },
});
let _ = self.request("initialize", params).await?;
self.notify("notifications/initialized", json!({})).await?;
Ok(())
}
async fn request(&self, method: &str, params: Value) -> Result<Value> {
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let line = build_request(id, method, ¶ms);
let mut io = self.io.lock().await;
io.stdin
.write_all(line.as_bytes())
.await
.map_err(|e| AonyxError::Mcp(format!("write {method}: {e}")))?;
io.stdin
.flush()
.await
.map_err(|e| AonyxError::Mcp(format!("flush {method}: {e}")))?;
let mut buf = String::new();
loop {
buf.clear();
let n = io
.stdout
.read_line(&mut buf)
.await
.map_err(|e| AonyxError::Mcp(format!("read {method}: {e}")))?;
if n == 0 {
return Err(AonyxError::Mcp(format!(
"{method}: server closed the connection"
)));
}
let trimmed = buf.trim();
if trimmed.is_empty() {
continue;
}
match match_response(trimmed, id) {
ResponseMatch::Result(v) => return Ok(v),
ResponseMatch::Error(msg) => {
return Err(AonyxError::Mcp(format!("{method}: {msg}")))
}
ResponseMatch::Other => continue,
}
}
}
async fn notify(&self, method: &str, params: Value) -> Result<()> {
let line = build_notification(method, ¶ms);
let mut io = self.io.lock().await;
io.stdin
.write_all(line.as_bytes())
.await
.map_err(|e| AonyxError::Mcp(format!("notify {method}: {e}")))?;
io.stdin
.flush()
.await
.map_err(|e| AonyxError::Mcp(format!("flush notify {method}: {e}")))?;
Ok(())
}
}
#[async_trait]
impl McpTransport for StdioMcpClient {
fn server_name(&self) -> &str {
&self.server_name
}
async fn list_tools(&self) -> Result<Vec<McpToolDef>> {
let resp = self.request("tools/list", json!({})).await?;
Ok(parse_tools_list(&resp))
}
async fn call_tool(&self, name: &str, args: Value) -> Result<Value> {
let params = json!({ "name": name, "arguments": args });
let resp = self.request("tools/call", params).await?;
Ok(extract_call_result(&resp))
}
}
pub struct HttpMcpClient {
server_name: String,
url: String,
bearer: Option<String>,
http: reqwest::Client,
next_id: AtomicI64,
session_id: Mutex<Option<String>>,
}
impl HttpMcpClient {
pub async fn connect(
server_name: impl Into<String>,
url: impl Into<String>,
bearer: Option<String>,
) -> Result<Self> {
let client = Self {
server_name: server_name.into(),
url: url.into(),
bearer,
http: reqwest::Client::new(),
next_id: AtomicI64::new(1),
session_id: Mutex::new(None),
};
let params = json!({
"protocolVersion": MCP_PROTOCOL_VERSION,
"capabilities": {},
"clientInfo": { "name": "aonyx-agent", "version": env!("CARGO_PKG_VERSION") },
});
let _ = client.request("initialize", params).await?;
Ok(client)
}
async fn request(&self, method: &str, params: Value) -> Result<Value> {
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let body = json!({
"jsonrpc": "2.0",
"id": id,
"method": method,
"params": params,
});
let mut req = self
.http
.post(&self.url)
.header("content-type", "application/json")
.header("accept", "application/json, text/event-stream")
.json(&body);
if let Some(token) = &self.bearer {
req = req.header("authorization", format!("Bearer {token}"));
}
if let Some(sid) = self.session_id.lock().await.clone() {
req = req.header("mcp-session-id", sid);
}
let resp = req
.send()
.await
.map_err(|e| AonyxError::Mcp(format!("{method} POST: {e}")))?;
if !resp.status().is_success() {
let status = resp.status();
let txt = resp.text().await.unwrap_or_default();
return Err(AonyxError::Mcp(format!("{method}: HTTP {status}: {txt}")));
}
if let Some(sid) = resp
.headers()
.get("mcp-session-id")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
{
*self.session_id.lock().await = Some(sid);
}
let content_type = resp
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_string();
let text = resp
.text()
.await
.map_err(|e| AonyxError::Mcp(format!("{method} body: {e}")))?;
let payload = if content_type.contains("text/event-stream") {
extract_sse_json(&text)
.ok_or_else(|| AonyxError::Mcp(format!("{method}: no JSON in SSE stream")))?
} else {
serde_json::from_str::<Value>(text.trim())
.map_err(|e| AonyxError::Mcp(format!("{method}: bad JSON: {e}")))?
};
match match_response_value(&payload, id) {
ResponseMatch::Result(v) => Ok(v),
ResponseMatch::Error(msg) => Err(AonyxError::Mcp(format!("{method}: {msg}"))),
ResponseMatch::Other => Err(AonyxError::Mcp(format!("{method}: unexpected response"))),
}
}
}
#[async_trait]
impl McpTransport for HttpMcpClient {
fn server_name(&self) -> &str {
&self.server_name
}
async fn list_tools(&self) -> Result<Vec<McpToolDef>> {
let resp = self.request("tools/list", json!({})).await?;
Ok(parse_tools_list(&resp))
}
async fn call_tool(&self, name: &str, args: Value) -> Result<Value> {
let params = json!({ "name": name, "arguments": args });
let resp = self.request("tools/call", params).await?;
Ok(extract_call_result(&resp))
}
}
pub struct McpToolHandler {
qualified_name: String,
remote_name: String,
schema: Value,
transport: Arc<dyn McpTransport>,
}
impl McpToolHandler {
pub fn new(transport: Arc<dyn McpTransport>, def: McpToolDef) -> Self {
let qualified_name = format!("{}__{}", transport.server_name(), def.name);
Self {
qualified_name,
remote_name: def.name,
schema: def.input_schema,
transport,
}
}
}
#[async_trait]
impl ToolHandler for McpToolHandler {
fn name(&self) -> &str {
&self.qualified_name
}
fn schema(&self) -> Value {
self.schema.clone()
}
fn classify(&self) -> SafetyClass {
SafetyClass::Caution
}
async fn invoke(&self, call: ToolCall) -> Result<ToolResult> {
let output = self
.transport
.call_tool(&self.remote_name, call.args)
.await?;
Ok(ToolResult {
call_id: call.id,
output,
error: None,
})
}
}
pub async fn connect_and_register(
registry: &mut aonyx_tools::ToolRegistry,
server_name: &str,
command: &str,
args: &[String],
) -> Result<usize> {
let client: Arc<dyn McpTransport> =
Arc::new(StdioMcpClient::connect(server_name, command, args).await?);
register_transport(registry, client).await
}
pub async fn connect_http_and_register(
registry: &mut aonyx_tools::ToolRegistry,
server_name: &str,
url: &str,
bearer: Option<String>,
) -> Result<usize> {
let client: Arc<dyn McpTransport> =
Arc::new(HttpMcpClient::connect(server_name, url, bearer).await?);
register_transport(registry, client).await
}
async fn register_transport(
registry: &mut aonyx_tools::ToolRegistry,
transport: Arc<dyn McpTransport>,
) -> Result<usize> {
let tools = transport.list_tools().await?;
let count = tools.len();
for def in tools {
registry.register(Arc::new(McpToolHandler::new(Arc::clone(&transport), def)));
}
Ok(count)
}
fn build_request(id: i64, method: &str, params: &Value) -> String {
let msg = json!({
"jsonrpc": "2.0",
"id": id,
"method": method,
"params": params,
});
format!("{msg}\n")
}
fn build_notification(method: &str, params: &Value) -> String {
let msg = json!({
"jsonrpc": "2.0",
"method": method,
"params": params,
});
format!("{msg}\n")
}
enum ResponseMatch {
Result(Value),
Error(String),
Other,
}
fn match_response(line: &str, expected_id: i64) -> ResponseMatch {
let Ok(v) = serde_json::from_str::<Value>(line) else {
return ResponseMatch::Other;
};
match_response_value(&v, expected_id)
}
fn match_response_value(v: &Value, expected_id: i64) -> ResponseMatch {
let id_matches = v.get("id").and_then(|i| i.as_i64()) == Some(expected_id);
if !id_matches {
return ResponseMatch::Other;
}
if let Some(err) = v.get("error") {
let msg = err
.get("message")
.and_then(|m| m.as_str())
.unwrap_or("unknown error")
.to_string();
return ResponseMatch::Error(msg);
}
ResponseMatch::Result(v.get("result").cloned().unwrap_or(Value::Null))
}
fn extract_sse_json(body: &str) -> Option<Value> {
let mut data = String::new();
let flush = |buf: &mut String| -> Option<Value> {
if buf.is_empty() {
return None;
}
let parsed = serde_json::from_str::<Value>(buf.trim()).ok();
buf.clear();
parsed.filter(|v| v.get("id").is_some())
};
for line in body.lines() {
if let Some(payload) = line.strip_prefix("data:") {
if !data.is_empty() {
data.push('\n');
}
data.push_str(payload.trim_start());
} else if line.trim().is_empty() {
if let Some(v) = flush(&mut data) {
return Some(v);
}
}
}
flush(&mut data)
}
fn parse_tools_list(result: &Value) -> Vec<McpToolDef> {
result
.get("tools")
.and_then(|t| t.as_array())
.map(|arr| {
arr.iter()
.filter_map(|t| {
let name = t.get("name")?.as_str()?.to_string();
let description = t
.get("description")
.and_then(|d| d.as_str())
.unwrap_or("")
.to_string();
let input_schema = t
.get("inputSchema")
.cloned()
.unwrap_or_else(|| json!({ "type": "object" }));
Some(McpToolDef {
name,
description,
input_schema,
})
})
.collect()
})
.unwrap_or_default()
}
fn extract_call_result(result: &Value) -> Value {
let Some(content) = result.get("content").and_then(|c| c.as_array()) else {
return result.clone();
};
let text = content
.iter()
.filter_map(|c| c.get("text").and_then(|t| t.as_str()))
.collect::<Vec<_>>()
.join("\n");
if text.is_empty() {
result.clone()
} else {
Value::String(text)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn build_request_is_newline_terminated_jsonrpc() {
let line = build_request(7, "tools/list", &json!({}));
assert!(line.ends_with('\n'));
let v: Value = serde_json::from_str(line.trim()).unwrap();
assert_eq!(v["jsonrpc"], "2.0");
assert_eq!(v["id"], 7);
assert_eq!(v["method"], "tools/list");
}
#[test]
fn build_notification_has_no_id() {
let line = build_notification("notifications/initialized", &json!({}));
let v: Value = serde_json::from_str(line.trim()).unwrap();
assert!(v.get("id").is_none());
assert_eq!(v["method"], "notifications/initialized");
}
#[test]
fn match_response_returns_result_for_matching_id() {
let line = r#"{"jsonrpc":"2.0","id":3,"result":{"ok":true}}"#;
match match_response(line, 3) {
ResponseMatch::Result(v) => assert_eq!(v["ok"], true),
_ => panic!("expected result"),
}
}
#[test]
fn match_response_skips_other_ids_and_notifications() {
assert!(matches!(
match_response(r#"{"jsonrpc":"2.0","id":99,"result":{}}"#, 3),
ResponseMatch::Other
));
assert!(matches!(
match_response(r#"{"jsonrpc":"2.0","method":"log","params":{}}"#, 3),
ResponseMatch::Other
));
assert!(matches!(
match_response("not json", 3),
ResponseMatch::Other
));
}
#[test]
fn match_response_surfaces_errors() {
let line = r#"{"jsonrpc":"2.0","id":3,"error":{"code":-32601,"message":"no such method"}}"#;
match match_response(line, 3) {
ResponseMatch::Error(m) => assert!(m.contains("no such method")),
_ => panic!("expected error"),
}
}
#[test]
fn parse_tools_list_extracts_defs() {
let result = json!({
"tools": [
{ "name": "search", "description": "web search",
"inputSchema": { "type": "object", "properties": { "q": { "type": "string" } } } },
{ "name": "fetch" }
]
});
let tools = parse_tools_list(&result);
assert_eq!(tools.len(), 2);
assert_eq!(tools[0].name, "search");
assert_eq!(tools[0].description, "web search");
assert_eq!(tools[1].name, "fetch");
assert_eq!(tools[1].description, "");
assert_eq!(tools[1].input_schema["type"], "object");
}
#[test]
fn parse_tools_list_handles_missing_tools_key() {
assert!(parse_tools_list(&json!({})).is_empty());
}
#[test]
fn extract_call_result_joins_text_content() {
let result = json!({
"content": [
{ "type": "text", "text": "line one" },
{ "type": "text", "text": "line two" }
]
});
assert_eq!(extract_call_result(&result), json!("line one\nline two"));
}
#[test]
fn extract_call_result_falls_back_to_raw() {
let result = json!({ "data": 42 });
assert_eq!(extract_call_result(&result), result);
}
#[test]
fn extract_sse_json_pulls_the_response_frame() {
let body =
"event: message\ndata: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"ok\":true}}\n\n";
let v = extract_sse_json(body).expect("frame");
assert_eq!(v["id"], 1);
assert_eq!(v["result"]["ok"], true);
}
#[test]
fn extract_sse_json_joins_multiline_data_and_skips_keepalives() {
let body = ": keep-alive\ndata: {\"jsonrpc\":\"2.0\",\ndata: \"id\":5,\"result\":{}}\n\n";
let v = extract_sse_json(body).expect("frame");
assert_eq!(v["id"], 5);
}
#[test]
fn extract_sse_json_none_when_no_data() {
assert!(extract_sse_json(": only a comment\n\n").is_none());
assert!(extract_sse_json("").is_none());
}
#[test]
fn match_response_value_matches_parsed_json() {
let v = json!({"jsonrpc":"2.0","id":4,"result":{"x":1}});
match match_response_value(&v, 4) {
ResponseMatch::Result(r) => assert_eq!(r["x"], 1),
_ => panic!("expected result"),
}
assert!(matches!(match_response_value(&v, 99), ResponseMatch::Other));
}
#[test]
fn mcp_tool_handler_qualifies_the_name() {
let def = McpToolDef {
name: "search".into(),
description: String::new(),
input_schema: json!({ "type": "object" }),
};
let qualified = format!("{}__{}", "brave", def.name);
assert_eq!(qualified, "brave__search");
}
}