Skip to main content

mcp_memory/
server.rs

1use serde_json::{Value, json};
2use std::path::Path;
3use std::sync::{Arc, Mutex, OnceLock};
4use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
5use tokio::net::TcpListener;
6use tracing::{error, info};
7
8use crate::actions::memory;
9use crate::config::Config;
10use crate::errors::{MCSError, Result};
11use crate::kg::KnowledgeGraph;
12use crate::protocol::{JsonRpcRequest, JsonRpcResponse};
13use crate::tools;
14
15const BUFFER_CAPACITY: usize = 65536;
16const NEWLINE: &[u8] = b"\n";
17/// Maximum size of a single inbound JSON-RPC message (shared by all transports).
18pub const MAX_REQUEST_BYTES: usize = 16 * 1024 * 1024;
19
20enum LineRead {
21    Line,
22    Eof,
23    TooLong,
24}
25
26async fn read_line_capped<R>(reader: &mut R, out: &mut String, max: usize) -> std::io::Result<LineRead>
27where
28    R: AsyncBufReadExt + Unpin,
29{
30    out.clear();
31    let mut buf: Vec<u8> = Vec::new();
32    loop {
33        let available = reader.fill_buf().await?;
34        if available.is_empty() {
35            if buf.is_empty() {
36                return Ok(LineRead::Eof);
37            }
38            *out = String::from_utf8(buf.clone()).map_err(|_| {
39                std::io::Error::new(std::io::ErrorKind::InvalidData, "Non-UTF-8 input")
40            })?;
41            return Ok(LineRead::Line);
42        }
43        match available.iter().position(|&b| b == b'\n') {
44            Some(i) => {
45                if buf.len() + i + 1 > max {
46                    reader.consume(i + 1);
47                    return Ok(LineRead::TooLong);
48                }
49                buf.extend_from_slice(&available[..=i]);
50                reader.consume(i + 1);
51                *out = String::from_utf8(buf.clone()).map_err(|_| {
52                    std::io::Error::new(std::io::ErrorKind::InvalidData, "Non-UTF-8 input")
53                })?;
54                return Ok(LineRead::Line);
55            }
56            None => {
57                let take = available.len();
58                if buf.len() + take > max {
59                    reader.consume(take);
60                    return Ok(LineRead::TooLong);
61                }
62                buf.extend_from_slice(available);
63                reader.consume(take);
64            }
65        }
66    }
67}
68
69fn parse_error(msg: String) -> JsonRpcResponse {
70    let mcp_error = MCSError::ParseError(msg);
71    JsonRpcResponse::error(None, mcp_error.error_code(), mcp_error.to_string())
72}
73
74// ---------------------------------------------------------------------------
75// Transport-agnostic dispatch core.
76//
77// All three transports (stdio, tcp, http) share the exact same JSON-RPC/MCP
78// semantics; they differ only in framing. `process_value` is the single source
79// of truth: it takes one parsed JSON-RPC message and returns the response
80// value, or `None` for a notification (which gets no reply anywhere).
81// ---------------------------------------------------------------------------
82
83/// Process one parsed JSON-RPC message. `None` means "no reply" — the message
84/// was a notification (no `id`), per JSON-RPC.
85pub fn process_value(value: Value, kg: &Mutex<KnowledgeGraph>) -> Option<Value> {
86    let req: JsonRpcRequest = match serde_json::from_value(value) {
87        Ok(r) => r,
88        Err(e) => return Some(to_value(parse_error(e.to_string()))),
89    };
90    // Notifications never get a reply, even on error.
91    if req.id.is_none() {
92        return None;
93    }
94    let response = match process_request(&req, kg) {
95        Ok(result) => JsonRpcResponse::success(req.id, result),
96        Err(e) => JsonRpcResponse::error(req.id, e.error_code(), e.to_string()),
97    };
98    Some(to_value(response))
99}
100
101/// Dispatch one framed line (stdio / tcp). Returns the serialized response, or
102/// `None` for a notification.
103pub fn dispatch_line(line: &str, kg: &Mutex<KnowledgeGraph>) -> Option<String> {
104    let trimmed = line.trim();
105    if trimmed.is_empty() {
106        return Some(serde_json::to_string(&parse_error("Empty request".into())).unwrap());
107    }
108    let value: Value = match serde_json::from_str(trimmed) {
109        Ok(v) => v,
110        Err(e) => return Some(serde_json::to_string(&parse_error(e.to_string())).unwrap()),
111    };
112    process_value(value, kg).map(|v| serde_json::to_string(&v).unwrap())
113}
114
115/// Dispatch a Streamable-HTTP POST body, which may be a single JSON-RPC message
116/// or a batch array. `Ok(None)` means the body held only notifications (HTTP
117/// 202, empty body); `Err` means the body was not valid JSON.
118pub fn dispatch_http_body(
119    body: &str,
120    kg: &Mutex<KnowledgeGraph>,
121) -> std::result::Result<Option<Value>, String> {
122    let value: Value = serde_json::from_str(body.trim()).map_err(|e| e.to_string())?;
123    match value {
124        Value::Array(items) => {
125            let responses: Vec<Value> =
126                items.into_iter().filter_map(|v| process_value(v, kg)).collect();
127            Ok((!responses.is_empty()).then_some(Value::Array(responses)))
128        }
129        other => Ok(process_value(other, kg)),
130    }
131}
132
133#[inline]
134fn to_value(resp: JsonRpcResponse) -> Value {
135    serde_json::to_value(resp).expect("JsonRpcResponse always serializes")
136}
137
138pub struct MCPServer {
139    _config: Arc<Config>,
140    kg: Arc<Mutex<KnowledgeGraph>>,
141}
142
143impl MCPServer {
144    pub fn new(config: Config) -> Result<Self> {
145        let path = Path::new(&config.memory_file_path);
146        let kg = KnowledgeGraph::new(path)
147            .map_err(MCSError::IoError)?;
148
149        Ok(Self {
150            _config: Arc::new(config),
151            kg: Arc::new(Mutex::new(kg)),
152        })
153    }
154
155    /// Expose the shared graph handle (used to drive the HTTP transport).
156    pub fn graph(&self) -> Arc<Mutex<KnowledgeGraph>> {
157        Arc::clone(&self.kg)
158    }
159
160    /// stdio transport: newline-delimited JSON-RPC over stdin/stdout.
161    pub async fn run_stdio(&self) -> Result<()> {
162        let stdin = tokio::io::stdin();
163        let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, stdin);
164        let mut stdout = tokio::io::stdout();
165        serve_line_conn(&mut reader, &mut stdout, &self.kg).await
166    }
167
168    /// TCP transport: each accepted connection speaks newline-delimited
169    /// JSON-RPC, exactly like stdio. Connections are served concurrently and
170    /// share the one graph behind its mutex.
171    pub async fn run_tcp(&self, addr: &str) -> Result<()> {
172        let listener = TcpListener::bind(addr).await.map_err(MCSError::IoError)?;
173        info!("Listening for TCP MCP connections on {addr}");
174        loop {
175            let (socket, peer) = listener.accept().await.map_err(MCSError::IoError)?;
176            let kg = Arc::clone(&self.kg);
177            tokio::spawn(async move {
178                let (read_half, mut write_half) = socket.into_split();
179                let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, read_half);
180                if let Err(e) = serve_line_conn(&mut reader, &mut write_half, &kg).await {
181                    error!("TCP connection {peer} error: {e}");
182                }
183            });
184        }
185    }
186
187    /// MCP Streamable HTTP transport (POST/GET `/mcp`, JSON or SSE responses).
188    pub async fn run_http(&self, addr: &str) -> Result<()> {
189        crate::http::run(addr, self.graph()).await
190    }
191}
192
193/// Drive one line-framed connection (stdio or a single TCP socket): read
194/// newline-delimited JSON-RPC requests, write newline-delimited responses.
195/// Notifications produce no output. Returns when the peer closes the stream.
196async fn serve_line_conn<R, W>(reader: &mut R, writer: &mut W, kg: &Mutex<KnowledgeGraph>) -> Result<()>
197where
198    R: AsyncBufReadExt + Unpin,
199    W: AsyncWriteExt + Unpin,
200{
201    let mut line = String::with_capacity(1024);
202    let mut out = Vec::with_capacity(BUFFER_CAPACITY);
203
204    loop {
205        match read_line_capped(reader, &mut line, MAX_REQUEST_BYTES).await {
206            Ok(LineRead::Eof) => break,
207            Ok(LineRead::Line) => {
208                if let Some(resp) = dispatch_line(&line, kg) {
209                    out.clear();
210                    out.extend_from_slice(resp.as_bytes());
211                    out.extend_from_slice(NEWLINE);
212                    writer.write_all(&out).await.map_err(MCSError::IoError)?;
213                    writer.flush().await.map_err(MCSError::IoError)?;
214                }
215            }
216            Ok(LineRead::TooLong) => {
217                let err = MCSError::InvalidParams("Request exceeds maximum size of 16MB".into());
218                let response = JsonRpcResponse::error(None, err.error_code(), err.to_string());
219                out.clear();
220                serde_json::to_writer(&mut out, &response).map_err(MCSError::JsonError)?;
221                out.extend_from_slice(NEWLINE);
222                writer.write_all(&out).await.map_err(MCSError::IoError)?;
223                writer.flush().await.map_err(MCSError::IoError)?;
224                break;
225            }
226            Err(e) => {
227                error!("IO error: {}", e);
228                break;
229            }
230        }
231    }
232    Ok(())
233}
234
235fn process_request(req: &JsonRpcRequest, kg: &Mutex<KnowledgeGraph>) -> Result<Value> {
236    match req.method.as_str() {
237        "initialize" => handle_initialize(),
238        "tools/list" => handle_tools_list(),
239        "tools/call" => handle_tools_call(req, kg),
240        "ping" => handle_ping(),
241        method if method.starts_with("notifications/") => handle_notification(method),
242        _ => Err(MCSError::MethodNotFound(req.method.clone())),
243    }
244}
245
246const fn handle_ping() -> Result<Value> {
247    Ok(Value::Null)
248}
249
250fn handle_notification(method: &str) -> Result<Value> {
251    tracing::trace!("Received notification: {method}");
252    Ok(Value::Null)
253}
254
255fn handle_initialize() -> Result<Value> {
256    Ok(json!({
257        "protocolVersion": "2024-11-05",
258        "capabilities": {
259            "tools": { "listChanged": false }
260        },
261        "serverInfo": {
262            "name": "mcp-memory",
263            "version": env!("CARGO_PKG_VERSION")
264        }
265    }))
266}
267
268fn handle_tools_list() -> Result<Value> {
269    static CACHED: OnceLock<Value> = OnceLock::new();
270    if let Some(cached) = CACHED.get() {
271        return Ok(cached.clone());
272    }
273    let tools_json = include_str!("../tools.json");
274    let tools: Vec<Value> =
275        serde_json::from_str(tools_json).map_err(MCSError::JsonError)?;
276    let result = json!({ "tools": tools });
277    let _ = CACHED.set(result.clone());
278    Ok(result)
279}
280
281fn handle_tools_call(req: &JsonRpcRequest, kg: &Mutex<KnowledgeGraph>) -> Result<Value> {
282    let tool_name = req
283        .params
284        .as_ref()
285        .and_then(|p| p.get("name").and_then(|v| v.as_str()))
286        .ok_or_else(|| MCSError::InvalidParams("Missing 'name' parameter".into()))?;
287
288    let tool_args = req.params.as_ref().and_then(|p| p.get("arguments"));
289
290    if !tools::tool_exists(tool_name) {
291        return Err(MCSError::MethodNotFound(tool_name.to_string()));
292    }
293
294    match tool_name {
295        "create_entities" => memory::handle_create_entities(kg, tool_args),
296        "create_relations" => memory::handle_create_relations(kg, tool_args),
297        "add_observations" => memory::handle_add_observations(kg, tool_args),
298        "delete_entities" => memory::handle_delete_entities(kg, tool_args),
299        "delete_observations" => memory::handle_delete_observations(kg, tool_args),
300        "delete_relations" => memory::handle_delete_relations(kg, tool_args),
301        "read_graph" => memory::handle_read_graph(kg, tool_args),
302        "search_nodes" => memory::handle_search_nodes(kg, tool_args),
303        "open_nodes" => memory::handle_open_nodes(kg, tool_args),
304        "get_entity" => memory::handle_get_entity(kg, tool_args),
305        "graph_stats" => memory::handle_graph_stats(kg),
306        "search_relations" => memory::handle_search_relations(kg, tool_args),
307        "find_path" => memory::handle_find_path(kg, tool_args),
308        "compact" => memory::handle_compact(kg),
309        "get_neighbors" => memory::handle_get_neighbors(kg, tool_args),
310        "describe_entity" => memory::handle_describe_entity(kg, tool_args),
311        "list_entity_types" => memory::handle_list_entity_types(kg),
312        "list_relation_types" => memory::handle_list_relation_types(kg),
313        "upsert_entities" => memory::handle_upsert_entities(kg, tool_args),
314        "export_graph" => memory::handle_export_graph(kg, tool_args),
315        tool => Err(MCSError::MethodNotFound(tool.to_string())),
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322    use std::sync::atomic::{AtomicU64, Ordering};
323
324    static COUNTER: AtomicU64 = AtomicU64::new(0);
325
326    fn setup_kg() -> (Arc<Mutex<KnowledgeGraph>>, String) {
327        let pid = std::process::id();
328        let seq = COUNTER.fetch_add(1, Ordering::SeqCst);
329        let path = format!("/tmp/mcp_mem_test_{pid}_{seq}.bin");
330        let kg = KnowledgeGraph::new(Path::new(&path)).unwrap();
331        (Arc::new(Mutex::new(kg)), path)
332    }
333
334    fn cleanup(path: &str) {
335        let _ = std::fs::remove_file(path);
336    }
337
338    #[test]
339    fn test_dispatch_line_valid_request() {
340        let (kg, path) = setup_kg();
341        let line = r#"{"jsonrpc":"2.0","method":"initialize","id":1}"#;
342        let resp = dispatch_line(line, &kg).unwrap();
343        let v: Value = serde_json::from_str(&resp).unwrap();
344        assert_eq!(v["id"], 1);
345        assert_eq!(v["result"]["serverInfo"]["name"], "mcp-memory");
346        cleanup(&path);
347    }
348
349    #[test]
350    fn test_dispatch_line_invalid_json() {
351        let (kg, path) = setup_kg();
352        let resp = dispatch_line("{invalid}", &kg).unwrap();
353        let v: Value = serde_json::from_str(&resp).unwrap();
354        // Parse error per JSON-RPC: code -32700, null id.
355        assert_eq!(v["error"]["code"], -32700);
356        assert!(v["id"].is_null());
357        cleanup(&path);
358    }
359
360    #[test]
361    fn test_dispatch_line_empty() {
362        let (kg, path) = setup_kg();
363        let resp = dispatch_line("   \n", &kg).unwrap();
364        let v: Value = serde_json::from_str(&resp).unwrap();
365        assert_eq!(v["error"]["code"], -32700);
366        cleanup(&path);
367    }
368
369    #[test]
370    fn test_notification_has_no_response() {
371        let (kg, path) = setup_kg();
372        // A message with no `id` is a notification → no reply.
373        let line = r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#;
374        assert!(dispatch_line(line, &kg).is_none());
375        cleanup(&path);
376    }
377
378    #[test]
379    fn test_unknown_method_error() {
380        let (kg, path) = setup_kg();
381        let line = r#"{"jsonrpc":"2.0","method":"does/not/exist","id":7}"#;
382        let v: Value = serde_json::from_str(&dispatch_line(line, &kg).unwrap()).unwrap();
383        assert_eq!(v["id"], 7);
384        assert_eq!(v["error"]["code"], -32601); // method not found
385        cleanup(&path);
386    }
387
388    #[test]
389    fn test_tools_call_roundtrip_via_dispatch() {
390        let (kg, path) = setup_kg();
391        let create = r#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"create_entities","arguments":{"entities":[{"name":"Ada","entityType":"person","observations":["math"]}]}}}"#;
392        assert!(dispatch_line(create, &kg).is_some());
393
394        let read = r#"{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"read_graph","arguments":{}}}"#;
395        let v: Value = serde_json::from_str(&dispatch_line(read, &kg).unwrap()).unwrap();
396        let text = v["result"]["content"][0]["text"].as_str().unwrap();
397        assert!(text.contains("Ada"));
398        cleanup(&path);
399    }
400
401    #[test]
402    fn test_http_body_batch_and_notifications() {
403        let (kg, path) = setup_kg();
404        // Batch: one request + one notification → array with a single response.
405        let batch = r#"[
406            {"jsonrpc":"2.0","method":"initialize","id":1},
407            {"jsonrpc":"2.0","method":"notifications/initialized"}
408        ]"#;
409        let out = dispatch_http_body(batch, &kg).unwrap().unwrap();
410        let arr = out.as_array().unwrap();
411        assert_eq!(arr.len(), 1);
412        assert_eq!(arr[0]["id"], 1);
413
414        // Notification-only body → no response (HTTP 202).
415        let notif = r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#;
416        assert!(dispatch_http_body(notif, &kg).unwrap().is_none());
417
418        // Invalid JSON → Err.
419        assert!(dispatch_http_body("{bad", &kg).is_err());
420        cleanup(&path);
421    }
422
423    #[test]
424    fn test_handle_initialize_response() {
425        let (kg, path) = setup_kg();
426        let req = JsonRpcRequest {
427            jsonrpc: "2.0".to_string(),
428            method: "initialize".to_string(),
429            params: None,
430            id: Some(Value::Number(1.into())),
431        };
432        let result = process_request(&req, &kg).unwrap();
433        assert_eq!(result["protocolVersion"], "2024-11-05");
434        assert_eq!(result["serverInfo"]["name"], "mcp-memory");
435        cleanup(&path);
436    }
437}