use std::collections::HashMap;
use std::ops::Deref;
use std::path::{Path, PathBuf};
use std::process::Stdio;
use std::sync::Arc;
use std::time::Duration;
use motosan_agent_loop::mcp::{McpServer, McpServerHttp};
use motosan_agent_loop::{AgentError, Result};
use motosan_agent_tool::ToolDef;
use rmcp::model::{CallToolRequestParams, CallToolResult, RawContent};
use rmcp::service::{Peer, RoleClient, RunningService};
use rmcp::transport::child_process::TokioChildProcess;
use rmcp::ServiceExt;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use crate::mcp::config::{McpConfig, McpServerConfig};
pub struct StartedServer {
pub name: String,
pub server: Arc<dyn McpServer>,
}
type StdioService = RunningService<RoleClient, ()>;
struct LoggedStdioServer {
name: String,
command: String,
args: Vec<String>,
env: HashMap<String, String>,
log_path: PathBuf,
session: futures::lock::Mutex<Option<StdioService>>,
stderr_task: futures::lock::Mutex<Option<tokio::task::JoinHandle<()>>>,
}
impl LoggedStdioServer {
fn new(
name: String,
command: String,
args: Vec<String>,
env: HashMap<String, String>,
log_dir: &Path,
) -> Self {
Self {
log_path: log_dir.join(format!("{}.stderr.log", sanitize_log_name(&name))),
name,
command,
args,
env,
session: futures::lock::Mutex::new(None),
stderr_task: futures::lock::Mutex::new(None),
}
}
}
#[async_trait::async_trait]
impl McpServer for LoggedStdioServer {
fn name(&self) -> &str {
&self.name
}
async fn connect(&self) -> Result<()> {
let mut cmd = tokio::process::Command::new(&self.command);
cmd.args(&self.args).stderr(Stdio::piped());
for (k, v) in &self.env {
cmd.env(k, v);
}
let (transport, stderr) = TokioChildProcess::builder(cmd)
.stderr(Stdio::piped())
.spawn()
.map_err(|e| AgentError::Mcp(e.to_string()))?;
if let Some(stderr) = stderr {
let task = spawn_stderr_tee(self.name.clone(), stderr, self.log_path.clone());
let mut stderr_task = self.stderr_task.lock().await;
*stderr_task = Some(task);
}
let service: StdioService = ().serve(transport).await.map_err(map_init_error)?;
let mut session = self.session.lock().await;
*session = Some(service);
Ok(())
}
async fn list_tools(&self) -> Result<Vec<ToolDef>> {
let session = self.session.lock().await;
let peer = peer_from_session(&session)?;
let tools = peer.list_all_tools().await.map_err(map_service_error)?;
Ok(tools.into_iter().map(rmcp_tool_to_tool_def).collect())
}
async fn call_tool(&self, name: &str, args: serde_json::Value) -> Result<String> {
let session = self.session.lock().await;
let peer = peer_from_session(&session)?;
let params = build_call_params(name, args);
let result = peer.call_tool(params).await.map_err(map_service_error)?;
Ok(call_tool_result_to_string(&result))
}
async fn disconnect(&self) -> Result<()> {
let mut session = self.session.lock().await;
if let Some(s) = session.take() {
let _ = s.cancel().await;
}
drop(session);
let mut stderr_task = self.stderr_task.lock().await;
if let Some(task) = stderr_task.take() {
let _ = tokio::time::timeout(Duration::from_secs(1), task).await;
}
Ok(())
}
}
pub async fn connect_all(config: &McpConfig, log_dir: &Path) -> Vec<StartedServer> {
let mut futures = Vec::new();
for (name, server_cfg) in &config.servers {
let enabled = match server_cfg {
McpServerConfig::Stdio { enabled, .. } => *enabled,
McpServerConfig::Http { enabled, .. } => *enabled,
};
if !enabled {
continue;
}
let name = name.clone();
let cfg = server_cfg.clone();
let log_dir = log_dir.to_path_buf();
futures.push(tokio::spawn(async move {
try_connect(name, cfg, log_dir).await
}));
}
let mut started = Vec::new();
for f in futures {
if let Ok(Some(s)) = f.await {
started.push(s);
}
}
started
}
async fn try_connect(
name: String,
cfg: McpServerConfig,
log_dir: PathBuf,
) -> Option<StartedServer> {
let timeout = match &cfg {
McpServerConfig::Stdio {
startup_timeout_ms, ..
} => *startup_timeout_ms,
McpServerConfig::Http {
startup_timeout_ms, ..
} => *startup_timeout_ms,
};
let server: Arc<dyn McpServer> = match cfg {
McpServerConfig::Stdio {
command, args, env, ..
} => Arc::new(LoggedStdioServer::new(
name.clone(),
command,
args,
env,
&log_dir,
)),
McpServerConfig::Http { url, headers, .. } => {
let mut s = McpServerHttp::new(&url).with_name(name.clone());
s.headers = headers;
Arc::new(s)
}
};
let connect_result =
tokio::time::timeout(Duration::from_millis(timeout), server.connect()).await;
match connect_result {
Ok(Ok(())) => Some(StartedServer { name, server }),
Ok(Err(e)) => {
tracing::error!(target: "mcp", server = %name, "connect failed: {e}");
write_lifecycle_log(&log_dir, &name, &format!("connect failed: {e}\n")).await;
None
}
Err(_) => {
tracing::error!(target: "mcp", server = %name, "startup_timeout_ms exceeded");
write_lifecycle_log(&log_dir, &name, "startup_timeout_ms exceeded\n").await;
None
}
}
}
pub async fn disconnect_all(servers: &[StartedServer]) {
for s in servers {
let _ = tokio::time::timeout(Duration::from_secs(2), s.server.disconnect()).await;
}
}
pub fn into_pairs(started: Vec<StartedServer>) -> Vec<(String, Arc<dyn McpServer>)> {
started.into_iter().map(|s| (s.name, s.server)).collect()
}
fn spawn_stderr_tee(
name: String,
stderr: tokio::process::ChildStderr,
log_path: PathBuf,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
if let Some(parent) = log_path.parent() {
if let Err(e) = tokio::fs::create_dir_all(parent).await {
tracing::warn!(target: "mcp", server = %name, path = %parent.display(), "create log dir failed: {e}");
}
}
let mut file = match tokio::fs::OpenOptions::new()
.create(true)
.append(true)
.open(&log_path)
.await
{
Ok(file) => Some(file),
Err(e) => {
tracing::warn!(target: "mcp", server = %name, path = %log_path.display(), "open stderr log failed: {e}");
None
}
};
let mut lines = BufReader::new(stderr).lines();
loop {
match lines.next_line().await {
Ok(Some(line)) => {
tracing::warn!(target: "mcp", server = %name, "{line}");
if let Some(file) = file.as_mut() {
let _ = file.write_all(line.as_bytes()).await;
let _ = file.write_all(b"\n").await;
}
}
Ok(None) => break,
Err(e) => {
tracing::warn!(target: "mcp", server = %name, "read stderr failed: {e}");
break;
}
}
}
})
}
async fn write_lifecycle_log(log_dir: &Path, name: &str, message: &str) {
let path = log_dir.join(format!("{}.stderr.log", sanitize_log_name(name)));
if let Some(parent) = path.parent() {
if let Err(e) = tokio::fs::create_dir_all(parent).await {
tracing::warn!(target: "mcp", server = %name, path = %parent.display(), "create log dir failed: {e}");
return;
}
}
match tokio::fs::OpenOptions::new()
.create(true)
.append(true)
.open(&path)
.await
{
Ok(mut file) => {
let _ = file.write_all(message.as_bytes()).await;
}
Err(e) => {
tracing::warn!(target: "mcp", server = %name, path = %path.display(), "write lifecycle log failed: {e}");
}
}
}
fn sanitize_log_name(name: &str) -> String {
name.chars()
.map(|c| {
if c.is_ascii_alphanumeric() || matches!(c, '-' | '_' | '.') {
c
} else {
'_'
}
})
.collect()
}
fn peer_from_session(
session: &Option<impl Deref<Target = Peer<RoleClient>>>,
) -> Result<Peer<RoleClient>> {
session
.as_ref()
.map(|s| Deref::deref(s).clone())
.ok_or_else(|| AgentError::Mcp("not connected".into()))
}
fn rmcp_tool_to_tool_def(tool: rmcp::model::Tool) -> ToolDef {
ToolDef {
name: tool.name.into_owned(),
description: tool.description.map(|d| d.into_owned()).unwrap_or_default(),
input_schema: serde_json::Value::Object(tool.input_schema.as_ref().clone()),
}
}
fn build_call_params(name: &str, args: serde_json::Value) -> CallToolRequestParams {
let arguments = match args {
serde_json::Value::Object(map) => Some(map),
_ => None,
};
let mut params = CallToolRequestParams::new(name.to_string());
params.arguments = arguments;
params
}
fn call_tool_result_to_string(result: &CallToolResult) -> String {
result
.content
.iter()
.filter_map(|c| match &c.raw {
RawContent::Text(t) => Some(t.text.clone()),
_ => None,
})
.collect::<Vec<_>>()
.join("\n")
}
fn map_service_error(e: rmcp::service::ServiceError) -> AgentError {
AgentError::Mcp(e.to_string())
}
fn map_init_error(e: rmcp::service::ClientInitializeError) -> AgentError {
AgentError::Mcp(e.to_string())
}