1use serde_json::{Value, json};
2use std::path::Path;
3use std::sync::Arc;
4
5use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
6use tokio::net::TcpListener;
7use tracing::{error, info};
8
9use crate::actions::memory;
10use crate::config::Config;
11use crate::errors::{MCSError, Result};
12use crate::kg::GraphHandle;
13use crate::protocol::{JsonRpcRequest, JsonRpcResponse};
14use crate::tools;
15
16enum HandlerResult {
20 Value(Value),
21 RawResult(String),
22}
23
24const BUFFER_CAPACITY: usize = 65536;
25const NEWLINE: &[u8] = b"\n";
26pub const MAX_REQUEST_BYTES: usize = 16 * 1024 * 1024;
28
29enum LineRead {
30 Line,
31 Eof,
32 TooLong,
33}
34
35async fn read_line_capped<R>(reader: &mut R, out: &mut String, max: usize) -> std::io::Result<LineRead>
36where
37 R: AsyncBufReadExt + Unpin,
38{
39 out.clear();
40 let mut buf: Vec<u8> = Vec::new();
41 loop {
42 let available = reader.fill_buf().await?;
43 if available.is_empty() {
44 if buf.is_empty() {
45 return Ok(LineRead::Eof);
46 }
47 *out = String::from_utf8(buf.clone()).map_err(|_| {
48 std::io::Error::new(std::io::ErrorKind::InvalidData, "Non-UTF-8 input")
49 })?;
50 return Ok(LineRead::Line);
51 }
52 match available.iter().position(|&b| b == b'\n') {
53 Some(i) => {
54 if buf.len() + i + 1 > max {
55 reader.consume(i + 1);
56 return Ok(LineRead::TooLong);
57 }
58 buf.extend_from_slice(&available[..=i]);
59 reader.consume(i + 1);
60 *out = String::from_utf8(buf.clone()).map_err(|_| {
61 std::io::Error::new(std::io::ErrorKind::InvalidData, "Non-UTF-8 input")
62 })?;
63 return Ok(LineRead::Line);
64 }
65 None => {
66 let take = available.len();
67 if buf.len() + take > max {
68 reader.consume(take);
69 return Ok(LineRead::TooLong);
70 }
71 buf.extend_from_slice(available);
72 reader.consume(take);
73 }
74 }
75 }
76}
77
78fn parse_error(msg: String) -> JsonRpcResponse {
79 let mcp_error = MCSError::ParseError(msg);
80 JsonRpcResponse::error(None, mcp_error.error_code(), mcp_error.to_string())
81}
82
83pub fn process_value(value: Value, kg: &GraphHandle) -> Option<Value> {
86 let req: JsonRpcRequest = match serde_json::from_value(value) {
87 Ok(r) => r,
88 Err(e) => return Some(to_value(parse_error(e.to_string()))),
89 };
90 req.id.as_ref()?;
91 let response = match process_request(&req, kg) {
92 Ok(HandlerResult::Value(result)) => Some(to_value(JsonRpcResponse::success(req.id, result))),
93 Ok(HandlerResult::RawResult(_)) => {
94 unreachable!("RawResult must be handled at the dispatch level, not via process_value");
97 }
98 Err(e) => Some(to_value(JsonRpcResponse::error(req.id, e.error_code(), e.to_string()))),
99 };
100 response
101}
102
103pub fn dispatch_line(line: &str, kg: &GraphHandle) -> Option<String> {
106 let trimmed = line.trim();
107 if trimmed.is_empty() {
108 return Some(serde_json::to_string(&parse_error("Empty request".into())).unwrap());
109 }
110 let raw: Value = match serde_json::from_str(trimmed) {
111 Ok(v) => v,
112 Err(e) => return Some(serde_json::to_string(&parse_error(e.to_string())).unwrap()),
113 };
114 let req: JsonRpcRequest = match serde_json::from_value(raw) {
115 Ok(r) => r,
116 Err(e) => return Some(serde_json::to_string(&parse_error(e.to_string())).unwrap()),
117 };
118 req.id.as_ref()?;
119 match process_request(&req, kg) {
120 Ok(HandlerResult::Value(result)) => {
121 let resp = JsonRpcResponse::success(req.id, result);
122 Some(serde_json::to_string(&resp).unwrap())
123 }
124 Ok(HandlerResult::RawResult(result_json)) => {
125 let id_json = serde_json::to_string(&req.id).unwrap();
126 let mut out = String::with_capacity(64 + id_json.len() + result_json.len());
127 out.push_str(r#"{"jsonrpc":"2.0","id":"#);
128 out.push_str(&id_json);
129 out.push_str(r#","result":"#);
130 out.push_str(&result_json);
131 out.push('}');
132 Some(out)
133 }
134 Err(e) => {
135 let resp = JsonRpcResponse::error(req.id, e.error_code(), e.to_string());
136 Some(serde_json::to_string(&resp).unwrap())
137 }
138 }
139}
140
141pub fn dispatch_http_body(
145 body: &str,
146 kg: &GraphHandle,
147) -> std::result::Result<Option<Value>, String> {
148 let value: Value = serde_json::from_str(body.trim()).map_err(|e| e.to_string())?;
149 match value {
150 Value::Array(items) => {
151 let responses: Vec<Value> =
153 items.into_iter().filter_map(|v| process_value_http(v, kg)).collect();
154 Ok((!responses.is_empty()).then_some(Value::Array(responses)))
155 }
156 other => Ok(process_value_http(other, kg)),
157 }
158}
159
160fn process_value_http(value: Value, kg: &GraphHandle) -> Option<Value> {
163 let req: JsonRpcRequest = match serde_json::from_value(value) {
164 Ok(r) => r,
165 Err(e) => return Some(to_value(parse_error(e.to_string()))),
166 };
167 req.id.as_ref()?;
168 match process_request(&req, kg) {
169 Ok(HandlerResult::Value(result)) => {
170 Some(to_value(JsonRpcResponse::success(req.id, result)))
171 }
172 Ok(HandlerResult::RawResult(result_json)) => {
173 let result_val: Value = serde_json::from_str(&result_json)
177 .unwrap_or(Value::Null);
178 Some(to_value(JsonRpcResponse::success(req.id, result_val)))
179 }
180 Err(e) => {
181 Some(to_value(JsonRpcResponse::error(req.id, e.error_code(), e.to_string())))
182 }
183 }
184}
185
186#[inline]
187fn to_value(resp: JsonRpcResponse) -> Value {
188 serde_json::to_value(resp).expect("JsonRpcResponse always serializes")
189}
190
191pub struct MCPServer {
192 _config: Arc<Config>,
193 kg: Arc<GraphHandle>,
194}
195
196impl MCPServer {
197 pub fn new(config: Config) -> Result<Self> {
198 let path = Path::new(&config.memory_file_path);
199 let kg = GraphHandle::new(path).map_err(MCSError::IoError)?;
200
201 Ok(Self {
202 _config: Arc::new(config),
203 kg: Arc::new(kg),
204 })
205 }
206
207 pub fn graph(&self) -> Arc<GraphHandle> {
209 Arc::clone(&self.kg)
210 }
211
212 pub async fn run_stdio(&self) -> Result<()> {
214 let stdin = tokio::io::stdin();
215 let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, stdin);
216 let mut stdout = tokio::io::stdout();
217 serve_line_conn(&mut reader, &mut stdout, &self.kg).await
218 }
219
220 pub async fn run_tcp(&self, addr: &str) -> Result<()> {
224 let listener = TcpListener::bind(addr).await.map_err(MCSError::IoError)?;
225 info!("Listening for TCP MCP connections on {addr}");
226 loop {
227 let (socket, peer) = listener.accept().await.map_err(MCSError::IoError)?;
228 let kg = Arc::clone(&self.kg);
229 tokio::spawn(async move {
230 let (read_half, mut write_half) = socket.into_split();
231 let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, read_half);
232 if let Err(e) = serve_line_conn(&mut reader, &mut write_half, &kg).await {
233 error!("TCP connection {peer} error: {e}");
234 }
235 });
236 }
237 }
238
239 pub async fn run_http(&self, addr: &str) -> Result<()> {
241 crate::http::run(addr, self.graph()).await
242 }
243}
244
245async fn serve_line_conn<R, W>(reader: &mut R, writer: &mut W, kg: &GraphHandle) -> Result<()>
249where
250 R: AsyncBufReadExt + Unpin,
251 W: AsyncWriteExt + Unpin,
252{
253 let mut line = String::with_capacity(1024);
254 let mut out = Vec::with_capacity(BUFFER_CAPACITY);
255
256 loop {
257 match read_line_capped(reader, &mut line, MAX_REQUEST_BYTES).await {
258 Ok(LineRead::Eof) => break,
259 Ok(LineRead::Line) => {
260 if let Some(resp) = dispatch_line(&line, kg) {
261 out.clear();
262 out.extend_from_slice(resp.as_bytes());
263 out.extend_from_slice(NEWLINE);
264 writer.write_all(&out).await.map_err(MCSError::IoError)?;
265 writer.flush().await.map_err(MCSError::IoError)?;
266 }
267 }
268 Ok(LineRead::TooLong) => {
269 let err = MCSError::InvalidParams("Request exceeds maximum size of 16MB".into());
270 let response = JsonRpcResponse::error(None, err.error_code(), err.to_string());
271 out.clear();
272 serde_json::to_writer(&mut out, &response).map_err(MCSError::JsonError)?;
273 out.extend_from_slice(NEWLINE);
274 writer.write_all(&out).await.map_err(MCSError::IoError)?;
275 writer.flush().await.map_err(MCSError::IoError)?;
276 break;
277 }
278 Err(e) => {
279 error!("IO error: {}", e);
280 break;
281 }
282 }
283 }
284 Ok(())
285}
286
287fn process_request(req: &JsonRpcRequest, kg: &GraphHandle) -> Result<HandlerResult> {
288 match req.method.as_str() {
289 "initialize" => Ok(HandlerResult::Value(handle_initialize())),
290 "tools/list" => Ok(HandlerResult::Value(handle_tools_list())),
291 "tools/call" => handle_tools_call(req, kg),
292 "ping" => Ok(HandlerResult::Value(Value::Null)),
293 method if method.starts_with("notifications/") => {
294 tracing::trace!("Received notification: {method}");
295 Ok(HandlerResult::Value(Value::Null))
296 }
297 _ => Err(MCSError::MethodNotFound(req.method.clone())),
298 }
299}
300
301fn handle_initialize() -> Value {
302 json!({
303 "protocolVersion": "2024-11-05",
304 "capabilities": {
305 "tools": { "listChanged": false }
306 },
307 "serverInfo": {
308 "name": "mcp-memory",
309 "version": env!("CARGO_PKG_VERSION")
310 }
311 })
312}
313
314fn handle_tools_list() -> Value {
315 static CACHED: std::sync::OnceLock<Value> = std::sync::OnceLock::new();
316 if let Some(cached) = CACHED.get() {
317 return cached.clone();
318 }
319 let tools_json = include_str!("../tools.json");
320 let tools: Vec<Value> =
321 serde_json::from_str(tools_json).map_err(MCSError::JsonError).unwrap();
322 let result = json!({ "tools": tools });
323 let _ = CACHED.set(result.clone());
324 result
325}
326
327fn handle_tools_call(req: &JsonRpcRequest, kg: &GraphHandle) -> Result<HandlerResult> {
328 let tool_name = req
329 .params
330 .as_ref()
331 .and_then(|p| p.get("name").and_then(|v| v.as_str()))
332 .ok_or_else(|| MCSError::InvalidParams("Missing 'name' parameter".into()))?;
333
334 let tool_args = req.params.as_ref().and_then(|p| p.get("arguments"));
335
336 if !tools::tool_exists(tool_name) {
337 return Err(MCSError::MethodNotFound(tool_name.to_string()));
338 }
339
340 match tool_name {
341 "read_graph" => Ok(HandlerResult::RawResult(memory::handle_read_graph(kg, tool_args)?)),
343 "search_nodes" => Ok(HandlerResult::RawResult(memory::handle_search_nodes(kg, tool_args)?)),
344 "create_entities" => Ok(HandlerResult::Value(memory::handle_create_entities(kg, tool_args)?)),
346 "create_relations" => Ok(HandlerResult::Value(memory::handle_create_relations(kg, tool_args)?)),
347 "add_observations" => Ok(HandlerResult::Value(memory::handle_add_observations(kg, tool_args)?)),
348 "delete_entities" => Ok(HandlerResult::Value(memory::handle_delete_entities(kg, tool_args)?)),
349 "delete_observations" => Ok(HandlerResult::Value(memory::handle_delete_observations(kg, tool_args)?)),
350 "delete_relations" => Ok(HandlerResult::Value(memory::handle_delete_relations(kg, tool_args)?)),
351 "open_nodes" => Ok(HandlerResult::Value(memory::handle_open_nodes(kg, tool_args)?)),
352 "get_entity" => Ok(HandlerResult::Value(memory::handle_get_entity(kg, tool_args)?)),
353 "graph_stats" => Ok(HandlerResult::Value(memory::handle_graph_stats(kg)?)),
354 "search_relations" => Ok(HandlerResult::Value(memory::handle_search_relations(kg, tool_args)?)),
355 "find_path" => Ok(HandlerResult::Value(memory::handle_find_path(kg, tool_args)?)),
356 "compact" => Ok(HandlerResult::Value(memory::handle_compact(kg)?)),
357 "get_neighbors" => Ok(HandlerResult::Value(memory::handle_get_neighbors(kg, tool_args)?)),
358 "describe_entity" => Ok(HandlerResult::Value(memory::handle_describe_entity(kg, tool_args)?)),
359 "list_entity_types" => Ok(HandlerResult::Value(memory::handle_list_entity_types(kg)?)),
360 "list_relation_types" => Ok(HandlerResult::Value(memory::handle_list_relation_types(kg)?)),
361 "upsert_entities" => Ok(HandlerResult::Value(memory::handle_upsert_entities(kg, tool_args)?)),
362 "export_graph" => Ok(HandlerResult::Value(memory::handle_export_graph(kg, tool_args)?)),
363 "merge_entities" => Ok(HandlerResult::Value(memory::handle_merge_entities(kg, tool_args)?)),
364 "extract_subgraph" => Ok(HandlerResult::Value(memory::handle_extract_subgraph(kg, tool_args)?)),
365 "batch_get_entities" => Ok(HandlerResult::Value(memory::handle_batch_get_entities(kg, tool_args)?)),
366 "find_all_paths" => Ok(HandlerResult::Value(memory::handle_find_all_paths(kg, tool_args)?)),
367 tool => Err(MCSError::MethodNotFound(tool.to_string())),
368 }
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374 use std::sync::atomic::{AtomicU64, Ordering};
375
376 static COUNTER: AtomicU64 = AtomicU64::new(0);
377
378 fn setup_kg() -> (Arc<GraphHandle>, String) {
379 let pid = std::process::id();
380 let seq = COUNTER.fetch_add(1, Ordering::SeqCst);
381 let path = format!("/tmp/mcp_mem_test_{pid}_{seq}.bin");
382 let kg = GraphHandle::new(Path::new(&path)).unwrap();
383 (Arc::new(kg), path)
384 }
385
386 fn cleanup(path: &str) {
387 let _ = std::fs::remove_file(path);
388 }
389
390 #[test]
391 fn test_dispatch_line_valid_request() {
392 let (kg, path) = setup_kg();
393 let line = r#"{"jsonrpc":"2.0","method":"initialize","id":1}"#;
394 let resp = dispatch_line(line, &kg).unwrap();
395 let v: Value = serde_json::from_str(&resp).unwrap();
396 assert_eq!(v["id"], 1);
397 assert_eq!(v["result"]["serverInfo"]["name"], "mcp-memory");
398 cleanup(&path);
399 }
400
401 #[test]
402 fn test_dispatch_line_invalid_json() {
403 let (kg, path) = setup_kg();
404 let resp = dispatch_line("{invalid}", &kg).unwrap();
405 let v: Value = serde_json::from_str(&resp).unwrap();
406 assert_eq!(v["error"]["code"], -32700);
407 assert!(v["id"].is_null());
408 cleanup(&path);
409 }
410
411 #[test]
412 fn test_dispatch_line_empty() {
413 let (kg, path) = setup_kg();
414 let resp = dispatch_line(" \n", &kg).unwrap();
415 let v: Value = serde_json::from_str(&resp).unwrap();
416 assert_eq!(v["error"]["code"], -32700);
417 cleanup(&path);
418 }
419
420 #[test]
421 fn test_notification_has_no_response() {
422 let (kg, path) = setup_kg();
423 let line = r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#;
424 assert!(dispatch_line(line, &kg).is_none());
425 cleanup(&path);
426 }
427
428 #[test]
429 fn test_unknown_method_error() {
430 let (kg, path) = setup_kg();
431 let line = r#"{"jsonrpc":"2.0","method":"does/not/exist","id":7}"#;
432 let v: Value = serde_json::from_str(&dispatch_line(line, &kg).unwrap()).unwrap();
433 assert_eq!(v["id"], 7);
434 assert_eq!(v["error"]["code"], -32601);
435 cleanup(&path);
436 }
437
438 #[test]
439 fn test_tools_call_roundtrip_via_dispatch() {
440 let (kg, path) = setup_kg();
441 let create = r#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"create_entities","arguments":{"entities":[{"name":"Ada","entityType":"person","observations":["math"]}]}}}"#;
442 assert!(dispatch_line(create, &kg).is_some());
443
444 let read = r#"{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"read_graph","arguments":{}}}"#;
445 let v: Value = serde_json::from_str(&dispatch_line(read, &kg).unwrap()).unwrap();
446 let text = v["result"]["content"][0]["text"].as_str().unwrap();
447 assert!(text.contains("Ada"));
448 cleanup(&path);
449 }
450
451 #[test]
452 fn test_http_body_batch_and_notifications() {
453 let (kg, path) = setup_kg();
454 let batch = r#"[
455 {"jsonrpc":"2.0","method":"initialize","id":1},
456 {"jsonrpc":"2.0","method":"notifications/initialized"}
457 ]"#;
458 let out = dispatch_http_body(batch, &kg).unwrap().unwrap();
459 let arr = out.as_array().unwrap();
460 assert_eq!(arr.len(), 1);
461 assert_eq!(arr[0]["id"], 1);
462
463 let notif = r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#;
464 assert!(dispatch_http_body(notif, &kg).unwrap().is_none());
465
466 assert!(dispatch_http_body("{bad", &kg).is_err());
467 cleanup(&path);
468 }
469
470 #[test]
471 fn test_handle_initialize_response() {
472 let (kg, path) = setup_kg();
473 let req = JsonRpcRequest {
474 jsonrpc: "2.0".to_string(),
475 method: "initialize".to_string(),
476 params: None,
477 id: Some(Value::Number(1.into())),
478 };
479 let result = match process_request(&req, &kg).unwrap() {
480 HandlerResult::Value(v) => v,
481 HandlerResult::RawResult(_) => panic!("expected Value"),
482 };
483 assert_eq!(result["protocolVersion"], "2024-11-05");
484 assert_eq!(result["serverInfo"]["name"], "mcp-memory");
485 cleanup(&path);
486 }
487}