1use std::collections::HashMap;
5use std::ops::Deref;
6use std::path::{Path, PathBuf};
7use std::process::Stdio;
8use std::sync::Arc;
9use std::time::Duration;
10
11use motosan_agent_loop::mcp::{McpServer, McpServerHttp};
12use motosan_agent_loop::{AgentError, Result};
13use motosan_agent_tool::ToolDef;
14use rmcp::model::{CallToolRequestParams, CallToolResult, RawContent};
15use rmcp::service::{Peer, RoleClient, RunningService};
16use rmcp::transport::child_process::TokioChildProcess;
17use rmcp::ServiceExt;
18use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
19
20use crate::mcp::config::{McpConfig, McpServerConfig};
21
22pub struct StartedServer {
23 pub name: String,
24 pub server: Arc<dyn McpServer>,
25}
26
27type StdioService = RunningService<RoleClient, ()>;
28
29struct LoggedStdioServer {
31 name: String,
32 command: String,
33 args: Vec<String>,
34 env: HashMap<String, String>,
35 log_path: PathBuf,
36 session: futures::lock::Mutex<Option<StdioService>>,
37 stderr_task: futures::lock::Mutex<Option<tokio::task::JoinHandle<()>>>,
38}
39
40impl LoggedStdioServer {
41 fn new(
42 name: String,
43 command: String,
44 args: Vec<String>,
45 env: HashMap<String, String>,
46 log_dir: &Path,
47 ) -> Self {
48 Self {
49 log_path: log_dir.join(format!("{}.stderr.log", sanitize_log_name(&name))),
50 name,
51 command,
52 args,
53 env,
54 session: futures::lock::Mutex::new(None),
55 stderr_task: futures::lock::Mutex::new(None),
56 }
57 }
58}
59
60#[async_trait::async_trait]
61impl McpServer for LoggedStdioServer {
62 fn name(&self) -> &str {
63 &self.name
64 }
65
66 async fn connect(&self) -> Result<()> {
67 let mut cmd = tokio::process::Command::new(&self.command);
68 cmd.args(&self.args).stderr(Stdio::piped());
69 for (k, v) in &self.env {
70 cmd.env(k, v);
71 }
72
73 let (transport, stderr) = TokioChildProcess::builder(cmd)
74 .stderr(Stdio::piped())
75 .spawn()
76 .map_err(|e| AgentError::Mcp(e.to_string()))?;
77 if let Some(stderr) = stderr {
78 let task = spawn_stderr_tee(self.name.clone(), stderr, self.log_path.clone());
79 let mut stderr_task = self.stderr_task.lock().await;
80 *stderr_task = Some(task);
81 }
82
83 let service: StdioService = ().serve(transport).await.map_err(map_init_error)?;
84 let mut session = self.session.lock().await;
85 *session = Some(service);
86 Ok(())
87 }
88
89 async fn list_tools(&self) -> Result<Vec<ToolDef>> {
90 let session = self.session.lock().await;
91 let peer = peer_from_session(&session)?;
92 let tools = peer.list_all_tools().await.map_err(map_service_error)?;
93 Ok(tools.into_iter().map(rmcp_tool_to_tool_def).collect())
94 }
95
96 async fn call_tool(&self, name: &str, args: serde_json::Value) -> Result<String> {
97 let session = self.session.lock().await;
98 let peer = peer_from_session(&session)?;
99 let params = build_call_params(name, args);
100 let result = peer.call_tool(params).await.map_err(map_service_error)?;
101 Ok(call_tool_result_to_string(&result))
102 }
103
104 async fn disconnect(&self) -> Result<()> {
105 let mut session = self.session.lock().await;
106 if let Some(s) = session.take() {
107 let _ = s.cancel().await;
108 }
109 drop(session);
110
111 let mut stderr_task = self.stderr_task.lock().await;
112 if let Some(task) = stderr_task.take() {
113 let _ = tokio::time::timeout(Duration::from_secs(1), task).await;
114 }
115 Ok(())
116 }
117}
118
119pub async fn connect_all(config: &McpConfig, log_dir: &Path) -> Vec<StartedServer> {
123 let mut futures = Vec::new();
124 for (name, server_cfg) in &config.servers {
125 let enabled = match server_cfg {
126 McpServerConfig::Stdio { enabled, .. } => *enabled,
127 McpServerConfig::Http { enabled, .. } => *enabled,
128 };
129 if !enabled {
130 continue;
131 }
132 let name = name.clone();
133 let cfg = server_cfg.clone();
134 let log_dir = log_dir.to_path_buf();
135 futures.push(tokio::spawn(async move {
136 try_connect(name, cfg, log_dir).await
137 }));
138 }
139 let mut started = Vec::new();
140 for f in futures {
141 if let Ok(Some(s)) = f.await {
142 started.push(s);
143 }
144 }
145 started
146}
147
148async fn try_connect(
149 name: String,
150 cfg: McpServerConfig,
151 log_dir: PathBuf,
152) -> Option<StartedServer> {
153 let timeout = match &cfg {
154 McpServerConfig::Stdio {
155 startup_timeout_ms, ..
156 } => *startup_timeout_ms,
157 McpServerConfig::Http {
158 startup_timeout_ms, ..
159 } => *startup_timeout_ms,
160 };
161 let server: Arc<dyn McpServer> = match cfg {
162 McpServerConfig::Stdio {
163 command, args, env, ..
164 } => Arc::new(LoggedStdioServer::new(
165 name.clone(),
166 command,
167 args,
168 env,
169 &log_dir,
170 )),
171 McpServerConfig::Http { url, headers, .. } => {
172 let mut s = McpServerHttp::new(&url).with_name(name.clone());
173 s.headers = headers;
174 Arc::new(s)
175 }
176 };
177 let connect_result =
178 tokio::time::timeout(Duration::from_millis(timeout), server.connect()).await;
179 match connect_result {
180 Ok(Ok(())) => Some(StartedServer { name, server }),
181 Ok(Err(e)) => {
182 tracing::error!(target: "mcp", server = %name, "connect failed: {e}");
183 write_lifecycle_log(&log_dir, &name, &format!("connect failed: {e}\n")).await;
184 None
185 }
186 Err(_) => {
187 tracing::error!(target: "mcp", server = %name, "startup_timeout_ms exceeded");
188 write_lifecycle_log(&log_dir, &name, "startup_timeout_ms exceeded\n").await;
189 None
190 }
191 }
192}
193
194pub async fn disconnect_all(servers: &[StartedServer]) {
197 for s in servers {
198 let _ = tokio::time::timeout(Duration::from_secs(2), s.server.disconnect()).await;
199 }
200}
201
202pub fn into_pairs(started: Vec<StartedServer>) -> Vec<(String, Arc<dyn McpServer>)> {
205 started.into_iter().map(|s| (s.name, s.server)).collect()
206}
207
208fn spawn_stderr_tee(
209 name: String,
210 stderr: tokio::process::ChildStderr,
211 log_path: PathBuf,
212) -> tokio::task::JoinHandle<()> {
213 tokio::spawn(async move {
214 if let Some(parent) = log_path.parent() {
215 if let Err(e) = tokio::fs::create_dir_all(parent).await {
216 tracing::warn!(target: "mcp", server = %name, path = %parent.display(), "create log dir failed: {e}");
217 }
218 }
219 let mut file = match tokio::fs::OpenOptions::new()
220 .create(true)
221 .append(true)
222 .open(&log_path)
223 .await
224 {
225 Ok(file) => Some(file),
226 Err(e) => {
227 tracing::warn!(target: "mcp", server = %name, path = %log_path.display(), "open stderr log failed: {e}");
228 None
229 }
230 };
231 let mut lines = BufReader::new(stderr).lines();
232 loop {
233 match lines.next_line().await {
234 Ok(Some(line)) => {
235 tracing::warn!(target: "mcp", server = %name, "{line}");
236 if let Some(file) = file.as_mut() {
237 let _ = file.write_all(line.as_bytes()).await;
238 let _ = file.write_all(b"\n").await;
239 }
240 }
241 Ok(None) => break,
242 Err(e) => {
243 tracing::warn!(target: "mcp", server = %name, "read stderr failed: {e}");
244 break;
245 }
246 }
247 }
248 })
249}
250
251async fn write_lifecycle_log(log_dir: &Path, name: &str, message: &str) {
252 let path = log_dir.join(format!("{}.stderr.log", sanitize_log_name(name)));
253 if let Some(parent) = path.parent() {
254 if let Err(e) = tokio::fs::create_dir_all(parent).await {
255 tracing::warn!(target: "mcp", server = %name, path = %parent.display(), "create log dir failed: {e}");
256 return;
257 }
258 }
259 match tokio::fs::OpenOptions::new()
260 .create(true)
261 .append(true)
262 .open(&path)
263 .await
264 {
265 Ok(mut file) => {
266 let _ = file.write_all(message.as_bytes()).await;
267 }
268 Err(e) => {
269 tracing::warn!(target: "mcp", server = %name, path = %path.display(), "write lifecycle log failed: {e}");
270 }
271 }
272}
273
274fn sanitize_log_name(name: &str) -> String {
275 name.chars()
276 .map(|c| {
277 if c.is_ascii_alphanumeric() || matches!(c, '-' | '_' | '.') {
278 c
279 } else {
280 '_'
281 }
282 })
283 .collect()
284}
285
286fn peer_from_session(
287 session: &Option<impl Deref<Target = Peer<RoleClient>>>,
288) -> Result<Peer<RoleClient>> {
289 session
290 .as_ref()
291 .map(|s| Deref::deref(s).clone())
292 .ok_or_else(|| AgentError::Mcp("not connected".into()))
293}
294
295fn rmcp_tool_to_tool_def(tool: rmcp::model::Tool) -> ToolDef {
296 ToolDef {
297 name: tool.name.into_owned(),
298 description: tool.description.map(|d| d.into_owned()).unwrap_or_default(),
299 input_schema: serde_json::Value::Object(tool.input_schema.as_ref().clone()),
300 }
301}
302
303fn build_call_params(name: &str, args: serde_json::Value) -> CallToolRequestParams {
304 let arguments = match args {
305 serde_json::Value::Object(map) => Some(map),
306 _ => None,
307 };
308 let mut params = CallToolRequestParams::new(name.to_string());
309 params.arguments = arguments;
310 params
311}
312
313fn call_tool_result_to_string(result: &CallToolResult) -> String {
314 result
315 .content
316 .iter()
317 .filter_map(|c| match &c.raw {
318 RawContent::Text(t) => Some(t.text.clone()),
319 _ => None,
320 })
321 .collect::<Vec<_>>()
322 .join("\n")
323}
324
325fn map_service_error(e: rmcp::service::ServiceError) -> AgentError {
326 AgentError::Mcp(e.to_string())
327}
328
329fn map_init_error(e: rmcp::service::ClientInitializeError) -> AgentError {
330 AgentError::Mcp(e.to_string())
331}