Skip to main content

hematite/agent/
mcp_manager.rs

1use crate::agent::mcp::*;
2use crate::agent::types::McpRuntimeState;
3use crate::tools::file_ops::hematite_dir;
4use anyhow::{anyhow, Result};
5use serde::{Deserialize, Serialize};
6use serde_json::Value as JsonValue;
7use std::collections::HashMap;
8use std::path::Path;
9
10#[derive(Debug, Serialize, Deserialize, Default)]
11pub struct McpConfig {
12    #[serde(rename = "mcpServers")]
13    pub servers: HashMap<String, McpServerConfig>,
14}
15
16#[derive(Debug, Serialize, Deserialize, Clone)]
17pub struct McpServerConfig {
18    pub command: String,
19    pub args: Option<Vec<String>>,
20    pub env: Option<HashMap<String, String>>,
21}
22
23pub struct McpManager {
24    pub connections: HashMap<String, McpProcess>,
25    pub tool_map: HashMap<String, String>, // qualified_name -> server_name
26    pub discovered_tools: Vec<McpTool>,
27    pub active_config_signature: Option<String>,
28    pub configured_servers: usize,
29    pub startup_errors: Vec<String>,
30    pub discovery_errors: Vec<String>,
31    pub next_id: u64,
32}
33
34#[derive(Debug, Clone, PartialEq, Eq)]
35pub struct McpRuntimeReport {
36    pub state: McpRuntimeState,
37    pub configured_servers: usize,
38    pub connected_servers: usize,
39    pub active_tools: usize,
40    pub error_count: usize,
41    pub summary: String,
42}
43
44impl McpManager {
45    pub fn new() -> Self {
46        Self {
47            connections: HashMap::new(),
48            tool_map: HashMap::new(),
49            discovered_tools: Vec::new(),
50            active_config_signature: None,
51            configured_servers: 0,
52            startup_errors: Vec::new(),
53            discovery_errors: Vec::new(),
54            next_id: 1,
55        }
56    }
57
58    pub async fn initialize_all(&mut self) -> Result<()> {
59        let config = self.load_mcp_config();
60        self.configured_servers = config.servers.len();
61        let signature = self.config_signature(&config);
62        let all_connected = self.connections.len() == config.servers.len();
63        if self.active_config_signature.as_deref() == Some(signature.as_str())
64            && (all_connected || config.servers.is_empty())
65        {
66            return Ok(());
67        }
68
69        self.shutdown_all().await;
70        self.tool_map.clear();
71        self.discovered_tools.clear();
72        self.startup_errors.clear();
73        self.discovery_errors.clear();
74        self.active_config_signature = Some(signature);
75
76        for (name, cfg) in config.servers {
77            let args = cfg.args.clone().unwrap_or_default();
78            let env = cfg.env.clone().unwrap_or_default();
79
80            match self
81                .spawn_and_initialize_server(&cfg.command, &args, &env)
82                .await
83            {
84                Ok(proc) => {
85                    self.connections.insert(name.clone(), proc);
86                }
87                Err(e) => {
88                    self.startup_errors.push(format!("{}: {}", name, e));
89                }
90            }
91        }
92        Ok(())
93    }
94
95    pub fn load_mcp_config(&self) -> McpConfig {
96        let mut config = McpConfig::default();
97
98        // 1. Load GLOBAL config (~/.hematite/mcp_servers.json)
99        if let Some(mut global_path) = home::home_dir() {
100            global_path.push(".hematite");
101            global_path.push("mcp_servers.json");
102            if let Ok(global_cfg) = self.read_mcp_file(&global_path) {
103                self.merge_configs(&mut config, global_cfg);
104            }
105        }
106
107        // 2. Load LOCAL config (.hematite/mcp_servers.json in workspace)
108        let local_path = hematite_dir().join("mcp_servers.json");
109        if let Ok(local_cfg) = self.read_mcp_file(&local_path) {
110            self.merge_configs(&mut config, local_cfg);
111        }
112
113        config
114    }
115
116    fn read_mcp_file(&self, path: &Path) -> Result<McpConfig> {
117        let data = std::fs::read_to_string(path)?;
118        let config: McpConfig = serde_json::from_str(&data)?;
119        Ok(config)
120    }
121
122    fn merge_configs(&self, base: &mut McpConfig, new: McpConfig) {
123        for (name, server) in new.servers {
124            base.servers.insert(name, server);
125        }
126    }
127
128    fn config_signature(&self, config: &McpConfig) -> String {
129        let mut servers: Vec<_> = config.servers.iter().collect();
130        servers.sort_by(|a, b| a.0.cmp(b.0));
131
132        let mut signature = String::new();
133        for (name, server) in servers {
134            signature.push_str(name);
135            signature.push('|');
136            signature.push_str(&server.command);
137            signature.push('|');
138
139            if let Some(args) = &server.args {
140                for arg in args {
141                    signature.push_str(arg);
142                    signature.push('\u{1f}');
143                }
144            }
145            signature.push('|');
146
147            let mut env_pairs = server
148                .env
149                .as_ref()
150                .map(|env| env.iter().collect::<Vec<_>>())
151                .unwrap_or_default();
152            env_pairs.sort_by(|a, b| a.0.cmp(b.0));
153            for (key, value) in env_pairs {
154                signature.push_str(key);
155                signature.push('=');
156                signature.push_str(value);
157                signature.push(';');
158            }
159            signature.push('\n');
160        }
161
162        signature
163    }
164
165    async fn shutdown_all(&mut self) {
166        let connections = std::mem::take(&mut self.connections);
167        for (_, proc) in connections {
168            proc.shutdown().await;
169        }
170    }
171
172    async fn spawn_and_initialize_server(
173        &mut self,
174        command: &str,
175        args: &[String],
176        env: &HashMap<String, String>,
177    ) -> Result<McpProcess> {
178        let mut last_error = None;
179
180        for framing in [McpFraming::NewlineDelimited, McpFraming::ContentLength] {
181            let mut proc = McpProcess::spawn_with_framing(command, args, env, framing)?;
182            let init_result = tokio::time::timeout(
183                std::time::Duration::from_secs(5),
184                proc.initialize(self.next_id),
185            )
186            .await;
187
188            match init_result {
189                Ok(Ok(())) => {
190                    self.next_id += 1;
191                    return Ok(proc);
192                }
193                Ok(Err(err)) => {
194                    last_error = Some(Self::format_mcp_init_error(&proc, err.to_string()));
195                    proc.shutdown().await;
196                }
197                Err(_) => {
198                    last_error = Some(Self::format_mcp_init_error(
199                        &proc,
200                        "initialize timed out after 5s".to_string(),
201                    ));
202                    proc.shutdown().await;
203                }
204            }
205        }
206
207        Err(anyhow!(last_error.unwrap_or_else(|| {
208            "server did not complete initialize using newline or content-length framing".to_string()
209        })))
210    }
211
212    fn format_mcp_init_error(proc: &McpProcess, base_error: String) -> String {
213        match proc.stderr_summary() {
214            Some(stderr) => format!("{base_error}; stderr: {stderr}"),
215            None => base_error,
216        }
217    }
218
219    pub async fn discover_tools(&mut self) -> Vec<McpTool> {
220        if !self.discovered_tools.is_empty() {
221            return self.discovered_tools.clone();
222        }
223
224        let mut all_tools = Vec::new();
225        self.tool_map.clear();
226        self.discovery_errors.clear();
227        let server_names: Vec<String> = self.connections.keys().cloned().collect();
228
229        for name in server_names {
230            if let Some(proc) = self.connections.get_mut(&name) {
231                match proc.list_tools(self.next_id).await {
232                    Ok(tools) => {
233                        self.next_id += 1;
234                        for mut tool in tools {
235                            let original_name = tool.name.clone();
236                            // Prefix to avoid collisions
237                            tool.name = format!("mcp__{}__{}", name, original_name);
238                            self.tool_map.insert(tool.name.clone(), name.clone());
239                            all_tools.push(tool);
240                        }
241                    }
242                    Err(e) => {
243                        self.discovery_errors.push(format!("{}: {}", name, e));
244                    }
245                }
246            }
247        }
248        self.discovered_tools = all_tools.clone();
249        all_tools
250    }
251
252    pub async fn call_tool(&mut self, full_name: &str, args: &JsonValue) -> Result<String> {
253        let server_name = self
254            .tool_map
255            .get(full_name)
256            .ok_or_else(|| anyhow!("Unknown MCP tool: {}", full_name))?;
257        let proc = self
258            .connections
259            .get_mut(server_name)
260            .ok_or_else(|| anyhow!("Server not connected: {}", server_name))?;
261
262        // Strip prefix to get original name
263        let prefix = format!("mcp__{}__", server_name);
264        let original_name = full_name.strip_prefix(&prefix).unwrap_or(full_name);
265
266        let result = proc
267            .call_tool(self.next_id, original_name, args.clone())
268            .await?;
269        self.next_id += 1;
270
271        let mut output = String::new();
272        for content in result.content {
273            match content {
274                McpContent::Text { text } => output.push_str(&text),
275                McpContent::Image { .. } => {
276                    output.push_str("\n[Image Data Not Supported in TUI]\n")
277                }
278            }
279        }
280
281        if result.is_error.unwrap_or(false) {
282            Err(anyhow!(output))
283        } else {
284            // VRAM Guard: Truncate massive outputs to protect the local context window.
285            if output.len() > 2500 {
286                output.truncate(2500);
287                output.push_str("\n\n[Output Truncated by Hematite for VRAM Safety]");
288            }
289            Ok(output)
290        }
291    }
292
293    pub fn runtime_report(&self) -> McpRuntimeReport {
294        let first_error = self
295            .startup_errors
296            .first()
297            .or_else(|| self.discovery_errors.first())
298            .map(String::as_str);
299        runtime_report_from_snapshot(
300            self.configured_servers,
301            self.connections.len(),
302            self.discovered_tools.len(),
303            self.startup_errors.len() + self.discovery_errors.len(),
304            first_error,
305        )
306    }
307}
308
309fn runtime_report_from_snapshot(
310    configured_servers: usize,
311    connected_servers: usize,
312    active_tools: usize,
313    error_count: usize,
314    first_error: Option<&str>,
315) -> McpRuntimeReport {
316    let state = if configured_servers == 0 {
317        McpRuntimeState::Unconfigured
318    } else if connected_servers == 0 {
319        McpRuntimeState::Failed
320    } else if error_count > 0 {
321        McpRuntimeState::Degraded
322    } else {
323        McpRuntimeState::Healthy
324    };
325
326    let detail = summarize_runtime_error(first_error);
327
328    let summary = match state {
329        McpRuntimeState::Unconfigured => "No MCP servers configured.".to_string(),
330        McpRuntimeState::Healthy => format!(
331            "MCP healthy: {}/{} servers connected; {} tools active.",
332            connected_servers, configured_servers, active_tools
333        ),
334        McpRuntimeState::Degraded => format!(
335            "MCP degraded: {}/{} servers connected; {} tools active; {} startup/discovery issue(s){}",
336            connected_servers, configured_servers, active_tools, error_count, detail
337        ),
338        McpRuntimeState::Failed => format!(
339            "MCP failed: 0/{} servers connected; {} startup/discovery issue(s){}",
340            configured_servers, error_count, detail
341        ),
342    };
343
344    McpRuntimeReport {
345        state,
346        configured_servers,
347        connected_servers,
348        active_tools,
349        error_count,
350        summary,
351    }
352}
353
354fn summarize_runtime_error(first_error: Option<&str>) -> String {
355    let Some(error) = first_error.map(str::trim).filter(|value| !value.is_empty()) else {
356        return ".".to_string();
357    };
358
359    const MAX_CHARS: usize = 160;
360    let mut truncated = error.chars().take(MAX_CHARS).collect::<String>();
361    if error.chars().count() > MAX_CHARS {
362        truncated.push_str("...");
363    }
364    format!(" First issue: {truncated}")
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370
371    #[test]
372    fn runtime_report_marks_unconfigured_when_no_servers_exist() {
373        let report = runtime_report_from_snapshot(0, 0, 0, 0, None);
374        assert_eq!(report.state, McpRuntimeState::Unconfigured);
375        assert!(report.summary.contains("No MCP servers configured"));
376    }
377
378    #[test]
379    fn runtime_report_marks_failed_when_servers_exist_but_none_connect() {
380        let report = runtime_report_from_snapshot(2, 0, 0, 2, Some("filesystem: spawn failed"));
381        assert_eq!(report.state, McpRuntimeState::Failed);
382        assert!(report.summary.contains("0/2"));
383        assert!(report.summary.contains("filesystem: spawn failed"));
384    }
385
386    #[test]
387    fn runtime_report_marks_degraded_when_some_servers_or_discovery_steps_fail() {
388        let report =
389            runtime_report_from_snapshot(2, 1, 3, 1, Some("filesystem: tools/list failed"));
390        assert_eq!(report.state, McpRuntimeState::Degraded);
391        assert!(report.summary.contains("1/2"));
392        assert!(report.summary.contains("tools/list failed"));
393    }
394
395    #[test]
396    fn runtime_report_marks_healthy_when_all_servers_connect_without_errors() {
397        let report = runtime_report_from_snapshot(2, 2, 5, 0, None);
398        assert_eq!(report.state, McpRuntimeState::Healthy);
399        assert!(report.summary.contains("5 tools active"));
400    }
401}