Skip to main content

mcp_memory/
server.rs

1use serde_json::{Value, json};
2use std::path::Path;
3use std::sync::{Arc, OnceLock};
4
5use parking_lot::RwLock;
6use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
7use tokio::net::TcpListener;
8use tracing::{error, info};
9
10use crate::actions::memory;
11use crate::config::Config;
12use crate::errors::{MCSError, Result};
13use crate::kg::KnowledgeGraph;
14use crate::protocol::{JsonRpcRequest, JsonRpcResponse};
15use crate::tools;
16
17const BUFFER_CAPACITY: usize = 65536;
18const NEWLINE: &[u8] = b"\n";
19/// Maximum size of a single inbound JSON-RPC message (shared by all transports).
20pub const MAX_REQUEST_BYTES: usize = 16 * 1024 * 1024;
21
22enum LineRead {
23    Line,
24    Eof,
25    TooLong,
26}
27
28async fn read_line_capped<R>(reader: &mut R, out: &mut String, max: usize) -> std::io::Result<LineRead>
29where
30    R: AsyncBufReadExt + Unpin,
31{
32    out.clear();
33    let mut buf: Vec<u8> = Vec::new();
34    loop {
35        let available = reader.fill_buf().await?;
36        if available.is_empty() {
37            if buf.is_empty() {
38                return Ok(LineRead::Eof);
39            }
40            *out = String::from_utf8(buf.clone()).map_err(|_| {
41                std::io::Error::new(std::io::ErrorKind::InvalidData, "Non-UTF-8 input")
42            })?;
43            return Ok(LineRead::Line);
44        }
45        match available.iter().position(|&b| b == b'\n') {
46            Some(i) => {
47                if buf.len() + i + 1 > max {
48                    reader.consume(i + 1);
49                    return Ok(LineRead::TooLong);
50                }
51                buf.extend_from_slice(&available[..=i]);
52                reader.consume(i + 1);
53                *out = String::from_utf8(buf.clone()).map_err(|_| {
54                    std::io::Error::new(std::io::ErrorKind::InvalidData, "Non-UTF-8 input")
55                })?;
56                return Ok(LineRead::Line);
57            }
58            None => {
59                let take = available.len();
60                if buf.len() + take > max {
61                    reader.consume(take);
62                    return Ok(LineRead::TooLong);
63                }
64                buf.extend_from_slice(available);
65                reader.consume(take);
66            }
67        }
68    }
69}
70
71fn parse_error(msg: String) -> JsonRpcResponse {
72    let mcp_error = MCSError::ParseError(msg);
73    JsonRpcResponse::error(None, mcp_error.error_code(), mcp_error.to_string())
74}
75
76// ---------------------------------------------------------------------------
77// Transport-agnostic dispatch core.
78//
79// All three transports (stdio, tcp, http) share the exact same JSON-RPC/MCP
80// semantics; they differ only in framing. `process_value` is the single source
81// of truth: it takes one parsed JSON-RPC message and returns the response
82// value, or `None` for a notification (which gets no reply anywhere).
83// ---------------------------------------------------------------------------
84
85/// Process one parsed JSON-RPC message. `None` means "no reply" — the message
86/// was a notification (no `id`), per JSON-RPC.
87pub fn process_value(value: Value, kg: &RwLock<KnowledgeGraph>) -> Option<Value> {
88    let req: JsonRpcRequest = match serde_json::from_value(value) {
89        Ok(r) => r,
90        Err(e) => return Some(to_value(parse_error(e.to_string()))),
91    };
92    // Notifications never get a reply, even on error.
93    if req.id.is_none() {
94        return None;
95    }
96    let response = match process_request(&req, kg) {
97        Ok(result) => JsonRpcResponse::success(req.id, result),
98        Err(e) => JsonRpcResponse::error(req.id, e.error_code(), e.to_string()),
99    };
100    Some(to_value(response))
101}
102
103/// Dispatch one framed line (stdio / tcp). Returns the serialized response, or
104/// `None` for a notification.
105pub fn dispatch_line(line: &str, kg: &RwLock<KnowledgeGraph>) -> Option<String> {
106    let trimmed = line.trim();
107    if trimmed.is_empty() {
108        return Some(serde_json::to_string(&parse_error("Empty request".into())).unwrap());
109    }
110    let value: Value = match serde_json::from_str(trimmed) {
111        Ok(v) => v,
112        Err(e) => return Some(serde_json::to_string(&parse_error(e.to_string())).unwrap()),
113    };
114    process_value(value, kg).map(|v| serde_json::to_string(&v).unwrap())
115}
116
117/// Dispatch a Streamable-HTTP POST body, which may be a single JSON-RPC message
118/// or a batch array. `Ok(None)` means the body held only notifications (HTTP
119/// 202, empty body); `Err` means the body was not valid JSON.
120pub fn dispatch_http_body(
121    body: &str,
122    kg: &RwLock<KnowledgeGraph>,
123) -> std::result::Result<Option<Value>, String> {
124    let value: Value = serde_json::from_str(body.trim()).map_err(|e| e.to_string())?;
125    match value {
126        Value::Array(items) => {
127            let responses: Vec<Value> =
128                items.into_iter().filter_map(|v| process_value(v, kg)).collect();
129            Ok((!responses.is_empty()).then_some(Value::Array(responses)))
130        }
131        other => Ok(process_value(other, kg)),
132    }
133}
134
135#[inline]
136fn to_value(resp: JsonRpcResponse) -> Value {
137    serde_json::to_value(resp).expect("JsonRpcResponse always serializes")
138}
139
140pub struct MCPServer {
141    _config: Arc<Config>,
142    kg: Arc<RwLock<KnowledgeGraph>>,
143}
144
145impl MCPServer {
146    pub fn new(config: Config) -> Result<Self> {
147        let path = Path::new(&config.memory_file_path);
148        let kg = KnowledgeGraph::new(path)
149            .map_err(MCSError::IoError)?;
150
151        Ok(Self {
152            _config: Arc::new(config),
153            kg: Arc::new(RwLock::new(kg)),
154        })
155    }
156
157    /// Expose the shared graph handle (used to drive the HTTP transport).
158    pub fn graph(&self) -> Arc<RwLock<KnowledgeGraph>> {
159        Arc::clone(&self.kg)
160    }
161
162    /// stdio transport: newline-delimited JSON-RPC over stdin/stdout.
163    pub async fn run_stdio(&self) -> Result<()> {
164        let stdin = tokio::io::stdin();
165        let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, stdin);
166        let mut stdout = tokio::io::stdout();
167        serve_line_conn(&mut reader, &mut stdout, &self.kg).await
168    }
169
170    /// TCP transport: each accepted connection speaks newline-delimited
171    /// JSON-RPC, exactly like stdio. Connections are served concurrently and
172    /// share the one graph behind its mutex.
173    pub async fn run_tcp(&self, addr: &str) -> Result<()> {
174        let listener = TcpListener::bind(addr).await.map_err(MCSError::IoError)?;
175        info!("Listening for TCP MCP connections on {addr}");
176        loop {
177            let (socket, peer) = listener.accept().await.map_err(MCSError::IoError)?;
178            let kg = Arc::clone(&self.kg);
179            tokio::spawn(async move {
180                let (read_half, mut write_half) = socket.into_split();
181                let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, read_half);
182                if let Err(e) = serve_line_conn(&mut reader, &mut write_half, &kg).await {
183                    error!("TCP connection {peer} error: {e}");
184                }
185            });
186        }
187    }
188
189    /// MCP Streamable HTTP transport (POST/GET `/mcp`, JSON or SSE responses).
190    pub async fn run_http(&self, addr: &str) -> Result<()> {
191        crate::http::run(addr, self.graph()).await
192    }
193}
194
195/// Drive one line-framed connection (stdio or a single TCP socket): read
196/// newline-delimited JSON-RPC requests, write newline-delimited responses.
197/// Notifications produce no output. Returns when the peer closes the stream.
198async fn serve_line_conn<R, W>(reader: &mut R, writer: &mut W, kg: &RwLock<KnowledgeGraph>) -> Result<()>
199where
200    R: AsyncBufReadExt + Unpin,
201    W: AsyncWriteExt + Unpin,
202{
203    let mut line = String::with_capacity(1024);
204    let mut out = Vec::with_capacity(BUFFER_CAPACITY);
205
206    loop {
207        match read_line_capped(reader, &mut line, MAX_REQUEST_BYTES).await {
208            Ok(LineRead::Eof) => break,
209            Ok(LineRead::Line) => {
210                if let Some(resp) = dispatch_line(&line, kg) {
211                    out.clear();
212                    out.extend_from_slice(resp.as_bytes());
213                    out.extend_from_slice(NEWLINE);
214                    writer.write_all(&out).await.map_err(MCSError::IoError)?;
215                    writer.flush().await.map_err(MCSError::IoError)?;
216                }
217            }
218            Ok(LineRead::TooLong) => {
219                let err = MCSError::InvalidParams("Request exceeds maximum size of 16MB".into());
220                let response = JsonRpcResponse::error(None, err.error_code(), err.to_string());
221                out.clear();
222                serde_json::to_writer(&mut out, &response).map_err(MCSError::JsonError)?;
223                out.extend_from_slice(NEWLINE);
224                writer.write_all(&out).await.map_err(MCSError::IoError)?;
225                writer.flush().await.map_err(MCSError::IoError)?;
226                break;
227            }
228            Err(e) => {
229                error!("IO error: {}", e);
230                break;
231            }
232        }
233    }
234    Ok(())
235}
236
237fn process_request(req: &JsonRpcRequest, kg: &RwLock<KnowledgeGraph>) -> Result<Value> {
238    match req.method.as_str() {
239        "initialize" => handle_initialize(),
240        "tools/list" => handle_tools_list(),
241        "tools/call" => handle_tools_call(req, kg),
242        "ping" => handle_ping(),
243        method if method.starts_with("notifications/") => handle_notification(method),
244        _ => Err(MCSError::MethodNotFound(req.method.clone())),
245    }
246}
247
248const fn handle_ping() -> Result<Value> {
249    Ok(Value::Null)
250}
251
252fn handle_notification(method: &str) -> Result<Value> {
253    tracing::trace!("Received notification: {method}");
254    Ok(Value::Null)
255}
256
257fn handle_initialize() -> Result<Value> {
258    Ok(json!({
259        "protocolVersion": "2024-11-05",
260        "capabilities": {
261            "tools": { "listChanged": false }
262        },
263        "serverInfo": {
264            "name": "mcp-memory",
265            "version": env!("CARGO_PKG_VERSION")
266        }
267    }))
268}
269
270fn handle_tools_list() -> Result<Value> {
271    static CACHED: OnceLock<Value> = OnceLock::new();
272    if let Some(cached) = CACHED.get() {
273        return Ok(cached.clone());
274    }
275    let tools_json = include_str!("../tools.json");
276    let tools: Vec<Value> =
277        serde_json::from_str(tools_json).map_err(MCSError::JsonError)?;
278    let result = json!({ "tools": tools });
279    let _ = CACHED.set(result.clone());
280    Ok(result)
281}
282
283fn handle_tools_call(req: &JsonRpcRequest, kg: &RwLock<KnowledgeGraph>) -> Result<Value> {
284    let tool_name = req
285        .params
286        .as_ref()
287        .and_then(|p| p.get("name").and_then(|v| v.as_str()))
288        .ok_or_else(|| MCSError::InvalidParams("Missing 'name' parameter".into()))?;
289
290    let tool_args = req.params.as_ref().and_then(|p| p.get("arguments"));
291
292    if !tools::tool_exists(tool_name) {
293        return Err(MCSError::MethodNotFound(tool_name.to_string()));
294    }
295
296    match tool_name {
297        "create_entities" => memory::handle_create_entities(kg, tool_args),
298        "create_relations" => memory::handle_create_relations(kg, tool_args),
299        "add_observations" => memory::handle_add_observations(kg, tool_args),
300        "delete_entities" => memory::handle_delete_entities(kg, tool_args),
301        "delete_observations" => memory::handle_delete_observations(kg, tool_args),
302        "delete_relations" => memory::handle_delete_relations(kg, tool_args),
303        "read_graph" => memory::handle_read_graph(kg, tool_args),
304        "search_nodes" => memory::handle_search_nodes(kg, tool_args),
305        "open_nodes" => memory::handle_open_nodes(kg, tool_args),
306        "get_entity" => memory::handle_get_entity(kg, tool_args),
307        "graph_stats" => memory::handle_graph_stats(kg),
308        "search_relations" => memory::handle_search_relations(kg, tool_args),
309        "find_path" => memory::handle_find_path(kg, tool_args),
310        "compact" => memory::handle_compact(kg),
311        "get_neighbors" => memory::handle_get_neighbors(kg, tool_args),
312        "describe_entity" => memory::handle_describe_entity(kg, tool_args),
313        "list_entity_types" => memory::handle_list_entity_types(kg),
314        "list_relation_types" => memory::handle_list_relation_types(kg),
315        "upsert_entities" => memory::handle_upsert_entities(kg, tool_args),
316        "export_graph" => memory::handle_export_graph(kg, tool_args),
317        "merge_entities" => memory::handle_merge_entities(kg, tool_args),
318        "extract_subgraph" => memory::handle_extract_subgraph(kg, tool_args),
319        "batch_get_entities" => memory::handle_batch_get_entities(kg, tool_args),
320        "find_all_paths" => memory::handle_find_all_paths(kg, tool_args),
321        tool => Err(MCSError::MethodNotFound(tool.to_string())),
322    }
323}
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328    use std::sync::atomic::{AtomicU64, Ordering};
329
330    static COUNTER: AtomicU64 = AtomicU64::new(0);
331
332    fn setup_kg() -> (Arc<RwLock<KnowledgeGraph>>, String) {
333        let pid = std::process::id();
334        let seq = COUNTER.fetch_add(1, Ordering::SeqCst);
335        let path = format!("/tmp/mcp_mem_test_{pid}_{seq}.bin");
336        let kg = KnowledgeGraph::new(Path::new(&path)).unwrap();
337        (Arc::new(RwLock::new(kg)), path)
338    }
339
340    fn cleanup(path: &str) {
341        let _ = std::fs::remove_file(path);
342    }
343
344    #[test]
345    fn test_dispatch_line_valid_request() {
346        let (kg, path) = setup_kg();
347        let line = r#"{"jsonrpc":"2.0","method":"initialize","id":1}"#;
348        let resp = dispatch_line(line, &kg).unwrap();
349        let v: Value = serde_json::from_str(&resp).unwrap();
350        assert_eq!(v["id"], 1);
351        assert_eq!(v["result"]["serverInfo"]["name"], "mcp-memory");
352        cleanup(&path);
353    }
354
355    #[test]
356    fn test_dispatch_line_invalid_json() {
357        let (kg, path) = setup_kg();
358        let resp = dispatch_line("{invalid}", &kg).unwrap();
359        let v: Value = serde_json::from_str(&resp).unwrap();
360        // Parse error per JSON-RPC: code -32700, null id.
361        assert_eq!(v["error"]["code"], -32700);
362        assert!(v["id"].is_null());
363        cleanup(&path);
364    }
365
366    #[test]
367    fn test_dispatch_line_empty() {
368        let (kg, path) = setup_kg();
369        let resp = dispatch_line("   \n", &kg).unwrap();
370        let v: Value = serde_json::from_str(&resp).unwrap();
371        assert_eq!(v["error"]["code"], -32700);
372        cleanup(&path);
373    }
374
375    #[test]
376    fn test_notification_has_no_response() {
377        let (kg, path) = setup_kg();
378        // A message with no `id` is a notification → no reply.
379        let line = r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#;
380        assert!(dispatch_line(line, &kg).is_none());
381        cleanup(&path);
382    }
383
384    #[test]
385    fn test_unknown_method_error() {
386        let (kg, path) = setup_kg();
387        let line = r#"{"jsonrpc":"2.0","method":"does/not/exist","id":7}"#;
388        let v: Value = serde_json::from_str(&dispatch_line(line, &kg).unwrap()).unwrap();
389        assert_eq!(v["id"], 7);
390        assert_eq!(v["error"]["code"], -32601); // method not found
391        cleanup(&path);
392    }
393
394    #[test]
395    fn test_tools_call_roundtrip_via_dispatch() {
396        let (kg, path) = setup_kg();
397        let create = r#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"create_entities","arguments":{"entities":[{"name":"Ada","entityType":"person","observations":["math"]}]}}}"#;
398        assert!(dispatch_line(create, &kg).is_some());
399
400        let read = r#"{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"read_graph","arguments":{}}}"#;
401        let v: Value = serde_json::from_str(&dispatch_line(read, &kg).unwrap()).unwrap();
402        let text = v["result"]["content"][0]["text"].as_str().unwrap();
403        assert!(text.contains("Ada"));
404        cleanup(&path);
405    }
406
407    #[test]
408    fn test_http_body_batch_and_notifications() {
409        let (kg, path) = setup_kg();
410        // Batch: one request + one notification → array with a single response.
411        let batch = r#"[
412            {"jsonrpc":"2.0","method":"initialize","id":1},
413            {"jsonrpc":"2.0","method":"notifications/initialized"}
414        ]"#;
415        let out = dispatch_http_body(batch, &kg).unwrap().unwrap();
416        let arr = out.as_array().unwrap();
417        assert_eq!(arr.len(), 1);
418        assert_eq!(arr[0]["id"], 1);
419
420        // Notification-only body → no response (HTTP 202).
421        let notif = r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#;
422        assert!(dispatch_http_body(notif, &kg).unwrap().is_none());
423
424        // Invalid JSON → Err.
425        assert!(dispatch_http_body("{bad", &kg).is_err());
426        cleanup(&path);
427    }
428
429    #[test]
430    fn test_handle_initialize_response() {
431        let (kg, path) = setup_kg();
432        let req = JsonRpcRequest {
433            jsonrpc: "2.0".to_string(),
434            method: "initialize".to_string(),
435            params: None,
436            id: Some(Value::Number(1.into())),
437        };
438        let result = process_request(&req, &kg).unwrap();
439        assert_eq!(result["protocolVersion"], "2024-11-05");
440        assert_eq!(result["serverInfo"]["name"], "mcp-memory");
441        cleanup(&path);
442    }
443}