capo-agent 0.7.0

Coding-agent library built on motosan-agent-loop. Composable, embeddable.
Documentation
//! MCP server lifecycle: parallel `connect()` with per-server timeout,
//! log capture on failure, ordered `disconnect()` on shutdown.

use std::collections::HashMap;
use std::ops::Deref;
use std::path::{Path, PathBuf};
use std::process::Stdio;
use std::sync::Arc;
use std::time::Duration;

use motosan_agent_loop::mcp::{McpServer, McpServerHttp};
use motosan_agent_loop::{AgentError, Result};
use motosan_agent_tool::ToolDef;
use rmcp::model::{CallToolRequestParams, CallToolResult, RawContent};
use rmcp::service::{Peer, RoleClient, RunningService};
use rmcp::transport::child_process::TokioChildProcess;
use rmcp::ServiceExt;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};

use crate::mcp::config::{McpConfig, McpServerConfig};

pub struct StartedServer {
    pub name: String,
    pub server: Arc<dyn McpServer>,
}

type StdioService = RunningService<RoleClient, ()>;

/// Stdio MCP server with stderr tee'd into `<log_dir>/<server>.stderr.log`.
struct LoggedStdioServer {
    name: String,
    command: String,
    args: Vec<String>,
    env: HashMap<String, String>,
    log_path: PathBuf,
    session: futures::lock::Mutex<Option<StdioService>>,
    stderr_task: futures::lock::Mutex<Option<tokio::task::JoinHandle<()>>>,
}

impl LoggedStdioServer {
    fn new(
        name: String,
        command: String,
        args: Vec<String>,
        env: HashMap<String, String>,
        log_dir: &Path,
    ) -> Self {
        Self {
            log_path: log_dir.join(format!("{}.stderr.log", sanitize_log_name(&name))),
            name,
            command,
            args,
            env,
            session: futures::lock::Mutex::new(None),
            stderr_task: futures::lock::Mutex::new(None),
        }
    }
}

#[async_trait::async_trait]
impl McpServer for LoggedStdioServer {
    fn name(&self) -> &str {
        &self.name
    }

    async fn connect(&self) -> Result<()> {
        let mut cmd = tokio::process::Command::new(&self.command);
        cmd.args(&self.args).stderr(Stdio::piped());
        for (k, v) in &self.env {
            cmd.env(k, v);
        }

        let (transport, stderr) = TokioChildProcess::builder(cmd)
            .stderr(Stdio::piped())
            .spawn()
            .map_err(|e| AgentError::Mcp(e.to_string()))?;
        if let Some(stderr) = stderr {
            let task = spawn_stderr_tee(self.name.clone(), stderr, self.log_path.clone());
            let mut stderr_task = self.stderr_task.lock().await;
            *stderr_task = Some(task);
        }

        let service: StdioService = ().serve(transport).await.map_err(map_init_error)?;
        let mut session = self.session.lock().await;
        *session = Some(service);
        Ok(())
    }

    async fn list_tools(&self) -> Result<Vec<ToolDef>> {
        let session = self.session.lock().await;
        let peer = peer_from_session(&session)?;
        let tools = peer.list_all_tools().await.map_err(map_service_error)?;
        Ok(tools.into_iter().map(rmcp_tool_to_tool_def).collect())
    }

    async fn call_tool(&self, name: &str, args: serde_json::Value) -> Result<String> {
        let session = self.session.lock().await;
        let peer = peer_from_session(&session)?;
        let params = build_call_params(name, args);
        let result = peer.call_tool(params).await.map_err(map_service_error)?;
        Ok(call_tool_result_to_string(&result))
    }

    async fn disconnect(&self) -> Result<()> {
        let mut session = self.session.lock().await;
        if let Some(s) = session.take() {
            let _ = s.cancel().await;
        }
        drop(session);

        let mut stderr_task = self.stderr_task.lock().await;
        if let Some(task) = stderr_task.take() {
            let _ = tokio::time::timeout(Duration::from_secs(1), task).await;
        }
        Ok(())
    }
}

/// Build & connect every enabled server in parallel. Failed/timed-out
/// servers are skipped (logged via `tracing`). Returns only the
/// successfully-connected servers, in arbitrary order.
pub async fn connect_all(config: &McpConfig, log_dir: &Path) -> Vec<StartedServer> {
    let mut futures = Vec::new();
    for (name, server_cfg) in &config.servers {
        let enabled = match server_cfg {
            McpServerConfig::Stdio { enabled, .. } => *enabled,
            McpServerConfig::Http { enabled, .. } => *enabled,
        };
        if !enabled {
            continue;
        }
        let name = name.clone();
        let cfg = server_cfg.clone();
        let log_dir = log_dir.to_path_buf();
        futures.push(tokio::spawn(async move {
            try_connect(name, cfg, log_dir).await
        }));
    }
    let mut started = Vec::new();
    for f in futures {
        if let Ok(Some(s)) = f.await {
            started.push(s);
        }
    }
    started
}

async fn try_connect(
    name: String,
    cfg: McpServerConfig,
    log_dir: PathBuf,
) -> Option<StartedServer> {
    let timeout = match &cfg {
        McpServerConfig::Stdio {
            startup_timeout_ms, ..
        } => *startup_timeout_ms,
        McpServerConfig::Http {
            startup_timeout_ms, ..
        } => *startup_timeout_ms,
    };
    let server: Arc<dyn McpServer> = match cfg {
        McpServerConfig::Stdio {
            command, args, env, ..
        } => Arc::new(LoggedStdioServer::new(
            name.clone(),
            command,
            args,
            env,
            &log_dir,
        )),
        McpServerConfig::Http { url, headers, .. } => {
            let mut s = McpServerHttp::new(&url).with_name(name.clone());
            s.headers = headers;
            Arc::new(s)
        }
    };
    let connect_result =
        tokio::time::timeout(Duration::from_millis(timeout), server.connect()).await;
    match connect_result {
        Ok(Ok(())) => Some(StartedServer { name, server }),
        Ok(Err(e)) => {
            tracing::error!(target: "mcp", server = %name, "connect failed: {e}");
            write_lifecycle_log(&log_dir, &name, &format!("connect failed: {e}\n")).await;
            None
        }
        Err(_) => {
            tracing::error!(target: "mcp", server = %name, "startup_timeout_ms exceeded");
            write_lifecycle_log(&log_dir, &name, "startup_timeout_ms exceeded\n").await;
            None
        }
    }
}

/// Disconnect every server with a 2-second per-server timeout. Errors
/// are logged but never raised — shutdown is best-effort.
pub async fn disconnect_all(servers: &[StartedServer]) {
    for s in servers {
        let _ = tokio::time::timeout(Duration::from_secs(2), s.server.disconnect()).await;
    }
}

/// Convenience: convert started servers into `(name, Arc<dyn McpServer>)`
/// pairs for storage on `App`.
pub fn into_pairs(started: Vec<StartedServer>) -> Vec<(String, Arc<dyn McpServer>)> {
    started.into_iter().map(|s| (s.name, s.server)).collect()
}

fn spawn_stderr_tee(
    name: String,
    stderr: tokio::process::ChildStderr,
    log_path: PathBuf,
) -> tokio::task::JoinHandle<()> {
    tokio::spawn(async move {
        if let Some(parent) = log_path.parent() {
            if let Err(e) = tokio::fs::create_dir_all(parent).await {
                tracing::warn!(target: "mcp", server = %name, path = %parent.display(), "create log dir failed: {e}");
            }
        }
        let mut file = match tokio::fs::OpenOptions::new()
            .create(true)
            .append(true)
            .open(&log_path)
            .await
        {
            Ok(file) => Some(file),
            Err(e) => {
                tracing::warn!(target: "mcp", server = %name, path = %log_path.display(), "open stderr log failed: {e}");
                None
            }
        };
        let mut lines = BufReader::new(stderr).lines();
        loop {
            match lines.next_line().await {
                Ok(Some(line)) => {
                    tracing::warn!(target: "mcp", server = %name, "{line}");
                    if let Some(file) = file.as_mut() {
                        let _ = file.write_all(line.as_bytes()).await;
                        let _ = file.write_all(b"\n").await;
                    }
                }
                Ok(None) => break,
                Err(e) => {
                    tracing::warn!(target: "mcp", server = %name, "read stderr failed: {e}");
                    break;
                }
            }
        }
    })
}

async fn write_lifecycle_log(log_dir: &Path, name: &str, message: &str) {
    let path = log_dir.join(format!("{}.stderr.log", sanitize_log_name(name)));
    if let Some(parent) = path.parent() {
        if let Err(e) = tokio::fs::create_dir_all(parent).await {
            tracing::warn!(target: "mcp", server = %name, path = %parent.display(), "create log dir failed: {e}");
            return;
        }
    }
    match tokio::fs::OpenOptions::new()
        .create(true)
        .append(true)
        .open(&path)
        .await
    {
        Ok(mut file) => {
            let _ = file.write_all(message.as_bytes()).await;
        }
        Err(e) => {
            tracing::warn!(target: "mcp", server = %name, path = %path.display(), "write lifecycle log failed: {e}");
        }
    }
}

fn sanitize_log_name(name: &str) -> String {
    name.chars()
        .map(|c| {
            if c.is_ascii_alphanumeric() || matches!(c, '-' | '_' | '.') {
                c
            } else {
                '_'
            }
        })
        .collect()
}

fn peer_from_session(
    session: &Option<impl Deref<Target = Peer<RoleClient>>>,
) -> Result<Peer<RoleClient>> {
    session
        .as_ref()
        .map(|s| Deref::deref(s).clone())
        .ok_or_else(|| AgentError::Mcp("not connected".into()))
}

fn rmcp_tool_to_tool_def(tool: rmcp::model::Tool) -> ToolDef {
    ToolDef {
        name: tool.name.into_owned(),
        description: tool.description.map(|d| d.into_owned()).unwrap_or_default(),
        input_schema: serde_json::Value::Object(tool.input_schema.as_ref().clone()),
    }
}

fn build_call_params(name: &str, args: serde_json::Value) -> CallToolRequestParams {
    let arguments = match args {
        serde_json::Value::Object(map) => Some(map),
        _ => None,
    };
    let mut params = CallToolRequestParams::new(name.to_string());
    params.arguments = arguments;
    params
}

fn call_tool_result_to_string(result: &CallToolResult) -> String {
    result
        .content
        .iter()
        .filter_map(|c| match &c.raw {
            RawContent::Text(t) => Some(t.text.clone()),
            _ => None,
        })
        .collect::<Vec<_>>()
        .join("\n")
}

fn map_service_error(e: rmcp::service::ServiceError) -> AgentError {
    AgentError::Mcp(e.to_string())
}

fn map_init_error(e: rmcp::service::ClientInitializeError) -> AgentError {
    AgentError::Mcp(e.to_string())
}