use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
use async_trait::async_trait;
use parking_lot::Mutex as ParkingMutex;
use serde_json::{json, Value};
use tokio::sync::{oneshot, Mutex};
use tokio::task::JoinHandle;
use tokio::time::timeout;
use tracing::{debug, trace, warn};
use crate::error::{Error, Result};
use crate::tools::{Tool, ToolContext, ToolRunner};
use crate::types::McpServerConfig;
mod protocol;
mod transport;
pub use protocol::McpToolDecl;
use protocol::{
ClientInfo, InitializeParams, InitializeResult, Notification, Request, Response,
ToolCallParams, ToolCallResult, ToolsListResult, MCP_PROTOCOL_VERSION,
};
use transport::StdioTransport;
const CALL_TIMEOUT: Duration = Duration::from_secs(60);
const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(15);
pub struct McpClient {
transport: Arc<StdioTransport>,
pending: Arc<Mutex<HashMap<u64, oneshot::Sender<Response>>>>,
next_id: AtomicU64,
dispatcher: ParkingMutex<Option<JoinHandle<()>>>,
pub server_name: String,
pub tools: Vec<McpToolDecl>,
}
impl McpClient {
pub async fn connect_stdio(command: &str, args: &[String]) -> Result<Arc<Self>> {
let transport = Arc::new(StdioTransport::spawn(command, args).await?);
let pending: Arc<Mutex<HashMap<u64, oneshot::Sender<Response>>>> =
Arc::new(Mutex::new(HashMap::new()));
let dispatcher = spawn_dispatcher(transport.clone(), pending.clone());
let next_id = AtomicU64::new(1);
let init_result = match timeout(
HANDSHAKE_TIMEOUT,
initialize_via(&transport, &pending, &next_id),
)
.await
{
Ok(Ok(r)) => r,
Ok(Err(e)) => {
dispatcher.abort();
transport.shutdown().await;
return Err(e);
}
Err(_) => {
dispatcher.abort();
transport.shutdown().await;
return Err(Error::Timeout(HANDSHAKE_TIMEOUT));
}
};
let tools = match timeout(
HANDSHAKE_TIMEOUT,
list_tools_via(&transport, &pending, &next_id),
)
.await
{
Ok(Ok(t)) => t,
Ok(Err(e)) => {
dispatcher.abort();
transport.shutdown().await;
return Err(e);
}
Err(_) => {
dispatcher.abort();
transport.shutdown().await;
return Err(Error::Timeout(HANDSHAKE_TIMEOUT));
}
};
let server_name = init_result
.server_info
.map(|s| s.name)
.unwrap_or_else(|| command.to_string());
debug!(server = %server_name, count = tools.len(), "mcp connected");
Ok(Arc::new(Self {
transport,
pending,
next_id,
dispatcher: ParkingMutex::new(Some(dispatcher)),
server_name,
tools,
}))
}
pub async fn call_tool(&self, name: &str, arguments: Value) -> Result<Value> {
let params = serde_json::to_value(ToolCallParams {
name,
arguments,
})
.map_err(|e| Error::other(format!("tools/call encode: {e}")))?;
let resp = timeout(CALL_TIMEOUT, self.request("tools/call", Some(params)))
.await
.map_err(|_| Error::Timeout(CALL_TIMEOUT))??;
let result: ToolCallResult = serde_json::from_value(
resp.ok_or_else(|| Error::other("tools/call returned no result"))?,
)
.map_err(|e| Error::other(format!("tools/call decode: {e}")))?;
Ok(result.flatten())
}
async fn request(&self, method: &str, params: Option<Value>) -> Result<Option<Value>> {
request_via(&self.transport, &self.pending, &self.next_id, method, params).await
}
pub async fn shutdown(&self) {
let h = self.dispatcher.lock().take();
if let Some(h) = h {
h.abort();
}
self.transport.shutdown().await;
}
}
struct PendingGuard {
pending: Arc<Mutex<HashMap<u64, oneshot::Sender<Response>>>>,
id: u64,
armed: bool,
}
impl PendingGuard {
fn disarm(mut self) {
self.armed = false;
}
}
impl Drop for PendingGuard {
fn drop(&mut self) {
if !self.armed {
return;
}
let pending = self.pending.clone();
let id = self.id;
tokio::spawn(async move {
pending.lock().await.remove(&id);
});
}
}
async fn request_via(
transport: &StdioTransport,
pending: &Arc<Mutex<HashMap<u64, oneshot::Sender<Response>>>>,
next_id: &AtomicU64,
method: &str,
params: Option<Value>,
) -> Result<Option<Value>> {
let id = next_id.fetch_add(1, Ordering::Relaxed);
let (tx, rx) = oneshot::channel();
pending.lock().await.insert(id, tx);
let guard = PendingGuard {
pending: pending.clone(),
id,
armed: true,
};
let req = Request::new(id, method, params);
let payload = serde_json::to_string(&req)
.map_err(|e| Error::other(format!("mcp encode: {e}")))?;
trace!(method, %payload, "mcp request");
if let Err(e) = transport.send(&payload).await {
pending.lock().await.remove(&id);
guard.disarm();
return Err(e);
}
let resp = match rx.await {
Ok(r) => {
guard.disarm();
r
}
Err(_) => return Err(Error::Closed),
};
if let Some(err) = resp.error {
return Err(Error::other(format!(
"mcp '{method}' rpc error {}: {}",
err.code, err.message
)));
}
Ok(resp.result)
}
async fn initialize_via(
transport: &StdioTransport,
pending: &Arc<Mutex<HashMap<u64, oneshot::Sender<Response>>>>,
next_id: &AtomicU64,
) -> Result<InitializeResult> {
let params = serde_json::to_value(InitializeParams {
protocol_version: MCP_PROTOCOL_VERSION,
capabilities: json!({}),
client_info: ClientInfo {
name: "localharness",
version: env!("CARGO_PKG_VERSION"),
},
})
.map_err(|e| Error::other(format!("mcp initialize encode: {e}")))?;
let resp = request_via(transport, pending, next_id, "initialize", Some(params)).await?;
let result: InitializeResult = serde_json::from_value(
resp.ok_or_else(|| Error::other("mcp initialize returned no result"))?,
)
.map_err(|e| Error::other(format!("mcp initialize decode: {e}")))?;
let notif = Notification::new("notifications/initialized", None);
let payload = serde_json::to_string(¬if)
.map_err(|e| Error::other(format!("mcp notify encode: {e}")))?;
transport.send(&payload).await?;
Ok(result)
}
async fn list_tools_via(
transport: &StdioTransport,
pending: &Arc<Mutex<HashMap<u64, oneshot::Sender<Response>>>>,
next_id: &AtomicU64,
) -> Result<Vec<McpToolDecl>> {
let resp = request_via(transport, pending, next_id, "tools/list", None).await?;
let result: ToolsListResult = serde_json::from_value(
resp.ok_or_else(|| Error::other("tools/list returned no result"))?,
)
.map_err(|e| Error::other(format!("tools/list decode: {e}")))?;
Ok(result.tools)
}
fn spawn_dispatcher(
transport: Arc<StdioTransport>,
pending: Arc<Mutex<HashMap<u64, oneshot::Sender<Response>>>>,
) -> JoinHandle<()> {
tokio::spawn(async move {
loop {
let line = {
let mut rx = transport.inbound.lock().await;
match rx.recv().await {
Some(l) => l,
None => return,
}
};
route_line(&line, &pending).await;
}
})
}
async fn route_line(line: &str, pending: &Mutex<HashMap<u64, oneshot::Sender<Response>>>) {
let resp: Response = match serde_json::from_str(line) {
Ok(r) => r,
Err(e) => {
trace!(?e, %line, "mcp: undecodable line (likely a notification)");
return;
}
};
if let Some(id) = resp.id {
if let Some(tx) = pending.lock().await.remove(&id) {
let _ = tx.send(resp);
}
}
}
#[derive(Default)]
pub struct McpBridge {
clients: Vec<Arc<McpClient>>,
}
impl McpBridge {
pub fn new() -> Self {
Self::default()
}
pub async fn connect(&mut self, config: &McpServerConfig) -> Result<()> {
let client = match config {
McpServerConfig::Stdio { command, args } => {
McpClient::connect_stdio(command, args).await?
}
McpServerConfig::Sse { .. } => {
return Err(Error::config(
"MCP SSE transport not implemented yet (use Stdio)",
))
}
McpServerConfig::Http { .. } => {
return Err(Error::config(
"MCP HTTP transport not implemented yet (use Stdio)",
))
}
};
self.clients.push(client);
Ok(())
}
pub fn register_into(&self, runner: &ToolRunner) -> Vec<String> {
let existing = runner.names();
let mut registered = Vec::new();
for client in &self.clients {
for decl in &client.tools {
if existing.iter().any(|n| n == &decl.name) {
debug!(name = %decl.name, "mcp: skipping (already registered)");
continue;
}
let tool: Arc<dyn Tool> = Arc::new(McpTool {
client: client.clone(),
decl: decl.clone(),
});
runner.register(tool);
registered.push(decl.name.clone());
}
}
registered
}
pub async fn shutdown(&self) {
for c in &self.clients {
c.shutdown().await;
}
}
}
struct McpTool {
client: Arc<McpClient>,
decl: McpToolDecl,
}
#[async_trait]
impl Tool for McpTool {
fn name(&self) -> &str {
&self.decl.name
}
fn description(&self) -> &str {
self.decl
.description
.as_deref()
.unwrap_or("(no description provided by MCP server)")
}
fn input_schema(&self) -> Value {
self.decl
.input_schema
.clone()
.unwrap_or_else(|| json!({ "type": "object", "properties": {} }))
}
async fn execute(&self, args: Value, _ctx: Option<Arc<ToolContext>>) -> Result<Value> {
match self.client.call_tool(&self.decl.name, args).await {
Ok(v) => Ok(v),
Err(e) => {
warn!(
server = %self.client.server_name,
tool = %self.decl.name,
error = %e,
"mcp tool call failed"
);
Err(e)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn pending_map() -> Arc<Mutex<HashMap<u64, oneshot::Sender<Response>>>> {
Arc::new(Mutex::new(HashMap::new()))
}
async fn register(
pending: &Mutex<HashMap<u64, oneshot::Sender<Response>>>,
id: u64,
) -> oneshot::Receiver<Response> {
let (tx, rx) = oneshot::channel();
pending.lock().await.insert(id, tx);
rx
}
#[tokio::test]
async fn routes_result_response_to_waiter() {
let pending = pending_map();
let rx = register(&pending, 1).await;
route_line(r#"{"jsonrpc":"2.0","id":1,"result":{"v":42}}"#, &pending).await;
let resp = rx.await.expect("delivered");
assert_eq!(resp.result, Some(serde_json::json!({"v": 42})));
assert!(pending.lock().await.is_empty());
}
#[tokio::test]
async fn routes_error_response_to_waiter() {
let pending = pending_map();
let rx = register(&pending, 5).await;
route_line(
r#"{"jsonrpc":"2.0","id":5,"error":{"code":-32000,"message":"nope"}}"#,
&pending,
)
.await;
let resp = rx.await.expect("delivered");
let err = resp.error.expect("error present");
assert_eq!(err.code, -32000);
assert_eq!(err.message, "nope");
}
#[tokio::test]
async fn out_of_order_responses_match_by_id() {
let pending = pending_map();
let rx1 = register(&pending, 1).await;
let rx2 = register(&pending, 2).await;
route_line(r#"{"jsonrpc":"2.0","id":2,"result":"second"}"#, &pending).await;
route_line(r#"{"jsonrpc":"2.0","id":1,"result":"first"}"#, &pending).await;
assert_eq!(rx1.await.unwrap().result, Some(serde_json::json!("first")));
assert_eq!(rx2.await.unwrap().result, Some(serde_json::json!("second")));
}
#[tokio::test]
async fn unmatched_id_is_dropped_without_panic() {
let pending = pending_map();
let rx = register(&pending, 1).await;
route_line(r#"{"jsonrpc":"2.0","id":999,"result":"ghost"}"#, &pending).await;
assert!(pending.lock().await.contains_key(&1));
route_line(r#"{"jsonrpc":"2.0","id":1,"result":"ok"}"#, &pending).await;
assert_eq!(rx.await.unwrap().result, Some(serde_json::json!("ok")));
}
#[tokio::test]
async fn duplicate_response_for_same_id_does_not_panic() {
let pending = pending_map();
let rx = register(&pending, 1).await;
route_line(r#"{"jsonrpc":"2.0","id":1,"result":"a"}"#, &pending).await;
route_line(r#"{"jsonrpc":"2.0","id":1,"result":"b"}"#, &pending).await;
assert_eq!(rx.await.unwrap().result, Some(serde_json::json!("a")));
assert!(pending.lock().await.is_empty());
}
#[tokio::test]
async fn notification_without_id_does_not_consume_a_waiter() {
let pending = pending_map();
let rx = register(&pending, 1).await;
route_line(
r#"{"jsonrpc":"2.0","method":"notifications/message","params":{"level":"info"}}"#,
&pending,
)
.await;
assert!(pending.lock().await.contains_key(&1));
route_line(r#"{"jsonrpc":"2.0","id":1,"result":"done"}"#, &pending).await;
assert_eq!(rx.await.unwrap().result, Some(serde_json::json!("done")));
}
#[tokio::test]
async fn undecodable_noise_line_is_ignored() {
let pending = pending_map();
let rx = register(&pending, 1).await;
route_line("INFO server ready", &pending).await;
route_line("{ not valid json", &pending).await;
route_line("", &pending).await;
assert!(pending.lock().await.contains_key(&1));
route_line(r#"{"jsonrpc":"2.0","id":1,"result":1}"#, &pending).await;
assert_eq!(rx.await.unwrap().result, Some(serde_json::json!(1)));
}
#[tokio::test]
async fn response_with_null_id_is_dropped() {
let pending = pending_map();
let rx = register(&pending, 1).await;
route_line(
r#"{"jsonrpc":"2.0","id":null,"error":{"code":-32700,"message":"parse error"}}"#,
&pending,
)
.await;
assert!(pending.lock().await.contains_key(&1));
drop(rx);
}
#[tokio::test]
async fn dropped_sender_yields_recv_error_for_caller() {
let pending = pending_map();
let rx = register(&pending, 1).await;
pending.lock().await.clear();
assert!(rx.await.is_err());
}
#[tokio::test]
async fn armed_pending_guard_removes_entry_on_drop() {
let pending = pending_map();
let _rx = register(&pending, 7).await;
assert!(pending.lock().await.contains_key(&7));
{
let _guard = PendingGuard {
pending: pending.clone(),
id: 7,
armed: true,
};
} for _ in 0..100 {
if !pending.lock().await.contains_key(&7) {
break;
}
tokio::task::yield_now().await;
}
assert!(
!pending.lock().await.contains_key(&7),
"armed guard must remove the pending entry on drop"
);
}
#[tokio::test]
async fn disarmed_pending_guard_leaves_entry_alone() {
let pending = pending_map();
let _rx = register(&pending, 9).await;
{
let guard = PendingGuard {
pending: pending.clone(),
id: 9,
armed: true,
};
guard.disarm();
}
for _ in 0..10 {
tokio::task::yield_now().await;
}
assert!(
pending.lock().await.contains_key(&9),
"disarmed guard must not touch the table"
);
}
}