Skip to main content

hematite/agent/
mcp_manager.rs

1use crate::agent::mcp::*;
2use crate::agent::truncation::safe_head;
3use crate::agent::types::McpRuntimeState;
4use crate::tools::file_ops::hematite_dir;
5use anyhow::{anyhow, Result};
6use serde::{Deserialize, Serialize};
7use serde_json::Value as JsonValue;
8use std::collections::HashMap;
9use std::path::Path;
10
11#[derive(Debug, Serialize, Deserialize, Default)]
12pub struct McpConfig {
13    #[serde(rename = "mcpServers")]
14    pub servers: HashMap<String, McpServerConfig>,
15}
16
17#[derive(Debug, Serialize, Deserialize, Clone)]
18pub struct McpServerConfig {
19    pub command: String,
20    pub args: Option<Vec<String>>,
21    pub env: Option<HashMap<String, String>>,
22}
23
24pub struct McpManager {
25    pub connections: HashMap<String, McpProcess>,
26    pub tool_map: HashMap<String, String>, // qualified_name -> server_name
27    pub discovered_tools: Vec<McpTool>,
28    pub active_config_signature: Option<String>,
29    pub configured_servers: usize,
30    pub startup_errors: Vec<String>,
31    pub discovery_errors: Vec<String>,
32    pub next_id: u64,
33}
34
35#[derive(Debug, Clone, PartialEq, Eq)]
36pub struct McpRuntimeReport {
37    pub state: McpRuntimeState,
38    pub configured_servers: usize,
39    pub connected_servers: usize,
40    pub active_tools: usize,
41    pub error_count: usize,
42    pub summary: String,
43}
44
45impl Default for McpManager {
46    fn default() -> Self {
47        Self::new()
48    }
49}
50
51impl McpManager {
52    pub fn new() -> Self {
53        Self {
54            connections: HashMap::new(),
55            tool_map: HashMap::new(),
56            discovered_tools: Vec::new(),
57            active_config_signature: None,
58            configured_servers: 0,
59            startup_errors: Vec::new(),
60            discovery_errors: Vec::new(),
61            next_id: 1,
62        }
63    }
64
65    pub async fn initialize_all(&mut self) -> Result<()> {
66        let config = self.load_mcp_config();
67        self.configured_servers = config.servers.len();
68        let signature = self.config_signature(&config);
69        let all_connected = self.connections.len() == config.servers.len();
70        if self.active_config_signature.as_deref() == Some(signature.as_str())
71            && (all_connected || config.servers.is_empty())
72        {
73            return Ok(());
74        }
75
76        self.shutdown_all().await;
77        self.tool_map.clear();
78        self.discovered_tools.clear();
79        self.startup_errors.clear();
80        self.discovery_errors.clear();
81        self.active_config_signature = Some(signature);
82
83        for (name, cfg) in config.servers {
84            let args = cfg.args.clone().unwrap_or_default();
85            let env = cfg.env.clone().unwrap_or_default();
86
87            match self
88                .spawn_and_initialize_server(&cfg.command, &args, &env)
89                .await
90            {
91                Ok(proc) => {
92                    self.connections.insert(name.clone(), proc);
93                }
94                Err(e) => {
95                    self.startup_errors.push(format!("{}: {}", name, e));
96                }
97            }
98        }
99        Ok(())
100    }
101
102    pub fn load_mcp_config(&self) -> McpConfig {
103        let mut config = McpConfig::default();
104
105        // 1. Load GLOBAL config (~/.hematite/mcp_servers.json)
106        if let Some(mut global_path) = home::home_dir() {
107            global_path.push(".hematite");
108            global_path.push("mcp_servers.json");
109            if let Ok(global_cfg) = self.read_mcp_file(&global_path) {
110                self.merge_configs(&mut config, global_cfg);
111            }
112        }
113
114        // 2. Load LOCAL config (.hematite/mcp_servers.json in workspace)
115        let local_path = hematite_dir().join("mcp_servers.json");
116        if let Ok(local_cfg) = self.read_mcp_file(&local_path) {
117            self.merge_configs(&mut config, local_cfg);
118        }
119
120        config
121    }
122
123    fn read_mcp_file(&self, path: &Path) -> Result<McpConfig> {
124        let data = std::fs::read_to_string(path)?;
125        let config: McpConfig = serde_json::from_str(&data)?;
126        Ok(config)
127    }
128
129    fn merge_configs(&self, base: &mut McpConfig, new: McpConfig) {
130        for (name, server) in new.servers {
131            base.servers.insert(name, server);
132        }
133    }
134
135    fn config_signature(&self, config: &McpConfig) -> String {
136        let mut servers: Vec<_> = config.servers.iter().collect();
137        servers.sort_by(|a, b| a.0.cmp(b.0));
138
139        let mut signature = String::with_capacity(config.servers.len() * 80);
140        for (name, server) in servers {
141            signature.push_str(name);
142            signature.push('|');
143            signature.push_str(&server.command);
144            signature.push('|');
145
146            if let Some(args) = &server.args {
147                for arg in args {
148                    signature.push_str(arg);
149                    signature.push('\u{1f}');
150                }
151            }
152            signature.push('|');
153
154            let mut env_pairs = server
155                .env
156                .as_ref()
157                .map(|env| env.iter().collect::<Vec<_>>())
158                .unwrap_or_default();
159            env_pairs.sort_by(|a, b| a.0.cmp(b.0));
160            for (key, value) in env_pairs {
161                signature.push_str(key);
162                signature.push('=');
163                signature.push_str(value);
164                signature.push(';');
165            }
166            signature.push('\n');
167        }
168
169        signature
170    }
171
172    async fn shutdown_all(&mut self) {
173        let connections = std::mem::take(&mut self.connections);
174        for (_, proc) in connections {
175            proc.shutdown().await;
176        }
177    }
178
179    async fn spawn_and_initialize_server(
180        &mut self,
181        command: &str,
182        args: &[String],
183        env: &HashMap<String, String>,
184    ) -> Result<McpProcess> {
185        let mut last_error = None;
186
187        for framing in [McpFraming::NewlineDelimited, McpFraming::ContentLength] {
188            let mut proc = McpProcess::spawn_with_framing(command, args, env, framing)?;
189            let init_result = tokio::time::timeout(
190                std::time::Duration::from_secs(5),
191                proc.initialize(self.next_id),
192            )
193            .await;
194
195            match init_result {
196                Ok(Ok(())) => {
197                    self.next_id += 1;
198                    return Ok(proc);
199                }
200                Ok(Err(err)) => {
201                    last_error = Some(Self::format_mcp_init_error(&proc, err.to_string()));
202                    proc.shutdown().await;
203                }
204                Err(_) => {
205                    last_error = Some(Self::format_mcp_init_error(
206                        &proc,
207                        "initialize timed out after 5s".to_string(),
208                    ));
209                    proc.shutdown().await;
210                }
211            }
212        }
213
214        Err(anyhow!(last_error.unwrap_or_else(|| {
215            "server did not complete initialize using newline or content-length framing".to_string()
216        })))
217    }
218
219    fn format_mcp_init_error(proc: &McpProcess, base_error: String) -> String {
220        match proc.stderr_summary() {
221            Some(stderr) => format!("{base_error}; stderr: {stderr}"),
222            None => base_error,
223        }
224    }
225
226    pub async fn discover_tools(&mut self) -> Vec<McpTool> {
227        if !self.discovered_tools.is_empty() {
228            return self.discovered_tools.clone();
229        }
230
231        let mut all_tools = Vec::new();
232        self.tool_map.clear();
233        self.discovery_errors.clear();
234        let server_names: Vec<String> = self.connections.keys().cloned().collect();
235
236        for name in server_names {
237            if let Some(proc) = self.connections.get_mut(&name) {
238                match proc.list_tools(self.next_id).await {
239                    Ok(tools) => {
240                        self.next_id += 1;
241                        for mut tool in tools {
242                            let original_name = tool.name.clone();
243                            // Prefix to avoid collisions
244                            tool.name = format!("mcp__{}__{}", name, original_name);
245                            self.tool_map.insert(tool.name.clone(), name.clone());
246                            all_tools.push(tool);
247                        }
248                    }
249                    Err(e) => {
250                        self.discovery_errors.push(format!("{}: {}", name, e));
251                    }
252                }
253            }
254        }
255        self.discovered_tools = all_tools.clone();
256        all_tools
257    }
258
259    pub async fn call_tool(&mut self, full_name: &str, args: &JsonValue) -> Result<String> {
260        let server_name = self
261            .tool_map
262            .get(full_name)
263            .ok_or_else(|| anyhow!("Unknown MCP tool: {}", full_name))?;
264        let proc = self
265            .connections
266            .get_mut(server_name)
267            .ok_or_else(|| anyhow!("Server not connected: {}", server_name))?;
268
269        // Strip prefix to get original name
270        let prefix = format!("mcp__{}__", server_name);
271        let original_name = full_name.strip_prefix(&prefix).unwrap_or(full_name);
272
273        let result = proc
274            .call_tool(self.next_id, original_name, args.clone())
275            .await?;
276        self.next_id += 1;
277
278        let mut output = String::with_capacity(result.content.len() * 256);
279        for content in result.content {
280            match content {
281                McpContent::Text { text } => output.push_str(&text),
282                McpContent::Image { .. } => {
283                    output.push_str("\n[Image Data Not Supported in TUI]\n")
284                }
285            }
286        }
287
288        if result.is_error.unwrap_or(false) {
289            Err(anyhow!(output))
290        } else {
291            // VRAM Guard: Truncate massive outputs to protect the local context window.
292            if output.len() > 2500 {
293                let safe_end = safe_head(&output, 2500).len();
294                output.truncate(safe_end);
295                output.push_str("\n\n[Output Truncated by Hematite for VRAM Safety]");
296            }
297            Ok(output)
298        }
299    }
300
301    pub fn runtime_report(&self) -> McpRuntimeReport {
302        let first_error = self
303            .startup_errors
304            .first()
305            .or_else(|| self.discovery_errors.first())
306            .map(String::as_str);
307        runtime_report_from_snapshot(
308            self.configured_servers,
309            self.connections.len(),
310            self.discovered_tools.len(),
311            self.startup_errors.len() + self.discovery_errors.len(),
312            first_error,
313        )
314    }
315}
316
317fn runtime_report_from_snapshot(
318    configured_servers: usize,
319    connected_servers: usize,
320    active_tools: usize,
321    error_count: usize,
322    first_error: Option<&str>,
323) -> McpRuntimeReport {
324    let state = if configured_servers == 0 {
325        McpRuntimeState::Unconfigured
326    } else if connected_servers == 0 {
327        McpRuntimeState::Failed
328    } else if error_count > 0 {
329        McpRuntimeState::Degraded
330    } else {
331        McpRuntimeState::Healthy
332    };
333
334    let detail = summarize_runtime_error(first_error);
335
336    let summary = match state {
337        McpRuntimeState::Unconfigured => "No MCP servers configured.".to_string(),
338        McpRuntimeState::Healthy => format!(
339            "MCP healthy: {}/{} servers connected; {} tools active.",
340            connected_servers, configured_servers, active_tools
341        ),
342        McpRuntimeState::Degraded => format!(
343            "MCP degraded: {}/{} servers connected; {} tools active; {} startup/discovery issue(s){}",
344            connected_servers, configured_servers, active_tools, error_count, detail
345        ),
346        McpRuntimeState::Failed => format!(
347            "MCP failed: 0/{} servers connected; {} startup/discovery issue(s){}",
348            configured_servers, error_count, detail
349        ),
350    };
351
352    McpRuntimeReport {
353        state,
354        configured_servers,
355        connected_servers,
356        active_tools,
357        error_count,
358        summary,
359    }
360}
361
362fn summarize_runtime_error(first_error: Option<&str>) -> String {
363    let Some(error) = first_error.map(str::trim).filter(|value| !value.is_empty()) else {
364        return ".".to_string();
365    };
366
367    const MAX_CHARS: usize = 160;
368    let mut chars = error.chars();
369    let mut truncated: String = chars.by_ref().take(MAX_CHARS).collect();
370    if chars.next().is_some() {
371        truncated.push_str("...");
372    }
373    format!(" First issue: {truncated}")
374}
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379
380    #[test]
381    fn runtime_report_marks_unconfigured_when_no_servers_exist() {
382        let report = runtime_report_from_snapshot(0, 0, 0, 0, None);
383        assert_eq!(report.state, McpRuntimeState::Unconfigured);
384        assert!(report.summary.contains("No MCP servers configured"));
385    }
386
387    #[test]
388    fn runtime_report_marks_failed_when_servers_exist_but_none_connect() {
389        let report = runtime_report_from_snapshot(2, 0, 0, 2, Some("filesystem: spawn failed"));
390        assert_eq!(report.state, McpRuntimeState::Failed);
391        assert!(report.summary.contains("0/2"));
392        assert!(report.summary.contains("filesystem: spawn failed"));
393    }
394
395    #[test]
396    fn runtime_report_marks_degraded_when_some_servers_or_discovery_steps_fail() {
397        let report =
398            runtime_report_from_snapshot(2, 1, 3, 1, Some("filesystem: tools/list failed"));
399        assert_eq!(report.state, McpRuntimeState::Degraded);
400        assert!(report.summary.contains("1/2"));
401        assert!(report.summary.contains("tools/list failed"));
402    }
403
404    #[test]
405    fn runtime_report_marks_healthy_when_all_servers_connect_without_errors() {
406        let report = runtime_report_from_snapshot(2, 2, 5, 0, None);
407        assert_eq!(report.state, McpRuntimeState::Healthy);
408        assert!(report.summary.contains("5 tools active"));
409    }
410}