Skip to main content

engram_server/
mcp.rs

1//! MCP stdio server — JSON-RPC 2.0 over stdin/stdout.
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use std::collections::HashMap;
6use std::future::Future;
7use std::pin::Pin;
8use std::sync::Arc;
9use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
10
11// ---------------------------------------------------------------------------
12// JSON-RPC types
13// ---------------------------------------------------------------------------
14
15#[derive(Debug, Deserialize)]
16pub struct JsonRpcRequest {
17    pub jsonrpc: String,
18    pub id: Option<Value>,
19    pub method: String,
20    pub params: Option<Value>,
21}
22
23#[derive(Debug, Serialize)]
24pub struct JsonRpcResponse {
25    pub jsonrpc: String,
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub id: Option<Value>,
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub result: Option<Value>,
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub error: Option<JsonRpcError>,
32}
33
34#[derive(Debug, Serialize)]
35pub struct JsonRpcError {
36    pub code: i64,
37    pub message: String,
38}
39
40// ---------------------------------------------------------------------------
41// MCP types
42// ---------------------------------------------------------------------------
43
44#[derive(Debug, Clone, Serialize)]
45pub struct McpToolDef {
46    pub name: String,
47    pub description: String,
48    #[serde(rename = "inputSchema")]
49    pub input_schema: Value,
50}
51
52// ---------------------------------------------------------------------------
53// Tool handler type
54// ---------------------------------------------------------------------------
55
56pub type ToolHandler = Box<
57    dyn Fn(Value) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send>> + Send + Sync,
58>;
59
60// ---------------------------------------------------------------------------
61// McpServer
62// ---------------------------------------------------------------------------
63
64pub struct McpServer {
65    tools: Vec<McpToolDef>,
66    handlers: HashMap<String, Arc<ToolHandler>>,
67}
68
69impl McpServer {
70    pub fn new() -> Self {
71        Self {
72            tools: Vec::new(),
73            handlers: HashMap::new(),
74        }
75    }
76
77    /// Register a tool with its definition and async handler.
78    pub fn tool<F, Fut>(mut self, def: McpToolDef, handler: F) -> Self
79    where
80        F: Fn(Value) -> Fut + Send + Sync + 'static,
81        Fut: Future<Output = Result<String, String>> + Send + 'static,
82    {
83        let name = def.name.clone();
84        self.tools.push(def);
85        self.handlers.insert(
86            name,
87            Arc::new(Box::new(move |args| Box::pin(handler(args)))),
88        );
89        self
90    }
91
92    /// Run the stdio JSON-RPC loop.
93    pub async fn run(self) -> Result<(), Box<dyn std::error::Error>> {
94        let stdin = tokio::io::stdin();
95        let mut stdout = tokio::io::stdout();
96        let mut reader = BufReader::new(stdin);
97        let mut line = String::new();
98
99        loop {
100            line.clear();
101            let n = reader.read_line(&mut line).await?;
102            if n == 0 {
103                break; // EOF
104            }
105
106            let trimmed = line.trim();
107            if trimmed.is_empty() {
108                continue;
109            }
110
111            let request: JsonRpcRequest = match serde_json::from_str(trimmed) {
112                Ok(r) => r,
113                Err(e) => {
114                    let err_resp = JsonRpcResponse {
115                        jsonrpc: "2.0".into(),
116                        id: None,
117                        result: None,
118                        error: Some(JsonRpcError {
119                            code: -32700,
120                            message: format!("parse error: {e}"),
121                        }),
122                    };
123                    let mut out = serde_json::to_string(&err_resp)?;
124                    out.push('\n');
125                    stdout.write_all(out.as_bytes()).await?;
126                    stdout.flush().await?;
127                    continue;
128                }
129            };
130
131            let response = self.handle_request(&request).await;
132            let mut out = serde_json::to_string(&response)?;
133            out.push('\n');
134            stdout.write_all(out.as_bytes()).await?;
135            stdout.flush().await?;
136        }
137
138        Ok(())
139    }
140
141    async fn handle_request(&self, req: &JsonRpcRequest) -> JsonRpcResponse {
142        match req.method.as_str() {
143            "initialize" => JsonRpcResponse {
144                jsonrpc: "2.0".into(),
145                id: req.id.clone(),
146                result: Some(serde_json::json!({
147                    "protocolVersion": "2024-11-05",
148                    "capabilities": {
149                        "tools": { "listChanged": false }
150                    },
151                    "serverInfo": {
152                        "name": "engram",
153                        "version": env!("CARGO_PKG_VERSION")
154                    }
155                })),
156                error: None,
157            },
158            "notifications/initialized" => JsonRpcResponse {
159                jsonrpc: "2.0".into(),
160                id: req.id.clone(),
161                result: Some(Value::Null),
162                error: None,
163            },
164            "tools/list" => JsonRpcResponse {
165                jsonrpc: "2.0".into(),
166                id: req.id.clone(),
167                result: Some(serde_json::json!({ "tools": self.tools })),
168                error: None,
169            },
170            "tools/call" => {
171                let params = req.params.as_ref().cloned().unwrap_or(Value::Null);
172                let name = params["name"].as_str().unwrap_or("");
173                let arguments = params
174                    .get("arguments")
175                    .cloned()
176                    .unwrap_or(Value::Object(Default::default()));
177
178                match self.handlers.get(name) {
179                    Some(handler) => {
180                        let result = handler(arguments).await;
181                        match result {
182                            Ok(text) => JsonRpcResponse {
183                                jsonrpc: "2.0".into(),
184                                id: req.id.clone(),
185                                result: Some(serde_json::json!({
186                                    "content": [{ "type": "text", "text": text }],
187                                    "isError": false
188                                })),
189                                error: None,
190                            },
191                            Err(e) => JsonRpcResponse {
192                                jsonrpc: "2.0".into(),
193                                id: req.id.clone(),
194                                result: Some(serde_json::json!({
195                                    "content": [{ "type": "text", "text": e }],
196                                    "isError": true
197                                })),
198                                error: None,
199                            },
200                        }
201                    }
202                    None => JsonRpcResponse {
203                        jsonrpc: "2.0".into(),
204                        id: req.id.clone(),
205                        result: None,
206                        error: Some(JsonRpcError {
207                            code: -32601,
208                            message: format!("unknown tool: {name}"),
209                        }),
210                    },
211                }
212            }
213            _ => JsonRpcResponse {
214                jsonrpc: "2.0".into(),
215                id: req.id.clone(),
216                result: None,
217                error: Some(JsonRpcError {
218                    code: -32601,
219                    message: format!("method not found: {}", req.method),
220                }),
221            },
222        }
223    }
224}
225
226impl Default for McpServer {
227    fn default() -> Self {
228        Self::new()
229    }
230}