Skip to main content

mcp_memory/
server.rs

1use serde_json::{Value, json};
2use std::path::Path;
3use std::sync::Arc;
4
5use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
6use tokio::net::TcpListener;
7use tokio::sync::Semaphore;
8use tracing::{error, info};
9
10use crate::actions::memory;
11use crate::config::Config;
12use crate::errors::{MCSError, Result};
13use crate::kg::GraphHandle;
14use crate::protocol::{JsonRpcRequest, JsonRpcResponse};
15use crate::tools;
16
17/// Outcome of processing a request: either a pre-escaped JSON Value (small
18/// payloads) or a pre-serialized JSON *string* of the `result` field (avoids
19/// a second serialization pass for large payloads such as `read_graph`).
20enum HandlerResult {
21    Value(Value),
22    RawResult(String),
23}
24
25const BUFFER_CAPACITY: usize = 65536;
26const NEWLINE: &[u8] = b"\n";
27/// Maximum size of a single inbound JSON-RPC message (shared by all transports).
28pub const MAX_REQUEST_BYTES: usize = 16 * 1024 * 1024;
29/// Maximum number of concurrent TCP connections (C4).
30const MAX_TCP_CONNECTIONS: usize = 128;
31
32enum LineRead {
33    Line,
34    Eof,
35    TooLong,
36}
37
38async fn read_line_capped<R>(reader: &mut R, out: &mut String, max: usize) -> std::io::Result<LineRead>
39where
40    R: AsyncBufReadExt + Unpin,
41{
42    out.clear();
43    let mut buf: Vec<u8> = Vec::new();
44    loop {
45        let available = reader.fill_buf().await?;
46        if available.is_empty() {
47            if buf.is_empty() {
48                return Ok(LineRead::Eof);
49            }
50            *out = String::from_utf8(buf.clone()).map_err(|_| {
51                std::io::Error::new(std::io::ErrorKind::InvalidData, "Non-UTF-8 input")
52            })?;
53            return Ok(LineRead::Line);
54        }
55        match available.iter().position(|&b| b == b'\n') {
56            Some(i) => {
57                if buf.len() + i + 1 > max {
58                    reader.consume(i + 1);
59                    return Ok(LineRead::TooLong);
60                }
61                buf.extend_from_slice(&available[..=i]);
62                reader.consume(i + 1);
63                *out = String::from_utf8(buf.clone()).map_err(|_| {
64                    std::io::Error::new(std::io::ErrorKind::InvalidData, "Non-UTF-8 input")
65                })?;
66                return Ok(LineRead::Line);
67            }
68            None => {
69                let take = available.len();
70                if buf.len() + take > max {
71                    reader.consume(take);
72                    return Ok(LineRead::TooLong);
73                }
74                buf.extend_from_slice(available);
75                reader.consume(take);
76            }
77        }
78    }
79}
80
81fn parse_error(msg: String) -> JsonRpcResponse {
82    let mcp_error = MCSError::ParseError(msg);
83    JsonRpcResponse::error(None, mcp_error.error_code(), mcp_error.to_string())
84}
85
86/// Process one parsed JSON-RPC message. `None` means "no reply" — the message
87/// was a notification (no `id`), per JSON-RPC.
88pub fn process_value(value: Value, kg: &GraphHandle) -> Option<Value> {
89    let req: JsonRpcRequest = match serde_json::from_value(value) {
90        Ok(r) => r,
91        Err(e) => return Some(to_value(parse_error(e.to_string()))),
92    };
93    req.id.as_ref()?;
94    let response = match process_request(&req, kg) {
95        Ok(HandlerResult::Value(result)) => Some(to_value(JsonRpcResponse::success(req.id, result))),
96        Ok(HandlerResult::RawResult(_)) => {
97            // RawResult cannot pass through Value — dispatch_line and
98            // dispatch_http_body handle it via separate code paths.
99            unreachable!("RawResult must be handled at the dispatch level, not via process_value");
100        }
101        Err(e) => Some(to_value(JsonRpcResponse::error(req.id, e.error_code(), e.to_string()))),
102    };
103    response
104}
105
106/// Dispatch one framed line (stdio / tcp). Returns the serialized response, or
107/// `None` for a notification.
108pub fn dispatch_line(line: &str, kg: &GraphHandle) -> Option<String> {
109    let trimmed = line.trim();
110    if trimmed.is_empty() {
111        return Some(serde_json::to_string(&parse_error("Empty request".into())).unwrap());
112    }
113    let raw: Value = match serde_json::from_str(trimmed) {
114        Ok(v) => v,
115        Err(e) => return Some(serde_json::to_string(&parse_error(e.to_string())).unwrap()),
116    };
117    let req: JsonRpcRequest = match serde_json::from_value(raw) {
118        Ok(r) => r,
119        Err(e) => return Some(serde_json::to_string(&parse_error(e.to_string())).unwrap()),
120    };
121    req.id.as_ref()?;
122    match process_request(&req, kg) {
123        Ok(HandlerResult::Value(result)) => {
124            let resp = JsonRpcResponse::success(req.id, result);
125            Some(serde_json::to_string(&resp).unwrap())
126        }
127        Ok(HandlerResult::RawResult(result_json)) => {
128            let id_json = serde_json::to_string(&req.id).unwrap();
129            let mut out = String::with_capacity(64 + id_json.len() + result_json.len());
130            out.push_str(r#"{"jsonrpc":"2.0","id":"#);
131            out.push_str(&id_json);
132            out.push_str(r#","result":"#);
133            out.push_str(&result_json);
134            out.push('}');
135            Some(out)
136        }
137        Err(e) => {
138            let resp = JsonRpcResponse::error(req.id, e.error_code(), e.to_string());
139            Some(serde_json::to_string(&resp).unwrap())
140        }
141    }
142}
143
144/// Dispatch a Streamable-HTTP POST body, which may be a single JSON-RPC message
145/// or a batch array. `Ok(None)` means the body held only notifications (HTTP
146/// 202, empty body); `Err` means the body was not valid JSON.
147pub fn dispatch_http_body(
148    body: &str,
149    kg: &GraphHandle,
150) -> std::result::Result<Option<Value>, String> {
151    let value: Value = serde_json::from_str(body.trim()).map_err(|e| e.to_string())?;
152    match value {
153        Value::Array(items) => {
154            // Batches are rare and never huge — keep Value path for simplicity.
155            let responses: Vec<Value> =
156                items.into_iter().filter_map(|v| process_value_http(v, kg)).collect();
157            Ok((!responses.is_empty()).then_some(Value::Array(responses)))
158        }
159        other => Ok(process_value_http(other, kg)),
160    }
161}
162
163/// HTTP variant of process_value that handles RawResult by converting to Value
164/// (acceptable since HTTP payloads are typically much smaller in this context).
165fn process_value_http(value: Value, kg: &GraphHandle) -> Option<Value> {
166    let req: JsonRpcRequest = match serde_json::from_value(value) {
167        Ok(r) => r,
168        Err(e) => return Some(to_value(parse_error(e.to_string()))),
169    };
170    req.id.as_ref()?;
171    match process_request(&req, kg) {
172        Ok(HandlerResult::Value(result)) => {
173            Some(to_value(JsonRpcResponse::success(req.id, result)))
174        }
175        Ok(HandlerResult::RawResult(result_json)) => {
176            // Parse the pre-serialized result back into a Value for HTTP delivery.
177            // This is a small extra cost for the HTTP transport; the stdio/TCP
178            // path (dispatch_line) avoids it entirely.
179            let result_val: Value = serde_json::from_str(&result_json)
180                .unwrap_or(Value::Null);
181            Some(to_value(JsonRpcResponse::success(req.id, result_val)))
182        }
183        Err(e) => {
184            Some(to_value(JsonRpcResponse::error(req.id, e.error_code(), e.to_string())))
185        }
186    }
187}
188
189#[inline]
190fn to_value(resp: JsonRpcResponse) -> Value {
191    serde_json::to_value(resp).expect("JsonRpcResponse always serializes")
192}
193
194pub struct MCPServer {
195    _config: Arc<Config>,
196    kg: Arc<GraphHandle>,
197}
198
199impl MCPServer {
200    pub fn new(config: Config) -> Result<Self> {
201        let path = Path::new(&config.memory_file_path);
202        let kg = GraphHandle::new(path, config.durability).map_err(MCSError::IoError)?;
203
204        Ok(Self {
205            _config: Arc::new(config),
206            kg: Arc::new(kg),
207        })
208    }
209
210    /// Expose the shared graph handle (used to drive the HTTP transport).
211    pub fn graph(&self) -> Arc<GraphHandle> {
212        Arc::clone(&self.kg)
213    }
214
215    /// stdio transport: newline-delimited JSON-RPC over stdin/stdout.
216    pub async fn run_stdio(&self) -> Result<()> {
217        let stdin = tokio::io::stdin();
218        let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, stdin);
219        let mut stdout = tokio::io::stdout();
220        serve_line_conn(&mut reader, &mut stdout, Arc::clone(&self.kg)).await
221    }
222
223    /// TCP transport: each accepted connection speaks newline-delimited
224    /// JSON-RPC, exactly like stdio. Connections are served concurrently (up to
225    /// [`MAX_TCP_CONNECTIONS`]) and share the one graph behind its mutex.
226    pub async fn run_tcp(&self, addr: &str) -> Result<()> {
227        let listener = TcpListener::bind(addr).await.map_err(MCSError::IoError)?;
228        let semaphore = Arc::new(Semaphore::new(MAX_TCP_CONNECTIONS));
229        info!("Listening for TCP MCP connections on {addr} (max {MAX_TCP_CONNECTIONS})");
230        loop {
231            let permit = Arc::clone(&semaphore).acquire_owned().await;
232            let (socket, peer) = listener.accept().await.map_err(MCSError::IoError)?;
233            let kg = Arc::clone(&self.kg);
234            tokio::spawn(async move {
235                let _permit = permit; // held for the connection lifetime
236                let (read_half, mut write_half) = socket.into_split();
237                let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, read_half);
238                if let Err(e) = serve_line_conn(&mut reader, &mut write_half, kg).await {
239                    error!("TCP connection {peer} error: {e}");
240                }
241            });
242        }
243    }
244
245    /// MCP Streamable HTTP transport (POST/GET `/mcp`, JSON or SSE responses).
246    pub async fn run_http(&self, addr: &str) -> Result<()> {
247        crate::http::run(addr, self.graph()).await
248    }
249}
250
251/// Drive one line-framed connection (stdio or a single TCP socket): read
252/// newline-delimited JSON-RPC requests, write newline-delimited responses.
253/// Notifications produce no output. Returns when the peer closes the stream.
254/// The dispatch path (graph lock + optional fsync) is offloaded to
255/// [`tokio::task::spawn_blocking`] to keep the async reactor responsive (C3).
256async fn serve_line_conn<R, W>(reader: &mut R, writer: &mut W, kg: Arc<GraphHandle>) -> Result<()>
257where
258    R: AsyncBufReadExt + Unpin,
259    W: AsyncWriteExt + Unpin,
260{
261    let mut line = String::with_capacity(1024);
262    let mut out = Vec::with_capacity(BUFFER_CAPACITY);
263
264    loop {
265        match read_line_capped(reader, &mut line, MAX_REQUEST_BYTES).await {
266            Ok(LineRead::Eof) => break,
267            Ok(LineRead::Line) => {
268                let line_copy = line.clone();
269                let kg_clone = Arc::clone(&kg);
270                let resp = tokio::task::spawn_blocking(move || dispatch_line(&line_copy, &kg_clone))
271                    .await
272                    .map_err(|join_err| {
273                        error!("dispatch task panicked: {join_err}");
274                        MCSError::IoError(std::io::Error::other("dispatch task panicked"))
275                    })?;
276                if let Some(resp) = resp {
277                    out.clear();
278                    out.extend_from_slice(resp.as_bytes());
279                    out.extend_from_slice(NEWLINE);
280                    writer.write_all(&out).await.map_err(MCSError::IoError)?;
281                    writer.flush().await.map_err(MCSError::IoError)?;
282                }
283            }
284            Ok(LineRead::TooLong) => {
285                let err = MCSError::InvalidParams("Request exceeds maximum size of 16MB".into());
286                let response = JsonRpcResponse::error(None, err.error_code(), err.to_string());
287                out.clear();
288                serde_json::to_writer(&mut out, &response).map_err(MCSError::JsonError)?;
289                out.extend_from_slice(NEWLINE);
290                writer.write_all(&out).await.map_err(MCSError::IoError)?;
291                writer.flush().await.map_err(MCSError::IoError)?;
292                break;
293            }
294            Err(e) => {
295                error!("IO error: {}", e);
296                break;
297            }
298        }
299    }
300    Ok(())
301}
302
303fn process_request(req: &JsonRpcRequest, kg: &GraphHandle) -> Result<HandlerResult> {
304    match req.method.as_str() {
305        "initialize" => Ok(HandlerResult::Value(handle_initialize())),
306        "tools/list" => Ok(HandlerResult::Value(handle_tools_list())),
307        "tools/call" => handle_tools_call(req, kg),
308        "ping" => Ok(HandlerResult::Value(Value::Null)),
309        method if method.starts_with("notifications/") => {
310            tracing::trace!("Received notification: {method}");
311            Ok(HandlerResult::Value(Value::Null))
312        }
313        _ => Err(MCSError::MethodNotFound(req.method.clone())),
314    }
315}
316
317fn handle_initialize() -> Value {
318    json!({
319        "protocolVersion": "2024-11-05",
320        "capabilities": {
321            "tools": { "listChanged": false }
322        },
323        "serverInfo": {
324            "name": "mcp-memory",
325            "version": env!("CARGO_PKG_VERSION")
326        }
327    })
328}
329
330fn handle_tools_list() -> Value {
331    static CACHED: std::sync::OnceLock<Value> = std::sync::OnceLock::new();
332    if let Some(cached) = CACHED.get() {
333        return cached.clone();
334    }
335    let tools_json = include_str!("../tools.json");
336    let tools: Vec<Value> =
337        serde_json::from_str(tools_json).map_err(MCSError::JsonError).unwrap();
338    let result = json!({ "tools": tools });
339    let _ = CACHED.set(result.clone());
340    result
341}
342
343fn handle_tools_call(req: &JsonRpcRequest, kg: &GraphHandle) -> Result<HandlerResult> {
344    let tool_name = req
345        .params
346        .as_ref()
347        .and_then(|p| p.get("name").and_then(|v| v.as_str()))
348        .ok_or_else(|| MCSError::InvalidParams("Missing 'name' parameter".into()))?;
349
350    let tool_args = req.params.as_ref().and_then(|p| p.get("arguments"));
351
352    if !tools::tool_exists(tool_name) {
353        return Err(MCSError::MethodNotFound(tool_name.to_string()));
354    }
355
356    match tool_name {
357        // Raw-result handlers (large payloads, avoid second serialization pass).
358        "read_graph" => Ok(HandlerResult::RawResult(memory::handle_read_graph(kg, tool_args)?)),
359        "search_nodes" => Ok(HandlerResult::RawResult(memory::handle_search_nodes(kg, tool_args)?)),
360        // Standard Value handlers.
361        "create_entities" => Ok(HandlerResult::Value(memory::handle_create_entities(kg, tool_args)?)),
362        "create_relations" => Ok(HandlerResult::Value(memory::handle_create_relations(kg, tool_args)?)),
363        "add_observations" => Ok(HandlerResult::Value(memory::handle_add_observations(kg, tool_args)?)),
364        "delete_entities" => Ok(HandlerResult::Value(memory::handle_delete_entities(kg, tool_args)?)),
365        "delete_observations" => Ok(HandlerResult::Value(memory::handle_delete_observations(kg, tool_args)?)),
366        "delete_relations" => Ok(HandlerResult::Value(memory::handle_delete_relations(kg, tool_args)?)),
367        "open_nodes" => Ok(HandlerResult::Value(memory::handle_open_nodes(kg, tool_args)?)),
368        "get_entity" => Ok(HandlerResult::Value(memory::handle_get_entity(kg, tool_args)?)),
369        "graph_stats" => Ok(HandlerResult::Value(memory::handle_graph_stats(kg)?)),
370        "search_relations" => Ok(HandlerResult::Value(memory::handle_search_relations(kg, tool_args)?)),
371        "find_path" => Ok(HandlerResult::Value(memory::handle_find_path(kg, tool_args)?)),
372        "compact" => Ok(HandlerResult::Value(memory::handle_compact(kg)?)),
373        "get_neighbors" => Ok(HandlerResult::Value(memory::handle_get_neighbors(kg, tool_args)?)),
374        "describe_entity" => Ok(HandlerResult::Value(memory::handle_describe_entity(kg, tool_args)?)),
375        "list_entity_types" => Ok(HandlerResult::Value(memory::handle_list_entity_types(kg)?)),
376        "list_relation_types" => Ok(HandlerResult::Value(memory::handle_list_relation_types(kg)?)),
377        "upsert_entities" => Ok(HandlerResult::Value(memory::handle_upsert_entities(kg, tool_args)?)),
378        "export_graph" => Ok(HandlerResult::Value(memory::handle_export_graph(kg, tool_args)?)),
379        "merge_entities" => Ok(HandlerResult::Value(memory::handle_merge_entities(kg, tool_args)?)),
380        "extract_subgraph" => Ok(HandlerResult::Value(memory::handle_extract_subgraph(kg, tool_args)?)),
381        "batch_get_entities" => Ok(HandlerResult::Value(memory::handle_batch_get_entities(kg, tool_args)?)),
382        "find_all_paths" => Ok(HandlerResult::Value(memory::handle_find_all_paths(kg, tool_args)?)),
383        tool => Err(MCSError::MethodNotFound(tool.to_string())),
384    }
385}
386
387#[cfg(test)]
388mod tests {
389    use super::*;
390    use crate::config::Durability;
391    use std::sync::atomic::{AtomicU64, Ordering};
392
393    static COUNTER: AtomicU64 = AtomicU64::new(0);
394
395    fn setup_kg() -> (Arc<GraphHandle>, String) {
396        let pid = std::process::id();
397        let seq = COUNTER.fetch_add(1, Ordering::SeqCst);
398        let path = format!("/tmp/mcp_mem_test_{pid}_{seq}.bin");
399        let kg = GraphHandle::new(Path::new(&path), Durability::Async).unwrap();
400        (Arc::new(kg), path)
401    }
402
403    fn cleanup(path: &str) {
404        let _ = std::fs::remove_file(path);
405    }
406
407    #[test]
408    fn test_dispatch_line_valid_request() {
409        let (kg, path) = setup_kg();
410        let line = r#"{"jsonrpc":"2.0","method":"initialize","id":1}"#;
411        let resp = dispatch_line(line, &kg).unwrap();
412        let v: Value = serde_json::from_str(&resp).unwrap();
413        assert_eq!(v["id"], 1);
414        assert_eq!(v["result"]["serverInfo"]["name"], "mcp-memory");
415        cleanup(&path);
416    }
417
418    #[test]
419    fn test_dispatch_line_invalid_json() {
420        let (kg, path) = setup_kg();
421        let resp = dispatch_line("{invalid}", &kg).unwrap();
422        let v: Value = serde_json::from_str(&resp).unwrap();
423        assert_eq!(v["error"]["code"], -32700);
424        assert!(v["id"].is_null());
425        cleanup(&path);
426    }
427
428    #[test]
429    fn test_dispatch_line_empty() {
430        let (kg, path) = setup_kg();
431        let resp = dispatch_line("   \n", &kg).unwrap();
432        let v: Value = serde_json::from_str(&resp).unwrap();
433        assert_eq!(v["error"]["code"], -32700);
434        cleanup(&path);
435    }
436
437    #[test]
438    fn test_notification_has_no_response() {
439        let (kg, path) = setup_kg();
440        let line = r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#;
441        assert!(dispatch_line(line, &kg).is_none());
442        cleanup(&path);
443    }
444
445    #[test]
446    fn test_unknown_method_error() {
447        let (kg, path) = setup_kg();
448        let line = r#"{"jsonrpc":"2.0","method":"does/not/exist","id":7}"#;
449        let v: Value = serde_json::from_str(&dispatch_line(line, &kg).unwrap()).unwrap();
450        assert_eq!(v["id"], 7);
451        assert_eq!(v["error"]["code"], -32601);
452        cleanup(&path);
453    }
454
455    #[test]
456    fn test_tools_call_roundtrip_via_dispatch() {
457        let (kg, path) = setup_kg();
458        let create = r#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"create_entities","arguments":{"entities":[{"name":"Ada","entityType":"person","observations":["math"]}]}}}"#;
459        assert!(dispatch_line(create, &kg).is_some());
460
461        let read = r#"{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"read_graph","arguments":{}}}"#;
462        let v: Value = serde_json::from_str(&dispatch_line(read, &kg).unwrap()).unwrap();
463        let text = v["result"]["content"][0]["text"].as_str().unwrap();
464        assert!(text.contains("Ada"));
465        cleanup(&path);
466    }
467
468    #[test]
469    fn test_http_body_batch_and_notifications() {
470        let (kg, path) = setup_kg();
471        let batch = r#"[
472            {"jsonrpc":"2.0","method":"initialize","id":1},
473            {"jsonrpc":"2.0","method":"notifications/initialized"}
474        ]"#;
475        let out = dispatch_http_body(batch, &kg).unwrap().unwrap();
476        let arr = out.as_array().unwrap();
477        assert_eq!(arr.len(), 1);
478        assert_eq!(arr[0]["id"], 1);
479
480        let notif = r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#;
481        assert!(dispatch_http_body(notif, &kg).unwrap().is_none());
482
483        assert!(dispatch_http_body("{bad", &kg).is_err());
484        cleanup(&path);
485    }
486
487    #[test]
488    fn test_handle_initialize_response() {
489        let (kg, path) = setup_kg();
490        let req = JsonRpcRequest {
491            jsonrpc: "2.0".to_string(),
492            method: "initialize".to_string(),
493            params: None,
494            id: Some(Value::Number(1.into())),
495        };
496        let result = match process_request(&req, &kg).unwrap() {
497            HandlerResult::Value(v) => v,
498            HandlerResult::RawResult(_) => panic!("expected Value"),
499        };
500        assert_eq!(result["protocolVersion"], "2024-11-05");
501        assert_eq!(result["serverInfo"]["name"], "mcp-memory");
502        cleanup(&path);
503    }
504}