Skip to main content

mcp_memory/
server.rs

1use serde_json::{Value, json};
2use std::num::NonZeroUsize;
3use std::path::Path;
4use std::sync::Arc;
5use std::time::Duration;
6
7use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
8use tokio::net::TcpListener;
9use tokio::sync::Semaphore;
10use tracing::{error, info};
11
12use crate::actions::memory;
13use crate::config::Config;
14use crate::errors::{MCSError, Result};
15use crate::kg::GraphHandle;
16use crate::protocol::{JsonRpcRequest, JsonRpcResponse};
17use crate::tools;
18
19/// Outcome of processing a request: either a pre-escaped JSON Value (small
20/// payloads) or a pre-serialized JSON *string* of the `result` field (avoids
21/// a second serialization pass for large payloads such as `read_graph`).
22enum HandlerResult {
23    Value(Value),
24    RawResult(String),
25}
26
27const BUFFER_CAPACITY: usize = 65536;
28const NEWLINE: &[u8] = b"\n";
29/// Maximum size of a single inbound JSON-RPC message (shared by all transports).
30pub const MAX_REQUEST_BYTES: usize = 16 * 1024 * 1024;
31/// Maximum number of concurrent TCP connections (C4).
32const MAX_TCP_CONNECTIONS: usize = 128;
33
34enum LineRead {
35    Line,
36    Eof,
37    TooLong,
38}
39
40async fn read_line_capped<R>(
41    reader: &mut R,
42    out: &mut String,
43    max: usize,
44) -> std::io::Result<LineRead>
45where
46    R: AsyncBufReadExt + Unpin,
47{
48    out.clear();
49    let mut buf: Vec<u8> = Vec::new();
50    loop {
51        let available = reader.fill_buf().await?;
52        if available.is_empty() {
53            if buf.is_empty() {
54                return Ok(LineRead::Eof);
55            }
56            // Move `buf` into the String — no copy. `buf` is not used afterward.
57            *out = String::from_utf8(buf).map_err(|_| {
58                std::io::Error::new(std::io::ErrorKind::InvalidData, "Non-UTF-8 input")
59            })?;
60            return Ok(LineRead::Line);
61        }
62        match available.iter().position(|&b| b == b'\n') {
63            Some(i) => {
64                if buf.len() + i + 1 > max {
65                    reader.consume(i + 1);
66                    return Ok(LineRead::TooLong);
67                }
68                buf.extend_from_slice(&available[..=i]);
69                reader.consume(i + 1);
70                *out = String::from_utf8(buf).map_err(|_| {
71                    std::io::Error::new(std::io::ErrorKind::InvalidData, "Non-UTF-8 input")
72                })?;
73                return Ok(LineRead::Line);
74            }
75            None => {
76                let take = available.len();
77                if buf.len() + take > max {
78                    reader.consume(take);
79                    return Ok(LineRead::TooLong);
80                }
81                buf.extend_from_slice(available);
82                reader.consume(take);
83            }
84        }
85    }
86}
87
88fn parse_error(msg: String) -> JsonRpcResponse {
89    let mcp_error = MCSError::ParseError(msg);
90    JsonRpcResponse::error(None, mcp_error.error_code(), mcp_error.to_string())
91}
92
93/// Process one parsed JSON-RPC message. `None` means "no reply" — the message
94/// was a notification (no `id`), per JSON-RPC.
95pub fn process_value(value: Value, kg: &GraphHandle) -> Option<Value> {
96    let req: JsonRpcRequest = match serde_json::from_value(value) {
97        Ok(r) => r,
98        Err(e) => return Some(to_value(parse_error(e.to_string()))),
99    };
100    req.id.as_ref()?;
101    
102    match process_request(&req, kg) {
103        Ok(HandlerResult::Value(result)) => {
104            Some(to_value(JsonRpcResponse::success(req.id, result)))
105        }
106        Ok(HandlerResult::RawResult(_)) => {
107            // RawResult cannot pass through Value — dispatch_line and
108            // dispatch_http_body handle it via separate code paths.
109            unreachable!("RawResult must be handled at the dispatch level, not via process_value");
110        }
111        Err(e) => Some(to_value(JsonRpcResponse::error(
112            req.id,
113            e.error_code(),
114            e.to_string(),
115        ))),
116    }
117}
118
119/// Dispatch one framed line (stdio / tcp). Returns the serialized response, or
120/// `None` for a notification.
121pub fn dispatch_line(line: &str, kg: &GraphHandle) -> Option<String> {
122    let trimmed = line.trim();
123    if trimmed.is_empty() {
124        return Some(serde_json::to_string(&parse_error("Empty request".into())).unwrap());
125    }
126    let raw: Value = match serde_json::from_str(trimmed) {
127        Ok(v) => v,
128        Err(e) => return Some(serde_json::to_string(&parse_error(e.to_string())).unwrap()),
129    };
130    let req: JsonRpcRequest = match serde_json::from_value(raw) {
131        Ok(r) => r,
132        Err(e) => return Some(serde_json::to_string(&parse_error(e.to_string())).unwrap()),
133    };
134    req.id.as_ref()?;
135    match process_request(&req, kg) {
136        Ok(HandlerResult::Value(result)) => {
137            let resp = JsonRpcResponse::success(req.id, result);
138            Some(serde_json::to_string(&resp).unwrap())
139        }
140        Ok(HandlerResult::RawResult(result_json)) => {
141            let id_json = serde_json::to_string(&req.id).unwrap();
142            let mut out = String::with_capacity(64 + id_json.len() + result_json.len());
143            out.push_str(r#"{"jsonrpc":"2.0","id":"#);
144            out.push_str(&id_json);
145            out.push_str(",\"result\":");
146            out.push_str(&result_json);
147            out.push('}');
148            Some(out)
149        }
150        Err(e) => {
151            let resp = JsonRpcResponse::error(req.id, e.error_code(), e.to_string());
152            Some(serde_json::to_string(&resp).unwrap())
153        }
154    }
155}
156
157/// Dispatch a Streamable-HTTP POST body, which may be a single JSON-RPC message
158/// or a batch array. `Ok(None)` means the body held only notifications (HTTP
159/// 202, empty body); `Err` means the body was not valid JSON.
160pub fn dispatch_http_body(
161    body: &str,
162    kg: &GraphHandle,
163) -> std::result::Result<Option<Value>, String> {
164    let value: Value = serde_json::from_str(body.trim()).map_err(|e| e.to_string())?;
165    match value {
166        Value::Array(items) => {
167            // Batches are rare and never huge — keep Value path for simplicity.
168            let responses: Vec<Value> = items
169                .into_iter()
170                .filter_map(|v| process_value_http(v, kg))
171                .collect();
172            Ok((!responses.is_empty()).then_some(Value::Array(responses)))
173        }
174        other => Ok(process_value_http(other, kg)),
175    }
176}
177
178/// HTTP variant of process_value that handles RawResult by converting to Value
179/// (acceptable since HTTP payloads are typically much smaller in this context).
180fn process_value_http(value: Value, kg: &GraphHandle) -> Option<Value> {
181    let req: JsonRpcRequest = match serde_json::from_value(value) {
182        Ok(r) => r,
183        Err(e) => return Some(to_value(parse_error(e.to_string()))),
184    };
185    req.id.as_ref()?;
186    match process_request(&req, kg) {
187        Ok(HandlerResult::Value(result)) => {
188            Some(to_value(JsonRpcResponse::success(req.id, result)))
189        }
190        Ok(HandlerResult::RawResult(result_json)) => {
191            // Parse the pre-serialized result back into a Value for HTTP delivery.
192            // This is a small extra cost for the HTTP transport; the stdio/TCP
193            // path (dispatch_line) avoids it entirely.
194            let result_val: Value = serde_json::from_str(&result_json).unwrap_or(Value::Null);
195            Some(to_value(JsonRpcResponse::success(req.id, result_val)))
196        }
197        Err(e) => Some(to_value(JsonRpcResponse::error(
198            req.id,
199            e.error_code(),
200            e.to_string(),
201        ))),
202    }
203}
204
205#[inline]
206fn to_value(resp: JsonRpcResponse) -> Value {
207    serde_json::to_value(resp).expect("JsonRpcResponse always serializes")
208}
209
210pub struct MCPServer {
211    config: Arc<Config>,
212    kg: Arc<GraphHandle>,
213}
214
215impl MCPServer {
216    pub fn new(config: Config) -> Result<Self> {
217        let path = Path::new(&config.memory_file_path);
218        let lru_cache = NonZeroUsize::new(config.lru_cache_size).unwrap_or_else(|| {
219            NonZeroUsize::new(10000).expect("10000 > 0")
220        });
221        let kg = GraphHandle::new(
222            path,
223            config.durability,
224            config.mmap_size,
225            lru_cache,
226            config.read_pool_size,
227        )?;
228
229        Ok(Self {
230            config: Arc::new(config),
231            kg: Arc::new(kg),
232        })
233    }
234
235    /// Expose the shared graph handle (used to drive the HTTP transport).
236    pub fn graph(&self) -> Arc<GraphHandle> {
237        Arc::clone(&self.kg)
238    }
239
240    /// stdio transport: newline-delimited JSON-RPC over stdin/stdout.
241    pub async fn run_stdio(&self) -> Result<()> {
242        spawn_maintenance(self.kg.clone());
243        let stdin = tokio::io::stdin();
244        let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, stdin);
245        let mut stdout = tokio::io::stdout();
246        serve_line_conn(&mut reader, &mut stdout, Arc::clone(&self.kg)).await
247    }
248
249    /// TCP transport: each accepted connection speaks newline-delimited
250    /// JSON-RPC, exactly like stdio. Connections are served concurrently (up to
251    /// [`MAX_TCP_CONNECTIONS`]) and share the one graph behind its mutex.
252    pub async fn run_tcp(&self, addr: &str) -> Result<()> {
253        spawn_maintenance(self.kg.clone());
254        let listener = TcpListener::bind(addr).await.map_err(MCSError::IoError)?;
255        let semaphore = Arc::new(Semaphore::new(MAX_TCP_CONNECTIONS));
256        let auth_token = self.config.auth_token.clone();
257        info!(
258            "Listening for TCP MCP connections on {addr} (max {MAX_TCP_CONNECTIONS}, auth {})",
259            if auth_token.is_some() { "on" } else { "off" }
260        );
261        loop {
262            let permit = Arc::clone(&semaphore).acquire_owned().await;
263            let (socket, peer) = listener.accept().await.map_err(MCSError::IoError)?;
264            let kg = Arc::clone(&self.kg);
265            let auth_token = auth_token.clone();
266            tokio::spawn(async move {
267                let _permit = permit; // held for the connection lifetime
268                let (read_half, mut write_half) = socket.into_split();
269                let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, read_half);
270                // When a token is configured, the client must send it as the
271                // first line before any JSON-RPC traffic.
272                if let Some(ref expected) = auth_token {
273                    match authenticate_line_conn(&mut reader, expected).await {
274                        Ok(true) => {}
275                        Ok(false) => {
276                            let _ = write_half.write_all(AUTH_REQUIRED_LINE.as_bytes()).await;
277                            let _ = write_half.flush().await;
278                            return;
279                        }
280                        Err(e) => {
281                            error!("TCP auth error for {peer}: {e}");
282                            return;
283                        }
284                    }
285                }
286                if let Err(e) = serve_line_conn(&mut reader, &mut write_half, kg).await {
287                    error!("TCP connection {peer} error: {e}");
288                }
289            });
290        }
291    }
292
293    /// MCP Streamable HTTP transport (POST/GET `/mcp`, JSON or SSE responses).
294    pub async fn run_http(&self, addr: &str) -> Result<()> {
295        spawn_maintenance(self.kg.clone());
296        crate::http::run(
297            addr,
298            self.graph(),
299            self.config.auth_token.clone(),
300            self.config.tls_cert.clone(),
301            self.config.tls_key.clone(),
302        )
303        .await
304    }
305}
306
307/// Spawn a background task that runs periodic database maintenance every
308/// 5 minutes until the runtime shuts down.
309fn spawn_maintenance(kg: Arc<GraphHandle>) {
310    tokio::spawn(async move {
311        let mut interval = tokio::time::interval(Duration::from_secs(300));
312        interval.tick().await; // skip immediate first tick
313        loop {
314            interval.tick().await;
315            let kg = kg.clone();
316            tokio::task::spawn_blocking(move || {
317                if let Err(e) = kg.run_maintenance() {
318                    tracing::warn!("Maintenance error: {e}");
319                }
320            })
321            .await
322            .ok();
323        }
324    });
325}
326
327/// JSON-RPC error line returned to a TCP client that fails authentication.
328const AUTH_REQUIRED_LINE: &str = "{\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32001,\
329\"message\":\"Authentication required: send the bearer token as the first line\"},\"id\":null}\n";
330
331/// Read the first line of a connection and compare it (constant-time) to the
332/// expected bearer token. Returns `Ok(false)` on EOF / oversized first line.
333async fn authenticate_line_conn<R>(reader: &mut R, expected: &str) -> Result<bool>
334where
335    R: AsyncBufReadExt + Unpin,
336{
337    let mut line = String::new();
338    match read_line_capped(reader, &mut line, MAX_REQUEST_BYTES)
339        .await
340        .map_err(MCSError::IoError)?
341    {
342        LineRead::Line => Ok(token_matches(&line, expected)),
343        _ => Ok(false),
344    }
345}
346
347/// Drive one line-framed connection (stdio or a single TCP socket): read
348/// newline-delimited JSON-RPC requests, write newline-delimited responses.
349/// Notifications produce no output. Returns when the peer closes the stream.
350/// The dispatch path (graph lock + optional fsync) is offloaded to
351/// [`tokio::task::spawn_blocking`] to keep the async reactor responsive (C3).
352async fn serve_line_conn<R, W>(reader: &mut R, writer: &mut W, kg: Arc<GraphHandle>) -> Result<()>
353where
354    R: AsyncBufReadExt + Unpin,
355    W: AsyncWriteExt + Unpin,
356{
357    let mut line = String::with_capacity(1024);
358    let mut out = Vec::with_capacity(BUFFER_CAPACITY);
359
360    loop {
361        match read_line_capped(reader, &mut line, MAX_REQUEST_BYTES).await {
362            Ok(LineRead::Eof) => break,
363            Ok(LineRead::Line) => {
364                let line_copy = line.clone();
365                let kg_clone = Arc::clone(&kg);
366                let resp =
367                    tokio::task::spawn_blocking(move || dispatch_line(&line_copy, &kg_clone))
368                        .await
369                        .map_err(|join_err| {
370                            error!("dispatch task panicked: {join_err}");
371                            MCSError::IoError(std::io::Error::other("dispatch task panicked"))
372                        })?;
373                if let Some(resp) = resp {
374                    out.clear();
375                    out.extend_from_slice(resp.as_bytes());
376                    out.extend_from_slice(NEWLINE);
377                    writer.write_all(&out).await.map_err(MCSError::IoError)?;
378                    writer.flush().await.map_err(MCSError::IoError)?;
379                }
380            }
381            Ok(LineRead::TooLong) => {
382                let err = MCSError::InvalidParams("Request exceeds maximum size of 16MB".into());
383                let response = JsonRpcResponse::error(None, err.error_code(), err.to_string());
384                out.clear();
385                serde_json::to_writer(&mut out, &response).map_err(MCSError::JsonError)?;
386                out.extend_from_slice(NEWLINE);
387                writer.write_all(&out).await.map_err(MCSError::IoError)?;
388                writer.flush().await.map_err(MCSError::IoError)?;
389                break;
390            }
391            Err(e) => {
392                error!("IO error: {}", e);
393                break;
394            }
395        }
396    }
397    Ok(())
398}
399
400fn process_request(req: &JsonRpcRequest, kg: &GraphHandle) -> Result<HandlerResult> {
401    match req.method.as_str() {
402        "initialize" => Ok(HandlerResult::Value(handle_initialize(req))),
403        "tools/list" => Ok(HandlerResult::Value(handle_tools_list())),
404        "tools/call" => handle_tools_call(req, kg),
405        "ping" => Ok(HandlerResult::Value(Value::Null)),
406        method if method.starts_with("notifications/") => {
407            tracing::trace!("Received notification: {method}");
408            Ok(HandlerResult::Value(Value::Null))
409        }
410        _ => Err(MCSError::MethodNotFound(req.method.clone())),
411    }
412}
413
414/// MCP protocol revisions this server can speak, newest first (for `initialize`
415/// version negotiation).
416const SUPPORTED_PROTOCOL_VERSIONS: &[&str] =
417    &["2025-11-25", "2025-06-18", "2025-03-26", "2024-11-05"];
418/// Newest revision we implement; offered when the client requests an unknown one.
419const LATEST_PROTOCOL_VERSION: &str = "2025-11-25";
420
421/// `instructions` surfaced to the client and appended to the model's system prompt.
422const SERVER_INSTRUCTIONS: &str = "Knowledge-graph memory MCP server. Entity names are unique and \
423case-sensitive. Use `create_entities`/`create_relations` to build the graph, `add_observations` to \
424attach facts, and `search_nodes`/`open_nodes`/`read_graph` to retrieve. Prefer `upsert_entities` for \
425idempotent writes and `merge_entities` to collapse duplicates. Tool failures are returned with \
426`isError: true` rather than as protocol errors — read the message and retry.";
427
428fn handle_initialize(req: &JsonRpcRequest) -> Value {
429    // Version negotiation: echo a supported requested revision, else offer latest.
430    let protocol_version = req
431        .params
432        .as_ref()
433        .and_then(|p| p.get("protocolVersion"))
434        .and_then(Value::as_str)
435        .filter(|v| SUPPORTED_PROTOCOL_VERSIONS.contains(v))
436        .unwrap_or(LATEST_PROTOCOL_VERSION);
437
438    json!({
439        "protocolVersion": protocol_version,
440        "capabilities": {
441            "tools": { "listChanged": false }
442        },
443        "serverInfo": {
444            "name": "mcp-memory",
445            "version": env!("CARGO_PKG_VERSION")
446        },
447        "instructions": SERVER_INSTRUCTIONS
448    })
449}
450
451/// Wrap a tool execution failure as an MCP `CallToolResult` with `isError: true`
452/// so the model sees the message and can self-correct, instead of receiving an
453/// opaque JSON-RPC protocol error. (Successful results are already content-
454/// wrapped by the action handlers.)
455#[inline]
456fn tool_error(message: &str) -> Value {
457    json!({
458        "content": [{ "type": "text", "text": message }],
459        "isError": true
460    })
461}
462
463/// Constant-time bearer-token check. Accepts the raw token or a `Bearer <token>`
464/// form; surrounding whitespace is trimmed.
465pub fn token_matches(presented: &str, expected: &str) -> bool {
466    use subtle::ConstantTimeEq;
467    let presented = presented.trim();
468    let presented = presented
469        .strip_prefix("Bearer ")
470        .unwrap_or(presented)
471        .trim();
472    presented.as_bytes().ct_eq(expected.as_bytes()).into()
473}
474
475fn handle_tools_list() -> Value {
476    static CACHED: std::sync::OnceLock<Value> = std::sync::OnceLock::new();
477    if let Some(cached) = CACHED.get() {
478        return cached.clone();
479    }
480    let tools_json = include_str!("../tools.json");
481    let tools: Vec<Value> = serde_json::from_str(tools_json)
482        .expect("tools.json is valid JSON compiled at build time");
483    let result = json!({ "tools": tools });
484    let _ = CACHED.set(result.clone());
485    result
486}
487
488fn handle_tools_call(req: &JsonRpcRequest, kg: &GraphHandle) -> Result<HandlerResult> {
489    let tool_name = req
490        .params
491        .as_ref()
492        .and_then(|p| p.get("name").and_then(|v| v.as_str()))
493        .ok_or_else(|| MCSError::InvalidParams("Missing 'name' parameter".into()))?;
494
495    let tool_args = req.params.as_ref().and_then(|p| p.get("arguments"));
496
497    if !tools::tool_exists(tool_name) {
498        return Err(MCSError::MethodNotFound(tool_name.to_string()));
499    }
500
501    let result = match tool_name {
502        // Raw-result handlers (large payloads, avoid second serialization pass).
503        "read_graph" => memory::handle_read_graph(kg, tool_args).map(HandlerResult::RawResult),
504        "search_nodes" => memory::handle_search_nodes(kg, tool_args).map(HandlerResult::RawResult),
505        // Standard Value handlers.
506        "create_entities" => {
507            memory::handle_create_entities(kg, tool_args).map(HandlerResult::Value)
508        }
509        "create_relations" => {
510            memory::handle_create_relations(kg, tool_args).map(HandlerResult::Value)
511        }
512        "add_observations" => {
513            memory::handle_add_observations(kg, tool_args).map(HandlerResult::Value)
514        }
515        "delete_entities" => {
516            memory::handle_delete_entities(kg, tool_args).map(HandlerResult::Value)
517        }
518        "delete_observations" => {
519            memory::handle_delete_observations(kg, tool_args).map(HandlerResult::Value)
520        }
521        "delete_relations" => {
522            memory::handle_delete_relations(kg, tool_args).map(HandlerResult::Value)
523        }
524        "open_nodes" => memory::handle_open_nodes(kg, tool_args).map(HandlerResult::Value),
525        "get_entity" => memory::handle_get_entity(kg, tool_args).map(HandlerResult::Value),
526        "graph_stats" => memory::handle_graph_stats(kg).map(HandlerResult::Value),
527        "search_relations" => {
528            memory::handle_search_relations(kg, tool_args).map(HandlerResult::Value)
529        }
530        "find_path" => memory::handle_find_path(kg, tool_args).map(HandlerResult::Value),
531        "compact" => memory::handle_compact(kg).map(HandlerResult::Value),
532        "get_neighbors" => memory::handle_get_neighbors(kg, tool_args).map(HandlerResult::Value),
533        "describe_entity" => {
534            memory::handle_describe_entity(kg, tool_args).map(HandlerResult::Value)
535        }
536        "list_entity_types" => memory::handle_list_entity_types(kg).map(HandlerResult::Value),
537        "list_relation_types" => memory::handle_list_relation_types(kg).map(HandlerResult::Value),
538        "upsert_entities" => {
539            memory::handle_upsert_entities(kg, tool_args).map(HandlerResult::Value)
540        }
541        "export_graph" => memory::handle_export_graph(kg, tool_args).map(HandlerResult::Value),
542        "merge_entities" => memory::handle_merge_entities(kg, tool_args).map(HandlerResult::Value),
543        "extract_subgraph" => {
544            memory::handle_extract_subgraph(kg, tool_args).map(HandlerResult::Value)
545        }
546        "batch_get_entities" => {
547            memory::handle_batch_get_entities(kg, tool_args).map(HandlerResult::Value)
548        }
549        "find_all_paths" => memory::handle_find_all_paths(kg, tool_args).map(HandlerResult::Value),
550        "entity_exists" => memory::handle_entity_exists(kg, tool_args).map(HandlerResult::Value),
551        "degree" => memory::handle_degree(kg, tool_args).map(HandlerResult::Value),
552        tool => Err(MCSError::MethodNotFound(tool.to_string())),
553    };
554
555    // Tool execution failures become isError CallToolResults so the model can
556    // read the message and self-correct, instead of an opaque protocol error.
557    Ok(result.unwrap_or_else(|e| {
558        error!("Tool '{tool_name}' error: {e}");
559        HandlerResult::Value(tool_error(&e.to_string()))
560    }))
561}
562
563