use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, RwLock};
use std::time::Duration;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use crate::model::{ImageData, ImageSource, UserAttachment};
use crate::tools::{
ToolFailure, ToolFailureKind, ToolInvocation, ToolOutcome, ToolRuntime, ToolRuntimeError,
ToolSpec,
};
pub const MCP_PROTOCOL_VERSION: &str = "2024-11-05";
pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
#[derive(Debug, Clone)]
pub struct McpServerConfig {
pub name: String,
pub timeout: Duration,
pub startup_timeout: Duration,
pub required: bool,
pub enabled_tools: Vec<String>,
pub transport: McpTransport,
}
#[derive(Debug, Clone)]
pub enum McpTransport {
Http {
url: String,
headers: HashMap<String, String>,
},
Stdio {
command: String,
args: Vec<String>,
env: HashMap<String, String>,
working_dir: Option<String>,
},
}
impl McpServerConfig {
pub fn new(name: impl Into<String>, url: impl Into<String>) -> Self {
Self::http(name, url)
}
pub fn http(name: impl Into<String>, url: impl Into<String>) -> Self {
Self {
name: name.into(),
timeout: DEFAULT_TIMEOUT,
startup_timeout: Duration::from_secs(10),
required: false,
enabled_tools: Vec::new(),
transport: McpTransport::Http {
url: url.into(),
headers: HashMap::new(),
},
}
}
pub fn stdio(name: impl Into<String>, command: impl Into<String>, args: Vec<String>) -> Self {
Self {
name: name.into(),
timeout: DEFAULT_TIMEOUT,
startup_timeout: Duration::from_secs(10),
required: false,
enabled_tools: Vec::new(),
transport: McpTransport::Stdio {
command: command.into(),
args,
env: HashMap::new(),
working_dir: None,
},
}
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub fn with_startup_timeout(mut self, timeout: Duration) -> Self {
self.startup_timeout = timeout;
self
}
pub fn with_required(mut self, required: bool) -> Self {
self.required = required;
self
}
pub fn with_enabled_tools(mut self, tools: Vec<String>) -> Self {
self.enabled_tools = tools;
self
}
pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
if let McpTransport::Http { headers, .. } = &mut self.transport {
headers.insert(key.into(), value.into());
}
self
}
pub fn with_env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
if let McpTransport::Stdio { env, .. } = &mut self.transport {
env.insert(key.into(), value.into());
}
self
}
pub fn with_working_dir(mut self, dir: impl Into<String>) -> Self {
if let McpTransport::Stdio { working_dir, .. } = &mut self.transport {
*working_dir = Some(dir.into());
}
self
}
}
#[derive(Debug, thiserror::Error)]
pub enum McpError {
#[error("timeout: {0}")]
Timeout(String),
#[error("transport: {0}")]
Transport(String),
#[error("HTTP {status}: {body}")]
Http { status: u16, body: String },
#[error("decode: {0}")]
Decode(String),
#[error("server error code={code} message={message}")]
Server { code: i64, message: String },
#[error("missing field {0}")]
MissingField(&'static str),
}
#[derive(Debug, Serialize)]
struct McpRequest<'a> {
jsonrpc: &'static str,
id: u64,
method: &'a str,
#[serde(skip_serializing_if = "Option::is_none")]
params: Option<Value>,
}
#[derive(Debug, Serialize)]
struct McpNotification<'a> {
jsonrpc: &'static str,
method: &'a str,
#[serde(skip_serializing_if = "Option::is_none")]
params: Option<Value>,
}
#[derive(Debug, Deserialize)]
struct McpResponse {
#[allow(dead_code)]
jsonrpc: String,
id: Option<u64>,
result: Option<Value>,
error: Option<McpResponseError>,
}
#[derive(Debug, Deserialize)]
struct McpResponseError {
code: i64,
message: String,
#[serde(default)]
#[allow(dead_code)]
data: Option<Value>,
}
#[derive(Debug, Deserialize)]
struct McpToolDef {
name: String,
#[serde(default)]
description: String,
#[serde(rename = "inputSchema", default = "default_input_schema")]
input_schema: Value,
}
fn default_input_schema() -> Value {
json!({"type": "object", "properties": {}})
}
#[derive(Debug, Deserialize)]
struct McpToolsListResult {
tools: Vec<McpToolDef>,
}
#[derive(Debug, Deserialize)]
struct McpToolsCallResult {
#[serde(default)]
content: Vec<McpContent>,
#[serde(default, rename = "isError")]
is_error: bool,
}
#[derive(Debug, Deserialize)]
#[serde(tag = "type")]
enum McpContent {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image")]
Image {
#[serde(default, rename = "mimeType")]
mime_type: String,
#[serde(default)]
data: String,
},
#[serde(other)]
Other,
}
pub struct McpClient {
name: String,
timeout: Duration,
next_id: AtomicU64,
inner: McpClientInner,
}
enum McpClientInner {
Http(HttpInner),
Stdio(StdioInner),
}
struct HttpInner {
http: reqwest::Client,
url: String,
headers: HashMap<String, String>,
session_id: Arc<RwLock<Option<String>>>,
}
struct StdioInner {
request_tx: tokio::sync::mpsc::Sender<StdioRequest>,
pending_kill: Option<Arc<std::sync::Mutex<Option<tokio::process::Child>>>>,
}
enum StdioRequest {
Call {
id: u64,
body: String,
reply: tokio::sync::oneshot::Sender<Result<Value, McpError>>,
},
Notify { body: String },
}
type PendingReplies =
Arc<std::sync::Mutex<HashMap<u64, tokio::sync::oneshot::Sender<Result<Value, McpError>>>>>;
impl McpClient {
pub fn new(config: McpServerConfig) -> Result<Self, McpError> {
let name = config.name.clone();
let timeout = config.timeout;
let inner = match config.transport {
McpTransport::Http { url, headers } => {
let http = reqwest::Client::builder()
.timeout(timeout)
.build()
.map_err(|e| McpError::Transport(e.to_string()))?;
McpClientInner::Http(HttpInner {
http,
url,
headers,
session_id: Arc::new(RwLock::new(None)),
})
}
McpTransport::Stdio {
command,
args,
env,
working_dir,
} => spawn_stdio(&name, command, args, env, working_dir)?,
};
Ok(Self {
name,
timeout,
next_id: AtomicU64::new(1),
inner,
})
}
pub fn name(&self) -> &str {
&self.name
}
pub async fn initialize(&self) -> Result<(), McpError> {
let params = json!({
"protocolVersion": MCP_PROTOCOL_VERSION,
"capabilities": {},
"clientInfo": {
"name": "agentmatrix-runtime-driver",
"version": env!("CARGO_PKG_VERSION"),
}
});
let _ = self.call("initialize", Some(params)).await?;
if let Err(e) = self.notify("notifications/initialized", None).await {
tracing::warn!(
target: "harness::mcp",
server = %self.name,
error = %e,
"notifications/initialized fire-and-forget failed; continuing"
);
}
Ok(())
}
pub async fn tools_list(&self) -> Result<Vec<ToolSpec>, McpError> {
let value = self.call("tools/list", None).await?;
let result: McpToolsListResult = serde_json::from_value(value)
.map_err(|e| McpError::Decode(format!("tools/list result: {e}")))?;
Ok(result
.tools
.into_iter()
.map(|t| ToolSpec {
name: t.name,
description: t.description,
input_schema: t.input_schema,
})
.collect())
}
pub async fn tools_call(&self, name: &str, arguments: Value) -> Result<ToolOutcome, McpError> {
let params = json!({
"name": name,
"arguments": arguments,
});
let value = self.call("tools/call", Some(params)).await?;
let result: McpToolsCallResult = serde_json::from_value(value)
.map_err(|e| McpError::Decode(format!("tools/call result: {e}")))?;
let mut text_parts: Vec<String> = Vec::new();
let mut attachments: Vec<UserAttachment> = Vec::new();
for c in result.content {
match c {
McpContent::Text { text } => text_parts.push(text),
McpContent::Image { mime_type, data } => {
if data.is_empty() {
text_parts.push(format!("[image {mime_type} returned with empty data]"));
} else {
attachments.push(UserAttachment::Image(ImageSource {
media_type: if mime_type.is_empty() {
"image/png".to_string()
} else {
mime_type
},
data: ImageData::Base64(data),
}));
}
}
McpContent::Other => {
text_parts.push("[non-text MCP content elided]".into());
}
}
}
let content_str = text_parts.join("\n");
let output = if result.is_error {
Err(ToolFailure::new(
ToolFailureKind::Runtime,
if content_str.is_empty() {
format!("MCP tool {name} reported error")
} else {
format!("MCP tool {name} error: {content_str}")
},
))
} else {
Ok(json!({"content": content_str}))
};
Ok(ToolOutcome {
output,
attachments,
})
}
async fn call(&self, method: &str, params: Option<Value>) -> Result<Value, McpError> {
let id = self.next_id.fetch_add(1, Ordering::SeqCst);
let body = McpRequest {
jsonrpc: "2.0",
id,
method,
params,
};
match &self.inner {
McpClientInner::Http(http) => self.call_http(http, id, &body).await,
McpClientInner::Stdio(stdio) => self.call_stdio(stdio, id, &body).await,
}
}
async fn notify(&self, method: &str, params: Option<Value>) -> Result<(), McpError> {
let body = McpNotification {
jsonrpc: "2.0",
method,
params,
};
match &self.inner {
McpClientInner::Http(http) => self.notify_http(http, &body).await,
McpClientInner::Stdio(stdio) => self.notify_stdio(stdio, &body).await,
}
}
async fn call_http(
&self,
http: &HttpInner,
id: u64,
body: &McpRequest<'_>,
) -> Result<Value, McpError> {
let mut req = http
.http
.post(&http.url)
.header("Accept", "application/json, text/event-stream")
.json(body);
for (k, v) in &http.headers {
req = req.header(k.as_str(), v.as_str());
}
if let Some(sid) = cached_session_id(&http.session_id) {
req = req.header("Mcp-Session-Id", sid);
}
let resp = req.send().await.map_err(|e| {
if e.is_timeout() {
McpError::Timeout(e.to_string())
} else {
McpError::Transport(e.to_string())
}
})?;
if let Some(sid) = resp
.headers()
.get("mcp-session-id")
.and_then(|v| v.to_str().ok())
{
if !sid.is_empty() {
if let Ok(mut guard) = http.session_id.write() {
if guard.is_none() {
*guard = Some(sid.to_string());
}
}
}
}
let status = resp.status();
if !status.is_success() {
let body_text = resp.text().await.unwrap_or_default();
return Err(McpError::Http {
status: status.as_u16(),
body: body_text.chars().take(512).collect(),
});
}
let content_type = resp
.headers()
.get(reqwest::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_ascii_lowercase())
.unwrap_or_default();
if content_type.starts_with("text/event-stream") {
parse_mcp_sse_response(resp, id, &self.name).await
} else {
let body_text = resp.text().await.unwrap_or_default();
let parsed: McpResponse = serde_json::from_str(&body_text)
.map_err(|e| McpError::Decode(format!("response body: {e}; raw={body_text}")))?;
if let Some(err) = parsed.error {
return Err(McpError::Server {
code: err.code,
message: err.message,
});
}
parsed.result.ok_or(McpError::MissingField("result"))
}
}
async fn notify_http(
&self,
http: &HttpInner,
body: &McpNotification<'_>,
) -> Result<(), McpError> {
let mut req = http
.http
.post(&http.url)
.header("Accept", "application/json, text/event-stream")
.json(body);
for (k, v) in &http.headers {
req = req.header(k.as_str(), v.as_str());
}
if let Some(sid) = cached_session_id(&http.session_id) {
req = req.header("Mcp-Session-Id", sid);
}
req.send()
.await
.map_err(|e| McpError::Transport(e.to_string()))?;
Ok(())
}
async fn call_stdio(
&self,
stdio: &StdioInner,
id: u64,
body: &McpRequest<'_>,
) -> Result<Value, McpError> {
let line = serde_json::to_string(body)
.map_err(|e| McpError::Decode(format!("encode request: {e}")))?;
let (reply_tx, reply_rx) = tokio::sync::oneshot::channel();
stdio
.request_tx
.send(StdioRequest::Call {
id,
body: line,
reply: reply_tx,
})
.await
.map_err(|_| McpError::Transport("stdio worker gone".into()))?;
match tokio::time::timeout(self.timeout, reply_rx).await {
Ok(Ok(result)) => result,
Ok(Err(_)) => Err(McpError::Transport(
"stdio reply channel closed before response".into(),
)),
Err(_) => Err(McpError::Timeout(format!(
"stdio request timed out after {:?}",
self.timeout
))),
}
}
async fn notify_stdio(
&self,
stdio: &StdioInner,
body: &McpNotification<'_>,
) -> Result<(), McpError> {
let line = serde_json::to_string(body)
.map_err(|e| McpError::Decode(format!("encode notification: {e}")))?;
stdio
.request_tx
.send(StdioRequest::Notify { body: line })
.await
.map_err(|_| McpError::Transport("stdio worker gone".into()))?;
Ok(())
}
}
fn cached_session_id(slot: &Arc<RwLock<Option<String>>>) -> Option<String> {
slot.read().ok().and_then(|g| g.clone())
}
async fn parse_mcp_sse_response(
resp: reqwest::Response,
expected_id: u64,
server_name: &str,
) -> Result<Value, McpError> {
use eventsource_stream::Eventsource;
use futures::StreamExt;
let mut events = resp.bytes_stream().eventsource();
while let Some(ev) = events.next().await {
let ev = ev.map_err(|e| McpError::Transport(format!("SSE transport error: {e}")))?;
if !ev.event.is_empty() && ev.event != "message" {
tracing::debug!(
target: "harness::mcp",
server = %server_name,
event = %ev.event,
"ignoring non-message SSE event"
);
continue;
}
let trimmed = ev.data.trim();
if trimmed.is_empty() {
continue;
}
let parsed: McpResponse = match serde_json::from_str(trimmed) {
Ok(v) => v,
Err(e) => {
tracing::warn!(
target: "harness::mcp",
server = %server_name,
error = %e,
"SSE event body is not a JSON-RPC envelope; skipping"
);
continue;
}
};
let Some(rid) = parsed.id else {
tracing::debug!(
target: "harness::mcp",
server = %server_name,
"ignoring server-initiated notification mid-SSE stream"
);
continue;
};
if rid != expected_id {
tracing::debug!(
target: "harness::mcp",
server = %server_name,
rid,
expected_id,
"ignoring SSE response with mismatched id"
);
continue;
}
if let Some(err) = parsed.error {
return Err(McpError::Server {
code: err.code,
message: err.message,
});
}
return parsed.result.ok_or(McpError::MissingField("result"));
}
Err(McpError::Transport(format!(
"SSE stream closed without a JSON-RPC response matching id={expected_id}"
)))
}
impl Drop for McpClient {
fn drop(&mut self) {
if let McpClientInner::Stdio(stdio) = &mut self.inner {
if let Some(child_slot) = stdio.pending_kill.take() {
if let Ok(mut guard) = child_slot.lock() {
if let Some(mut child) = guard.take() {
let _ = child.start_kill();
}
}
}
}
}
}
fn spawn_stdio(
name: &str,
command: String,
args: Vec<String>,
env: HashMap<String, String>,
working_dir: Option<String>,
) -> Result<McpClientInner, McpError> {
use std::process::Stdio;
use tokio::io::{AsyncBufReadExt, BufReader};
use tokio::process::Command;
let mut cmd = Command::new(&command);
cmd.args(&args)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.env_clear();
for (k, v) in &env {
cmd.env(k, v);
}
if !env.contains_key("PATH") {
if let Some(path) = std::env::var_os("PATH") {
cmd.env("PATH", path);
}
}
if let Some(dir) = working_dir.as_deref() {
cmd.current_dir(dir);
}
cmd.kill_on_drop(true);
let mut child = cmd
.spawn()
.map_err(|e| McpError::Transport(format!("stdio spawn {command:?}: {e}")))?;
let stdin = child
.stdin
.take()
.ok_or_else(|| McpError::Transport("stdio child has no stdin".into()))?;
let stdout = child
.stdout
.take()
.ok_or_else(|| McpError::Transport("stdio child has no stdout".into()))?;
let stderr = child.stderr.take();
let pending: PendingReplies = Arc::new(std::sync::Mutex::new(HashMap::new()));
let (request_tx, mut request_rx) = tokio::sync::mpsc::channel::<StdioRequest>(32);
{
let pending = pending.clone();
let server_name = name.to_string();
tokio::spawn(async move {
let mut stdin = stdin;
while let Some(req) = request_rx.recv().await {
match req {
StdioRequest::Call { id, body, reply } => {
if let Ok(mut guard) = pending.lock() {
guard.insert(id, reply);
}
if let Err(e) = write_stdio_line(&mut stdin, &body).await {
if let Ok(mut guard) = pending.lock() {
if let Some(slot) = guard.remove(&id) {
let _ = slot.send(Err(McpError::Transport(format!(
"stdio write failed: {e}"
))));
}
}
tracing::warn!(
target: "harness::mcp::stdio",
server = %server_name,
error = %e,
"stdio writer terminated"
);
break;
}
}
StdioRequest::Notify { body } => {
if let Err(e) = write_stdio_line(&mut stdin, &body).await {
tracing::warn!(
target: "harness::mcp::stdio",
server = %server_name,
error = %e,
"stdio writer terminated during notify"
);
break;
}
}
}
}
drop(stdin);
});
}
{
let pending = pending.clone();
let server_name = name.to_string();
tokio::spawn(async move {
let mut reader = BufReader::new(stdout);
let mut line = String::new();
loop {
line.clear();
match reader.read_line(&mut line).await {
Ok(0) => {
if let Ok(mut guard) = pending.lock() {
for (_, slot) in guard.drain() {
let _ = slot.send(Err(McpError::Transport(
"stdio server closed stdout".into(),
)));
}
}
break;
}
Ok(_) => {
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
let parsed: McpResponse = match serde_json::from_str(trimmed) {
Ok(v) => v,
Err(e) => {
tracing::warn!(
target: "harness::mcp::stdio",
server = %server_name,
error = %e,
line = %trimmed.chars().take(256).collect::<String>(),
"stdio reader could not parse JSON-RPC envelope"
);
continue;
}
};
let Some(id) = parsed.id else {
tracing::debug!(
target: "harness::mcp::stdio",
server = %server_name,
"ignoring server-initiated notification"
);
continue;
};
let slot = if let Ok(mut guard) = pending.lock() {
guard.remove(&id)
} else {
None
};
if let Some(slot) = slot {
let result = if let Some(err) = parsed.error {
Err(McpError::Server {
code: err.code,
message: err.message,
})
} else {
parsed.result.ok_or(McpError::MissingField("result"))
};
let _ = slot.send(result);
} else {
tracing::debug!(
target: "harness::mcp::stdio",
server = %server_name,
id,
"stdio reply for unknown id (timeout already fired?)"
);
}
}
Err(e) => {
tracing::warn!(
target: "harness::mcp::stdio",
server = %server_name,
error = %e,
"stdio reader I/O error"
);
break;
}
}
}
});
}
if let Some(stderr) = stderr {
let server_name = name.to_string();
tokio::spawn(async move {
let mut reader = BufReader::new(stderr);
let mut line = String::new();
loop {
line.clear();
match reader.read_line(&mut line).await {
Ok(0) | Err(_) => break,
Ok(_) => {
let trimmed = line.trim_end();
if !trimmed.is_empty() {
tracing::debug!(
target: "harness::mcp::stdio",
server = %server_name,
stderr = %trimmed,
);
}
}
}
}
});
}
let child_slot = Arc::new(std::sync::Mutex::new(Some(child)));
Ok(McpClientInner::Stdio(StdioInner {
request_tx,
pending_kill: Some(child_slot),
}))
}
async fn write_stdio_line<W: tokio::io::AsyncWrite + Unpin>(
stdin: &mut W,
body: &str,
) -> std::io::Result<()> {
use tokio::io::AsyncWriteExt;
stdin.write_all(body.as_bytes()).await?;
stdin.write_all(b"\n").await?;
stdin.flush().await
}
#[derive(Clone)]
pub struct McpToolRuntime {
inner: Arc<McpToolRuntimeInner>,
}
struct McpToolRuntimeInner {
clients: Vec<McpClient>,
specs: Vec<ToolSpec>,
tool_to_client: HashMap<String, usize>,
}
impl McpToolRuntime {
pub const NAME_SEPARATOR: &'static str = "__";
pub async fn discover(servers: Vec<McpServerConfig>) -> Self {
let mut clients: Vec<McpClient> = Vec::with_capacity(servers.len());
let mut specs: Vec<ToolSpec> = Vec::new();
let mut tool_to_client: HashMap<String, usize> = HashMap::new();
for config in servers {
let server_name = config.name.clone();
let required = config.required;
let enabled_tools: std::collections::HashSet<String> =
config.enabled_tools.iter().cloned().collect();
let client = match McpClient::new(config) {
Ok(c) => c,
Err(e) => {
tracing::warn!(
target: "harness::mcp",
server = %server_name,
error = %e,
"McpClient::new failed; skipping server"
);
continue;
}
};
if let Err(e) = client.initialize().await {
if required {
tracing::error!(
target: "harness::mcp",
server = %server_name,
error = %e,
"required MCP server failed to initialize; session boot will fail"
);
} else {
tracing::warn!(
target: "harness::mcp",
server = %server_name,
error = %e,
"MCP initialize failed; skipping server"
);
}
continue;
}
let server_specs = match client.tools_list().await {
Ok(s) => s,
Err(e) => {
tracing::warn!(
target: "harness::mcp",
server = %server_name,
error = %e,
"MCP tools/list failed; skipping server"
);
continue;
}
};
let client_idx = clients.len();
for mut spec in server_specs {
let original = spec.name.clone();
if !enabled_tools.is_empty() && !enabled_tools.contains(&original) {
continue;
}
spec.name = format!("{server_name}{}{original}", Self::NAME_SEPARATOR);
if tool_to_client.contains_key(&spec.name) {
tracing::warn!(
target: "harness::mcp",
tool = %spec.name,
"duplicate MCP tool name after prefixing; later registration wins"
);
}
tool_to_client.insert(spec.name.clone(), client_idx);
specs.push(spec);
}
clients.push(client);
}
Self {
inner: Arc::new(McpToolRuntimeInner {
clients,
specs,
tool_to_client,
}),
}
}
pub fn server_count(&self) -> usize {
self.inner.clients.len()
}
}
#[async_trait]
impl ToolRuntime for McpToolRuntime {
fn specs(&self) -> Vec<ToolSpec> {
self.inner.specs.clone()
}
async fn invoke(&self, invocation: ToolInvocation) -> Result<ToolOutcome, ToolRuntimeError> {
let Some(&idx) = self.inner.tool_to_client.get(&invocation.name) else {
return Err(ToolRuntimeError::UnknownTool(invocation.name));
};
let client = &self.inner.clients[idx];
let original_name = invocation
.name
.split_once(Self::NAME_SEPARATOR)
.map(|(_, name)| name)
.unwrap_or(invocation.name.as_str())
.to_string();
client
.tools_call(&original_name, invocation.input)
.await
.map_err(mcp_error_to_tool_runtime_error)
}
}
fn mcp_error_to_tool_runtime_error(err: McpError) -> ToolRuntimeError {
if matches!(err, McpError::Timeout(_)) {
return ToolRuntimeError::Timeout(format!("MCP: {err}"));
}
ToolRuntimeError::Runtime(format!("MCP: {err}"))
}
#[derive(Clone)]
pub struct CompositeToolRuntime {
primary: Arc<dyn ToolRuntime>,
secondary: Arc<dyn ToolRuntime>,
}
impl CompositeToolRuntime {
pub fn new(primary: Arc<dyn ToolRuntime>, secondary: Arc<dyn ToolRuntime>) -> Self {
Self { primary, secondary }
}
}
#[async_trait]
impl ToolRuntime for CompositeToolRuntime {
fn specs(&self) -> Vec<ToolSpec> {
let mut combined = self.primary.specs();
combined.extend(self.secondary.specs());
combined
}
async fn invoke(&self, invocation: ToolInvocation) -> Result<ToolOutcome, ToolRuntimeError> {
match self.primary.invoke(invocation.clone()).await {
Err(ToolRuntimeError::UnknownTool(_)) => self.secondary.invoke(invocation).await,
other => other,
}
}
async fn invoke_cancellable(
&self,
invocation: ToolInvocation,
cancel: Option<&tokio_util::sync::CancellationToken>,
) -> Result<ToolOutcome, ToolRuntimeError> {
match self
.primary
.invoke_cancellable(invocation.clone(), cancel)
.await
{
Err(ToolRuntimeError::UnknownTool(_)) => {
self.secondary.invoke_cancellable(invocation, cancel).await
}
other => other,
}
}
}
#[async_trait]
impl ToolRuntime for Arc<dyn ToolRuntime> {
fn specs(&self) -> Vec<ToolSpec> {
(**self).specs()
}
async fn invoke(&self, invocation: ToolInvocation) -> Result<ToolOutcome, ToolRuntimeError> {
(**self).invoke(invocation).await
}
async fn invoke_cancellable(
&self,
invocation: ToolInvocation,
cancel: Option<&tokio_util::sync::CancellationToken>,
) -> Result<ToolOutcome, ToolRuntimeError> {
(**self).invoke_cancellable(invocation, cancel).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
async fn spawn_mock_mcp_server(
scripted_responses: Vec<String>,
) -> (String, tokio::task::JoinHandle<()>) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let url = format!("http://{addr}/mcp");
let handle = tokio::spawn(async move {
let mut remaining = scripted_responses.into_iter();
while let Some(response_body) = remaining.next() {
let (mut stream, _) = listener.accept().await.unwrap();
let mut buf = Vec::with_capacity(2048);
let mut header_end = 0;
loop {
let mut tmp = [0u8; 1024];
let n = stream.read(&mut tmp).await.unwrap();
if n == 0 {
break;
}
buf.extend_from_slice(&tmp[..n]);
if let Some(pos) = find_header_end(&buf) {
header_end = pos + 4;
break;
}
}
let headers = std::str::from_utf8(&buf[..header_end.saturating_sub(4)])
.unwrap()
.to_lowercase();
let mut content_length = 0usize;
for line in headers.lines() {
if let Some(v) = line.strip_prefix("content-length:") {
content_length = v.trim().parse().unwrap_or(0);
}
}
let mut already_read = buf.len() - header_end;
while already_read < content_length {
let mut tmp = [0u8; 1024];
let n = stream.read(&mut tmp).await.unwrap();
if n == 0 {
break;
}
buf.extend_from_slice(&tmp[..n]);
already_read += n;
}
let response = format!(
"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}",
response_body.len(),
response_body
);
stream.write_all(response.as_bytes()).await.unwrap();
stream.flush().await.unwrap();
let _ = stream.shutdown().await;
}
});
(url, handle)
}
fn find_header_end(buf: &[u8]) -> Option<usize> {
buf.windows(4).position(|w| w == b"\r\n\r\n")
}
async fn spawn_mock_mcp_sse_server(
scripted_responses: Vec<Vec<String>>,
) -> (String, tokio::task::JoinHandle<()>) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let url = format!("http://{addr}/mcp");
let handle = tokio::spawn(async move {
let mut remaining = scripted_responses.into_iter();
while let Some(events) = remaining.next() {
let (mut stream, _) = listener.accept().await.unwrap();
let mut buf = Vec::with_capacity(2048);
let mut header_end = 0;
loop {
let mut tmp = [0u8; 1024];
let n = stream.read(&mut tmp).await.unwrap();
if n == 0 {
break;
}
buf.extend_from_slice(&tmp[..n]);
if let Some(pos) = find_header_end(&buf) {
header_end = pos + 4;
break;
}
}
let headers = std::str::from_utf8(&buf[..header_end.saturating_sub(4)])
.unwrap()
.to_lowercase();
let mut content_length = 0usize;
for line in headers.lines() {
if let Some(v) = line.strip_prefix("content-length:") {
content_length = v.trim().parse().unwrap_or(0);
}
}
let mut already_read = buf.len() - header_end;
while already_read < content_length {
let mut tmp = [0u8; 1024];
let n = stream.read(&mut tmp).await.unwrap();
if n == 0 {
break;
}
buf.extend_from_slice(&tmp[..n]);
already_read += n;
}
let body: String = events.concat();
let header_block = "HTTP/1.1 200 OK\r\n\
Content-Type: text/event-stream\r\n\
Cache-Control: no-cache\r\n\
Connection: close\r\n\r\n";
stream.write_all(header_block.as_bytes()).await.unwrap();
stream.write_all(body.as_bytes()).await.unwrap();
stream.flush().await.unwrap();
let _ = stream.shutdown().await;
}
});
(url, handle)
}
fn sse_event(body: &str) -> String {
format!("event: message\ndata: {body}\n\n")
}
fn jsonrpc_result(id: u64, result: Value) -> String {
json!({"jsonrpc": "2.0", "id": id, "result": result}).to_string()
}
#[tokio::test]
async fn mcp_client_initializes_lists_and_calls_a_tool() {
let (url, _server) = spawn_mock_mcp_server(vec![
jsonrpc_result(1, json!({"protocolVersion": "2024-11-05", "capabilities": {}})),
json!({}).to_string(),
jsonrpc_result(
3,
json!({
"tools": [{
"name": "echo",
"description": "echo back input",
"inputSchema": {"type": "object", "properties": {"text": {"type": "string"}}}
}]
}),
),
jsonrpc_result(
4,
json!({
"content": [{"type": "text", "text": "hello back"}],
"isError": false
}),
),
])
.await;
let client = McpClient::new(McpServerConfig::new("fs", url)).unwrap();
client.initialize().await.expect("init");
let specs = client.tools_list().await.expect("list");
assert_eq!(specs.len(), 1);
assert_eq!(specs[0].name, "echo");
let outcome = client
.tools_call("echo", json!({"text": "hi"}))
.await
.expect("call");
let v = outcome.output.expect("ok output");
assert_eq!(v["content"], "hello back");
}
#[tokio::test]
async fn mcp_client_extracts_image_content_into_attachments() {
let (url, _server) = spawn_mock_mcp_server(vec![
jsonrpc_result(
1,
json!({"protocolVersion": "2024-11-05", "capabilities": {}}),
),
json!({}).to_string(),
jsonrpc_result(
3,
json!({
"content": [
{"type": "text", "text": "captured"},
{"type": "image", "mimeType": "image/png", "data": "PNGBYTES"}
],
"isError": false
}),
),
])
.await;
let client = McpClient::new(McpServerConfig::new("screen", url)).unwrap();
client.initialize().await.unwrap();
let outcome = client
.tools_call("screenshot", json!({}))
.await
.expect("call");
let v = outcome.output.expect("ok output");
assert_eq!(v["content"], "captured");
assert_eq!(outcome.attachments.len(), 1);
let UserAttachment::Image(src) = &outcome.attachments[0];
assert_eq!(src.media_type, "image/png");
match &src.data {
ImageData::Base64(b) => assert_eq!(b, "PNGBYTES"),
ImageData::Url(_) => panic!("expected base64, got url"),
}
}
#[tokio::test]
async fn mcp_client_surfaces_tool_error_as_tool_failure() {
let (url, _server) = spawn_mock_mcp_server(vec![
jsonrpc_result(
1,
json!({"protocolVersion": "2024-11-05", "capabilities": {}}),
),
json!({}).to_string(),
jsonrpc_result(
3,
json!({
"content": [{"type": "text", "text": "file not found"}],
"isError": true
}),
),
])
.await;
let client = McpClient::new(McpServerConfig::new("fs", url)).unwrap();
client.initialize().await.unwrap();
let outcome = client
.tools_call("read", json!({"path": "/none"}))
.await
.unwrap();
let failure = outcome.output.expect_err("expected ToolFailure");
assert_eq!(failure.kind, ToolFailureKind::Runtime);
assert!(failure.message.contains("file not found"));
}
#[tokio::test]
async fn mcp_client_surfaces_jsonrpc_error_as_mcp_server_error() {
let (url, _server) = spawn_mock_mcp_server(vec![
jsonrpc_result(
1,
json!({"protocolVersion": "2024-11-05", "capabilities": {}}),
),
json!({}).to_string(),
json!({
"jsonrpc": "2.0",
"id": 3,
"error": {"code": -32601, "message": "method not found"}
})
.to_string(),
])
.await;
let client = McpClient::new(McpServerConfig::new("fs", url)).unwrap();
client.initialize().await.unwrap();
let err = client.tools_list().await.unwrap_err();
match err {
McpError::Server { code, message } => {
assert_eq!(code, -32601);
assert!(message.contains("method not found"));
}
other => panic!("expected Server error, got {other:?}"),
}
}
#[tokio::test]
async fn mcp_tool_runtime_prefixes_tool_names_and_routes_calls() {
let (url, _server) = spawn_mock_mcp_server(vec![
jsonrpc_result(
1,
json!({"protocolVersion": "2024-11-05", "capabilities": {}}),
),
json!({}).to_string(),
jsonrpc_result(
3,
json!({
"tools": [{
"name": "echo",
"description": "echo",
"inputSchema": {"type": "object"}
}]
}),
),
jsonrpc_result(
4,
json!({"content": [{"type": "text", "text": "routed"}], "isError": false}),
),
])
.await;
let rt = McpToolRuntime::discover(vec![McpServerConfig::new("fs", url)]).await;
assert_eq!(rt.server_count(), 1);
let specs = rt.specs();
assert_eq!(specs.len(), 1);
assert_eq!(specs[0].name, "fs__echo");
let outcome = rt
.invoke(ToolInvocation {
id: "tc1".into(),
name: "fs__echo".into(),
input: json!({"text": "x"}),
})
.await
.unwrap();
assert_eq!(outcome.output.unwrap()["content"], "routed");
}
#[tokio::test]
async fn mcp_tool_runtime_unknown_tool_returns_runtime_error() {
let rt = McpToolRuntime::discover(vec![]).await;
let err = rt
.invoke(ToolInvocation {
id: "tc".into(),
name: "nope__whatever".into(),
input: json!({}),
})
.await
.unwrap_err();
assert!(matches!(err, ToolRuntimeError::UnknownTool(ref s) if s == "nope__whatever"));
}
#[derive(Clone, Default)]
struct FakeNativeRuntime {
names: Vec<&'static str>,
}
#[async_trait]
impl ToolRuntime for FakeNativeRuntime {
fn specs(&self) -> Vec<ToolSpec> {
self.names
.iter()
.map(|n| ToolSpec {
name: n.to_string(),
description: "fake".into(),
input_schema: json!({"type": "object"}),
})
.collect()
}
async fn invoke(&self, inv: ToolInvocation) -> Result<ToolOutcome, ToolRuntimeError> {
if self.names.contains(&inv.name.as_str()) {
Ok(ToolOutcome {
output: Ok(json!({"served_by": "native", "name": inv.name})),
attachments: vec![],
})
} else {
Err(ToolRuntimeError::UnknownTool(inv.name))
}
}
}
#[derive(Clone, Default)]
struct FakeMcpRuntime {
names: Vec<&'static str>,
}
#[async_trait]
impl ToolRuntime for FakeMcpRuntime {
fn specs(&self) -> Vec<ToolSpec> {
self.names
.iter()
.map(|n| ToolSpec {
name: n.to_string(),
description: "mcp".into(),
input_schema: json!({"type": "object"}),
})
.collect()
}
async fn invoke(&self, inv: ToolInvocation) -> Result<ToolOutcome, ToolRuntimeError> {
if self.names.contains(&inv.name.as_str()) {
Ok(ToolOutcome {
output: Ok(json!({"served_by": "mcp", "name": inv.name})),
attachments: vec![],
})
} else {
Err(ToolRuntimeError::UnknownTool(inv.name))
}
}
}
#[tokio::test]
async fn composite_runtime_merges_specs_and_falls_back_to_secondary() {
let native = Arc::new(FakeNativeRuntime {
names: vec!["bash", "read"],
}) as Arc<dyn ToolRuntime>;
let mcp = Arc::new(FakeMcpRuntime {
names: vec!["fs__list", "git__diff"],
}) as Arc<dyn ToolRuntime>;
let composite = CompositeToolRuntime::new(native, mcp);
let names: Vec<String> = composite.specs().into_iter().map(|s| s.name).collect();
assert_eq!(names, vec!["bash", "read", "fs__list", "git__diff"]);
let outcome = composite
.invoke(ToolInvocation {
id: "tc".into(),
name: "bash".into(),
input: json!({}),
})
.await
.unwrap();
assert_eq!(outcome.output.unwrap()["served_by"], "native");
let outcome = composite
.invoke(ToolInvocation {
id: "tc".into(),
name: "fs__list".into(),
input: json!({}),
})
.await
.unwrap();
assert_eq!(outcome.output.unwrap()["served_by"], "mcp");
let err = composite
.invoke(ToolInvocation {
id: "tc".into(),
name: "ghost".into(),
input: json!({}),
})
.await
.unwrap_err();
assert!(matches!(err, ToolRuntimeError::UnknownTool(_)));
}
fn write_mock_stdio_server(dir: &std::path::Path) -> std::path::PathBuf {
let path = dir.join("mock-mcp-stdio.sh");
let body = r#"#!/usr/bin/env bash
set -u
while IFS= read -r line; do
# Pull out method + id (best-effort sed; the test driver only sends
# well-formed JSON so we don't need a real parser).
method=$(echo "$line" | sed -n 's/.*"method"[[:space:]]*:[[:space:]]*"\([^"]*\)".*/\1/p')
id=$(echo "$line" | sed -n 's/.*"id"[[:space:]]*:[[:space:]]*\([0-9]*\).*/\1/p')
case "$method" in
initialize)
printf '{"jsonrpc":"2.0","id":%s,"result":{"protocolVersion":"2024-11-05","capabilities":{}}}\n' "$id"
;;
tools/list)
printf '{"jsonrpc":"2.0","id":%s,"result":{"tools":[{"name":"echo","description":"d","inputSchema":{"type":"object"}}]}}\n' "$id"
;;
tools/call)
printf '{"jsonrpc":"2.0","id":%s,"result":{"content":[{"type":"text","text":"stdio-routed"}],"isError":false}}\n' "$id"
;;
notifications/initialized)
# No response for notifications.
;;
*)
printf '{"jsonrpc":"2.0","id":%s,"error":{"code":-32601,"message":"method not found"}}\n' "$id"
;;
esac
done
"#;
std::fs::write(&path, body).unwrap();
use std::os::unix::fs::PermissionsExt;
let mut perms = std::fs::metadata(&path).unwrap().permissions();
perms.set_mode(0o755);
std::fs::set_permissions(&path, perms).unwrap();
path
}
#[tokio::test]
async fn mcp_stdio_initialize_lists_and_calls_a_tool() {
let tmp = std::env::temp_dir().join(format!(
"rd-mock-mcp-stdio-{}-{}",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos()
));
std::fs::create_dir_all(&tmp).unwrap();
let script = write_mock_stdio_server(&tmp);
let config =
McpServerConfig::stdio("local-fs", script.to_string_lossy().into_owned(), vec![])
.with_timeout(Duration::from_secs(5));
let client = McpClient::new(config).expect("spawn");
client.initialize().await.expect("init over stdio");
let specs = client.tools_list().await.expect("list over stdio");
assert_eq!(specs.len(), 1);
assert_eq!(specs[0].name, "echo");
let outcome = client
.tools_call("echo", json!({"text": "hi"}))
.await
.expect("call over stdio");
let v = outcome.output.expect("ok output");
assert_eq!(v["content"], "stdio-routed");
}
#[tokio::test]
async fn mcp_stdio_returns_transport_error_when_command_missing() {
let config = McpServerConfig::stdio(
"nope",
"/definitely/not/a/real/binary-xyz".to_string(),
vec![],
);
let err = match McpClient::new(config) {
Ok(_) => panic!("must fail to spawn"),
Err(e) => e,
};
match err {
McpError::Transport(msg) => {
assert!(msg.contains("stdio spawn"), "got: {msg}");
}
other => panic!("expected Transport, got {other:?}"),
}
}
#[tokio::test]
async fn mcp_stdio_call_times_out_when_server_doesnt_reply() {
let config = McpServerConfig::stdio("silent", "sleep".to_string(), vec!["30".to_string()])
.with_timeout(Duration::from_millis(250));
let client = McpClient::new(config).expect("spawn cat");
let err = client.initialize().await.expect_err("must time out");
match err {
McpError::Timeout(msg) => {
assert!(msg.contains("timed out"), "got: {msg}");
}
other => panic!("expected Timeout, got {other:?}"),
}
}
#[tokio::test]
async fn mcp_http_sse_response_yields_jsonrpc_result() {
let (url, _server) = spawn_mock_mcp_sse_server(vec![
vec![sse_event(&jsonrpc_result(
1,
json!({"protocolVersion": "2024-11-05", "capabilities": {}}),
))],
vec![],
vec![sse_event(&jsonrpc_result(
2,
json!({
"tools": [{
"name": "echo",
"description": "d",
"inputSchema": {"type": "object"}
}]
}),
))],
vec![sse_event(&jsonrpc_result(
3,
json!({
"content": [{"type": "text", "text": "sse-routed"}],
"isError": false
}),
))],
])
.await;
let client = McpClient::new(McpServerConfig::http("sse-fs", url)).expect("build client");
client.initialize().await.expect("init over sse");
let specs = client.tools_list().await.expect("list over sse");
assert_eq!(specs.len(), 1);
let outcome = client
.tools_call("echo", json!({}))
.await
.expect("call over sse");
assert_eq!(outcome.output.unwrap()["content"], "sse-routed");
}
#[tokio::test]
async fn mcp_http_sse_response_skips_server_notifications() {
let server_notification =
r#"{"jsonrpc":"2.0","method":"notifications/progress","params":{"percent":42}}"#;
let unrelated = r#"{"jsonrpc":"2.0","id":999,"result":{"unrelated":true}}"#;
let (url, _server) = spawn_mock_mcp_sse_server(vec![
vec![
sse_event(server_notification),
sse_event(unrelated),
sse_event(&jsonrpc_result(
1,
json!({"protocolVersion": "2024-11-05", "capabilities": {}}),
)),
],
vec![],
])
.await;
let client = McpClient::new(McpServerConfig::http("sse-noisy", url)).expect("build client");
client
.initialize()
.await
.expect("must skip notifications and find id match");
}
#[tokio::test]
async fn mcp_http_sse_response_propagates_jsonrpc_error() {
let (url, _server) = spawn_mock_mcp_sse_server(vec![vec![sse_event(
&json!({
"jsonrpc": "2.0",
"id": 1,
"error": {"code": -32601, "message": "method not found"}
})
.to_string(),
)]])
.await;
let client = McpClient::new(McpServerConfig::http("sse-err", url)).expect("build");
let err = client.initialize().await.unwrap_err();
match err {
McpError::Server { code, message } => {
assert_eq!(code, -32601);
assert!(message.contains("method not found"));
}
other => panic!("expected Server, got {other:?}"),
}
}
}