1use serde_json::{Value, json};
2use std::path::Path;
3use std::sync::{Arc, OnceLock};
4
5use parking_lot::RwLock;
6use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
7use tokio::net::TcpListener;
8use tracing::{error, info};
9
10use crate::actions::memory;
11use crate::config::Config;
12use crate::errors::{MCSError, Result};
13use crate::kg::KnowledgeGraph;
14use crate::protocol::{JsonRpcRequest, JsonRpcResponse};
15use crate::tools;
16
17const BUFFER_CAPACITY: usize = 65536;
18const NEWLINE: &[u8] = b"\n";
19pub const MAX_REQUEST_BYTES: usize = 16 * 1024 * 1024;
21
22enum LineRead {
23 Line,
24 Eof,
25 TooLong,
26}
27
28async fn read_line_capped<R>(reader: &mut R, out: &mut String, max: usize) -> std::io::Result<LineRead>
29where
30 R: AsyncBufReadExt + Unpin,
31{
32 out.clear();
33 let mut buf: Vec<u8> = Vec::new();
34 loop {
35 let available = reader.fill_buf().await?;
36 if available.is_empty() {
37 if buf.is_empty() {
38 return Ok(LineRead::Eof);
39 }
40 *out = String::from_utf8(buf.clone()).map_err(|_| {
41 std::io::Error::new(std::io::ErrorKind::InvalidData, "Non-UTF-8 input")
42 })?;
43 return Ok(LineRead::Line);
44 }
45 match available.iter().position(|&b| b == b'\n') {
46 Some(i) => {
47 if buf.len() + i + 1 > max {
48 reader.consume(i + 1);
49 return Ok(LineRead::TooLong);
50 }
51 buf.extend_from_slice(&available[..=i]);
52 reader.consume(i + 1);
53 *out = String::from_utf8(buf.clone()).map_err(|_| {
54 std::io::Error::new(std::io::ErrorKind::InvalidData, "Non-UTF-8 input")
55 })?;
56 return Ok(LineRead::Line);
57 }
58 None => {
59 let take = available.len();
60 if buf.len() + take > max {
61 reader.consume(take);
62 return Ok(LineRead::TooLong);
63 }
64 buf.extend_from_slice(available);
65 reader.consume(take);
66 }
67 }
68 }
69}
70
71fn parse_error(msg: String) -> JsonRpcResponse {
72 let mcp_error = MCSError::ParseError(msg);
73 JsonRpcResponse::error(None, mcp_error.error_code(), mcp_error.to_string())
74}
75
76pub fn process_value(value: Value, kg: &RwLock<KnowledgeGraph>) -> Option<Value> {
88 let req: JsonRpcRequest = match serde_json::from_value(value) {
89 Ok(r) => r,
90 Err(e) => return Some(to_value(parse_error(e.to_string()))),
91 };
92 req.id.as_ref()?;
94 let response = match process_request(&req, kg) {
95 Ok(result) => JsonRpcResponse::success(req.id, result),
96 Err(e) => JsonRpcResponse::error(req.id, e.error_code(), e.to_string()),
97 };
98 Some(to_value(response))
99}
100
101pub fn dispatch_line(line: &str, kg: &RwLock<KnowledgeGraph>) -> Option<String> {
104 let trimmed = line.trim();
105 if trimmed.is_empty() {
106 return Some(serde_json::to_string(&parse_error("Empty request".into())).unwrap());
107 }
108 let value: Value = match serde_json::from_str(trimmed) {
109 Ok(v) => v,
110 Err(e) => return Some(serde_json::to_string(&parse_error(e.to_string())).unwrap()),
111 };
112 process_value(value, kg).map(|v| serde_json::to_string(&v).unwrap())
113}
114
115pub fn dispatch_http_body(
119 body: &str,
120 kg: &RwLock<KnowledgeGraph>,
121) -> std::result::Result<Option<Value>, String> {
122 let value: Value = serde_json::from_str(body.trim()).map_err(|e| e.to_string())?;
123 match value {
124 Value::Array(items) => {
125 let responses: Vec<Value> =
126 items.into_iter().filter_map(|v| process_value(v, kg)).collect();
127 Ok((!responses.is_empty()).then_some(Value::Array(responses)))
128 }
129 other => Ok(process_value(other, kg)),
130 }
131}
132
133#[inline]
134fn to_value(resp: JsonRpcResponse) -> Value {
135 serde_json::to_value(resp).expect("JsonRpcResponse always serializes")
136}
137
138pub struct MCPServer {
139 _config: Arc<Config>,
140 kg: Arc<RwLock<KnowledgeGraph>>,
141}
142
143impl MCPServer {
144 pub fn new(config: Config) -> Result<Self> {
145 let path = Path::new(&config.memory_file_path);
146 let kg = KnowledgeGraph::new(path)
147 .map_err(MCSError::IoError)?;
148
149 Ok(Self {
150 _config: Arc::new(config),
151 kg: Arc::new(RwLock::new(kg)),
152 })
153 }
154
155 pub fn graph(&self) -> Arc<RwLock<KnowledgeGraph>> {
157 Arc::clone(&self.kg)
158 }
159
160 pub async fn run_stdio(&self) -> Result<()> {
162 let stdin = tokio::io::stdin();
163 let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, stdin);
164 let mut stdout = tokio::io::stdout();
165 serve_line_conn(&mut reader, &mut stdout, &self.kg).await
166 }
167
168 pub async fn run_tcp(&self, addr: &str) -> Result<()> {
172 let listener = TcpListener::bind(addr).await.map_err(MCSError::IoError)?;
173 info!("Listening for TCP MCP connections on {addr}");
174 loop {
175 let (socket, peer) = listener.accept().await.map_err(MCSError::IoError)?;
176 let kg = Arc::clone(&self.kg);
177 tokio::spawn(async move {
178 let (read_half, mut write_half) = socket.into_split();
179 let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, read_half);
180 if let Err(e) = serve_line_conn(&mut reader, &mut write_half, &kg).await {
181 error!("TCP connection {peer} error: {e}");
182 }
183 });
184 }
185 }
186
187 pub async fn run_http(&self, addr: &str) -> Result<()> {
189 crate::http::run(addr, self.graph()).await
190 }
191}
192
193async fn serve_line_conn<R, W>(reader: &mut R, writer: &mut W, kg: &RwLock<KnowledgeGraph>) -> Result<()>
197where
198 R: AsyncBufReadExt + Unpin,
199 W: AsyncWriteExt + Unpin,
200{
201 let mut line = String::with_capacity(1024);
202 let mut out = Vec::with_capacity(BUFFER_CAPACITY);
203
204 loop {
205 match read_line_capped(reader, &mut line, MAX_REQUEST_BYTES).await {
206 Ok(LineRead::Eof) => break,
207 Ok(LineRead::Line) => {
208 if let Some(resp) = dispatch_line(&line, kg) {
209 out.clear();
210 out.extend_from_slice(resp.as_bytes());
211 out.extend_from_slice(NEWLINE);
212 writer.write_all(&out).await.map_err(MCSError::IoError)?;
213 writer.flush().await.map_err(MCSError::IoError)?;
214 }
215 }
216 Ok(LineRead::TooLong) => {
217 let err = MCSError::InvalidParams("Request exceeds maximum size of 16MB".into());
218 let response = JsonRpcResponse::error(None, err.error_code(), err.to_string());
219 out.clear();
220 serde_json::to_writer(&mut out, &response).map_err(MCSError::JsonError)?;
221 out.extend_from_slice(NEWLINE);
222 writer.write_all(&out).await.map_err(MCSError::IoError)?;
223 writer.flush().await.map_err(MCSError::IoError)?;
224 break;
225 }
226 Err(e) => {
227 error!("IO error: {}", e);
228 break;
229 }
230 }
231 }
232 Ok(())
233}
234
235fn process_request(req: &JsonRpcRequest, kg: &RwLock<KnowledgeGraph>) -> Result<Value> {
236 match req.method.as_str() {
237 "initialize" => handle_initialize(),
238 "tools/list" => handle_tools_list(),
239 "tools/call" => handle_tools_call(req, kg),
240 "ping" => handle_ping(),
241 method if method.starts_with("notifications/") => handle_notification(method),
242 _ => Err(MCSError::MethodNotFound(req.method.clone())),
243 }
244}
245
246const fn handle_ping() -> Result<Value> {
247 Ok(Value::Null)
248}
249
250fn handle_notification(method: &str) -> Result<Value> {
251 tracing::trace!("Received notification: {method}");
252 Ok(Value::Null)
253}
254
255fn handle_initialize() -> Result<Value> {
256 Ok(json!({
257 "protocolVersion": "2024-11-05",
258 "capabilities": {
259 "tools": { "listChanged": false }
260 },
261 "serverInfo": {
262 "name": "mcp-memory",
263 "version": env!("CARGO_PKG_VERSION")
264 }
265 }))
266}
267
268fn handle_tools_list() -> Result<Value> {
269 static CACHED: OnceLock<Value> = OnceLock::new();
270 if let Some(cached) = CACHED.get() {
271 return Ok(cached.clone());
272 }
273 let tools_json = include_str!("../tools.json");
274 let tools: Vec<Value> =
275 serde_json::from_str(tools_json).map_err(MCSError::JsonError)?;
276 let result = json!({ "tools": tools });
277 let _ = CACHED.set(result.clone());
278 Ok(result)
279}
280
281fn handle_tools_call(req: &JsonRpcRequest, kg: &RwLock<KnowledgeGraph>) -> Result<Value> {
282 let tool_name = req
283 .params
284 .as_ref()
285 .and_then(|p| p.get("name").and_then(|v| v.as_str()))
286 .ok_or_else(|| MCSError::InvalidParams("Missing 'name' parameter".into()))?;
287
288 let tool_args = req.params.as_ref().and_then(|p| p.get("arguments"));
289
290 if !tools::tool_exists(tool_name) {
291 return Err(MCSError::MethodNotFound(tool_name.to_string()));
292 }
293
294 match tool_name {
295 "create_entities" => memory::handle_create_entities(kg, tool_args),
296 "create_relations" => memory::handle_create_relations(kg, tool_args),
297 "add_observations" => memory::handle_add_observations(kg, tool_args),
298 "delete_entities" => memory::handle_delete_entities(kg, tool_args),
299 "delete_observations" => memory::handle_delete_observations(kg, tool_args),
300 "delete_relations" => memory::handle_delete_relations(kg, tool_args),
301 "read_graph" => memory::handle_read_graph(kg, tool_args),
302 "search_nodes" => memory::handle_search_nodes(kg, tool_args),
303 "open_nodes" => memory::handle_open_nodes(kg, tool_args),
304 "get_entity" => memory::handle_get_entity(kg, tool_args),
305 "graph_stats" => memory::handle_graph_stats(kg),
306 "search_relations" => memory::handle_search_relations(kg, tool_args),
307 "find_path" => memory::handle_find_path(kg, tool_args),
308 "compact" => memory::handle_compact(kg),
309 "get_neighbors" => memory::handle_get_neighbors(kg, tool_args),
310 "describe_entity" => memory::handle_describe_entity(kg, tool_args),
311 "list_entity_types" => memory::handle_list_entity_types(kg),
312 "list_relation_types" => memory::handle_list_relation_types(kg),
313 "upsert_entities" => memory::handle_upsert_entities(kg, tool_args),
314 "export_graph" => memory::handle_export_graph(kg, tool_args),
315 "merge_entities" => memory::handle_merge_entities(kg, tool_args),
316 "extract_subgraph" => memory::handle_extract_subgraph(kg, tool_args),
317 "batch_get_entities" => memory::handle_batch_get_entities(kg, tool_args),
318 "find_all_paths" => memory::handle_find_all_paths(kg, tool_args),
319 tool => Err(MCSError::MethodNotFound(tool.to_string())),
320 }
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326 use std::sync::atomic::{AtomicU64, Ordering};
327
328 static COUNTER: AtomicU64 = AtomicU64::new(0);
329
330 fn setup_kg() -> (Arc<RwLock<KnowledgeGraph>>, String) {
331 let pid = std::process::id();
332 let seq = COUNTER.fetch_add(1, Ordering::SeqCst);
333 let path = format!("/tmp/mcp_mem_test_{pid}_{seq}.bin");
334 let kg = KnowledgeGraph::new(Path::new(&path)).unwrap();
335 (Arc::new(RwLock::new(kg)), path)
336 }
337
338 fn cleanup(path: &str) {
339 let _ = std::fs::remove_file(path);
340 }
341
342 #[test]
343 fn test_dispatch_line_valid_request() {
344 let (kg, path) = setup_kg();
345 let line = r#"{"jsonrpc":"2.0","method":"initialize","id":1}"#;
346 let resp = dispatch_line(line, &kg).unwrap();
347 let v: Value = serde_json::from_str(&resp).unwrap();
348 assert_eq!(v["id"], 1);
349 assert_eq!(v["result"]["serverInfo"]["name"], "mcp-memory");
350 cleanup(&path);
351 }
352
353 #[test]
354 fn test_dispatch_line_invalid_json() {
355 let (kg, path) = setup_kg();
356 let resp = dispatch_line("{invalid}", &kg).unwrap();
357 let v: Value = serde_json::from_str(&resp).unwrap();
358 assert_eq!(v["error"]["code"], -32700);
360 assert!(v["id"].is_null());
361 cleanup(&path);
362 }
363
364 #[test]
365 fn test_dispatch_line_empty() {
366 let (kg, path) = setup_kg();
367 let resp = dispatch_line(" \n", &kg).unwrap();
368 let v: Value = serde_json::from_str(&resp).unwrap();
369 assert_eq!(v["error"]["code"], -32700);
370 cleanup(&path);
371 }
372
373 #[test]
374 fn test_notification_has_no_response() {
375 let (kg, path) = setup_kg();
376 let line = r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#;
378 assert!(dispatch_line(line, &kg).is_none());
379 cleanup(&path);
380 }
381
382 #[test]
383 fn test_unknown_method_error() {
384 let (kg, path) = setup_kg();
385 let line = r#"{"jsonrpc":"2.0","method":"does/not/exist","id":7}"#;
386 let v: Value = serde_json::from_str(&dispatch_line(line, &kg).unwrap()).unwrap();
387 assert_eq!(v["id"], 7);
388 assert_eq!(v["error"]["code"], -32601); cleanup(&path);
390 }
391
392 #[test]
393 fn test_tools_call_roundtrip_via_dispatch() {
394 let (kg, path) = setup_kg();
395 let create = r#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"create_entities","arguments":{"entities":[{"name":"Ada","entityType":"person","observations":["math"]}]}}}"#;
396 assert!(dispatch_line(create, &kg).is_some());
397
398 let read = r#"{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"read_graph","arguments":{}}}"#;
399 let v: Value = serde_json::from_str(&dispatch_line(read, &kg).unwrap()).unwrap();
400 let text = v["result"]["content"][0]["text"].as_str().unwrap();
401 assert!(text.contains("Ada"));
402 cleanup(&path);
403 }
404
405 #[test]
406 fn test_http_body_batch_and_notifications() {
407 let (kg, path) = setup_kg();
408 let batch = r#"[
410 {"jsonrpc":"2.0","method":"initialize","id":1},
411 {"jsonrpc":"2.0","method":"notifications/initialized"}
412 ]"#;
413 let out = dispatch_http_body(batch, &kg).unwrap().unwrap();
414 let arr = out.as_array().unwrap();
415 assert_eq!(arr.len(), 1);
416 assert_eq!(arr[0]["id"], 1);
417
418 let notif = r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#;
420 assert!(dispatch_http_body(notif, &kg).unwrap().is_none());
421
422 assert!(dispatch_http_body("{bad", &kg).is_err());
424 cleanup(&path);
425 }
426
427 #[test]
428 fn test_handle_initialize_response() {
429 let (kg, path) = setup_kg();
430 let req = JsonRpcRequest {
431 jsonrpc: "2.0".to_string(),
432 method: "initialize".to_string(),
433 params: None,
434 id: Some(Value::Number(1.into())),
435 };
436 let result = process_request(&req, &kg).unwrap();
437 assert_eq!(result["protocolVersion"], "2024-11-05");
438 assert_eq!(result["serverInfo"]["name"], "mcp-memory");
439 cleanup(&path);
440 }
441}