Skip to main content

batuta/agent/tool/
mcp_server.rs

1//! MCP Server — expose agent tools to external MCP clients.
2//!
3//! Implements handler dispatch for agent tools (memory, rag, compute)
4//! so external LLM clients (Claude Code, other agents) can call
5//! the agent's tools over MCP protocol.
6//!
7//! Uses a trait-based handler abstraction aligned with pmcp (v2.3).
8//! Refs: arXiv:2505.02279, arXiv:2503.23278
9
10use std::collections::HashMap;
11use std::sync::Arc;
12
13use async_trait::async_trait;
14
15use super::ToolResult;
16use crate::agent::memory::MemorySubstrate;
17
18/// Handler for a single MCP tool endpoint.
19///
20/// Mirrors pmcp `Handler` trait pattern for forward compatibility.
21#[async_trait]
22pub trait McpHandler: Send + Sync {
23    /// Tool name as exposed via MCP (e.g., `memory_store`).
24    fn name(&self) -> &'static str;
25
26    /// Human-readable description for tool discovery.
27    fn description(&self) -> &'static str;
28
29    /// JSON Schema for the tool's input parameters.
30    fn input_schema(&self) -> serde_json::Value;
31
32    /// Execute the tool with the given parameters.
33    async fn handle(&self, params: serde_json::Value) -> ToolResult;
34}
35
36/// Registry of MCP handlers for dispatch.
37pub struct HandlerRegistry {
38    handlers: HashMap<String, Box<dyn McpHandler>>,
39}
40
41impl HandlerRegistry {
42    /// Create an empty registry.
43    pub fn new() -> Self {
44        Self { handlers: HashMap::new() }
45    }
46
47    /// Register a handler.
48    pub fn register(&mut self, handler: Box<dyn McpHandler>) {
49        let name = handler.name().to_string();
50        self.handlers.insert(name, handler);
51    }
52
53    /// Dispatch a tool call to the appropriate handler.
54    pub async fn dispatch(&self, method: &str, params: serde_json::Value) -> ToolResult {
55        match self.handlers.get(method) {
56            Some(handler) => handler.handle(params).await,
57            None => ToolResult::error(format!("unknown method: {method}")),
58        }
59    }
60
61    /// List available tools for MCP discovery.
62    pub fn list_tools(&self) -> Vec<McpToolInfo> {
63        self.handlers
64            .values()
65            .map(|h| McpToolInfo {
66                name: h.name().to_string(),
67                description: h.description().to_string(),
68                input_schema: h.input_schema(),
69            })
70            .collect()
71    }
72
73    /// Number of registered handlers.
74    pub fn len(&self) -> usize {
75        self.handlers.len()
76    }
77
78    /// Whether the registry is empty.
79    pub fn is_empty(&self) -> bool {
80        self.handlers.is_empty()
81    }
82}
83
84impl Default for HandlerRegistry {
85    fn default() -> Self {
86        Self::new()
87    }
88}
89
90/// Tool info returned by MCP tools/list.
91#[derive(Debug, Clone, serde::Serialize)]
92pub struct McpToolInfo {
93    /// Tool name.
94    pub name: String,
95    /// Tool description.
96    pub description: String,
97    /// JSON Schema for input.
98    pub input_schema: serde_json::Value,
99}
100
101/// Memory handler — exposes agent memory via MCP.
102///
103/// Supports `store` (remember) and `recall` (search) actions.
104pub struct MemoryHandler {
105    memory: Arc<dyn MemorySubstrate>,
106    agent_id: String,
107}
108
109impl MemoryHandler {
110    /// Create a new memory handler.
111    pub fn new(memory: Arc<dyn MemorySubstrate>, agent_id: impl Into<String>) -> Self {
112        Self { memory, agent_id: agent_id.into() }
113    }
114}
115
116#[async_trait]
117impl McpHandler for MemoryHandler {
118    fn name(&self) -> &'static str {
119        "memory"
120    }
121
122    fn description(&self) -> &'static str {
123        "Store and recall agent memory fragments"
124    }
125
126    fn input_schema(&self) -> serde_json::Value {
127        serde_json::json!({
128            "type": "object",
129            "properties": {
130                "action": {
131                    "type": "string",
132                    "enum": ["store", "recall"]
133                },
134                "content": { "type": "string" },
135                "query": { "type": "string" },
136                "limit": { "type": "integer" }
137            },
138            "required": ["action"]
139        })
140    }
141
142    async fn handle(&self, params: serde_json::Value) -> ToolResult {
143        let action = params.get("action").and_then(|v| v.as_str()).unwrap_or("");
144
145        match action {
146            "store" => {
147                let content = params.get("content").and_then(|v| v.as_str()).unwrap_or("");
148                if content.is_empty() {
149                    return ToolResult::error("content is required for store");
150                }
151                match self
152                    .memory
153                    .remember(
154                        &self.agent_id,
155                        content,
156                        crate::agent::memory::MemorySource::User,
157                        None,
158                    )
159                    .await
160                {
161                    Ok(id) => ToolResult::success(format!("Stored with id: {id}")),
162                    Err(e) => ToolResult::error(format!("store failed: {e}")),
163                }
164            }
165            "recall" => {
166                let query = params.get("query").and_then(|v| v.as_str()).unwrap_or("");
167                let limit = params
168                    .get("limit")
169                    .and_then(serde_json::Value::as_u64)
170                    .map_or(5, |v| usize::try_from(v).unwrap_or(5));
171                match self.memory.recall(query, limit, None, None).await {
172                    Ok(fragments) => {
173                        if fragments.is_empty() {
174                            return ToolResult::success("No matching memories found.");
175                        }
176                        let mut out = String::new();
177                        for f in &fragments {
178                            use std::fmt::Write;
179                            let _ =
180                                writeln!(out, "- {} (score: {:.2})", f.content, f.relevance_score,);
181                        }
182                        ToolResult::success(out)
183                    }
184                    Err(e) => ToolResult::error(format!("recall failed: {e}")),
185                }
186            }
187            _ => ToolResult::error(format!("unknown action: {action} (expected: store, recall)")),
188        }
189    }
190}
191
192/// RAG handler — exposes document search via MCP.
193///
194/// Wraps `RagOracle` to allow external clients to search
195/// indexed Sovereign AI Stack documentation.
196#[cfg(feature = "rag")]
197pub struct RagHandler {
198    oracle: Arc<crate::oracle::rag::RagOracle>,
199    max_results: usize,
200}
201
202#[cfg(feature = "rag")]
203impl RagHandler {
204    /// Create a new RAG handler.
205    pub fn new(oracle: Arc<crate::oracle::rag::RagOracle>, max_results: usize) -> Self {
206        Self { oracle, max_results }
207    }
208}
209
210#[cfg(feature = "rag")]
211#[async_trait]
212impl McpHandler for RagHandler {
213    fn name(&self) -> &'static str {
214        "rag"
215    }
216
217    fn description(&self) -> &'static str {
218        "Search indexed Sovereign AI Stack documentation"
219    }
220
221    fn input_schema(&self) -> serde_json::Value {
222        serde_json::json!({
223            "type": "object",
224            "properties": {
225                "query": {
226                    "type": "string",
227                    "description": "Search query for documentation"
228                },
229                "limit": {
230                    "type": "integer",
231                    "description": "Maximum results (default: 5)"
232                }
233            },
234            "required": ["query"]
235        })
236    }
237
238    async fn handle(&self, params: serde_json::Value) -> ToolResult {
239        let query = params.get("query").and_then(|v| v.as_str()).unwrap_or("");
240        if query.is_empty() {
241            return ToolResult::error("query is required for search");
242        }
243
244        let limit = params
245            .get("limit")
246            .and_then(serde_json::Value::as_u64)
247            .map_or(self.max_results, |v| usize::try_from(v).unwrap_or(self.max_results));
248
249        let results = self.oracle.query(query);
250        let truncated: Vec<_> = results.into_iter().take(limit).collect();
251
252        if truncated.is_empty() {
253            return ToolResult::success("No results found.");
254        }
255
256        let mut out = String::new();
257        for (i, r) in truncated.iter().enumerate() {
258            use std::fmt::Write;
259            let _ =
260                writeln!(out, "{}. [{}] {} (score: {:.3})", i + 1, r.component, r.source, r.score,);
261            let _ = writeln!(out, "   {}", r.content);
262        }
263        ToolResult::success(out)
264    }
265}
266
267/// Compute handler — exposes task execution via MCP.
268///
269/// Supports `run` (single command) and `parallel` (multiple commands)
270/// actions. Output is truncated to prevent context overflow.
271pub struct ComputeHandler {
272    working_dir: String,
273    max_output_bytes: usize,
274}
275
276impl ComputeHandler {
277    /// Create a new compute handler.
278    pub fn new(working_dir: impl Into<String>) -> Self {
279        Self { working_dir: working_dir.into(), max_output_bytes: 8192 }
280    }
281}
282
283#[async_trait]
284impl McpHandler for ComputeHandler {
285    fn name(&self) -> &'static str {
286        "compute"
287    }
288
289    fn description(&self) -> &'static str {
290        "Execute shell commands with output capture"
291    }
292
293    fn input_schema(&self) -> serde_json::Value {
294        serde_json::json!({
295            "type": "object",
296            "properties": {
297                "action": {
298                    "type": "string",
299                    "enum": ["run", "parallel"]
300                },
301                "command": { "type": "string" },
302                "commands": {
303                    "type": "array",
304                    "items": { "type": "string" }
305                }
306            },
307            "required": ["action"]
308        })
309    }
310
311    async fn handle(&self, params: serde_json::Value) -> ToolResult {
312        let action = params.get("action").and_then(|v| v.as_str()).unwrap_or("");
313
314        match action {
315            "run" => {
316                let command = params.get("command").and_then(|v| v.as_str()).unwrap_or("");
317                if command.is_empty() {
318                    return ToolResult::error("command is required for run");
319                }
320                execute_command(command, &self.working_dir, self.max_output_bytes).await
321            }
322            "parallel" => {
323                let commands: Vec<String> = params
324                    .get("commands")
325                    .and_then(|v| v.as_array())
326                    .map(|arr| arr.iter().filter_map(|v| v.as_str().map(String::from)).collect())
327                    .unwrap_or_default();
328                if commands.is_empty() {
329                    return ToolResult::error("commands array is required for parallel");
330                }
331                let mut results = Vec::new();
332                for cmd in &commands {
333                    let r = execute_command(cmd, &self.working_dir, self.max_output_bytes).await;
334                    results.push(format!("$ {cmd}\n{}", r.content));
335                }
336                ToolResult::success(results.join("\n---\n"))
337            }
338            _ => ToolResult::error(format!("unknown action: {action} (expected: run, parallel)")),
339        }
340    }
341}
342
343/// Execute a single shell command and capture output.
344async fn execute_command(command: &str, working_dir: &str, max_bytes: usize) -> ToolResult {
345    let output = tokio::process::Command::new("sh")
346        .arg("-c")
347        .arg(command)
348        .current_dir(working_dir)
349        .output()
350        .await;
351
352    match output {
353        Ok(out) => {
354            let stdout = String::from_utf8_lossy(&out.stdout);
355            let stderr = String::from_utf8_lossy(&out.stderr);
356            let mut text = stdout.to_string();
357            if !stderr.is_empty() {
358                text.push_str("\nstderr: ");
359                text.push_str(&stderr);
360            }
361            if text.len() > max_bytes {
362                text.truncate(max_bytes);
363                text.push_str("\n[truncated]");
364            }
365            if out.status.success() {
366                ToolResult::success(text)
367            } else {
368                ToolResult::error(format!("exit {}: {}", out.status.code().unwrap_or(-1), text,))
369            }
370        }
371        Err(e) => ToolResult::error(format!("exec failed: {e}")),
372    }
373}
374
375#[cfg(test)]
376#[path = "mcp_server_tests.rs"]
377mod tests;