1use serde_json::{Value, json};
2use std::path::Path;
3use std::sync::Arc;
4
5use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
6use tokio::net::TcpListener;
7use tokio::sync::Semaphore;
8use tracing::{error, info};
9
10use crate::actions::memory;
11use crate::config::Config;
12use crate::errors::{MCSError, Result};
13use crate::kg::GraphHandle;
14use crate::protocol::{JsonRpcRequest, JsonRpcResponse};
15use crate::tools;
16
17enum HandlerResult {
21 Value(Value),
22 RawResult(String),
23}
24
25const BUFFER_CAPACITY: usize = 65536;
26const NEWLINE: &[u8] = b"\n";
27pub const MAX_REQUEST_BYTES: usize = 16 * 1024 * 1024;
29const MAX_TCP_CONNECTIONS: usize = 128;
31
32enum LineRead {
33 Line,
34 Eof,
35 TooLong,
36}
37
38async fn read_line_capped<R>(reader: &mut R, out: &mut String, max: usize) -> std::io::Result<LineRead>
39where
40 R: AsyncBufReadExt + Unpin,
41{
42 out.clear();
43 let mut buf: Vec<u8> = Vec::new();
44 loop {
45 let available = reader.fill_buf().await?;
46 if available.is_empty() {
47 if buf.is_empty() {
48 return Ok(LineRead::Eof);
49 }
50 *out = String::from_utf8(buf.clone()).map_err(|_| {
51 std::io::Error::new(std::io::ErrorKind::InvalidData, "Non-UTF-8 input")
52 })?;
53 return Ok(LineRead::Line);
54 }
55 match available.iter().position(|&b| b == b'\n') {
56 Some(i) => {
57 if buf.len() + i + 1 > max {
58 reader.consume(i + 1);
59 return Ok(LineRead::TooLong);
60 }
61 buf.extend_from_slice(&available[..=i]);
62 reader.consume(i + 1);
63 *out = String::from_utf8(buf.clone()).map_err(|_| {
64 std::io::Error::new(std::io::ErrorKind::InvalidData, "Non-UTF-8 input")
65 })?;
66 return Ok(LineRead::Line);
67 }
68 None => {
69 let take = available.len();
70 if buf.len() + take > max {
71 reader.consume(take);
72 return Ok(LineRead::TooLong);
73 }
74 buf.extend_from_slice(available);
75 reader.consume(take);
76 }
77 }
78 }
79}
80
81fn parse_error(msg: String) -> JsonRpcResponse {
82 let mcp_error = MCSError::ParseError(msg);
83 JsonRpcResponse::error(None, mcp_error.error_code(), mcp_error.to_string())
84}
85
86pub fn process_value(value: Value, kg: &GraphHandle) -> Option<Value> {
89 let req: JsonRpcRequest = match serde_json::from_value(value) {
90 Ok(r) => r,
91 Err(e) => return Some(to_value(parse_error(e.to_string()))),
92 };
93 req.id.as_ref()?;
94 let response = match process_request(&req, kg) {
95 Ok(HandlerResult::Value(result)) => Some(to_value(JsonRpcResponse::success(req.id, result))),
96 Ok(HandlerResult::RawResult(_)) => {
97 unreachable!("RawResult must be handled at the dispatch level, not via process_value");
100 }
101 Err(e) => Some(to_value(JsonRpcResponse::error(req.id, e.error_code(), e.to_string()))),
102 };
103 response
104}
105
106pub fn dispatch_line(line: &str, kg: &GraphHandle) -> Option<String> {
109 let trimmed = line.trim();
110 if trimmed.is_empty() {
111 return Some(serde_json::to_string(&parse_error("Empty request".into())).unwrap());
112 }
113 let raw: Value = match serde_json::from_str(trimmed) {
114 Ok(v) => v,
115 Err(e) => return Some(serde_json::to_string(&parse_error(e.to_string())).unwrap()),
116 };
117 let req: JsonRpcRequest = match serde_json::from_value(raw) {
118 Ok(r) => r,
119 Err(e) => return Some(serde_json::to_string(&parse_error(e.to_string())).unwrap()),
120 };
121 req.id.as_ref()?;
122 match process_request(&req, kg) {
123 Ok(HandlerResult::Value(result)) => {
124 let resp = JsonRpcResponse::success(req.id, result);
125 Some(serde_json::to_string(&resp).unwrap())
126 }
127 Ok(HandlerResult::RawResult(result_json)) => {
128 let id_json = serde_json::to_string(&req.id).unwrap();
129 let mut out = String::with_capacity(64 + id_json.len() + result_json.len());
130 out.push_str(r#"{"jsonrpc":"2.0","id":"#);
131 out.push_str(&id_json);
132 out.push_str(r#","result":"#);
133 out.push_str(&result_json);
134 out.push('}');
135 Some(out)
136 }
137 Err(e) => {
138 let resp = JsonRpcResponse::error(req.id, e.error_code(), e.to_string());
139 Some(serde_json::to_string(&resp).unwrap())
140 }
141 }
142}
143
144pub fn dispatch_http_body(
148 body: &str,
149 kg: &GraphHandle,
150) -> std::result::Result<Option<Value>, String> {
151 let value: Value = serde_json::from_str(body.trim()).map_err(|e| e.to_string())?;
152 match value {
153 Value::Array(items) => {
154 let responses: Vec<Value> =
156 items.into_iter().filter_map(|v| process_value_http(v, kg)).collect();
157 Ok((!responses.is_empty()).then_some(Value::Array(responses)))
158 }
159 other => Ok(process_value_http(other, kg)),
160 }
161}
162
163fn process_value_http(value: Value, kg: &GraphHandle) -> Option<Value> {
166 let req: JsonRpcRequest = match serde_json::from_value(value) {
167 Ok(r) => r,
168 Err(e) => return Some(to_value(parse_error(e.to_string()))),
169 };
170 req.id.as_ref()?;
171 match process_request(&req, kg) {
172 Ok(HandlerResult::Value(result)) => {
173 Some(to_value(JsonRpcResponse::success(req.id, result)))
174 }
175 Ok(HandlerResult::RawResult(result_json)) => {
176 let result_val: Value = serde_json::from_str(&result_json)
180 .unwrap_or(Value::Null);
181 Some(to_value(JsonRpcResponse::success(req.id, result_val)))
182 }
183 Err(e) => {
184 Some(to_value(JsonRpcResponse::error(req.id, e.error_code(), e.to_string())))
185 }
186 }
187}
188
189#[inline]
190fn to_value(resp: JsonRpcResponse) -> Value {
191 serde_json::to_value(resp).expect("JsonRpcResponse always serializes")
192}
193
194pub struct MCPServer {
195 _config: Arc<Config>,
196 kg: Arc<GraphHandle>,
197}
198
199impl MCPServer {
200 pub fn new(config: Config) -> Result<Self> {
201 let path = Path::new(&config.memory_file_path);
202 let kg = GraphHandle::new(path, config.durability).map_err(MCSError::IoError)?;
203
204 Ok(Self {
205 _config: Arc::new(config),
206 kg: Arc::new(kg),
207 })
208 }
209
210 pub fn graph(&self) -> Arc<GraphHandle> {
212 Arc::clone(&self.kg)
213 }
214
215 pub async fn run_stdio(&self) -> Result<()> {
217 let stdin = tokio::io::stdin();
218 let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, stdin);
219 let mut stdout = tokio::io::stdout();
220 serve_line_conn(&mut reader, &mut stdout, Arc::clone(&self.kg)).await
221 }
222
223 pub async fn run_tcp(&self, addr: &str) -> Result<()> {
227 let listener = TcpListener::bind(addr).await.map_err(MCSError::IoError)?;
228 let semaphore = Arc::new(Semaphore::new(MAX_TCP_CONNECTIONS));
229 info!("Listening for TCP MCP connections on {addr} (max {MAX_TCP_CONNECTIONS})");
230 loop {
231 let permit = Arc::clone(&semaphore).acquire_owned().await;
232 let (socket, peer) = listener.accept().await.map_err(MCSError::IoError)?;
233 let kg = Arc::clone(&self.kg);
234 tokio::spawn(async move {
235 let _permit = permit; let (read_half, mut write_half) = socket.into_split();
237 let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, read_half);
238 if let Err(e) = serve_line_conn(&mut reader, &mut write_half, kg).await {
239 error!("TCP connection {peer} error: {e}");
240 }
241 });
242 }
243 }
244
245 pub async fn run_http(&self, addr: &str) -> Result<()> {
247 crate::http::run(addr, self.graph()).await
248 }
249}
250
251async fn serve_line_conn<R, W>(reader: &mut R, writer: &mut W, kg: Arc<GraphHandle>) -> Result<()>
257where
258 R: AsyncBufReadExt + Unpin,
259 W: AsyncWriteExt + Unpin,
260{
261 let mut line = String::with_capacity(1024);
262 let mut out = Vec::with_capacity(BUFFER_CAPACITY);
263
264 loop {
265 match read_line_capped(reader, &mut line, MAX_REQUEST_BYTES).await {
266 Ok(LineRead::Eof) => break,
267 Ok(LineRead::Line) => {
268 let line_copy = line.clone();
269 let kg_clone = Arc::clone(&kg);
270 let resp = tokio::task::spawn_blocking(move || dispatch_line(&line_copy, &kg_clone))
271 .await
272 .map_err(|join_err| {
273 error!("dispatch task panicked: {join_err}");
274 MCSError::IoError(std::io::Error::other("dispatch task panicked"))
275 })?;
276 if let Some(resp) = resp {
277 out.clear();
278 out.extend_from_slice(resp.as_bytes());
279 out.extend_from_slice(NEWLINE);
280 writer.write_all(&out).await.map_err(MCSError::IoError)?;
281 writer.flush().await.map_err(MCSError::IoError)?;
282 }
283 }
284 Ok(LineRead::TooLong) => {
285 let err = MCSError::InvalidParams("Request exceeds maximum size of 16MB".into());
286 let response = JsonRpcResponse::error(None, err.error_code(), err.to_string());
287 out.clear();
288 serde_json::to_writer(&mut out, &response).map_err(MCSError::JsonError)?;
289 out.extend_from_slice(NEWLINE);
290 writer.write_all(&out).await.map_err(MCSError::IoError)?;
291 writer.flush().await.map_err(MCSError::IoError)?;
292 break;
293 }
294 Err(e) => {
295 error!("IO error: {}", e);
296 break;
297 }
298 }
299 }
300 Ok(())
301}
302
303fn process_request(req: &JsonRpcRequest, kg: &GraphHandle) -> Result<HandlerResult> {
304 match req.method.as_str() {
305 "initialize" => Ok(HandlerResult::Value(handle_initialize())),
306 "tools/list" => Ok(HandlerResult::Value(handle_tools_list())),
307 "tools/call" => handle_tools_call(req, kg),
308 "ping" => Ok(HandlerResult::Value(Value::Null)),
309 method if method.starts_with("notifications/") => {
310 tracing::trace!("Received notification: {method}");
311 Ok(HandlerResult::Value(Value::Null))
312 }
313 _ => Err(MCSError::MethodNotFound(req.method.clone())),
314 }
315}
316
317fn handle_initialize() -> Value {
318 json!({
319 "protocolVersion": "2024-11-05",
320 "capabilities": {
321 "tools": { "listChanged": false }
322 },
323 "serverInfo": {
324 "name": "mcp-memory",
325 "version": env!("CARGO_PKG_VERSION")
326 }
327 })
328}
329
330fn handle_tools_list() -> Value {
331 static CACHED: std::sync::OnceLock<Value> = std::sync::OnceLock::new();
332 if let Some(cached) = CACHED.get() {
333 return cached.clone();
334 }
335 let tools_json = include_str!("../tools.json");
336 let tools: Vec<Value> =
337 serde_json::from_str(tools_json).map_err(MCSError::JsonError).unwrap();
338 let result = json!({ "tools": tools });
339 let _ = CACHED.set(result.clone());
340 result
341}
342
343fn handle_tools_call(req: &JsonRpcRequest, kg: &GraphHandle) -> Result<HandlerResult> {
344 let tool_name = req
345 .params
346 .as_ref()
347 .and_then(|p| p.get("name").and_then(|v| v.as_str()))
348 .ok_or_else(|| MCSError::InvalidParams("Missing 'name' parameter".into()))?;
349
350 let tool_args = req.params.as_ref().and_then(|p| p.get("arguments"));
351
352 if !tools::tool_exists(tool_name) {
353 return Err(MCSError::MethodNotFound(tool_name.to_string()));
354 }
355
356 match tool_name {
357 "read_graph" => Ok(HandlerResult::RawResult(memory::handle_read_graph(kg, tool_args)?)),
359 "search_nodes" => Ok(HandlerResult::RawResult(memory::handle_search_nodes(kg, tool_args)?)),
360 "create_entities" => Ok(HandlerResult::Value(memory::handle_create_entities(kg, tool_args)?)),
362 "create_relations" => Ok(HandlerResult::Value(memory::handle_create_relations(kg, tool_args)?)),
363 "add_observations" => Ok(HandlerResult::Value(memory::handle_add_observations(kg, tool_args)?)),
364 "delete_entities" => Ok(HandlerResult::Value(memory::handle_delete_entities(kg, tool_args)?)),
365 "delete_observations" => Ok(HandlerResult::Value(memory::handle_delete_observations(kg, tool_args)?)),
366 "delete_relations" => Ok(HandlerResult::Value(memory::handle_delete_relations(kg, tool_args)?)),
367 "open_nodes" => Ok(HandlerResult::Value(memory::handle_open_nodes(kg, tool_args)?)),
368 "get_entity" => Ok(HandlerResult::Value(memory::handle_get_entity(kg, tool_args)?)),
369 "graph_stats" => Ok(HandlerResult::Value(memory::handle_graph_stats(kg)?)),
370 "search_relations" => Ok(HandlerResult::Value(memory::handle_search_relations(kg, tool_args)?)),
371 "find_path" => Ok(HandlerResult::Value(memory::handle_find_path(kg, tool_args)?)),
372 "compact" => Ok(HandlerResult::Value(memory::handle_compact(kg)?)),
373 "get_neighbors" => Ok(HandlerResult::Value(memory::handle_get_neighbors(kg, tool_args)?)),
374 "describe_entity" => Ok(HandlerResult::Value(memory::handle_describe_entity(kg, tool_args)?)),
375 "list_entity_types" => Ok(HandlerResult::Value(memory::handle_list_entity_types(kg)?)),
376 "list_relation_types" => Ok(HandlerResult::Value(memory::handle_list_relation_types(kg)?)),
377 "upsert_entities" => Ok(HandlerResult::Value(memory::handle_upsert_entities(kg, tool_args)?)),
378 "export_graph" => Ok(HandlerResult::Value(memory::handle_export_graph(kg, tool_args)?)),
379 "merge_entities" => Ok(HandlerResult::Value(memory::handle_merge_entities(kg, tool_args)?)),
380 "extract_subgraph" => Ok(HandlerResult::Value(memory::handle_extract_subgraph(kg, tool_args)?)),
381 "batch_get_entities" => Ok(HandlerResult::Value(memory::handle_batch_get_entities(kg, tool_args)?)),
382 "find_all_paths" => Ok(HandlerResult::Value(memory::handle_find_all_paths(kg, tool_args)?)),
383 tool => Err(MCSError::MethodNotFound(tool.to_string())),
384 }
385}
386
387#[cfg(test)]
388mod tests {
389 use super::*;
390 use crate::config::Durability;
391 use std::sync::atomic::{AtomicU64, Ordering};
392
393 static COUNTER: AtomicU64 = AtomicU64::new(0);
394
395 fn setup_kg() -> (Arc<GraphHandle>, String) {
396 let pid = std::process::id();
397 let seq = COUNTER.fetch_add(1, Ordering::SeqCst);
398 let path = format!("/tmp/mcp_mem_test_{pid}_{seq}.bin");
399 let kg = GraphHandle::new(Path::new(&path), Durability::Async).unwrap();
400 (Arc::new(kg), path)
401 }
402
403 fn cleanup(path: &str) {
404 let _ = std::fs::remove_file(path);
405 }
406
407 #[test]
408 fn test_dispatch_line_valid_request() {
409 let (kg, path) = setup_kg();
410 let line = r#"{"jsonrpc":"2.0","method":"initialize","id":1}"#;
411 let resp = dispatch_line(line, &kg).unwrap();
412 let v: Value = serde_json::from_str(&resp).unwrap();
413 assert_eq!(v["id"], 1);
414 assert_eq!(v["result"]["serverInfo"]["name"], "mcp-memory");
415 cleanup(&path);
416 }
417
418 #[test]
419 fn test_dispatch_line_invalid_json() {
420 let (kg, path) = setup_kg();
421 let resp = dispatch_line("{invalid}", &kg).unwrap();
422 let v: Value = serde_json::from_str(&resp).unwrap();
423 assert_eq!(v["error"]["code"], -32700);
424 assert!(v["id"].is_null());
425 cleanup(&path);
426 }
427
428 #[test]
429 fn test_dispatch_line_empty() {
430 let (kg, path) = setup_kg();
431 let resp = dispatch_line(" \n", &kg).unwrap();
432 let v: Value = serde_json::from_str(&resp).unwrap();
433 assert_eq!(v["error"]["code"], -32700);
434 cleanup(&path);
435 }
436
437 #[test]
438 fn test_notification_has_no_response() {
439 let (kg, path) = setup_kg();
440 let line = r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#;
441 assert!(dispatch_line(line, &kg).is_none());
442 cleanup(&path);
443 }
444
445 #[test]
446 fn test_unknown_method_error() {
447 let (kg, path) = setup_kg();
448 let line = r#"{"jsonrpc":"2.0","method":"does/not/exist","id":7}"#;
449 let v: Value = serde_json::from_str(&dispatch_line(line, &kg).unwrap()).unwrap();
450 assert_eq!(v["id"], 7);
451 assert_eq!(v["error"]["code"], -32601);
452 cleanup(&path);
453 }
454
455 #[test]
456 fn test_tools_call_roundtrip_via_dispatch() {
457 let (kg, path) = setup_kg();
458 let create = r#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"create_entities","arguments":{"entities":[{"name":"Ada","entityType":"person","observations":["math"]}]}}}"#;
459 assert!(dispatch_line(create, &kg).is_some());
460
461 let read = r#"{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"read_graph","arguments":{}}}"#;
462 let v: Value = serde_json::from_str(&dispatch_line(read, &kg).unwrap()).unwrap();
463 let text = v["result"]["content"][0]["text"].as_str().unwrap();
464 assert!(text.contains("Ada"));
465 cleanup(&path);
466 }
467
468 #[test]
469 fn test_http_body_batch_and_notifications() {
470 let (kg, path) = setup_kg();
471 let batch = r#"[
472 {"jsonrpc":"2.0","method":"initialize","id":1},
473 {"jsonrpc":"2.0","method":"notifications/initialized"}
474 ]"#;
475 let out = dispatch_http_body(batch, &kg).unwrap().unwrap();
476 let arr = out.as_array().unwrap();
477 assert_eq!(arr.len(), 1);
478 assert_eq!(arr[0]["id"], 1);
479
480 let notif = r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#;
481 assert!(dispatch_http_body(notif, &kg).unwrap().is_none());
482
483 assert!(dispatch_http_body("{bad", &kg).is_err());
484 cleanup(&path);
485 }
486
487 #[test]
488 fn test_handle_initialize_response() {
489 let (kg, path) = setup_kg();
490 let req = JsonRpcRequest {
491 jsonrpc: "2.0".to_string(),
492 method: "initialize".to_string(),
493 params: None,
494 id: Some(Value::Number(1.into())),
495 };
496 let result = match process_request(&req, &kg).unwrap() {
497 HandlerResult::Value(v) => v,
498 HandlerResult::RawResult(_) => panic!("expected Value"),
499 };
500 assert_eq!(result["protocolVersion"], "2024-11-05");
501 assert_eq!(result["serverInfo"]["name"], "mcp-memory");
502 cleanup(&path);
503 }
504}