1use serde_json::{Value, json};
2use std::num::NonZeroUsize;
3use std::path::Path;
4use std::sync::Arc;
5use std::time::Duration;
6
7use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
8use tokio::net::TcpListener;
9use tokio::sync::Semaphore;
10use tracing::{error, info};
11
12use crate::actions::memory;
13use crate::config::Config;
14use crate::errors::{MCSError, Result};
15use crate::kg::GraphHandle;
16use crate::protocol::{JsonRpcRequest, JsonRpcResponse};
17use crate::tools;
18
19enum HandlerResult {
23 Value(Value),
24 RawResult(String),
25}
26
27const BUFFER_CAPACITY: usize = 65536;
28const NEWLINE: &[u8] = b"\n";
29pub const MAX_REQUEST_BYTES: usize = 16 * 1024 * 1024;
31const MAX_TCP_CONNECTIONS: usize = 128;
33
34enum LineRead {
35 Line,
36 Eof,
37 TooLong,
38}
39
40async fn read_line_capped<R>(
41 reader: &mut R,
42 out: &mut String,
43 max: usize,
44) -> std::io::Result<LineRead>
45where
46 R: AsyncBufReadExt + Unpin,
47{
48 out.clear();
49 let mut buf: Vec<u8> = Vec::new();
50 loop {
51 let available = reader.fill_buf().await?;
52 if available.is_empty() {
53 if buf.is_empty() {
54 return Ok(LineRead::Eof);
55 }
56 *out = String::from_utf8(buf).map_err(|_| {
58 std::io::Error::new(std::io::ErrorKind::InvalidData, "Non-UTF-8 input")
59 })?;
60 return Ok(LineRead::Line);
61 }
62 match available.iter().position(|&b| b == b'\n') {
63 Some(i) => {
64 if buf.len() + i + 1 > max {
65 reader.consume(i + 1);
66 return Ok(LineRead::TooLong);
67 }
68 buf.extend_from_slice(&available[..=i]);
69 reader.consume(i + 1);
70 *out = String::from_utf8(buf).map_err(|_| {
71 std::io::Error::new(std::io::ErrorKind::InvalidData, "Non-UTF-8 input")
72 })?;
73 return Ok(LineRead::Line);
74 }
75 None => {
76 let take = available.len();
77 if buf.len() + take > max {
78 reader.consume(take);
79 return Ok(LineRead::TooLong);
80 }
81 buf.extend_from_slice(available);
82 reader.consume(take);
83 }
84 }
85 }
86}
87
88fn parse_error(msg: String) -> JsonRpcResponse {
89 let mcp_error = MCSError::ParseError(msg);
90 JsonRpcResponse::error(None, mcp_error.error_code(), mcp_error.to_string())
91}
92
93pub fn process_value(value: Value, kg: &GraphHandle) -> Option<Value> {
96 let req: JsonRpcRequest = match serde_json::from_value(value) {
97 Ok(r) => r,
98 Err(e) => return Some(to_value(parse_error(e.to_string()))),
99 };
100 req.id.as_ref()?;
101
102 match process_request(&req, kg) {
103 Ok(HandlerResult::Value(result)) => {
104 Some(to_value(JsonRpcResponse::success(req.id, result)))
105 }
106 Ok(HandlerResult::RawResult(_)) => {
107 unreachable!("RawResult must be handled at the dispatch level, not via process_value");
110 }
111 Err(e) => Some(to_value(JsonRpcResponse::error(
112 req.id,
113 e.error_code(),
114 e.to_string(),
115 ))),
116 }
117}
118
119pub fn dispatch_line(line: &str, kg: &GraphHandle) -> Option<String> {
122 let trimmed = line.trim();
123 if trimmed.is_empty() {
124 return Some(serde_json::to_string(&parse_error("Empty request".into())).unwrap());
125 }
126 let raw: Value = match serde_json::from_str(trimmed) {
127 Ok(v) => v,
128 Err(e) => return Some(serde_json::to_string(&parse_error(e.to_string())).unwrap()),
129 };
130 let req: JsonRpcRequest = match serde_json::from_value(raw) {
131 Ok(r) => r,
132 Err(e) => return Some(serde_json::to_string(&parse_error(e.to_string())).unwrap()),
133 };
134 req.id.as_ref()?;
135 match process_request(&req, kg) {
136 Ok(HandlerResult::Value(result)) => {
137 let resp = JsonRpcResponse::success(req.id, result);
138 Some(serde_json::to_string(&resp).unwrap())
139 }
140 Ok(HandlerResult::RawResult(result_json)) => {
141 let id_json = serde_json::to_string(&req.id).unwrap();
142 let mut out = String::with_capacity(64 + id_json.len() + result_json.len());
143 out.push_str(r#"{"jsonrpc":"2.0","id":"#);
144 out.push_str(&id_json);
145 out.push_str(",\"result\":");
146 out.push_str(&result_json);
147 out.push('}');
148 Some(out)
149 }
150 Err(e) => {
151 let resp = JsonRpcResponse::error(req.id, e.error_code(), e.to_string());
152 Some(serde_json::to_string(&resp).unwrap())
153 }
154 }
155}
156
157pub fn dispatch_http_body(
161 body: &str,
162 kg: &GraphHandle,
163) -> std::result::Result<Option<Value>, String> {
164 let value: Value = serde_json::from_str(body.trim()).map_err(|e| e.to_string())?;
165 match value {
166 Value::Array(items) => {
167 let responses: Vec<Value> = items
169 .into_iter()
170 .filter_map(|v| process_value_http(v, kg))
171 .collect();
172 Ok((!responses.is_empty()).then_some(Value::Array(responses)))
173 }
174 other => Ok(process_value_http(other, kg)),
175 }
176}
177
178fn process_value_http(value: Value, kg: &GraphHandle) -> Option<Value> {
181 let req: JsonRpcRequest = match serde_json::from_value(value) {
182 Ok(r) => r,
183 Err(e) => return Some(to_value(parse_error(e.to_string()))),
184 };
185 req.id.as_ref()?;
186 match process_request(&req, kg) {
187 Ok(HandlerResult::Value(result)) => {
188 Some(to_value(JsonRpcResponse::success(req.id, result)))
189 }
190 Ok(HandlerResult::RawResult(result_json)) => {
191 let result_val: Value = serde_json::from_str(&result_json).unwrap_or(Value::Null);
195 Some(to_value(JsonRpcResponse::success(req.id, result_val)))
196 }
197 Err(e) => Some(to_value(JsonRpcResponse::error(
198 req.id,
199 e.error_code(),
200 e.to_string(),
201 ))),
202 }
203}
204
205#[inline]
206fn to_value(resp: JsonRpcResponse) -> Value {
207 serde_json::to_value(resp).expect("JsonRpcResponse always serializes")
208}
209
210pub struct MCPServer {
211 config: Arc<Config>,
212 kg: Arc<GraphHandle>,
213}
214
215impl MCPServer {
216 pub fn new(config: Config) -> Result<Self> {
217 let path = Path::new(&config.memory_file_path);
218 let lru_cache = NonZeroUsize::new(config.lru_cache_size).unwrap_or_else(|| {
219 NonZeroUsize::new(10000).expect("10000 > 0")
220 });
221 let kg = GraphHandle::new(
222 path,
223 config.durability,
224 config.mmap_size,
225 lru_cache,
226 config.read_pool_size,
227 )?;
228
229 Ok(Self {
230 config: Arc::new(config),
231 kg: Arc::new(kg),
232 })
233 }
234
235 pub fn graph(&self) -> Arc<GraphHandle> {
237 Arc::clone(&self.kg)
238 }
239
240 pub async fn run_stdio(&self) -> Result<()> {
242 spawn_maintenance(self.kg.clone());
243 let stdin = tokio::io::stdin();
244 let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, stdin);
245 let mut stdout = tokio::io::stdout();
246 serve_line_conn(&mut reader, &mut stdout, Arc::clone(&self.kg)).await
247 }
248
249 pub async fn run_tcp(&self, addr: &str) -> Result<()> {
253 spawn_maintenance(self.kg.clone());
254 let listener = TcpListener::bind(addr).await.map_err(MCSError::IoError)?;
255 let semaphore = Arc::new(Semaphore::new(MAX_TCP_CONNECTIONS));
256 let auth_token = self.config.auth_token.clone();
257 info!(
258 "Listening for TCP MCP connections on {addr} (max {MAX_TCP_CONNECTIONS}, auth {})",
259 if auth_token.is_some() { "on" } else { "off" }
260 );
261 loop {
262 let permit = Arc::clone(&semaphore).acquire_owned().await;
263 let (socket, peer) = listener.accept().await.map_err(MCSError::IoError)?;
264 let kg = Arc::clone(&self.kg);
265 let auth_token = auth_token.clone();
266 tokio::spawn(async move {
267 let _permit = permit; let (read_half, mut write_half) = socket.into_split();
269 let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, read_half);
270 if let Some(ref expected) = auth_token {
273 match authenticate_line_conn(&mut reader, expected).await {
274 Ok(true) => {}
275 Ok(false) => {
276 let _ = write_half.write_all(AUTH_REQUIRED_LINE.as_bytes()).await;
277 let _ = write_half.flush().await;
278 return;
279 }
280 Err(e) => {
281 error!("TCP auth error for {peer}: {e}");
282 return;
283 }
284 }
285 }
286 if let Err(e) = serve_line_conn(&mut reader, &mut write_half, kg).await {
287 error!("TCP connection {peer} error: {e}");
288 }
289 });
290 }
291 }
292
293 pub async fn run_http(&self, addr: &str) -> Result<()> {
295 spawn_maintenance(self.kg.clone());
296 crate::http::run(
297 addr,
298 self.graph(),
299 self.config.auth_token.clone(),
300 self.config.tls_cert.clone(),
301 self.config.tls_key.clone(),
302 )
303 .await
304 }
305}
306
307fn spawn_maintenance(kg: Arc<GraphHandle>) {
310 tokio::spawn(async move {
311 let mut interval = tokio::time::interval(Duration::from_secs(300));
312 interval.tick().await; loop {
314 interval.tick().await;
315 let kg = kg.clone();
316 tokio::task::spawn_blocking(move || {
317 if let Err(e) = kg.run_maintenance() {
318 tracing::warn!("Maintenance error: {e}");
319 }
320 })
321 .await
322 .ok();
323 }
324 });
325}
326
327const AUTH_REQUIRED_LINE: &str = "{\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32001,\
329\"message\":\"Authentication required: send the bearer token as the first line\"},\"id\":null}\n";
330
331async fn authenticate_line_conn<R>(reader: &mut R, expected: &str) -> Result<bool>
334where
335 R: AsyncBufReadExt + Unpin,
336{
337 let mut line = String::new();
338 match read_line_capped(reader, &mut line, MAX_REQUEST_BYTES)
339 .await
340 .map_err(MCSError::IoError)?
341 {
342 LineRead::Line => Ok(token_matches(&line, expected)),
343 _ => Ok(false),
344 }
345}
346
347async fn serve_line_conn<R, W>(reader: &mut R, writer: &mut W, kg: Arc<GraphHandle>) -> Result<()>
353where
354 R: AsyncBufReadExt + Unpin,
355 W: AsyncWriteExt + Unpin,
356{
357 let mut line = String::with_capacity(1024);
358 let mut out = Vec::with_capacity(BUFFER_CAPACITY);
359
360 loop {
361 match read_line_capped(reader, &mut line, MAX_REQUEST_BYTES).await {
362 Ok(LineRead::Eof) => break,
363 Ok(LineRead::Line) => {
364 let line_copy = line.clone();
365 let kg_clone = Arc::clone(&kg);
366 let resp =
367 tokio::task::spawn_blocking(move || dispatch_line(&line_copy, &kg_clone))
368 .await
369 .map_err(|join_err| {
370 error!("dispatch task panicked: {join_err}");
371 MCSError::IoError(std::io::Error::other("dispatch task panicked"))
372 })?;
373 if let Some(resp) = resp {
374 out.clear();
375 out.extend_from_slice(resp.as_bytes());
376 out.extend_from_slice(NEWLINE);
377 writer.write_all(&out).await.map_err(MCSError::IoError)?;
378 writer.flush().await.map_err(MCSError::IoError)?;
379 }
380 }
381 Ok(LineRead::TooLong) => {
382 let err = MCSError::InvalidParams("Request exceeds maximum size of 16MB".into());
383 let response = JsonRpcResponse::error(None, err.error_code(), err.to_string());
384 out.clear();
385 serde_json::to_writer(&mut out, &response).map_err(MCSError::JsonError)?;
386 out.extend_from_slice(NEWLINE);
387 writer.write_all(&out).await.map_err(MCSError::IoError)?;
388 writer.flush().await.map_err(MCSError::IoError)?;
389 break;
390 }
391 Err(e) => {
392 error!("IO error: {}", e);
393 break;
394 }
395 }
396 }
397 Ok(())
398}
399
400fn process_request(req: &JsonRpcRequest, kg: &GraphHandle) -> Result<HandlerResult> {
401 match req.method.as_str() {
402 "initialize" => Ok(HandlerResult::Value(handle_initialize(req))),
403 "tools/list" => Ok(HandlerResult::Value(handle_tools_list())),
404 "tools/call" => handle_tools_call(req, kg),
405 "ping" => Ok(HandlerResult::Value(Value::Null)),
406 method if method.starts_with("notifications/") => {
407 tracing::trace!("Received notification: {method}");
408 Ok(HandlerResult::Value(Value::Null))
409 }
410 _ => Err(MCSError::MethodNotFound(req.method.clone())),
411 }
412}
413
414const SUPPORTED_PROTOCOL_VERSIONS: &[&str] =
417 &["2025-11-25", "2025-06-18", "2025-03-26", "2024-11-05"];
418const LATEST_PROTOCOL_VERSION: &str = "2025-11-25";
420
421const SERVER_INSTRUCTIONS: &str = "Knowledge-graph memory MCP server. Entity names are unique and \
423case-sensitive. Use `create_entities`/`create_relations` to build the graph, `add_observations` to \
424attach facts, and `search_nodes`/`open_nodes`/`read_graph` to retrieve. Prefer `upsert_entities` for \
425idempotent writes and `merge_entities` to collapse duplicates. Tool failures are returned with \
426`isError: true` rather than as protocol errors — read the message and retry.";
427
428fn handle_initialize(req: &JsonRpcRequest) -> Value {
429 let protocol_version = req
431 .params
432 .as_ref()
433 .and_then(|p| p.get("protocolVersion"))
434 .and_then(Value::as_str)
435 .filter(|v| SUPPORTED_PROTOCOL_VERSIONS.contains(v))
436 .unwrap_or(LATEST_PROTOCOL_VERSION);
437
438 json!({
439 "protocolVersion": protocol_version,
440 "capabilities": {
441 "tools": { "listChanged": false }
442 },
443 "serverInfo": {
444 "name": "mcp-memory",
445 "version": env!("CARGO_PKG_VERSION")
446 },
447 "instructions": SERVER_INSTRUCTIONS
448 })
449}
450
451#[inline]
456fn tool_error(message: &str) -> Value {
457 json!({
458 "content": [{ "type": "text", "text": message }],
459 "isError": true
460 })
461}
462
463pub fn token_matches(presented: &str, expected: &str) -> bool {
466 use subtle::ConstantTimeEq;
467 let presented = presented.trim();
468 let presented = presented
469 .strip_prefix("Bearer ")
470 .unwrap_or(presented)
471 .trim();
472 presented.as_bytes().ct_eq(expected.as_bytes()).into()
473}
474
475fn handle_tools_list() -> Value {
476 static CACHED: std::sync::OnceLock<Value> = std::sync::OnceLock::new();
477 if let Some(cached) = CACHED.get() {
478 return cached.clone();
479 }
480 let tools_json = include_str!("../tools.json");
481 let tools: Vec<Value> = serde_json::from_str(tools_json)
482 .expect("tools.json is valid JSON compiled at build time");
483 let result = json!({ "tools": tools });
484 let _ = CACHED.set(result.clone());
485 result
486}
487
488fn handle_tools_call(req: &JsonRpcRequest, kg: &GraphHandle) -> Result<HandlerResult> {
489 let tool_name = req
490 .params
491 .as_ref()
492 .and_then(|p| p.get("name").and_then(|v| v.as_str()))
493 .ok_or_else(|| MCSError::InvalidParams("Missing 'name' parameter".into()))?;
494
495 let tool_args = req.params.as_ref().and_then(|p| p.get("arguments"));
496
497 if !tools::tool_exists(tool_name) {
498 return Err(MCSError::MethodNotFound(tool_name.to_string()));
499 }
500
501 let result = match tool_name {
502 "read_graph" => memory::handle_read_graph(kg, tool_args).map(HandlerResult::RawResult),
504 "search_nodes" => memory::handle_search_nodes(kg, tool_args).map(HandlerResult::RawResult),
505 "create_entities" => {
507 memory::handle_create_entities(kg, tool_args).map(HandlerResult::Value)
508 }
509 "create_relations" => {
510 memory::handle_create_relations(kg, tool_args).map(HandlerResult::Value)
511 }
512 "add_observations" => {
513 memory::handle_add_observations(kg, tool_args).map(HandlerResult::Value)
514 }
515 "delete_entities" => {
516 memory::handle_delete_entities(kg, tool_args).map(HandlerResult::Value)
517 }
518 "delete_observations" => {
519 memory::handle_delete_observations(kg, tool_args).map(HandlerResult::Value)
520 }
521 "delete_relations" => {
522 memory::handle_delete_relations(kg, tool_args).map(HandlerResult::Value)
523 }
524 "open_nodes" => memory::handle_open_nodes(kg, tool_args).map(HandlerResult::Value),
525 "get_entity" => memory::handle_get_entity(kg, tool_args).map(HandlerResult::Value),
526 "graph_stats" => memory::handle_graph_stats(kg).map(HandlerResult::Value),
527 "search_relations" => {
528 memory::handle_search_relations(kg, tool_args).map(HandlerResult::Value)
529 }
530 "find_path" => memory::handle_find_path(kg, tool_args).map(HandlerResult::Value),
531 "compact" => memory::handle_compact(kg).map(HandlerResult::Value),
532 "get_neighbors" => memory::handle_get_neighbors(kg, tool_args).map(HandlerResult::Value),
533 "describe_entity" => {
534 memory::handle_describe_entity(kg, tool_args).map(HandlerResult::Value)
535 }
536 "list_entity_types" => memory::handle_list_entity_types(kg).map(HandlerResult::Value),
537 "list_relation_types" => memory::handle_list_relation_types(kg).map(HandlerResult::Value),
538 "upsert_entities" => {
539 memory::handle_upsert_entities(kg, tool_args).map(HandlerResult::Value)
540 }
541 "export_graph" => memory::handle_export_graph(kg, tool_args).map(HandlerResult::Value),
542 "merge_entities" => memory::handle_merge_entities(kg, tool_args).map(HandlerResult::Value),
543 "extract_subgraph" => {
544 memory::handle_extract_subgraph(kg, tool_args).map(HandlerResult::Value)
545 }
546 "batch_get_entities" => {
547 memory::handle_batch_get_entities(kg, tool_args).map(HandlerResult::Value)
548 }
549 "find_all_paths" => memory::handle_find_all_paths(kg, tool_args).map(HandlerResult::Value),
550 "entity_exists" => memory::handle_entity_exists(kg, tool_args).map(HandlerResult::Value),
551 "degree" => memory::handle_degree(kg, tool_args).map(HandlerResult::Value),
552 tool => Err(MCSError::MethodNotFound(tool.to_string())),
553 };
554
555 Ok(result.unwrap_or_else(|e| {
558 error!("Tool '{tool_name}' error: {e}");
559 HandlerResult::Value(tool_error(&e.to_string()))
560 }))
561}
562
563