mod client;
use std::path::PathBuf;
use std::sync::OnceLock;
use std::sync::atomic::AtomicU64;
use std::time::Duration;
use anyhow::{Result, anyhow, bail};
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use tokio::io::{BufReader, BufWriter};
use tokio::process::{Child, ChildStdin, ChildStdout};
use tracing::{debug, info, warn};
fn announced_plugins() -> &'static std::sync::Mutex<std::collections::HashSet<PathBuf>> {
static ANNOUNCED: OnceLock<std::sync::Mutex<std::collections::HashSet<PathBuf>>> =
OnceLock::new();
ANNOUNCED.get_or_init(|| std::sync::Mutex::new(std::collections::HashSet::new()))
}
fn plugin_slug(binary: &str) -> String {
std::path::Path::new(binary)
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("plugin")
.to_string()
}
fn plugin_log_path(slug: &str) -> PathBuf {
let base = std::env::var_os("HOME")
.map(|h| PathBuf::from(h).join(".trusty-agents").join("logs"))
.unwrap_or_else(|| PathBuf::from("/tmp"));
let _ = std::fs::create_dir_all(&base);
if std::env::var_os("HOME").is_some() {
base.join(format!("{slug}-stderr.log"))
} else {
base.join(format!("trusty-agents-{slug}.log"))
}
}
pub fn plugin_stderr_stdio(binary: &str) -> std::process::Stdio {
let slug = plugin_slug(binary);
let path = plugin_log_path(&slug);
let was_new = std::fs::metadata(&path).is_err();
if let Ok(mut set) = announced_plugins().lock()
&& !set.contains(&path)
{
set.insert(path.clone());
if was_new {
info!(
plugin = %slug,
log_path = %path.display(),
"(created) plugin stderr redirected to log file"
);
} else {
debug!(
plugin = %slug,
log_path = %path.display(),
"plugin stderr redirected to log file"
);
}
}
match std::fs::OpenOptions::new()
.create(true)
.append(true)
.open(&path)
{
Ok(file) => std::process::Stdio::from(file),
Err(e) => {
warn!(
plugin = %slug,
log_path = %path.display(),
error = %e,
"StdioMcpClient: failed to open plugin log; suppressing stderr"
);
std::process::Stdio::null()
}
}
}
pub const MCP_PROTOCOL_VERSION: &str = "2024-11-05";
pub(super) const CALL_TIMEOUT: Duration = Duration::from_secs(30);
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerInfo {
pub name: String,
pub version: String,
pub protocol_version: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpTool {
pub name: String,
pub description: Option<String>,
pub input_schema: Value,
}
pub struct StdioMcpClient {
pub(super) child: Child,
pub(super) stdin: BufWriter<ChildStdin>,
pub(super) stdout: BufReader<ChildStdout>,
pub(super) next_id: AtomicU64,
pub(super) binary: String,
pub(super) args: Vec<String>,
pub(super) client_name: String,
}
impl Drop for StdioMcpClient {
fn drop(&mut self) {
let _ = self.child.start_kill();
}
}
pub(super) fn build_initialize_request(id: u64, client_name: &str) -> Value {
json!({
"jsonrpc": "2.0",
"id": id,
"method": "initialize",
"params": {
"protocolVersion": MCP_PROTOCOL_VERSION,
"capabilities": {},
"clientInfo": {
"name": client_name,
"version": env!("CARGO_PKG_VERSION"),
}
}
})
}
pub(super) fn extract_result(resp: Value) -> Result<Value> {
if let Some(err) = resp.get("error") {
let code = err.get("code").and_then(|v| v.as_i64()).unwrap_or(0);
let message = err
.get("message")
.and_then(|v| v.as_str())
.unwrap_or("unknown error");
bail!("JSON-RPC error {code}: {message}");
}
resp.get("result")
.cloned()
.ok_or_else(|| anyhow!("JSON-RPC response missing result"))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn initialize_envelope_is_well_formed() {
let req = build_initialize_request(7, "trusty-agents");
assert_eq!(req["jsonrpc"], "2.0");
assert_eq!(req["id"], 7);
assert_eq!(req["method"], "initialize");
assert_eq!(req["params"]["protocolVersion"], MCP_PROTOCOL_VERSION);
assert_eq!(req["params"]["clientInfo"]["name"], "trusty-agents");
assert!(req["params"]["capabilities"].is_object());
let req2 = build_initialize_request(42, "trusty-console");
assert_eq!(req2["params"]["clientInfo"]["name"], "trusty-console");
}
#[test]
fn extract_result_maps_error_object() {
let resp = json!({
"jsonrpc": "2.0",
"id": 1,
"error": { "code": -32601, "message": "method not found" }
});
let err = extract_result(resp).unwrap_err();
let msg = err.to_string();
assert!(msg.contains("-32601"), "missing code: {msg}");
assert!(msg.contains("method not found"), "missing message: {msg}");
}
#[test]
fn extract_result_returns_inner_result() {
let resp = json!({
"jsonrpc": "2.0",
"id": 1,
"result": { "tools": [] }
});
let v = extract_result(resp).unwrap();
assert_eq!(v, json!({ "tools": [] }));
}
#[test]
fn extract_result_errors_when_missing_result() {
let resp = json!({ "jsonrpc": "2.0", "id": 1 });
assert!(extract_result(resp).is_err());
}
#[tokio::test]
async fn spawn_missing_binary_errors() {
let r = StdioMcpClient::spawn("/nonexistent/mcp/binary/xyzzy", &[], "test-client").await;
assert!(r.is_err());
}
#[tokio::test]
#[cfg(unix)]
async fn is_alive_returns_false_after_child_exits() {
let mut client = StdioMcpClient::spawn("sh", &["-c", "exit 0"], "test-client")
.await
.unwrap();
for _ in 0..50 {
if !client.is_alive() {
break;
}
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
}
assert!(
!client.is_alive(),
"child should be reported dead after `exit 0`"
);
}
#[tokio::test]
#[cfg(unix)]
async fn call_tool_errors_when_respawn_unavailable() {
let mut client = StdioMcpClient::spawn("sh", &["-c", "exit 0"], "test-client")
.await
.unwrap();
client.binary = "/nonexistent/mcp/binary/xyzzy-respawn".to_string();
client.args.clear();
for _ in 0..50 {
if !client.is_alive() {
break;
}
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
}
let start = std::time::Instant::now();
let r = client.call_tool("anything", json!({})).await;
let elapsed = start.elapsed();
assert!(r.is_err(), "call_tool should error when respawn fails");
assert!(
elapsed < std::time::Duration::from_secs(5),
"call_tool should fail fast (<5s), took {elapsed:?}"
);
}
#[tokio::test]
#[cfg(unix)]
async fn read_line_skips_non_json_prefix_lines() {
let script = r#"printf 'trusty-memory v0.1.14 — HTTP admin panel: http://127.0.0.1:9999\n{"jsonrpc":"2.0","id":1,"result":{"ok":true}}\n'; sleep 1"#;
let mut client = StdioMcpClient::spawn("sh", &["-c", script], "test-client")
.await
.unwrap();
let frame = client.read_line().await.unwrap();
assert_eq!(frame["jsonrpc"], "2.0");
assert_eq!(frame["id"], 1);
assert_eq!(frame["result"]["ok"], true);
}
#[tokio::test]
#[cfg(unix)]
async fn ids_are_monotonic() {
let client = StdioMcpClient::spawn("cat", &[], "test-client")
.await
.unwrap();
let a = client.alloc_id();
let b = client.alloc_id();
let c = client.alloc_id();
assert!(a < b && b < c);
}
#[test]
fn plugin_slug_extracts_stem() {
assert_eq!(plugin_slug("/usr/bin/trusty-search"), "trusty-search");
assert_eq!(plugin_slug("trusty-memory"), "trusty-memory");
assert_eq!(plugin_slug(""), "plugin"); }
}