Skip to main content

mcp_memory/
server.rs

1use serde_json::{Value, json};
2use std::path::Path;
3use std::sync::Arc;
4
5use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
6use tokio::net::TcpListener;
7use tracing::{error, info};
8
9use crate::actions::memory;
10use crate::config::Config;
11use crate::errors::{MCSError, Result};
12use crate::kg::GraphHandle;
13use crate::protocol::{JsonRpcRequest, JsonRpcResponse};
14use crate::tools;
15
16/// Outcome of processing a request: either a pre-escaped JSON Value (small
17/// payloads) or a pre-serialized JSON *string* of the `result` field (avoids
18/// a second serialization pass for large payloads such as `read_graph`).
19enum HandlerResult {
20    Value(Value),
21    RawResult(String),
22}
23
24const BUFFER_CAPACITY: usize = 65536;
25const NEWLINE: &[u8] = b"\n";
26/// Maximum size of a single inbound JSON-RPC message (shared by all transports).
27pub const MAX_REQUEST_BYTES: usize = 16 * 1024 * 1024;
28
29enum LineRead {
30    Line,
31    Eof,
32    TooLong,
33}
34
35async fn read_line_capped<R>(reader: &mut R, out: &mut String, max: usize) -> std::io::Result<LineRead>
36where
37    R: AsyncBufReadExt + Unpin,
38{
39    out.clear();
40    let mut buf: Vec<u8> = Vec::new();
41    loop {
42        let available = reader.fill_buf().await?;
43        if available.is_empty() {
44            if buf.is_empty() {
45                return Ok(LineRead::Eof);
46            }
47            *out = String::from_utf8(buf.clone()).map_err(|_| {
48                std::io::Error::new(std::io::ErrorKind::InvalidData, "Non-UTF-8 input")
49            })?;
50            return Ok(LineRead::Line);
51        }
52        match available.iter().position(|&b| b == b'\n') {
53            Some(i) => {
54                if buf.len() + i + 1 > max {
55                    reader.consume(i + 1);
56                    return Ok(LineRead::TooLong);
57                }
58                buf.extend_from_slice(&available[..=i]);
59                reader.consume(i + 1);
60                *out = String::from_utf8(buf.clone()).map_err(|_| {
61                    std::io::Error::new(std::io::ErrorKind::InvalidData, "Non-UTF-8 input")
62                })?;
63                return Ok(LineRead::Line);
64            }
65            None => {
66                let take = available.len();
67                if buf.len() + take > max {
68                    reader.consume(take);
69                    return Ok(LineRead::TooLong);
70                }
71                buf.extend_from_slice(available);
72                reader.consume(take);
73            }
74        }
75    }
76}
77
78fn parse_error(msg: String) -> JsonRpcResponse {
79    let mcp_error = MCSError::ParseError(msg);
80    JsonRpcResponse::error(None, mcp_error.error_code(), mcp_error.to_string())
81}
82
83/// Process one parsed JSON-RPC message. `None` means "no reply" — the message
84/// was a notification (no `id`), per JSON-RPC.
85pub fn process_value(value: Value, kg: &GraphHandle) -> Option<Value> {
86    let req: JsonRpcRequest = match serde_json::from_value(value) {
87        Ok(r) => r,
88        Err(e) => return Some(to_value(parse_error(e.to_string()))),
89    };
90    req.id.as_ref()?;
91    let response = match process_request(&req, kg) {
92        Ok(HandlerResult::Value(result)) => Some(to_value(JsonRpcResponse::success(req.id, result))),
93        Ok(HandlerResult::RawResult(_)) => {
94            // RawResult cannot pass through Value — dispatch_line and
95            // dispatch_http_body handle it via separate code paths.
96            unreachable!("RawResult must be handled at the dispatch level, not via process_value");
97        }
98        Err(e) => Some(to_value(JsonRpcResponse::error(req.id, e.error_code(), e.to_string()))),
99    };
100    response
101}
102
103/// Dispatch one framed line (stdio / tcp). Returns the serialized response, or
104/// `None` for a notification.
105pub fn dispatch_line(line: &str, kg: &GraphHandle) -> Option<String> {
106    let trimmed = line.trim();
107    if trimmed.is_empty() {
108        return Some(serde_json::to_string(&parse_error("Empty request".into())).unwrap());
109    }
110    let raw: Value = match serde_json::from_str(trimmed) {
111        Ok(v) => v,
112        Err(e) => return Some(serde_json::to_string(&parse_error(e.to_string())).unwrap()),
113    };
114    let req: JsonRpcRequest = match serde_json::from_value(raw) {
115        Ok(r) => r,
116        Err(e) => return Some(serde_json::to_string(&parse_error(e.to_string())).unwrap()),
117    };
118    req.id.as_ref()?;
119    match process_request(&req, kg) {
120        Ok(HandlerResult::Value(result)) => {
121            let resp = JsonRpcResponse::success(req.id, result);
122            Some(serde_json::to_string(&resp).unwrap())
123        }
124        Ok(HandlerResult::RawResult(result_json)) => {
125            let id_json = serde_json::to_string(&req.id).unwrap();
126            let mut out = String::with_capacity(64 + id_json.len() + result_json.len());
127            out.push_str(r#"{"jsonrpc":"2.0","id":"#);
128            out.push_str(&id_json);
129            out.push_str(r#","result":"#);
130            out.push_str(&result_json);
131            out.push('}');
132            Some(out)
133        }
134        Err(e) => {
135            let resp = JsonRpcResponse::error(req.id, e.error_code(), e.to_string());
136            Some(serde_json::to_string(&resp).unwrap())
137        }
138    }
139}
140
141/// Dispatch a Streamable-HTTP POST body, which may be a single JSON-RPC message
142/// or a batch array. `Ok(None)` means the body held only notifications (HTTP
143/// 202, empty body); `Err` means the body was not valid JSON.
144pub fn dispatch_http_body(
145    body: &str,
146    kg: &GraphHandle,
147) -> std::result::Result<Option<Value>, String> {
148    let value: Value = serde_json::from_str(body.trim()).map_err(|e| e.to_string())?;
149    match value {
150        Value::Array(items) => {
151            // Batches are rare and never huge — keep Value path for simplicity.
152            let responses: Vec<Value> =
153                items.into_iter().filter_map(|v| process_value_http(v, kg)).collect();
154            Ok((!responses.is_empty()).then_some(Value::Array(responses)))
155        }
156        other => Ok(process_value_http(other, kg)),
157    }
158}
159
160/// HTTP variant of process_value that handles RawResult by converting to Value
161/// (acceptable since HTTP payloads are typically much smaller in this context).
162fn process_value_http(value: Value, kg: &GraphHandle) -> Option<Value> {
163    let req: JsonRpcRequest = match serde_json::from_value(value) {
164        Ok(r) => r,
165        Err(e) => return Some(to_value(parse_error(e.to_string()))),
166    };
167    req.id.as_ref()?;
168    match process_request(&req, kg) {
169        Ok(HandlerResult::Value(result)) => {
170            Some(to_value(JsonRpcResponse::success(req.id, result)))
171        }
172        Ok(HandlerResult::RawResult(result_json)) => {
173            // Parse the pre-serialized result back into a Value for HTTP delivery.
174            // This is a small extra cost for the HTTP transport; the stdio/TCP
175            // path (dispatch_line) avoids it entirely.
176            let result_val: Value = serde_json::from_str(&result_json)
177                .unwrap_or(Value::Null);
178            Some(to_value(JsonRpcResponse::success(req.id, result_val)))
179        }
180        Err(e) => {
181            Some(to_value(JsonRpcResponse::error(req.id, e.error_code(), e.to_string())))
182        }
183    }
184}
185
186#[inline]
187fn to_value(resp: JsonRpcResponse) -> Value {
188    serde_json::to_value(resp).expect("JsonRpcResponse always serializes")
189}
190
191pub struct MCPServer {
192    _config: Arc<Config>,
193    kg: Arc<GraphHandle>,
194}
195
196impl MCPServer {
197    pub fn new(config: Config) -> Result<Self> {
198        let path = Path::new(&config.memory_file_path);
199        let kg = GraphHandle::new(path).map_err(MCSError::IoError)?;
200
201        Ok(Self {
202            _config: Arc::new(config),
203            kg: Arc::new(kg),
204        })
205    }
206
207    /// Expose the shared graph handle (used to drive the HTTP transport).
208    pub fn graph(&self) -> Arc<GraphHandle> {
209        Arc::clone(&self.kg)
210    }
211
212    /// stdio transport: newline-delimited JSON-RPC over stdin/stdout.
213    pub async fn run_stdio(&self) -> Result<()> {
214        let stdin = tokio::io::stdin();
215        let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, stdin);
216        let mut stdout = tokio::io::stdout();
217        serve_line_conn(&mut reader, &mut stdout, &self.kg).await
218    }
219
220    /// TCP transport: each accepted connection speaks newline-delimited
221    /// JSON-RPC, exactly like stdio. Connections are served concurrently and
222    /// share the one graph behind its mutex.
223    pub async fn run_tcp(&self, addr: &str) -> Result<()> {
224        let listener = TcpListener::bind(addr).await.map_err(MCSError::IoError)?;
225        info!("Listening for TCP MCP connections on {addr}");
226        loop {
227            let (socket, peer) = listener.accept().await.map_err(MCSError::IoError)?;
228            let kg = Arc::clone(&self.kg);
229            tokio::spawn(async move {
230                let (read_half, mut write_half) = socket.into_split();
231                let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, read_half);
232                if let Err(e) = serve_line_conn(&mut reader, &mut write_half, &kg).await {
233                    error!("TCP connection {peer} error: {e}");
234                }
235            });
236        }
237    }
238
239    /// MCP Streamable HTTP transport (POST/GET `/mcp`, JSON or SSE responses).
240    pub async fn run_http(&self, addr: &str) -> Result<()> {
241        crate::http::run(addr, self.graph()).await
242    }
243}
244
245/// Drive one line-framed connection (stdio or a single TCP socket): read
246/// newline-delimited JSON-RPC requests, write newline-delimited responses.
247/// Notifications produce no output. Returns when the peer closes the stream.
248async fn serve_line_conn<R, W>(reader: &mut R, writer: &mut W, kg: &GraphHandle) -> Result<()>
249where
250    R: AsyncBufReadExt + Unpin,
251    W: AsyncWriteExt + Unpin,
252{
253    let mut line = String::with_capacity(1024);
254    let mut out = Vec::with_capacity(BUFFER_CAPACITY);
255
256    loop {
257        match read_line_capped(reader, &mut line, MAX_REQUEST_BYTES).await {
258            Ok(LineRead::Eof) => break,
259            Ok(LineRead::Line) => {
260                if let Some(resp) = dispatch_line(&line, kg) {
261                    out.clear();
262                    out.extend_from_slice(resp.as_bytes());
263                    out.extend_from_slice(NEWLINE);
264                    writer.write_all(&out).await.map_err(MCSError::IoError)?;
265                    writer.flush().await.map_err(MCSError::IoError)?;
266                }
267            }
268            Ok(LineRead::TooLong) => {
269                let err = MCSError::InvalidParams("Request exceeds maximum size of 16MB".into());
270                let response = JsonRpcResponse::error(None, err.error_code(), err.to_string());
271                out.clear();
272                serde_json::to_writer(&mut out, &response).map_err(MCSError::JsonError)?;
273                out.extend_from_slice(NEWLINE);
274                writer.write_all(&out).await.map_err(MCSError::IoError)?;
275                writer.flush().await.map_err(MCSError::IoError)?;
276                break;
277            }
278            Err(e) => {
279                error!("IO error: {}", e);
280                break;
281            }
282        }
283    }
284    Ok(())
285}
286
287fn process_request(req: &JsonRpcRequest, kg: &GraphHandle) -> Result<HandlerResult> {
288    match req.method.as_str() {
289        "initialize" => Ok(HandlerResult::Value(handle_initialize())),
290        "tools/list" => Ok(HandlerResult::Value(handle_tools_list())),
291        "tools/call" => handle_tools_call(req, kg),
292        "ping" => Ok(HandlerResult::Value(Value::Null)),
293        method if method.starts_with("notifications/") => {
294            tracing::trace!("Received notification: {method}");
295            Ok(HandlerResult::Value(Value::Null))
296        }
297        _ => Err(MCSError::MethodNotFound(req.method.clone())),
298    }
299}
300
301fn handle_initialize() -> Value {
302    json!({
303        "protocolVersion": "2024-11-05",
304        "capabilities": {
305            "tools": { "listChanged": false }
306        },
307        "serverInfo": {
308            "name": "mcp-memory",
309            "version": env!("CARGO_PKG_VERSION")
310        }
311    })
312}
313
314fn handle_tools_list() -> Value {
315    static CACHED: std::sync::OnceLock<Value> = std::sync::OnceLock::new();
316    if let Some(cached) = CACHED.get() {
317        return cached.clone();
318    }
319    let tools_json = include_str!("../tools.json");
320    let tools: Vec<Value> =
321        serde_json::from_str(tools_json).map_err(MCSError::JsonError).unwrap();
322    let result = json!({ "tools": tools });
323    let _ = CACHED.set(result.clone());
324    result
325}
326
327fn handle_tools_call(req: &JsonRpcRequest, kg: &GraphHandle) -> Result<HandlerResult> {
328    let tool_name = req
329        .params
330        .as_ref()
331        .and_then(|p| p.get("name").and_then(|v| v.as_str()))
332        .ok_or_else(|| MCSError::InvalidParams("Missing 'name' parameter".into()))?;
333
334    let tool_args = req.params.as_ref().and_then(|p| p.get("arguments"));
335
336    if !tools::tool_exists(tool_name) {
337        return Err(MCSError::MethodNotFound(tool_name.to_string()));
338    }
339
340    match tool_name {
341        // Raw-result handlers (large payloads, avoid second serialization pass).
342        "read_graph" => Ok(HandlerResult::RawResult(memory::handle_read_graph(kg, tool_args)?)),
343        "search_nodes" => Ok(HandlerResult::RawResult(memory::handle_search_nodes(kg, tool_args)?)),
344        // Standard Value handlers.
345        "create_entities" => Ok(HandlerResult::Value(memory::handle_create_entities(kg, tool_args)?)),
346        "create_relations" => Ok(HandlerResult::Value(memory::handle_create_relations(kg, tool_args)?)),
347        "add_observations" => Ok(HandlerResult::Value(memory::handle_add_observations(kg, tool_args)?)),
348        "delete_entities" => Ok(HandlerResult::Value(memory::handle_delete_entities(kg, tool_args)?)),
349        "delete_observations" => Ok(HandlerResult::Value(memory::handle_delete_observations(kg, tool_args)?)),
350        "delete_relations" => Ok(HandlerResult::Value(memory::handle_delete_relations(kg, tool_args)?)),
351        "open_nodes" => Ok(HandlerResult::Value(memory::handle_open_nodes(kg, tool_args)?)),
352        "get_entity" => Ok(HandlerResult::Value(memory::handle_get_entity(kg, tool_args)?)),
353        "graph_stats" => Ok(HandlerResult::Value(memory::handle_graph_stats(kg)?)),
354        "search_relations" => Ok(HandlerResult::Value(memory::handle_search_relations(kg, tool_args)?)),
355        "find_path" => Ok(HandlerResult::Value(memory::handle_find_path(kg, tool_args)?)),
356        "compact" => Ok(HandlerResult::Value(memory::handle_compact(kg)?)),
357        "get_neighbors" => Ok(HandlerResult::Value(memory::handle_get_neighbors(kg, tool_args)?)),
358        "describe_entity" => Ok(HandlerResult::Value(memory::handle_describe_entity(kg, tool_args)?)),
359        "list_entity_types" => Ok(HandlerResult::Value(memory::handle_list_entity_types(kg)?)),
360        "list_relation_types" => Ok(HandlerResult::Value(memory::handle_list_relation_types(kg)?)),
361        "upsert_entities" => Ok(HandlerResult::Value(memory::handle_upsert_entities(kg, tool_args)?)),
362        "export_graph" => Ok(HandlerResult::Value(memory::handle_export_graph(kg, tool_args)?)),
363        "merge_entities" => Ok(HandlerResult::Value(memory::handle_merge_entities(kg, tool_args)?)),
364        "extract_subgraph" => Ok(HandlerResult::Value(memory::handle_extract_subgraph(kg, tool_args)?)),
365        "batch_get_entities" => Ok(HandlerResult::Value(memory::handle_batch_get_entities(kg, tool_args)?)),
366        "find_all_paths" => Ok(HandlerResult::Value(memory::handle_find_all_paths(kg, tool_args)?)),
367        tool => Err(MCSError::MethodNotFound(tool.to_string())),
368    }
369}
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374    use std::sync::atomic::{AtomicU64, Ordering};
375
376    static COUNTER: AtomicU64 = AtomicU64::new(0);
377
378    fn setup_kg() -> (Arc<GraphHandle>, String) {
379        let pid = std::process::id();
380        let seq = COUNTER.fetch_add(1, Ordering::SeqCst);
381        let path = format!("/tmp/mcp_mem_test_{pid}_{seq}.bin");
382        let kg = GraphHandle::new(Path::new(&path)).unwrap();
383        (Arc::new(kg), path)
384    }
385
386    fn cleanup(path: &str) {
387        let _ = std::fs::remove_file(path);
388    }
389
390    #[test]
391    fn test_dispatch_line_valid_request() {
392        let (kg, path) = setup_kg();
393        let line = r#"{"jsonrpc":"2.0","method":"initialize","id":1}"#;
394        let resp = dispatch_line(line, &kg).unwrap();
395        let v: Value = serde_json::from_str(&resp).unwrap();
396        assert_eq!(v["id"], 1);
397        assert_eq!(v["result"]["serverInfo"]["name"], "mcp-memory");
398        cleanup(&path);
399    }
400
401    #[test]
402    fn test_dispatch_line_invalid_json() {
403        let (kg, path) = setup_kg();
404        let resp = dispatch_line("{invalid}", &kg).unwrap();
405        let v: Value = serde_json::from_str(&resp).unwrap();
406        assert_eq!(v["error"]["code"], -32700);
407        assert!(v["id"].is_null());
408        cleanup(&path);
409    }
410
411    #[test]
412    fn test_dispatch_line_empty() {
413        let (kg, path) = setup_kg();
414        let resp = dispatch_line("   \n", &kg).unwrap();
415        let v: Value = serde_json::from_str(&resp).unwrap();
416        assert_eq!(v["error"]["code"], -32700);
417        cleanup(&path);
418    }
419
420    #[test]
421    fn test_notification_has_no_response() {
422        let (kg, path) = setup_kg();
423        let line = r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#;
424        assert!(dispatch_line(line, &kg).is_none());
425        cleanup(&path);
426    }
427
428    #[test]
429    fn test_unknown_method_error() {
430        let (kg, path) = setup_kg();
431        let line = r#"{"jsonrpc":"2.0","method":"does/not/exist","id":7}"#;
432        let v: Value = serde_json::from_str(&dispatch_line(line, &kg).unwrap()).unwrap();
433        assert_eq!(v["id"], 7);
434        assert_eq!(v["error"]["code"], -32601);
435        cleanup(&path);
436    }
437
438    #[test]
439    fn test_tools_call_roundtrip_via_dispatch() {
440        let (kg, path) = setup_kg();
441        let create = r#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"create_entities","arguments":{"entities":[{"name":"Ada","entityType":"person","observations":["math"]}]}}}"#;
442        assert!(dispatch_line(create, &kg).is_some());
443
444        let read = r#"{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"read_graph","arguments":{}}}"#;
445        let v: Value = serde_json::from_str(&dispatch_line(read, &kg).unwrap()).unwrap();
446        let text = v["result"]["content"][0]["text"].as_str().unwrap();
447        assert!(text.contains("Ada"));
448        cleanup(&path);
449    }
450
451    #[test]
452    fn test_http_body_batch_and_notifications() {
453        let (kg, path) = setup_kg();
454        let batch = r#"[
455            {"jsonrpc":"2.0","method":"initialize","id":1},
456            {"jsonrpc":"2.0","method":"notifications/initialized"}
457        ]"#;
458        let out = dispatch_http_body(batch, &kg).unwrap().unwrap();
459        let arr = out.as_array().unwrap();
460        assert_eq!(arr.len(), 1);
461        assert_eq!(arr[0]["id"], 1);
462
463        let notif = r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#;
464        assert!(dispatch_http_body(notif, &kg).unwrap().is_none());
465
466        assert!(dispatch_http_body("{bad", &kg).is_err());
467        cleanup(&path);
468    }
469
470    #[test]
471    fn test_handle_initialize_response() {
472        let (kg, path) = setup_kg();
473        let req = JsonRpcRequest {
474            jsonrpc: "2.0".to_string(),
475            method: "initialize".to_string(),
476            params: None,
477            id: Some(Value::Number(1.into())),
478        };
479        let result = match process_request(&req, &kg).unwrap() {
480            HandlerResult::Value(v) => v,
481            HandlerResult::RawResult(_) => panic!("expected Value"),
482        };
483        assert_eq!(result["protocolVersion"], "2024-11-05");
484        assert_eq!(result["serverInfo"]["name"], "mcp-memory");
485        cleanup(&path);
486    }
487}