Skip to main content

mcp_memory/
server.rs

1use serde_json::{Value, json};
2use std::path::Path;
3use std::sync::Arc;
4
5use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
6use tokio::net::TcpListener;
7use tracing::{error, info};
8
9use crate::actions::memory;
10use crate::config::Config;
11use crate::errors::{MCSError, Result};
12use crate::kg::GraphHandle;
13use crate::protocol::{JsonRpcRequest, JsonRpcResponse};
14use crate::tools;
15
16const BUFFER_CAPACITY: usize = 65536;
17const NEWLINE: &[u8] = b"\n";
18/// Maximum size of a single inbound JSON-RPC message (shared by all transports).
19pub const MAX_REQUEST_BYTES: usize = 16 * 1024 * 1024;
20
21enum LineRead {
22    Line,
23    Eof,
24    TooLong,
25}
26
27async fn read_line_capped<R>(reader: &mut R, out: &mut String, max: usize) -> std::io::Result<LineRead>
28where
29    R: AsyncBufReadExt + Unpin,
30{
31    out.clear();
32    let mut buf: Vec<u8> = Vec::new();
33    loop {
34        let available = reader.fill_buf().await?;
35        if available.is_empty() {
36            if buf.is_empty() {
37                return Ok(LineRead::Eof);
38            }
39            *out = String::from_utf8(buf.clone()).map_err(|_| {
40                std::io::Error::new(std::io::ErrorKind::InvalidData, "Non-UTF-8 input")
41            })?;
42            return Ok(LineRead::Line);
43        }
44        match available.iter().position(|&b| b == b'\n') {
45            Some(i) => {
46                if buf.len() + i + 1 > max {
47                    reader.consume(i + 1);
48                    return Ok(LineRead::TooLong);
49                }
50                buf.extend_from_slice(&available[..=i]);
51                reader.consume(i + 1);
52                *out = String::from_utf8(buf.clone()).map_err(|_| {
53                    std::io::Error::new(std::io::ErrorKind::InvalidData, "Non-UTF-8 input")
54                })?;
55                return Ok(LineRead::Line);
56            }
57            None => {
58                let take = available.len();
59                if buf.len() + take > max {
60                    reader.consume(take);
61                    return Ok(LineRead::TooLong);
62                }
63                buf.extend_from_slice(available);
64                reader.consume(take);
65            }
66        }
67    }
68}
69
70fn parse_error(msg: String) -> JsonRpcResponse {
71    let mcp_error = MCSError::ParseError(msg);
72    JsonRpcResponse::error(None, mcp_error.error_code(), mcp_error.to_string())
73}
74
75/// Process one parsed JSON-RPC message. `None` means "no reply" — the message
76/// was a notification (no `id`), per JSON-RPC.
77pub fn process_value(value: Value, kg: &GraphHandle) -> Option<Value> {
78    let req: JsonRpcRequest = match serde_json::from_value(value) {
79        Ok(r) => r,
80        Err(e) => return Some(to_value(parse_error(e.to_string()))),
81    };
82    req.id.as_ref()?;
83    let response = match process_request(&req, kg) {
84        Ok(result) => JsonRpcResponse::success(req.id, result),
85        Err(e) => JsonRpcResponse::error(req.id, e.error_code(), e.to_string()),
86    };
87    Some(to_value(response))
88}
89
90/// Dispatch one framed line (stdio / tcp). Returns the serialized response, or
91/// `None` for a notification.
92pub fn dispatch_line(line: &str, kg: &GraphHandle) -> Option<String> {
93    let trimmed = line.trim();
94    if trimmed.is_empty() {
95        return Some(serde_json::to_string(&parse_error("Empty request".into())).unwrap());
96    }
97    let value: Value = match serde_json::from_str(trimmed) {
98        Ok(v) => v,
99        Err(e) => return Some(serde_json::to_string(&parse_error(e.to_string())).unwrap()),
100    };
101    process_value(value, kg).map(|v| serde_json::to_string(&v).unwrap())
102}
103
104/// Dispatch a Streamable-HTTP POST body, which may be a single JSON-RPC message
105/// or a batch array. `Ok(None)` means the body held only notifications (HTTP
106/// 202, empty body); `Err` means the body was not valid JSON.
107pub fn dispatch_http_body(
108    body: &str,
109    kg: &GraphHandle,
110) -> std::result::Result<Option<Value>, String> {
111    let value: Value = serde_json::from_str(body.trim()).map_err(|e| e.to_string())?;
112    match value {
113        Value::Array(items) => {
114            let responses: Vec<Value> =
115                items.into_iter().filter_map(|v| process_value(v, kg)).collect();
116            Ok((!responses.is_empty()).then_some(Value::Array(responses)))
117        }
118        other => Ok(process_value(other, kg)),
119    }
120}
121
122#[inline]
123fn to_value(resp: JsonRpcResponse) -> Value {
124    serde_json::to_value(resp).expect("JsonRpcResponse always serializes")
125}
126
127pub struct MCPServer {
128    _config: Arc<Config>,
129    kg: Arc<GraphHandle>,
130}
131
132impl MCPServer {
133    pub fn new(config: Config) -> Result<Self> {
134        let path = Path::new(&config.memory_file_path);
135        let kg = GraphHandle::new(path).map_err(MCSError::IoError)?;
136
137        Ok(Self {
138            _config: Arc::new(config),
139            kg: Arc::new(kg),
140        })
141    }
142
143    /// Expose the shared graph handle (used to drive the HTTP transport).
144    pub fn graph(&self) -> Arc<GraphHandle> {
145        Arc::clone(&self.kg)
146    }
147
148    /// stdio transport: newline-delimited JSON-RPC over stdin/stdout.
149    pub async fn run_stdio(&self) -> Result<()> {
150        let stdin = tokio::io::stdin();
151        let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, stdin);
152        let mut stdout = tokio::io::stdout();
153        serve_line_conn(&mut reader, &mut stdout, &self.kg).await
154    }
155
156    /// TCP transport: each accepted connection speaks newline-delimited
157    /// JSON-RPC, exactly like stdio. Connections are served concurrently and
158    /// share the one graph behind its mutex.
159    pub async fn run_tcp(&self, addr: &str) -> Result<()> {
160        let listener = TcpListener::bind(addr).await.map_err(MCSError::IoError)?;
161        info!("Listening for TCP MCP connections on {addr}");
162        loop {
163            let (socket, peer) = listener.accept().await.map_err(MCSError::IoError)?;
164            let kg = Arc::clone(&self.kg);
165            tokio::spawn(async move {
166                let (read_half, mut write_half) = socket.into_split();
167                let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, read_half);
168                if let Err(e) = serve_line_conn(&mut reader, &mut write_half, &kg).await {
169                    error!("TCP connection {peer} error: {e}");
170                }
171            });
172        }
173    }
174
175    /// MCP Streamable HTTP transport (POST/GET `/mcp`, JSON or SSE responses).
176    pub async fn run_http(&self, addr: &str) -> Result<()> {
177        crate::http::run(addr, self.graph()).await
178    }
179}
180
181/// Drive one line-framed connection (stdio or a single TCP socket): read
182/// newline-delimited JSON-RPC requests, write newline-delimited responses.
183/// Notifications produce no output. Returns when the peer closes the stream.
184async fn serve_line_conn<R, W>(reader: &mut R, writer: &mut W, kg: &GraphHandle) -> Result<()>
185where
186    R: AsyncBufReadExt + Unpin,
187    W: AsyncWriteExt + Unpin,
188{
189    let mut line = String::with_capacity(1024);
190    let mut out = Vec::with_capacity(BUFFER_CAPACITY);
191
192    loop {
193        match read_line_capped(reader, &mut line, MAX_REQUEST_BYTES).await {
194            Ok(LineRead::Eof) => break,
195            Ok(LineRead::Line) => {
196                if let Some(resp) = dispatch_line(&line, kg) {
197                    out.clear();
198                    out.extend_from_slice(resp.as_bytes());
199                    out.extend_from_slice(NEWLINE);
200                    writer.write_all(&out).await.map_err(MCSError::IoError)?;
201                    writer.flush().await.map_err(MCSError::IoError)?;
202                }
203            }
204            Ok(LineRead::TooLong) => {
205                let err = MCSError::InvalidParams("Request exceeds maximum size of 16MB".into());
206                let response = JsonRpcResponse::error(None, err.error_code(), err.to_string());
207                out.clear();
208                serde_json::to_writer(&mut out, &response).map_err(MCSError::JsonError)?;
209                out.extend_from_slice(NEWLINE);
210                writer.write_all(&out).await.map_err(MCSError::IoError)?;
211                writer.flush().await.map_err(MCSError::IoError)?;
212                break;
213            }
214            Err(e) => {
215                error!("IO error: {}", e);
216                break;
217            }
218        }
219    }
220    Ok(())
221}
222
223fn process_request(req: &JsonRpcRequest, kg: &GraphHandle) -> Result<Value> {
224    match req.method.as_str() {
225        "initialize" => handle_initialize(),
226        "tools/list" => handle_tools_list(),
227        "tools/call" => handle_tools_call(req, kg),
228        "ping" => handle_ping(),
229        method if method.starts_with("notifications/") => handle_notification(method),
230        _ => Err(MCSError::MethodNotFound(req.method.clone())),
231    }
232}
233
234const fn handle_ping() -> Result<Value> {
235    Ok(Value::Null)
236}
237
238fn handle_notification(method: &str) -> Result<Value> {
239    tracing::trace!("Received notification: {method}");
240    Ok(Value::Null)
241}
242
243fn handle_initialize() -> Result<Value> {
244    Ok(json!({
245        "protocolVersion": "2024-11-05",
246        "capabilities": {
247            "tools": { "listChanged": false }
248        },
249        "serverInfo": {
250            "name": "mcp-memory",
251            "version": env!("CARGO_PKG_VERSION")
252        }
253    }))
254}
255
256fn handle_tools_list() -> Result<Value> {
257    static CACHED: std::sync::OnceLock<Value> = std::sync::OnceLock::new();
258    if let Some(cached) = CACHED.get() {
259        return Ok(cached.clone());
260    }
261    let tools_json = include_str!("../tools.json");
262    let tools: Vec<Value> =
263        serde_json::from_str(tools_json).map_err(MCSError::JsonError)?;
264    let result = json!({ "tools": tools });
265    let _ = CACHED.set(result.clone());
266    Ok(result)
267}
268
269fn handle_tools_call(req: &JsonRpcRequest, kg: &GraphHandle) -> Result<Value> {
270    let tool_name = req
271        .params
272        .as_ref()
273        .and_then(|p| p.get("name").and_then(|v| v.as_str()))
274        .ok_or_else(|| MCSError::InvalidParams("Missing 'name' parameter".into()))?;
275
276    let tool_args = req.params.as_ref().and_then(|p| p.get("arguments"));
277
278    if !tools::tool_exists(tool_name) {
279        return Err(MCSError::MethodNotFound(tool_name.to_string()));
280    }
281
282    match tool_name {
283        "create_entities" => memory::handle_create_entities(kg, tool_args),
284        "create_relations" => memory::handle_create_relations(kg, tool_args),
285        "add_observations" => memory::handle_add_observations(kg, tool_args),
286        "delete_entities" => memory::handle_delete_entities(kg, tool_args),
287        "delete_observations" => memory::handle_delete_observations(kg, tool_args),
288        "delete_relations" => memory::handle_delete_relations(kg, tool_args),
289        "read_graph" => memory::handle_read_graph(kg, tool_args),
290        "search_nodes" => memory::handle_search_nodes(kg, tool_args),
291        "open_nodes" => memory::handle_open_nodes(kg, tool_args),
292        "get_entity" => memory::handle_get_entity(kg, tool_args),
293        "graph_stats" => memory::handle_graph_stats(kg),
294        "search_relations" => memory::handle_search_relations(kg, tool_args),
295        "find_path" => memory::handle_find_path(kg, tool_args),
296        "compact" => memory::handle_compact(kg),
297        "get_neighbors" => memory::handle_get_neighbors(kg, tool_args),
298        "describe_entity" => memory::handle_describe_entity(kg, tool_args),
299        "list_entity_types" => memory::handle_list_entity_types(kg),
300        "list_relation_types" => memory::handle_list_relation_types(kg),
301        "upsert_entities" => memory::handle_upsert_entities(kg, tool_args),
302        "export_graph" => memory::handle_export_graph(kg, tool_args),
303        "merge_entities" => memory::handle_merge_entities(kg, tool_args),
304        "extract_subgraph" => memory::handle_extract_subgraph(kg, tool_args),
305        "batch_get_entities" => memory::handle_batch_get_entities(kg, tool_args),
306        "find_all_paths" => memory::handle_find_all_paths(kg, tool_args),
307        tool => Err(MCSError::MethodNotFound(tool.to_string())),
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314    use std::sync::atomic::{AtomicU64, Ordering};
315
316    static COUNTER: AtomicU64 = AtomicU64::new(0);
317
318    fn setup_kg() -> (Arc<GraphHandle>, String) {
319        let pid = std::process::id();
320        let seq = COUNTER.fetch_add(1, Ordering::SeqCst);
321        let path = format!("/tmp/mcp_mem_test_{pid}_{seq}.bin");
322        let kg = GraphHandle::new(Path::new(&path)).unwrap();
323        (Arc::new(kg), path)
324    }
325
326    fn cleanup(path: &str) {
327        let _ = std::fs::remove_file(path);
328    }
329
330    #[test]
331    fn test_dispatch_line_valid_request() {
332        let (kg, path) = setup_kg();
333        let line = r#"{"jsonrpc":"2.0","method":"initialize","id":1}"#;
334        let resp = dispatch_line(line, &kg).unwrap();
335        let v: Value = serde_json::from_str(&resp).unwrap();
336        assert_eq!(v["id"], 1);
337        assert_eq!(v["result"]["serverInfo"]["name"], "mcp-memory");
338        cleanup(&path);
339    }
340
341    #[test]
342    fn test_dispatch_line_invalid_json() {
343        let (kg, path) = setup_kg();
344        let resp = dispatch_line("{invalid}", &kg).unwrap();
345        let v: Value = serde_json::from_str(&resp).unwrap();
346        assert_eq!(v["error"]["code"], -32700);
347        assert!(v["id"].is_null());
348        cleanup(&path);
349    }
350
351    #[test]
352    fn test_dispatch_line_empty() {
353        let (kg, path) = setup_kg();
354        let resp = dispatch_line("   \n", &kg).unwrap();
355        let v: Value = serde_json::from_str(&resp).unwrap();
356        assert_eq!(v["error"]["code"], -32700);
357        cleanup(&path);
358    }
359
360    #[test]
361    fn test_notification_has_no_response() {
362        let (kg, path) = setup_kg();
363        let line = r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#;
364        assert!(dispatch_line(line, &kg).is_none());
365        cleanup(&path);
366    }
367
368    #[test]
369    fn test_unknown_method_error() {
370        let (kg, path) = setup_kg();
371        let line = r#"{"jsonrpc":"2.0","method":"does/not/exist","id":7}"#;
372        let v: Value = serde_json::from_str(&dispatch_line(line, &kg).unwrap()).unwrap();
373        assert_eq!(v["id"], 7);
374        assert_eq!(v["error"]["code"], -32601);
375        cleanup(&path);
376    }
377
378    #[test]
379    fn test_tools_call_roundtrip_via_dispatch() {
380        let (kg, path) = setup_kg();
381        let create = r#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"create_entities","arguments":{"entities":[{"name":"Ada","entityType":"person","observations":["math"]}]}}}"#;
382        assert!(dispatch_line(create, &kg).is_some());
383
384        let read = r#"{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"read_graph","arguments":{}}}"#;
385        let v: Value = serde_json::from_str(&dispatch_line(read, &kg).unwrap()).unwrap();
386        let text = v["result"]["content"][0]["text"].as_str().unwrap();
387        assert!(text.contains("Ada"));
388        cleanup(&path);
389    }
390
391    #[test]
392    fn test_http_body_batch_and_notifications() {
393        let (kg, path) = setup_kg();
394        let batch = r#"[
395            {"jsonrpc":"2.0","method":"initialize","id":1},
396            {"jsonrpc":"2.0","method":"notifications/initialized"}
397        ]"#;
398        let out = dispatch_http_body(batch, &kg).unwrap().unwrap();
399        let arr = out.as_array().unwrap();
400        assert_eq!(arr.len(), 1);
401        assert_eq!(arr[0]["id"], 1);
402
403        let notif = r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#;
404        assert!(dispatch_http_body(notif, &kg).unwrap().is_none());
405
406        assert!(dispatch_http_body("{bad", &kg).is_err());
407        cleanup(&path);
408    }
409
410    #[test]
411    fn test_handle_initialize_response() {
412        let (kg, path) = setup_kg();
413        let req = JsonRpcRequest {
414            jsonrpc: "2.0".to_string(),
415            method: "initialize".to_string(),
416            params: None,
417            id: Some(Value::Number(1.into())),
418        };
419        let result = process_request(&req, &kg).unwrap();
420        assert_eq!(result["protocolVersion"], "2024-11-05");
421        assert_eq!(result["serverInfo"]["name"], "mcp-memory");
422        cleanup(&path);
423    }
424}