Skip to main content

capo_agent/mcp/
lifecycle.rs

1//! MCP server lifecycle: parallel `connect()` with per-server timeout,
2//! log capture on failure, ordered `disconnect()` on shutdown.
3
4use std::collections::HashMap;
5use std::ops::Deref;
6use std::path::{Path, PathBuf};
7use std::process::Stdio;
8use std::sync::Arc;
9use std::time::Duration;
10
11use motosan_agent_loop::mcp::{McpServer, McpServerHttp};
12use motosan_agent_loop::{AgentError, Result};
13use motosan_agent_tool::ToolDef;
14use rmcp::model::{CallToolRequestParams, CallToolResult, RawContent};
15use rmcp::service::{Peer, RoleClient, RunningService};
16use rmcp::transport::child_process::TokioChildProcess;
17use rmcp::ServiceExt;
18use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
19
20use crate::mcp::config::{McpConfig, McpServerConfig};
21
22pub struct StartedServer {
23    pub name: String,
24    pub server: Arc<dyn McpServer>,
25}
26
27type StdioService = RunningService<RoleClient, ()>;
28
29/// Stdio MCP server with stderr tee'd into `<log_dir>/<server>.stderr.log`.
30struct LoggedStdioServer {
31    name: String,
32    command: String,
33    args: Vec<String>,
34    env: HashMap<String, String>,
35    log_path: PathBuf,
36    session: futures::lock::Mutex<Option<StdioService>>,
37    stderr_task: futures::lock::Mutex<Option<tokio::task::JoinHandle<()>>>,
38}
39
40impl LoggedStdioServer {
41    fn new(
42        name: String,
43        command: String,
44        args: Vec<String>,
45        env: HashMap<String, String>,
46        log_dir: &Path,
47    ) -> Self {
48        Self {
49            log_path: log_dir.join(format!("{}.stderr.log", sanitize_log_name(&name))),
50            name,
51            command,
52            args,
53            env,
54            session: futures::lock::Mutex::new(None),
55            stderr_task: futures::lock::Mutex::new(None),
56        }
57    }
58}
59
60#[async_trait::async_trait]
61impl McpServer for LoggedStdioServer {
62    fn name(&self) -> &str {
63        &self.name
64    }
65
66    async fn connect(&self) -> Result<()> {
67        let mut cmd = tokio::process::Command::new(&self.command);
68        cmd.args(&self.args).stderr(Stdio::piped());
69        for (k, v) in &self.env {
70            cmd.env(k, v);
71        }
72
73        let (transport, stderr) = TokioChildProcess::builder(cmd)
74            .stderr(Stdio::piped())
75            .spawn()
76            .map_err(|e| AgentError::Mcp(e.to_string()))?;
77        if let Some(stderr) = stderr {
78            let task = spawn_stderr_tee(self.name.clone(), stderr, self.log_path.clone());
79            let mut stderr_task = self.stderr_task.lock().await;
80            *stderr_task = Some(task);
81        }
82
83        let service: StdioService = ().serve(transport).await.map_err(map_init_error)?;
84        let mut session = self.session.lock().await;
85        *session = Some(service);
86        Ok(())
87    }
88
89    async fn list_tools(&self) -> Result<Vec<ToolDef>> {
90        let session = self.session.lock().await;
91        let peer = peer_from_session(&session)?;
92        let tools = peer.list_all_tools().await.map_err(map_service_error)?;
93        Ok(tools.into_iter().map(rmcp_tool_to_tool_def).collect())
94    }
95
96    async fn call_tool(&self, name: &str, args: serde_json::Value) -> Result<String> {
97        let session = self.session.lock().await;
98        let peer = peer_from_session(&session)?;
99        let params = build_call_params(name, args);
100        let result = peer.call_tool(params).await.map_err(map_service_error)?;
101        Ok(call_tool_result_to_string(&result))
102    }
103
104    async fn disconnect(&self) -> Result<()> {
105        let mut session = self.session.lock().await;
106        if let Some(s) = session.take() {
107            let _ = s.cancel().await;
108        }
109        drop(session);
110
111        let mut stderr_task = self.stderr_task.lock().await;
112        if let Some(task) = stderr_task.take() {
113            let _ = tokio::time::timeout(Duration::from_secs(1), task).await;
114        }
115        Ok(())
116    }
117}
118
119/// Build & connect every enabled server in parallel. Failed/timed-out
120/// servers are skipped (logged via `tracing`). Returns only the
121/// successfully-connected servers, in arbitrary order.
122pub async fn connect_all(config: &McpConfig, log_dir: &Path) -> Vec<StartedServer> {
123    let mut futures = Vec::new();
124    for (name, server_cfg) in &config.servers {
125        let enabled = match server_cfg {
126            McpServerConfig::Stdio { enabled, .. } => *enabled,
127            McpServerConfig::Http { enabled, .. } => *enabled,
128        };
129        if !enabled {
130            continue;
131        }
132        let name = name.clone();
133        let cfg = server_cfg.clone();
134        let log_dir = log_dir.to_path_buf();
135        futures.push(tokio::spawn(async move {
136            try_connect(name, cfg, log_dir).await
137        }));
138    }
139    let mut started = Vec::new();
140    for f in futures {
141        if let Ok(Some(s)) = f.await {
142            started.push(s);
143        }
144    }
145    started
146}
147
148async fn try_connect(
149    name: String,
150    cfg: McpServerConfig,
151    log_dir: PathBuf,
152) -> Option<StartedServer> {
153    let timeout = match &cfg {
154        McpServerConfig::Stdio {
155            startup_timeout_ms, ..
156        } => *startup_timeout_ms,
157        McpServerConfig::Http {
158            startup_timeout_ms, ..
159        } => *startup_timeout_ms,
160    };
161    let server: Arc<dyn McpServer> = match cfg {
162        McpServerConfig::Stdio {
163            command, args, env, ..
164        } => Arc::new(LoggedStdioServer::new(
165            name.clone(),
166            command,
167            args,
168            env,
169            &log_dir,
170        )),
171        McpServerConfig::Http { url, headers, .. } => {
172            let mut s = McpServerHttp::new(&url).with_name(name.clone());
173            s.headers = headers;
174            Arc::new(s)
175        }
176    };
177    let connect_result =
178        tokio::time::timeout(Duration::from_millis(timeout), server.connect()).await;
179    match connect_result {
180        Ok(Ok(())) => Some(StartedServer { name, server }),
181        Ok(Err(e)) => {
182            tracing::error!(target: "mcp", server = %name, "connect failed: {e}");
183            write_lifecycle_log(&log_dir, &name, &format!("connect failed: {e}\n")).await;
184            None
185        }
186        Err(_) => {
187            tracing::error!(target: "mcp", server = %name, "startup_timeout_ms exceeded");
188            write_lifecycle_log(&log_dir, &name, "startup_timeout_ms exceeded\n").await;
189            None
190        }
191    }
192}
193
194/// Disconnect every server with a 2-second per-server timeout. Errors
195/// are logged but never raised — shutdown is best-effort.
196pub async fn disconnect_all(servers: &[StartedServer]) {
197    for s in servers {
198        let _ = tokio::time::timeout(Duration::from_secs(2), s.server.disconnect()).await;
199    }
200}
201
202/// Convenience: convert started servers into `(name, Arc<dyn McpServer>)`
203/// pairs for storage on `App`.
204pub fn into_pairs(started: Vec<StartedServer>) -> Vec<(String, Arc<dyn McpServer>)> {
205    started.into_iter().map(|s| (s.name, s.server)).collect()
206}
207
208fn spawn_stderr_tee(
209    name: String,
210    stderr: tokio::process::ChildStderr,
211    log_path: PathBuf,
212) -> tokio::task::JoinHandle<()> {
213    tokio::spawn(async move {
214        if let Some(parent) = log_path.parent() {
215            if let Err(e) = tokio::fs::create_dir_all(parent).await {
216                tracing::warn!(target: "mcp", server = %name, path = %parent.display(), "create log dir failed: {e}");
217            }
218        }
219        let mut file = match tokio::fs::OpenOptions::new()
220            .create(true)
221            .append(true)
222            .open(&log_path)
223            .await
224        {
225            Ok(file) => Some(file),
226            Err(e) => {
227                tracing::warn!(target: "mcp", server = %name, path = %log_path.display(), "open stderr log failed: {e}");
228                None
229            }
230        };
231        let mut lines = BufReader::new(stderr).lines();
232        loop {
233            match lines.next_line().await {
234                Ok(Some(line)) => {
235                    tracing::warn!(target: "mcp", server = %name, "{line}");
236                    if let Some(file) = file.as_mut() {
237                        let _ = file.write_all(line.as_bytes()).await;
238                        let _ = file.write_all(b"\n").await;
239                    }
240                }
241                Ok(None) => break,
242                Err(e) => {
243                    tracing::warn!(target: "mcp", server = %name, "read stderr failed: {e}");
244                    break;
245                }
246            }
247        }
248    })
249}
250
251async fn write_lifecycle_log(log_dir: &Path, name: &str, message: &str) {
252    let path = log_dir.join(format!("{}.stderr.log", sanitize_log_name(name)));
253    if let Some(parent) = path.parent() {
254        if let Err(e) = tokio::fs::create_dir_all(parent).await {
255            tracing::warn!(target: "mcp", server = %name, path = %parent.display(), "create log dir failed: {e}");
256            return;
257        }
258    }
259    match tokio::fs::OpenOptions::new()
260        .create(true)
261        .append(true)
262        .open(&path)
263        .await
264    {
265        Ok(mut file) => {
266            let _ = file.write_all(message.as_bytes()).await;
267        }
268        Err(e) => {
269            tracing::warn!(target: "mcp", server = %name, path = %path.display(), "write lifecycle log failed: {e}");
270        }
271    }
272}
273
274fn sanitize_log_name(name: &str) -> String {
275    name.chars()
276        .map(|c| {
277            if c.is_ascii_alphanumeric() || matches!(c, '-' | '_' | '.') {
278                c
279            } else {
280                '_'
281            }
282        })
283        .collect()
284}
285
286fn peer_from_session(
287    session: &Option<impl Deref<Target = Peer<RoleClient>>>,
288) -> Result<Peer<RoleClient>> {
289    session
290        .as_ref()
291        .map(|s| Deref::deref(s).clone())
292        .ok_or_else(|| AgentError::Mcp("not connected".into()))
293}
294
295fn rmcp_tool_to_tool_def(tool: rmcp::model::Tool) -> ToolDef {
296    ToolDef {
297        name: tool.name.into_owned(),
298        description: tool.description.map(|d| d.into_owned()).unwrap_or_default(),
299        input_schema: serde_json::Value::Object(tool.input_schema.as_ref().clone()),
300    }
301}
302
303fn build_call_params(name: &str, args: serde_json::Value) -> CallToolRequestParams {
304    let arguments = match args {
305        serde_json::Value::Object(map) => Some(map),
306        _ => None,
307    };
308    let mut params = CallToolRequestParams::new(name.to_string());
309    params.arguments = arguments;
310    params
311}
312
313fn call_tool_result_to_string(result: &CallToolResult) -> String {
314    result
315        .content
316        .iter()
317        .filter_map(|c| match &c.raw {
318            RawContent::Text(t) => Some(t.text.clone()),
319            _ => None,
320        })
321        .collect::<Vec<_>>()
322        .join("\n")
323}
324
325fn map_service_error(e: rmcp::service::ServiceError) -> AgentError {
326    AgentError::Mcp(e.to_string())
327}
328
329fn map_init_error(e: rmcp::service::ClientInitializeError) -> AgentError {
330    AgentError::Mcp(e.to_string())
331}