1use std::collections::HashMap;
11use std::sync::Arc;
12
13use async_trait::async_trait;
14
15use super::ToolResult;
16use crate::agent::memory::MemorySubstrate;
17
18#[async_trait]
22pub trait McpHandler: Send + Sync {
23 fn name(&self) -> &'static str;
25
26 fn description(&self) -> &'static str;
28
29 fn input_schema(&self) -> serde_json::Value;
31
32 async fn handle(&self, params: serde_json::Value) -> ToolResult;
34}
35
36pub struct HandlerRegistry {
38 handlers: HashMap<String, Box<dyn McpHandler>>,
39}
40
41impl HandlerRegistry {
42 pub fn new() -> Self {
44 Self { handlers: HashMap::new() }
45 }
46
47 pub fn register(&mut self, handler: Box<dyn McpHandler>) {
49 let name = handler.name().to_string();
50 self.handlers.insert(name, handler);
51 }
52
53 pub async fn dispatch(&self, method: &str, params: serde_json::Value) -> ToolResult {
55 match self.handlers.get(method) {
56 Some(handler) => handler.handle(params).await,
57 None => ToolResult::error(format!("unknown method: {method}")),
58 }
59 }
60
61 pub fn list_tools(&self) -> Vec<McpToolInfo> {
63 self.handlers
64 .values()
65 .map(|h| McpToolInfo {
66 name: h.name().to_string(),
67 description: h.description().to_string(),
68 input_schema: h.input_schema(),
69 })
70 .collect()
71 }
72
73 pub fn len(&self) -> usize {
75 self.handlers.len()
76 }
77
78 pub fn is_empty(&self) -> bool {
80 self.handlers.is_empty()
81 }
82}
83
84impl Default for HandlerRegistry {
85 fn default() -> Self {
86 Self::new()
87 }
88}
89
90#[derive(Debug, Clone, serde::Serialize)]
92pub struct McpToolInfo {
93 pub name: String,
95 pub description: String,
97 pub input_schema: serde_json::Value,
99}
100
101pub struct MemoryHandler {
105 memory: Arc<dyn MemorySubstrate>,
106 agent_id: String,
107}
108
109impl MemoryHandler {
110 pub fn new(memory: Arc<dyn MemorySubstrate>, agent_id: impl Into<String>) -> Self {
112 Self { memory, agent_id: agent_id.into() }
113 }
114}
115
116#[async_trait]
117impl McpHandler for MemoryHandler {
118 fn name(&self) -> &'static str {
119 "memory"
120 }
121
122 fn description(&self) -> &'static str {
123 "Store and recall agent memory fragments"
124 }
125
126 fn input_schema(&self) -> serde_json::Value {
127 serde_json::json!({
128 "type": "object",
129 "properties": {
130 "action": {
131 "type": "string",
132 "enum": ["store", "recall"]
133 },
134 "content": { "type": "string" },
135 "query": { "type": "string" },
136 "limit": { "type": "integer" }
137 },
138 "required": ["action"]
139 })
140 }
141
142 async fn handle(&self, params: serde_json::Value) -> ToolResult {
143 let action = params.get("action").and_then(|v| v.as_str()).unwrap_or("");
144
145 match action {
146 "store" => {
147 let content = params.get("content").and_then(|v| v.as_str()).unwrap_or("");
148 if content.is_empty() {
149 return ToolResult::error("content is required for store");
150 }
151 match self
152 .memory
153 .remember(
154 &self.agent_id,
155 content,
156 crate::agent::memory::MemorySource::User,
157 None,
158 )
159 .await
160 {
161 Ok(id) => ToolResult::success(format!("Stored with id: {id}")),
162 Err(e) => ToolResult::error(format!("store failed: {e}")),
163 }
164 }
165 "recall" => {
166 let query = params.get("query").and_then(|v| v.as_str()).unwrap_or("");
167 let limit = params
168 .get("limit")
169 .and_then(serde_json::Value::as_u64)
170 .map_or(5, |v| usize::try_from(v).unwrap_or(5));
171 match self.memory.recall(query, limit, None, None).await {
172 Ok(fragments) => {
173 if fragments.is_empty() {
174 return ToolResult::success("No matching memories found.");
175 }
176 let mut out = String::new();
177 for f in &fragments {
178 use std::fmt::Write;
179 let _ =
180 writeln!(out, "- {} (score: {:.2})", f.content, f.relevance_score,);
181 }
182 ToolResult::success(out)
183 }
184 Err(e) => ToolResult::error(format!("recall failed: {e}")),
185 }
186 }
187 _ => ToolResult::error(format!("unknown action: {action} (expected: store, recall)")),
188 }
189 }
190}
191
192#[cfg(feature = "rag")]
197pub struct RagHandler {
198 oracle: Arc<crate::oracle::rag::RagOracle>,
199 max_results: usize,
200}
201
202#[cfg(feature = "rag")]
203impl RagHandler {
204 pub fn new(oracle: Arc<crate::oracle::rag::RagOracle>, max_results: usize) -> Self {
206 Self { oracle, max_results }
207 }
208}
209
210#[cfg(feature = "rag")]
211#[async_trait]
212impl McpHandler for RagHandler {
213 fn name(&self) -> &'static str {
214 "rag"
215 }
216
217 fn description(&self) -> &'static str {
218 "Search indexed Sovereign AI Stack documentation"
219 }
220
221 fn input_schema(&self) -> serde_json::Value {
222 serde_json::json!({
223 "type": "object",
224 "properties": {
225 "query": {
226 "type": "string",
227 "description": "Search query for documentation"
228 },
229 "limit": {
230 "type": "integer",
231 "description": "Maximum results (default: 5)"
232 }
233 },
234 "required": ["query"]
235 })
236 }
237
238 async fn handle(&self, params: serde_json::Value) -> ToolResult {
239 let query = params.get("query").and_then(|v| v.as_str()).unwrap_or("");
240 if query.is_empty() {
241 return ToolResult::error("query is required for search");
242 }
243
244 let limit = params
245 .get("limit")
246 .and_then(serde_json::Value::as_u64)
247 .map_or(self.max_results, |v| usize::try_from(v).unwrap_or(self.max_results));
248
249 let results = self.oracle.query(query);
250 let truncated: Vec<_> = results.into_iter().take(limit).collect();
251
252 if truncated.is_empty() {
253 return ToolResult::success("No results found.");
254 }
255
256 let mut out = String::new();
257 for (i, r) in truncated.iter().enumerate() {
258 use std::fmt::Write;
259 let _ =
260 writeln!(out, "{}. [{}] {} (score: {:.3})", i + 1, r.component, r.source, r.score,);
261 let _ = writeln!(out, " {}", r.content);
262 }
263 ToolResult::success(out)
264 }
265}
266
267pub struct ComputeHandler {
272 working_dir: String,
273 max_output_bytes: usize,
274}
275
276impl ComputeHandler {
277 pub fn new(working_dir: impl Into<String>) -> Self {
279 Self { working_dir: working_dir.into(), max_output_bytes: 8192 }
280 }
281}
282
283#[async_trait]
284impl McpHandler for ComputeHandler {
285 fn name(&self) -> &'static str {
286 "compute"
287 }
288
289 fn description(&self) -> &'static str {
290 "Execute shell commands with output capture"
291 }
292
293 fn input_schema(&self) -> serde_json::Value {
294 serde_json::json!({
295 "type": "object",
296 "properties": {
297 "action": {
298 "type": "string",
299 "enum": ["run", "parallel"]
300 },
301 "command": { "type": "string" },
302 "commands": {
303 "type": "array",
304 "items": { "type": "string" }
305 }
306 },
307 "required": ["action"]
308 })
309 }
310
311 async fn handle(&self, params: serde_json::Value) -> ToolResult {
312 let action = params.get("action").and_then(|v| v.as_str()).unwrap_or("");
313
314 match action {
315 "run" => {
316 let command = params.get("command").and_then(|v| v.as_str()).unwrap_or("");
317 if command.is_empty() {
318 return ToolResult::error("command is required for run");
319 }
320 execute_command(command, &self.working_dir, self.max_output_bytes).await
321 }
322 "parallel" => {
323 let commands: Vec<String> = params
324 .get("commands")
325 .and_then(|v| v.as_array())
326 .map(|arr| arr.iter().filter_map(|v| v.as_str().map(String::from)).collect())
327 .unwrap_or_default();
328 if commands.is_empty() {
329 return ToolResult::error("commands array is required for parallel");
330 }
331 let mut results = Vec::new();
332 for cmd in &commands {
333 let r = execute_command(cmd, &self.working_dir, self.max_output_bytes).await;
334 results.push(format!("$ {cmd}\n{}", r.content));
335 }
336 ToolResult::success(results.join("\n---\n"))
337 }
338 _ => ToolResult::error(format!("unknown action: {action} (expected: run, parallel)")),
339 }
340 }
341}
342
343async fn execute_command(command: &str, working_dir: &str, max_bytes: usize) -> ToolResult {
345 let output = tokio::process::Command::new("sh")
346 .arg("-c")
347 .arg(command)
348 .current_dir(working_dir)
349 .output()
350 .await;
351
352 match output {
353 Ok(out) => {
354 let stdout = String::from_utf8_lossy(&out.stdout);
355 let stderr = String::from_utf8_lossy(&out.stderr);
356 let mut text = stdout.to_string();
357 if !stderr.is_empty() {
358 text.push_str("\nstderr: ");
359 text.push_str(&stderr);
360 }
361 if text.len() > max_bytes {
362 text.truncate(max_bytes);
363 text.push_str("\n[truncated]");
364 }
365 if out.status.success() {
366 ToolResult::success(text)
367 } else {
368 ToolResult::error(format!("exit {}: {}", out.status.code().unwrap_or(-1), text,))
369 }
370 }
371 Err(e) => ToolResult::error(format!("exec failed: {e}")),
372 }
373}
374
375#[cfg(test)]
376#[path = "mcp_server_tests.rs"]
377mod tests;