1use serde_json::{Value, json};
2use std::path::Path;
3use std::sync::{Arc, OnceLock};
4
5use parking_lot::RwLock;
6use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
7use tokio::net::TcpListener;
8use tracing::{error, info};
9
10use crate::actions::memory;
11use crate::config::Config;
12use crate::errors::{MCSError, Result};
13use crate::kg::KnowledgeGraph;
14use crate::protocol::{JsonRpcRequest, JsonRpcResponse};
15use crate::tools;
16
17const BUFFER_CAPACITY: usize = 65536;
18const NEWLINE: &[u8] = b"\n";
19pub const MAX_REQUEST_BYTES: usize = 16 * 1024 * 1024;
21
22enum LineRead {
23 Line,
24 Eof,
25 TooLong,
26}
27
28async fn read_line_capped<R>(reader: &mut R, out: &mut String, max: usize) -> std::io::Result<LineRead>
29where
30 R: AsyncBufReadExt + Unpin,
31{
32 out.clear();
33 let mut buf: Vec<u8> = Vec::new();
34 loop {
35 let available = reader.fill_buf().await?;
36 if available.is_empty() {
37 if buf.is_empty() {
38 return Ok(LineRead::Eof);
39 }
40 *out = String::from_utf8(buf.clone()).map_err(|_| {
41 std::io::Error::new(std::io::ErrorKind::InvalidData, "Non-UTF-8 input")
42 })?;
43 return Ok(LineRead::Line);
44 }
45 match available.iter().position(|&b| b == b'\n') {
46 Some(i) => {
47 if buf.len() + i + 1 > max {
48 reader.consume(i + 1);
49 return Ok(LineRead::TooLong);
50 }
51 buf.extend_from_slice(&available[..=i]);
52 reader.consume(i + 1);
53 *out = String::from_utf8(buf.clone()).map_err(|_| {
54 std::io::Error::new(std::io::ErrorKind::InvalidData, "Non-UTF-8 input")
55 })?;
56 return Ok(LineRead::Line);
57 }
58 None => {
59 let take = available.len();
60 if buf.len() + take > max {
61 reader.consume(take);
62 return Ok(LineRead::TooLong);
63 }
64 buf.extend_from_slice(available);
65 reader.consume(take);
66 }
67 }
68 }
69}
70
71fn parse_error(msg: String) -> JsonRpcResponse {
72 let mcp_error = MCSError::ParseError(msg);
73 JsonRpcResponse::error(None, mcp_error.error_code(), mcp_error.to_string())
74}
75
76pub fn process_value(value: Value, kg: &RwLock<KnowledgeGraph>) -> Option<Value> {
88 let req: JsonRpcRequest = match serde_json::from_value(value) {
89 Ok(r) => r,
90 Err(e) => return Some(to_value(parse_error(e.to_string()))),
91 };
92 if req.id.is_none() {
94 return None;
95 }
96 let response = match process_request(&req, kg) {
97 Ok(result) => JsonRpcResponse::success(req.id, result),
98 Err(e) => JsonRpcResponse::error(req.id, e.error_code(), e.to_string()),
99 };
100 Some(to_value(response))
101}
102
103pub fn dispatch_line(line: &str, kg: &RwLock<KnowledgeGraph>) -> Option<String> {
106 let trimmed = line.trim();
107 if trimmed.is_empty() {
108 return Some(serde_json::to_string(&parse_error("Empty request".into())).unwrap());
109 }
110 let value: Value = match serde_json::from_str(trimmed) {
111 Ok(v) => v,
112 Err(e) => return Some(serde_json::to_string(&parse_error(e.to_string())).unwrap()),
113 };
114 process_value(value, kg).map(|v| serde_json::to_string(&v).unwrap())
115}
116
117pub fn dispatch_http_body(
121 body: &str,
122 kg: &RwLock<KnowledgeGraph>,
123) -> std::result::Result<Option<Value>, String> {
124 let value: Value = serde_json::from_str(body.trim()).map_err(|e| e.to_string())?;
125 match value {
126 Value::Array(items) => {
127 let responses: Vec<Value> =
128 items.into_iter().filter_map(|v| process_value(v, kg)).collect();
129 Ok((!responses.is_empty()).then_some(Value::Array(responses)))
130 }
131 other => Ok(process_value(other, kg)),
132 }
133}
134
135#[inline]
136fn to_value(resp: JsonRpcResponse) -> Value {
137 serde_json::to_value(resp).expect("JsonRpcResponse always serializes")
138}
139
140pub struct MCPServer {
141 _config: Arc<Config>,
142 kg: Arc<RwLock<KnowledgeGraph>>,
143}
144
145impl MCPServer {
146 pub fn new(config: Config) -> Result<Self> {
147 let path = Path::new(&config.memory_file_path);
148 let kg = KnowledgeGraph::new(path)
149 .map_err(MCSError::IoError)?;
150
151 Ok(Self {
152 _config: Arc::new(config),
153 kg: Arc::new(RwLock::new(kg)),
154 })
155 }
156
157 pub fn graph(&self) -> Arc<RwLock<KnowledgeGraph>> {
159 Arc::clone(&self.kg)
160 }
161
162 pub async fn run_stdio(&self) -> Result<()> {
164 let stdin = tokio::io::stdin();
165 let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, stdin);
166 let mut stdout = tokio::io::stdout();
167 serve_line_conn(&mut reader, &mut stdout, &self.kg).await
168 }
169
170 pub async fn run_tcp(&self, addr: &str) -> Result<()> {
174 let listener = TcpListener::bind(addr).await.map_err(MCSError::IoError)?;
175 info!("Listening for TCP MCP connections on {addr}");
176 loop {
177 let (socket, peer) = listener.accept().await.map_err(MCSError::IoError)?;
178 let kg = Arc::clone(&self.kg);
179 tokio::spawn(async move {
180 let (read_half, mut write_half) = socket.into_split();
181 let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, read_half);
182 if let Err(e) = serve_line_conn(&mut reader, &mut write_half, &kg).await {
183 error!("TCP connection {peer} error: {e}");
184 }
185 });
186 }
187 }
188
189 pub async fn run_http(&self, addr: &str) -> Result<()> {
191 crate::http::run(addr, self.graph()).await
192 }
193}
194
195async fn serve_line_conn<R, W>(reader: &mut R, writer: &mut W, kg: &RwLock<KnowledgeGraph>) -> Result<()>
199where
200 R: AsyncBufReadExt + Unpin,
201 W: AsyncWriteExt + Unpin,
202{
203 let mut line = String::with_capacity(1024);
204 let mut out = Vec::with_capacity(BUFFER_CAPACITY);
205
206 loop {
207 match read_line_capped(reader, &mut line, MAX_REQUEST_BYTES).await {
208 Ok(LineRead::Eof) => break,
209 Ok(LineRead::Line) => {
210 if let Some(resp) = dispatch_line(&line, kg) {
211 out.clear();
212 out.extend_from_slice(resp.as_bytes());
213 out.extend_from_slice(NEWLINE);
214 writer.write_all(&out).await.map_err(MCSError::IoError)?;
215 writer.flush().await.map_err(MCSError::IoError)?;
216 }
217 }
218 Ok(LineRead::TooLong) => {
219 let err = MCSError::InvalidParams("Request exceeds maximum size of 16MB".into());
220 let response = JsonRpcResponse::error(None, err.error_code(), err.to_string());
221 out.clear();
222 serde_json::to_writer(&mut out, &response).map_err(MCSError::JsonError)?;
223 out.extend_from_slice(NEWLINE);
224 writer.write_all(&out).await.map_err(MCSError::IoError)?;
225 writer.flush().await.map_err(MCSError::IoError)?;
226 break;
227 }
228 Err(e) => {
229 error!("IO error: {}", e);
230 break;
231 }
232 }
233 }
234 Ok(())
235}
236
237fn process_request(req: &JsonRpcRequest, kg: &RwLock<KnowledgeGraph>) -> Result<Value> {
238 match req.method.as_str() {
239 "initialize" => handle_initialize(),
240 "tools/list" => handle_tools_list(),
241 "tools/call" => handle_tools_call(req, kg),
242 "ping" => handle_ping(),
243 method if method.starts_with("notifications/") => handle_notification(method),
244 _ => Err(MCSError::MethodNotFound(req.method.clone())),
245 }
246}
247
248const fn handle_ping() -> Result<Value> {
249 Ok(Value::Null)
250}
251
252fn handle_notification(method: &str) -> Result<Value> {
253 tracing::trace!("Received notification: {method}");
254 Ok(Value::Null)
255}
256
257fn handle_initialize() -> Result<Value> {
258 Ok(json!({
259 "protocolVersion": "2024-11-05",
260 "capabilities": {
261 "tools": { "listChanged": false }
262 },
263 "serverInfo": {
264 "name": "mcp-memory",
265 "version": env!("CARGO_PKG_VERSION")
266 }
267 }))
268}
269
270fn handle_tools_list() -> Result<Value> {
271 static CACHED: OnceLock<Value> = OnceLock::new();
272 if let Some(cached) = CACHED.get() {
273 return Ok(cached.clone());
274 }
275 let tools_json = include_str!("../tools.json");
276 let tools: Vec<Value> =
277 serde_json::from_str(tools_json).map_err(MCSError::JsonError)?;
278 let result = json!({ "tools": tools });
279 let _ = CACHED.set(result.clone());
280 Ok(result)
281}
282
283fn handle_tools_call(req: &JsonRpcRequest, kg: &RwLock<KnowledgeGraph>) -> Result<Value> {
284 let tool_name = req
285 .params
286 .as_ref()
287 .and_then(|p| p.get("name").and_then(|v| v.as_str()))
288 .ok_or_else(|| MCSError::InvalidParams("Missing 'name' parameter".into()))?;
289
290 let tool_args = req.params.as_ref().and_then(|p| p.get("arguments"));
291
292 if !tools::tool_exists(tool_name) {
293 return Err(MCSError::MethodNotFound(tool_name.to_string()));
294 }
295
296 match tool_name {
297 "create_entities" => memory::handle_create_entities(kg, tool_args),
298 "create_relations" => memory::handle_create_relations(kg, tool_args),
299 "add_observations" => memory::handle_add_observations(kg, tool_args),
300 "delete_entities" => memory::handle_delete_entities(kg, tool_args),
301 "delete_observations" => memory::handle_delete_observations(kg, tool_args),
302 "delete_relations" => memory::handle_delete_relations(kg, tool_args),
303 "read_graph" => memory::handle_read_graph(kg, tool_args),
304 "search_nodes" => memory::handle_search_nodes(kg, tool_args),
305 "open_nodes" => memory::handle_open_nodes(kg, tool_args),
306 "get_entity" => memory::handle_get_entity(kg, tool_args),
307 "graph_stats" => memory::handle_graph_stats(kg),
308 "search_relations" => memory::handle_search_relations(kg, tool_args),
309 "find_path" => memory::handle_find_path(kg, tool_args),
310 "compact" => memory::handle_compact(kg),
311 "get_neighbors" => memory::handle_get_neighbors(kg, tool_args),
312 "describe_entity" => memory::handle_describe_entity(kg, tool_args),
313 "list_entity_types" => memory::handle_list_entity_types(kg),
314 "list_relation_types" => memory::handle_list_relation_types(kg),
315 "upsert_entities" => memory::handle_upsert_entities(kg, tool_args),
316 "export_graph" => memory::handle_export_graph(kg, tool_args),
317 "merge_entities" => memory::handle_merge_entities(kg, tool_args),
318 "extract_subgraph" => memory::handle_extract_subgraph(kg, tool_args),
319 "batch_get_entities" => memory::handle_batch_get_entities(kg, tool_args),
320 "find_all_paths" => memory::handle_find_all_paths(kg, tool_args),
321 tool => Err(MCSError::MethodNotFound(tool.to_string())),
322 }
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328 use std::sync::atomic::{AtomicU64, Ordering};
329
330 static COUNTER: AtomicU64 = AtomicU64::new(0);
331
332 fn setup_kg() -> (Arc<RwLock<KnowledgeGraph>>, String) {
333 let pid = std::process::id();
334 let seq = COUNTER.fetch_add(1, Ordering::SeqCst);
335 let path = format!("/tmp/mcp_mem_test_{pid}_{seq}.bin");
336 let kg = KnowledgeGraph::new(Path::new(&path)).unwrap();
337 (Arc::new(RwLock::new(kg)), path)
338 }
339
340 fn cleanup(path: &str) {
341 let _ = std::fs::remove_file(path);
342 }
343
344 #[test]
345 fn test_dispatch_line_valid_request() {
346 let (kg, path) = setup_kg();
347 let line = r#"{"jsonrpc":"2.0","method":"initialize","id":1}"#;
348 let resp = dispatch_line(line, &kg).unwrap();
349 let v: Value = serde_json::from_str(&resp).unwrap();
350 assert_eq!(v["id"], 1);
351 assert_eq!(v["result"]["serverInfo"]["name"], "mcp-memory");
352 cleanup(&path);
353 }
354
355 #[test]
356 fn test_dispatch_line_invalid_json() {
357 let (kg, path) = setup_kg();
358 let resp = dispatch_line("{invalid}", &kg).unwrap();
359 let v: Value = serde_json::from_str(&resp).unwrap();
360 assert_eq!(v["error"]["code"], -32700);
362 assert!(v["id"].is_null());
363 cleanup(&path);
364 }
365
366 #[test]
367 fn test_dispatch_line_empty() {
368 let (kg, path) = setup_kg();
369 let resp = dispatch_line(" \n", &kg).unwrap();
370 let v: Value = serde_json::from_str(&resp).unwrap();
371 assert_eq!(v["error"]["code"], -32700);
372 cleanup(&path);
373 }
374
375 #[test]
376 fn test_notification_has_no_response() {
377 let (kg, path) = setup_kg();
378 let line = r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#;
380 assert!(dispatch_line(line, &kg).is_none());
381 cleanup(&path);
382 }
383
384 #[test]
385 fn test_unknown_method_error() {
386 let (kg, path) = setup_kg();
387 let line = r#"{"jsonrpc":"2.0","method":"does/not/exist","id":7}"#;
388 let v: Value = serde_json::from_str(&dispatch_line(line, &kg).unwrap()).unwrap();
389 assert_eq!(v["id"], 7);
390 assert_eq!(v["error"]["code"], -32601); cleanup(&path);
392 }
393
394 #[test]
395 fn test_tools_call_roundtrip_via_dispatch() {
396 let (kg, path) = setup_kg();
397 let create = r#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"create_entities","arguments":{"entities":[{"name":"Ada","entityType":"person","observations":["math"]}]}}}"#;
398 assert!(dispatch_line(create, &kg).is_some());
399
400 let read = r#"{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"read_graph","arguments":{}}}"#;
401 let v: Value = serde_json::from_str(&dispatch_line(read, &kg).unwrap()).unwrap();
402 let text = v["result"]["content"][0]["text"].as_str().unwrap();
403 assert!(text.contains("Ada"));
404 cleanup(&path);
405 }
406
407 #[test]
408 fn test_http_body_batch_and_notifications() {
409 let (kg, path) = setup_kg();
410 let batch = r#"[
412 {"jsonrpc":"2.0","method":"initialize","id":1},
413 {"jsonrpc":"2.0","method":"notifications/initialized"}
414 ]"#;
415 let out = dispatch_http_body(batch, &kg).unwrap().unwrap();
416 let arr = out.as_array().unwrap();
417 assert_eq!(arr.len(), 1);
418 assert_eq!(arr[0]["id"], 1);
419
420 let notif = r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#;
422 assert!(dispatch_http_body(notif, &kg).unwrap().is_none());
423
424 assert!(dispatch_http_body("{bad", &kg).is_err());
426 cleanup(&path);
427 }
428
429 #[test]
430 fn test_handle_initialize_response() {
431 let (kg, path) = setup_kg();
432 let req = JsonRpcRequest {
433 jsonrpc: "2.0".to_string(),
434 method: "initialize".to_string(),
435 params: None,
436 id: Some(Value::Number(1.into())),
437 };
438 let result = process_request(&req, &kg).unwrap();
439 assert_eq!(result["protocolVersion"], "2024-11-05");
440 assert_eq!(result["serverInfo"]["name"], "mcp-memory");
441 cleanup(&path);
442 }
443}