1use serde_json::{Value, json};
2use std::path::Path;
3use std::sync::Arc;
4
5use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
6use tokio::net::TcpListener;
7use tracing::{error, info};
8
9use crate::actions::memory;
10use crate::config::Config;
11use crate::errors::{MCSError, Result};
12use crate::kg::GraphHandle;
13use crate::protocol::{JsonRpcRequest, JsonRpcResponse};
14use crate::tools;
15
16const BUFFER_CAPACITY: usize = 65536;
17const NEWLINE: &[u8] = b"\n";
18pub const MAX_REQUEST_BYTES: usize = 16 * 1024 * 1024;
20
21enum LineRead {
22 Line,
23 Eof,
24 TooLong,
25}
26
27async fn read_line_capped<R>(reader: &mut R, out: &mut String, max: usize) -> std::io::Result<LineRead>
28where
29 R: AsyncBufReadExt + Unpin,
30{
31 out.clear();
32 let mut buf: Vec<u8> = Vec::new();
33 loop {
34 let available = reader.fill_buf().await?;
35 if available.is_empty() {
36 if buf.is_empty() {
37 return Ok(LineRead::Eof);
38 }
39 *out = String::from_utf8(buf.clone()).map_err(|_| {
40 std::io::Error::new(std::io::ErrorKind::InvalidData, "Non-UTF-8 input")
41 })?;
42 return Ok(LineRead::Line);
43 }
44 match available.iter().position(|&b| b == b'\n') {
45 Some(i) => {
46 if buf.len() + i + 1 > max {
47 reader.consume(i + 1);
48 return Ok(LineRead::TooLong);
49 }
50 buf.extend_from_slice(&available[..=i]);
51 reader.consume(i + 1);
52 *out = String::from_utf8(buf.clone()).map_err(|_| {
53 std::io::Error::new(std::io::ErrorKind::InvalidData, "Non-UTF-8 input")
54 })?;
55 return Ok(LineRead::Line);
56 }
57 None => {
58 let take = available.len();
59 if buf.len() + take > max {
60 reader.consume(take);
61 return Ok(LineRead::TooLong);
62 }
63 buf.extend_from_slice(available);
64 reader.consume(take);
65 }
66 }
67 }
68}
69
70fn parse_error(msg: String) -> JsonRpcResponse {
71 let mcp_error = MCSError::ParseError(msg);
72 JsonRpcResponse::error(None, mcp_error.error_code(), mcp_error.to_string())
73}
74
75pub fn process_value(value: Value, kg: &GraphHandle) -> Option<Value> {
78 let req: JsonRpcRequest = match serde_json::from_value(value) {
79 Ok(r) => r,
80 Err(e) => return Some(to_value(parse_error(e.to_string()))),
81 };
82 req.id.as_ref()?;
83 let response = match process_request(&req, kg) {
84 Ok(result) => JsonRpcResponse::success(req.id, result),
85 Err(e) => JsonRpcResponse::error(req.id, e.error_code(), e.to_string()),
86 };
87 Some(to_value(response))
88}
89
90pub fn dispatch_line(line: &str, kg: &GraphHandle) -> Option<String> {
93 let trimmed = line.trim();
94 if trimmed.is_empty() {
95 return Some(serde_json::to_string(&parse_error("Empty request".into())).unwrap());
96 }
97 let value: Value = match serde_json::from_str(trimmed) {
98 Ok(v) => v,
99 Err(e) => return Some(serde_json::to_string(&parse_error(e.to_string())).unwrap()),
100 };
101 process_value(value, kg).map(|v| serde_json::to_string(&v).unwrap())
102}
103
104pub fn dispatch_http_body(
108 body: &str,
109 kg: &GraphHandle,
110) -> std::result::Result<Option<Value>, String> {
111 let value: Value = serde_json::from_str(body.trim()).map_err(|e| e.to_string())?;
112 match value {
113 Value::Array(items) => {
114 let responses: Vec<Value> =
115 items.into_iter().filter_map(|v| process_value(v, kg)).collect();
116 Ok((!responses.is_empty()).then_some(Value::Array(responses)))
117 }
118 other => Ok(process_value(other, kg)),
119 }
120}
121
122#[inline]
123fn to_value(resp: JsonRpcResponse) -> Value {
124 serde_json::to_value(resp).expect("JsonRpcResponse always serializes")
125}
126
127pub struct MCPServer {
128 _config: Arc<Config>,
129 kg: Arc<GraphHandle>,
130}
131
132impl MCPServer {
133 pub fn new(config: Config) -> Result<Self> {
134 let path = Path::new(&config.memory_file_path);
135 let kg = GraphHandle::new(path).map_err(MCSError::IoError)?;
136
137 Ok(Self {
138 _config: Arc::new(config),
139 kg: Arc::new(kg),
140 })
141 }
142
143 pub fn graph(&self) -> Arc<GraphHandle> {
145 Arc::clone(&self.kg)
146 }
147
148 pub async fn run_stdio(&self) -> Result<()> {
150 let stdin = tokio::io::stdin();
151 let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, stdin);
152 let mut stdout = tokio::io::stdout();
153 serve_line_conn(&mut reader, &mut stdout, &self.kg).await
154 }
155
156 pub async fn run_tcp(&self, addr: &str) -> Result<()> {
160 let listener = TcpListener::bind(addr).await.map_err(MCSError::IoError)?;
161 info!("Listening for TCP MCP connections on {addr}");
162 loop {
163 let (socket, peer) = listener.accept().await.map_err(MCSError::IoError)?;
164 let kg = Arc::clone(&self.kg);
165 tokio::spawn(async move {
166 let (read_half, mut write_half) = socket.into_split();
167 let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, read_half);
168 if let Err(e) = serve_line_conn(&mut reader, &mut write_half, &kg).await {
169 error!("TCP connection {peer} error: {e}");
170 }
171 });
172 }
173 }
174
175 pub async fn run_http(&self, addr: &str) -> Result<()> {
177 crate::http::run(addr, self.graph()).await
178 }
179}
180
181async fn serve_line_conn<R, W>(reader: &mut R, writer: &mut W, kg: &GraphHandle) -> Result<()>
185where
186 R: AsyncBufReadExt + Unpin,
187 W: AsyncWriteExt + Unpin,
188{
189 let mut line = String::with_capacity(1024);
190 let mut out = Vec::with_capacity(BUFFER_CAPACITY);
191
192 loop {
193 match read_line_capped(reader, &mut line, MAX_REQUEST_BYTES).await {
194 Ok(LineRead::Eof) => break,
195 Ok(LineRead::Line) => {
196 if let Some(resp) = dispatch_line(&line, kg) {
197 out.clear();
198 out.extend_from_slice(resp.as_bytes());
199 out.extend_from_slice(NEWLINE);
200 writer.write_all(&out).await.map_err(MCSError::IoError)?;
201 writer.flush().await.map_err(MCSError::IoError)?;
202 }
203 }
204 Ok(LineRead::TooLong) => {
205 let err = MCSError::InvalidParams("Request exceeds maximum size of 16MB".into());
206 let response = JsonRpcResponse::error(None, err.error_code(), err.to_string());
207 out.clear();
208 serde_json::to_writer(&mut out, &response).map_err(MCSError::JsonError)?;
209 out.extend_from_slice(NEWLINE);
210 writer.write_all(&out).await.map_err(MCSError::IoError)?;
211 writer.flush().await.map_err(MCSError::IoError)?;
212 break;
213 }
214 Err(e) => {
215 error!("IO error: {}", e);
216 break;
217 }
218 }
219 }
220 Ok(())
221}
222
223fn process_request(req: &JsonRpcRequest, kg: &GraphHandle) -> Result<Value> {
224 match req.method.as_str() {
225 "initialize" => handle_initialize(),
226 "tools/list" => handle_tools_list(),
227 "tools/call" => handle_tools_call(req, kg),
228 "ping" => handle_ping(),
229 method if method.starts_with("notifications/") => handle_notification(method),
230 _ => Err(MCSError::MethodNotFound(req.method.clone())),
231 }
232}
233
234const fn handle_ping() -> Result<Value> {
235 Ok(Value::Null)
236}
237
238fn handle_notification(method: &str) -> Result<Value> {
239 tracing::trace!("Received notification: {method}");
240 Ok(Value::Null)
241}
242
243fn handle_initialize() -> Result<Value> {
244 Ok(json!({
245 "protocolVersion": "2024-11-05",
246 "capabilities": {
247 "tools": { "listChanged": false }
248 },
249 "serverInfo": {
250 "name": "mcp-memory",
251 "version": env!("CARGO_PKG_VERSION")
252 }
253 }))
254}
255
256fn handle_tools_list() -> Result<Value> {
257 static CACHED: std::sync::OnceLock<Value> = std::sync::OnceLock::new();
258 if let Some(cached) = CACHED.get() {
259 return Ok(cached.clone());
260 }
261 let tools_json = include_str!("../tools.json");
262 let tools: Vec<Value> =
263 serde_json::from_str(tools_json).map_err(MCSError::JsonError)?;
264 let result = json!({ "tools": tools });
265 let _ = CACHED.set(result.clone());
266 Ok(result)
267}
268
269fn handle_tools_call(req: &JsonRpcRequest, kg: &GraphHandle) -> Result<Value> {
270 let tool_name = req
271 .params
272 .as_ref()
273 .and_then(|p| p.get("name").and_then(|v| v.as_str()))
274 .ok_or_else(|| MCSError::InvalidParams("Missing 'name' parameter".into()))?;
275
276 let tool_args = req.params.as_ref().and_then(|p| p.get("arguments"));
277
278 if !tools::tool_exists(tool_name) {
279 return Err(MCSError::MethodNotFound(tool_name.to_string()));
280 }
281
282 match tool_name {
283 "create_entities" => memory::handle_create_entities(kg, tool_args),
284 "create_relations" => memory::handle_create_relations(kg, tool_args),
285 "add_observations" => memory::handle_add_observations(kg, tool_args),
286 "delete_entities" => memory::handle_delete_entities(kg, tool_args),
287 "delete_observations" => memory::handle_delete_observations(kg, tool_args),
288 "delete_relations" => memory::handle_delete_relations(kg, tool_args),
289 "read_graph" => memory::handle_read_graph(kg, tool_args),
290 "search_nodes" => memory::handle_search_nodes(kg, tool_args),
291 "open_nodes" => memory::handle_open_nodes(kg, tool_args),
292 "get_entity" => memory::handle_get_entity(kg, tool_args),
293 "graph_stats" => memory::handle_graph_stats(kg),
294 "search_relations" => memory::handle_search_relations(kg, tool_args),
295 "find_path" => memory::handle_find_path(kg, tool_args),
296 "compact" => memory::handle_compact(kg),
297 "get_neighbors" => memory::handle_get_neighbors(kg, tool_args),
298 "describe_entity" => memory::handle_describe_entity(kg, tool_args),
299 "list_entity_types" => memory::handle_list_entity_types(kg),
300 "list_relation_types" => memory::handle_list_relation_types(kg),
301 "upsert_entities" => memory::handle_upsert_entities(kg, tool_args),
302 "export_graph" => memory::handle_export_graph(kg, tool_args),
303 "merge_entities" => memory::handle_merge_entities(kg, tool_args),
304 "extract_subgraph" => memory::handle_extract_subgraph(kg, tool_args),
305 "batch_get_entities" => memory::handle_batch_get_entities(kg, tool_args),
306 "find_all_paths" => memory::handle_find_all_paths(kg, tool_args),
307 tool => Err(MCSError::MethodNotFound(tool.to_string())),
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314 use std::sync::atomic::{AtomicU64, Ordering};
315
316 static COUNTER: AtomicU64 = AtomicU64::new(0);
317
318 fn setup_kg() -> (Arc<GraphHandle>, String) {
319 let pid = std::process::id();
320 let seq = COUNTER.fetch_add(1, Ordering::SeqCst);
321 let path = format!("/tmp/mcp_mem_test_{pid}_{seq}.bin");
322 let kg = GraphHandle::new(Path::new(&path)).unwrap();
323 (Arc::new(kg), path)
324 }
325
326 fn cleanup(path: &str) {
327 let _ = std::fs::remove_file(path);
328 }
329
330 #[test]
331 fn test_dispatch_line_valid_request() {
332 let (kg, path) = setup_kg();
333 let line = r#"{"jsonrpc":"2.0","method":"initialize","id":1}"#;
334 let resp = dispatch_line(line, &kg).unwrap();
335 let v: Value = serde_json::from_str(&resp).unwrap();
336 assert_eq!(v["id"], 1);
337 assert_eq!(v["result"]["serverInfo"]["name"], "mcp-memory");
338 cleanup(&path);
339 }
340
341 #[test]
342 fn test_dispatch_line_invalid_json() {
343 let (kg, path) = setup_kg();
344 let resp = dispatch_line("{invalid}", &kg).unwrap();
345 let v: Value = serde_json::from_str(&resp).unwrap();
346 assert_eq!(v["error"]["code"], -32700);
347 assert!(v["id"].is_null());
348 cleanup(&path);
349 }
350
351 #[test]
352 fn test_dispatch_line_empty() {
353 let (kg, path) = setup_kg();
354 let resp = dispatch_line(" \n", &kg).unwrap();
355 let v: Value = serde_json::from_str(&resp).unwrap();
356 assert_eq!(v["error"]["code"], -32700);
357 cleanup(&path);
358 }
359
360 #[test]
361 fn test_notification_has_no_response() {
362 let (kg, path) = setup_kg();
363 let line = r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#;
364 assert!(dispatch_line(line, &kg).is_none());
365 cleanup(&path);
366 }
367
368 #[test]
369 fn test_unknown_method_error() {
370 let (kg, path) = setup_kg();
371 let line = r#"{"jsonrpc":"2.0","method":"does/not/exist","id":7}"#;
372 let v: Value = serde_json::from_str(&dispatch_line(line, &kg).unwrap()).unwrap();
373 assert_eq!(v["id"], 7);
374 assert_eq!(v["error"]["code"], -32601);
375 cleanup(&path);
376 }
377
378 #[test]
379 fn test_tools_call_roundtrip_via_dispatch() {
380 let (kg, path) = setup_kg();
381 let create = r#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"create_entities","arguments":{"entities":[{"name":"Ada","entityType":"person","observations":["math"]}]}}}"#;
382 assert!(dispatch_line(create, &kg).is_some());
383
384 let read = r#"{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"read_graph","arguments":{}}}"#;
385 let v: Value = serde_json::from_str(&dispatch_line(read, &kg).unwrap()).unwrap();
386 let text = v["result"]["content"][0]["text"].as_str().unwrap();
387 assert!(text.contains("Ada"));
388 cleanup(&path);
389 }
390
391 #[test]
392 fn test_http_body_batch_and_notifications() {
393 let (kg, path) = setup_kg();
394 let batch = r#"[
395 {"jsonrpc":"2.0","method":"initialize","id":1},
396 {"jsonrpc":"2.0","method":"notifications/initialized"}
397 ]"#;
398 let out = dispatch_http_body(batch, &kg).unwrap().unwrap();
399 let arr = out.as_array().unwrap();
400 assert_eq!(arr.len(), 1);
401 assert_eq!(arr[0]["id"], 1);
402
403 let notif = r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#;
404 assert!(dispatch_http_body(notif, &kg).unwrap().is_none());
405
406 assert!(dispatch_http_body("{bad", &kg).is_err());
407 cleanup(&path);
408 }
409
410 #[test]
411 fn test_handle_initialize_response() {
412 let (kg, path) = setup_kg();
413 let req = JsonRpcRequest {
414 jsonrpc: "2.0".to_string(),
415 method: "initialize".to_string(),
416 params: None,
417 id: Some(Value::Number(1.into())),
418 };
419 let result = process_request(&req, &kg).unwrap();
420 assert_eq!(result["protocolVersion"], "2024-11-05");
421 assert_eq!(result["serverInfo"]["name"], "mcp-memory");
422 cleanup(&path);
423 }
424}