Skip to main content

mcp_memory/
server.rs

1use serde_json::{Value, json};
2use std::path::Path;
3use std::sync::Arc;
4
5use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
6use tokio::net::TcpListener;
7use tokio::sync::Semaphore;
8use tracing::{error, info};
9
10use crate::actions::memory;
11use crate::config::Config;
12use crate::errors::{MCSError, Result};
13use crate::kg::GraphHandle;
14use crate::protocol::{JsonRpcRequest, JsonRpcResponse};
15use crate::tools;
16
17/// Outcome of processing a request: either a pre-escaped JSON Value (small
18/// payloads) or a pre-serialized JSON *string* of the `result` field (avoids
19/// a second serialization pass for large payloads such as `read_graph`).
20enum HandlerResult {
21    Value(Value),
22    RawResult(String),
23}
24
25const BUFFER_CAPACITY: usize = 65536;
26const NEWLINE: &[u8] = b"\n";
27/// Maximum size of a single inbound JSON-RPC message (shared by all transports).
28pub const MAX_REQUEST_BYTES: usize = 16 * 1024 * 1024;
29/// Maximum number of concurrent TCP connections (C4).
30const MAX_TCP_CONNECTIONS: usize = 128;
31
32enum LineRead {
33    Line,
34    Eof,
35    TooLong,
36}
37
38async fn read_line_capped<R>(reader: &mut R, out: &mut String, max: usize) -> std::io::Result<LineRead>
39where
40    R: AsyncBufReadExt + Unpin,
41{
42    out.clear();
43    let mut buf: Vec<u8> = Vec::new();
44    loop {
45        let available = reader.fill_buf().await?;
46        if available.is_empty() {
47            if buf.is_empty() {
48                return Ok(LineRead::Eof);
49            }
50            *out = String::from_utf8(buf.clone()).map_err(|_| {
51                std::io::Error::new(std::io::ErrorKind::InvalidData, "Non-UTF-8 input")
52            })?;
53            return Ok(LineRead::Line);
54        }
55        match available.iter().position(|&b| b == b'\n') {
56            Some(i) => {
57                if buf.len() + i + 1 > max {
58                    reader.consume(i + 1);
59                    return Ok(LineRead::TooLong);
60                }
61                buf.extend_from_slice(&available[..=i]);
62                reader.consume(i + 1);
63                *out = String::from_utf8(buf.clone()).map_err(|_| {
64                    std::io::Error::new(std::io::ErrorKind::InvalidData, "Non-UTF-8 input")
65                })?;
66                return Ok(LineRead::Line);
67            }
68            None => {
69                let take = available.len();
70                if buf.len() + take > max {
71                    reader.consume(take);
72                    return Ok(LineRead::TooLong);
73                }
74                buf.extend_from_slice(available);
75                reader.consume(take);
76            }
77        }
78    }
79}
80
81fn parse_error(msg: String) -> JsonRpcResponse {
82    let mcp_error = MCSError::ParseError(msg);
83    JsonRpcResponse::error(None, mcp_error.error_code(), mcp_error.to_string())
84}
85
86/// Process one parsed JSON-RPC message. `None` means "no reply" — the message
87/// was a notification (no `id`), per JSON-RPC.
88pub fn process_value(value: Value, kg: &GraphHandle) -> Option<Value> {
89    let req: JsonRpcRequest = match serde_json::from_value(value) {
90        Ok(r) => r,
91        Err(e) => return Some(to_value(parse_error(e.to_string()))),
92    };
93    req.id.as_ref()?;
94    let response = match process_request(&req, kg) {
95        Ok(HandlerResult::Value(result)) => Some(to_value(JsonRpcResponse::success(req.id, result))),
96        Ok(HandlerResult::RawResult(_)) => {
97            // RawResult cannot pass through Value — dispatch_line and
98            // dispatch_http_body handle it via separate code paths.
99            unreachable!("RawResult must be handled at the dispatch level, not via process_value");
100        }
101        Err(e) => Some(to_value(JsonRpcResponse::error(req.id, e.error_code(), e.to_string()))),
102    };
103    response
104}
105
106/// Dispatch one framed line (stdio / tcp). Returns the serialized response, or
107/// `None` for a notification.
108pub fn dispatch_line(line: &str, kg: &GraphHandle) -> Option<String> {
109    let trimmed = line.trim();
110    if trimmed.is_empty() {
111        return Some(serde_json::to_string(&parse_error("Empty request".into())).unwrap());
112    }
113    let raw: Value = match serde_json::from_str(trimmed) {
114        Ok(v) => v,
115        Err(e) => return Some(serde_json::to_string(&parse_error(e.to_string())).unwrap()),
116    };
117    let req: JsonRpcRequest = match serde_json::from_value(raw) {
118        Ok(r) => r,
119        Err(e) => return Some(serde_json::to_string(&parse_error(e.to_string())).unwrap()),
120    };
121    req.id.as_ref()?;
122    match process_request(&req, kg) {
123        Ok(HandlerResult::Value(result)) => {
124            let resp = JsonRpcResponse::success(req.id, result);
125            Some(serde_json::to_string(&resp).unwrap())
126        }
127        Ok(HandlerResult::RawResult(result_json)) => {
128            let id_json = serde_json::to_string(&req.id).unwrap();
129            let mut out = String::with_capacity(64 + id_json.len() + result_json.len());
130            out.push_str(r#"{"jsonrpc":"2.0","id":"#);
131            out.push_str(&id_json);
132            out.push_str(r#","result":"#);
133            out.push_str(&result_json);
134            out.push('}');
135            Some(out)
136        }
137        Err(e) => {
138            let resp = JsonRpcResponse::error(req.id, e.error_code(), e.to_string());
139            Some(serde_json::to_string(&resp).unwrap())
140        }
141    }
142}
143
144/// Dispatch a Streamable-HTTP POST body, which may be a single JSON-RPC message
145/// or a batch array. `Ok(None)` means the body held only notifications (HTTP
146/// 202, empty body); `Err` means the body was not valid JSON.
147pub fn dispatch_http_body(
148    body: &str,
149    kg: &GraphHandle,
150) -> std::result::Result<Option<Value>, String> {
151    let value: Value = serde_json::from_str(body.trim()).map_err(|e| e.to_string())?;
152    match value {
153        Value::Array(items) => {
154            // Batches are rare and never huge — keep Value path for simplicity.
155            let responses: Vec<Value> =
156                items.into_iter().filter_map(|v| process_value_http(v, kg)).collect();
157            Ok((!responses.is_empty()).then_some(Value::Array(responses)))
158        }
159        other => Ok(process_value_http(other, kg)),
160    }
161}
162
163/// HTTP variant of process_value that handles RawResult by converting to Value
164/// (acceptable since HTTP payloads are typically much smaller in this context).
165fn process_value_http(value: Value, kg: &GraphHandle) -> Option<Value> {
166    let req: JsonRpcRequest = match serde_json::from_value(value) {
167        Ok(r) => r,
168        Err(e) => return Some(to_value(parse_error(e.to_string()))),
169    };
170    req.id.as_ref()?;
171    match process_request(&req, kg) {
172        Ok(HandlerResult::Value(result)) => {
173            Some(to_value(JsonRpcResponse::success(req.id, result)))
174        }
175        Ok(HandlerResult::RawResult(result_json)) => {
176            // Parse the pre-serialized result back into a Value for HTTP delivery.
177            // This is a small extra cost for the HTTP transport; the stdio/TCP
178            // path (dispatch_line) avoids it entirely.
179            let result_val: Value = serde_json::from_str(&result_json)
180                .unwrap_or(Value::Null);
181            Some(to_value(JsonRpcResponse::success(req.id, result_val)))
182        }
183        Err(e) => {
184            Some(to_value(JsonRpcResponse::error(req.id, e.error_code(), e.to_string())))
185        }
186    }
187}
188
189#[inline]
190fn to_value(resp: JsonRpcResponse) -> Value {
191    serde_json::to_value(resp).expect("JsonRpcResponse always serializes")
192}
193
194pub struct MCPServer {
195    config: Arc<Config>,
196    kg: Arc<GraphHandle>,
197}
198
199impl MCPServer {
200    pub fn new(config: Config) -> Result<Self> {
201        let path = Path::new(&config.memory_file_path);
202        let kg = GraphHandle::new(path, config.durability).map_err(MCSError::IoError)?;
203
204        Ok(Self {
205            config: Arc::new(config),
206            kg: Arc::new(kg),
207        })
208    }
209
210    /// Expose the shared graph handle (used to drive the HTTP transport).
211    pub fn graph(&self) -> Arc<GraphHandle> {
212        Arc::clone(&self.kg)
213    }
214
215    /// stdio transport: newline-delimited JSON-RPC over stdin/stdout.
216    pub async fn run_stdio(&self) -> Result<()> {
217        let stdin = tokio::io::stdin();
218        let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, stdin);
219        let mut stdout = tokio::io::stdout();
220        serve_line_conn(&mut reader, &mut stdout, Arc::clone(&self.kg)).await
221    }
222
223    /// TCP transport: each accepted connection speaks newline-delimited
224    /// JSON-RPC, exactly like stdio. Connections are served concurrently (up to
225    /// [`MAX_TCP_CONNECTIONS`]) and share the one graph behind its mutex.
226    pub async fn run_tcp(&self, addr: &str) -> Result<()> {
227        let listener = TcpListener::bind(addr).await.map_err(MCSError::IoError)?;
228        let semaphore = Arc::new(Semaphore::new(MAX_TCP_CONNECTIONS));
229        let auth_token = self.config.auth_token.clone();
230        info!(
231            "Listening for TCP MCP connections on {addr} (max {MAX_TCP_CONNECTIONS}, auth {})",
232            if auth_token.is_some() { "on" } else { "off" }
233        );
234        loop {
235            let permit = Arc::clone(&semaphore).acquire_owned().await;
236            let (socket, peer) = listener.accept().await.map_err(MCSError::IoError)?;
237            let kg = Arc::clone(&self.kg);
238            let auth_token = auth_token.clone();
239            tokio::spawn(async move {
240                let _permit = permit; // held for the connection lifetime
241                let (read_half, mut write_half) = socket.into_split();
242                let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, read_half);
243                // When a token is configured, the client must send it as the
244                // first line before any JSON-RPC traffic.
245                if let Some(ref expected) = auth_token {
246                    match authenticate_line_conn(&mut reader, expected).await {
247                        Ok(true) => {}
248                        Ok(false) => {
249                            let _ = write_half
250                                .write_all(AUTH_REQUIRED_LINE.as_bytes())
251                                .await;
252                            let _ = write_half.flush().await;
253                            return;
254                        }
255                        Err(e) => {
256                            error!("TCP auth error for {peer}: {e}");
257                            return;
258                        }
259                    }
260                }
261                if let Err(e) = serve_line_conn(&mut reader, &mut write_half, kg).await {
262                    error!("TCP connection {peer} error: {e}");
263                }
264            });
265        }
266    }
267
268    /// MCP Streamable HTTP transport (POST/GET `/mcp`, JSON or SSE responses).
269    pub async fn run_http(&self, addr: &str) -> Result<()> {
270        crate::http::run(addr, self.graph(), self.config.auth_token.clone()).await
271    }
272}
273
274/// JSON-RPC error line returned to a TCP client that fails authentication.
275const AUTH_REQUIRED_LINE: &str = "{\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32001,\
276\"message\":\"Authentication required: send the bearer token as the first line\"},\"id\":null}\n";
277
278/// Read the first line of a connection and compare it (constant-time) to the
279/// expected bearer token. Returns `Ok(false)` on EOF / oversized first line.
280async fn authenticate_line_conn<R>(reader: &mut R, expected: &str) -> Result<bool>
281where
282    R: AsyncBufReadExt + Unpin,
283{
284    let mut line = String::new();
285    match read_line_capped(reader, &mut line, MAX_REQUEST_BYTES)
286        .await
287        .map_err(MCSError::IoError)?
288    {
289        LineRead::Line => Ok(token_matches(&line, expected)),
290        _ => Ok(false),
291    }
292}
293
294/// Drive one line-framed connection (stdio or a single TCP socket): read
295/// newline-delimited JSON-RPC requests, write newline-delimited responses.
296/// Notifications produce no output. Returns when the peer closes the stream.
297/// The dispatch path (graph lock + optional fsync) is offloaded to
298/// [`tokio::task::spawn_blocking`] to keep the async reactor responsive (C3).
299async fn serve_line_conn<R, W>(reader: &mut R, writer: &mut W, kg: Arc<GraphHandle>) -> Result<()>
300where
301    R: AsyncBufReadExt + Unpin,
302    W: AsyncWriteExt + Unpin,
303{
304    let mut line = String::with_capacity(1024);
305    let mut out = Vec::with_capacity(BUFFER_CAPACITY);
306
307    loop {
308        match read_line_capped(reader, &mut line, MAX_REQUEST_BYTES).await {
309            Ok(LineRead::Eof) => break,
310            Ok(LineRead::Line) => {
311                let line_copy = line.clone();
312                let kg_clone = Arc::clone(&kg);
313                let resp = tokio::task::spawn_blocking(move || dispatch_line(&line_copy, &kg_clone))
314                    .await
315                    .map_err(|join_err| {
316                        error!("dispatch task panicked: {join_err}");
317                        MCSError::IoError(std::io::Error::other("dispatch task panicked"))
318                    })?;
319                if let Some(resp) = resp {
320                    out.clear();
321                    out.extend_from_slice(resp.as_bytes());
322                    out.extend_from_slice(NEWLINE);
323                    writer.write_all(&out).await.map_err(MCSError::IoError)?;
324                    writer.flush().await.map_err(MCSError::IoError)?;
325                }
326            }
327            Ok(LineRead::TooLong) => {
328                let err = MCSError::InvalidParams("Request exceeds maximum size of 16MB".into());
329                let response = JsonRpcResponse::error(None, err.error_code(), err.to_string());
330                out.clear();
331                serde_json::to_writer(&mut out, &response).map_err(MCSError::JsonError)?;
332                out.extend_from_slice(NEWLINE);
333                writer.write_all(&out).await.map_err(MCSError::IoError)?;
334                writer.flush().await.map_err(MCSError::IoError)?;
335                break;
336            }
337            Err(e) => {
338                error!("IO error: {}", e);
339                break;
340            }
341        }
342    }
343    Ok(())
344}
345
346fn process_request(req: &JsonRpcRequest, kg: &GraphHandle) -> Result<HandlerResult> {
347    match req.method.as_str() {
348        "initialize" => Ok(HandlerResult::Value(handle_initialize(req))),
349        "tools/list" => Ok(HandlerResult::Value(handle_tools_list())),
350        "tools/call" => handle_tools_call(req, kg),
351        "ping" => Ok(HandlerResult::Value(Value::Null)),
352        method if method.starts_with("notifications/") => {
353            tracing::trace!("Received notification: {method}");
354            Ok(HandlerResult::Value(Value::Null))
355        }
356        _ => Err(MCSError::MethodNotFound(req.method.clone())),
357    }
358}
359
360/// MCP protocol revisions this server can speak, newest first (for `initialize`
361/// version negotiation).
362const SUPPORTED_PROTOCOL_VERSIONS: &[&str] =
363    &["2025-11-25", "2025-06-18", "2025-03-26", "2024-11-05"];
364/// Newest revision we implement; offered when the client requests an unknown one.
365const LATEST_PROTOCOL_VERSION: &str = "2025-11-25";
366
367/// `instructions` surfaced to the client and appended to the model's system prompt.
368const SERVER_INSTRUCTIONS: &str = "Knowledge-graph memory MCP server. Entity names are unique and \
369case-sensitive. Use `create_entities`/`create_relations` to build the graph, `add_observations` to \
370attach facts, and `search_nodes`/`open_nodes`/`read_graph` to retrieve. Prefer `upsert_entities` for \
371idempotent writes and `merge_entities` to collapse duplicates. Tool failures are returned with \
372`isError: true` rather than as protocol errors — read the message and retry.";
373
374fn handle_initialize(req: &JsonRpcRequest) -> Value {
375    // Version negotiation: echo a supported requested revision, else offer latest.
376    let protocol_version = req
377        .params
378        .as_ref()
379        .and_then(|p| p.get("protocolVersion"))
380        .and_then(Value::as_str)
381        .filter(|v| SUPPORTED_PROTOCOL_VERSIONS.contains(v))
382        .unwrap_or(LATEST_PROTOCOL_VERSION);
383
384    json!({
385        "protocolVersion": protocol_version,
386        "capabilities": {
387            "tools": { "listChanged": false }
388        },
389        "serverInfo": {
390            "name": "mcp-memory",
391            "version": env!("CARGO_PKG_VERSION")
392        },
393        "instructions": SERVER_INSTRUCTIONS
394    })
395}
396
397/// Wrap a tool execution failure as an MCP `CallToolResult` with `isError: true`
398/// so the model sees the message and can self-correct, instead of receiving an
399/// opaque JSON-RPC protocol error. (Successful results are already content-
400/// wrapped by the action handlers.)
401#[inline]
402fn tool_error(message: &str) -> Value {
403    json!({
404        "content": [{ "type": "text", "text": message }],
405        "isError": true
406    })
407}
408
409/// Constant-time bearer-token check. Accepts the raw token or a `Bearer <token>`
410/// form; surrounding whitespace is trimmed.
411pub fn token_matches(presented: &str, expected: &str) -> bool {
412    use subtle::ConstantTimeEq;
413    let presented = presented.trim();
414    let presented = presented.strip_prefix("Bearer ").unwrap_or(presented).trim();
415    presented.as_bytes().ct_eq(expected.as_bytes()).into()
416}
417
418fn handle_tools_list() -> Value {
419    static CACHED: std::sync::OnceLock<Value> = std::sync::OnceLock::new();
420    if let Some(cached) = CACHED.get() {
421        return cached.clone();
422    }
423    let tools_json = include_str!("../tools.json");
424    let tools: Vec<Value> =
425        serde_json::from_str(tools_json).map_err(MCSError::JsonError).unwrap();
426    let result = json!({ "tools": tools });
427    let _ = CACHED.set(result.clone());
428    result
429}
430
431fn handle_tools_call(req: &JsonRpcRequest, kg: &GraphHandle) -> Result<HandlerResult> {
432    let tool_name = req
433        .params
434        .as_ref()
435        .and_then(|p| p.get("name").and_then(|v| v.as_str()))
436        .ok_or_else(|| MCSError::InvalidParams("Missing 'name' parameter".into()))?;
437
438    let tool_args = req.params.as_ref().and_then(|p| p.get("arguments"));
439
440    if !tools::tool_exists(tool_name) {
441        return Err(MCSError::MethodNotFound(tool_name.to_string()));
442    }
443
444    let result = match tool_name {
445        // Raw-result handlers (large payloads, avoid second serialization pass).
446        "read_graph" => memory::handle_read_graph(kg, tool_args).map(HandlerResult::RawResult),
447        "search_nodes" => memory::handle_search_nodes(kg, tool_args).map(HandlerResult::RawResult),
448        // Standard Value handlers.
449        "create_entities" => memory::handle_create_entities(kg, tool_args).map(HandlerResult::Value),
450        "create_relations" => memory::handle_create_relations(kg, tool_args).map(HandlerResult::Value),
451        "add_observations" => memory::handle_add_observations(kg, tool_args).map(HandlerResult::Value),
452        "delete_entities" => memory::handle_delete_entities(kg, tool_args).map(HandlerResult::Value),
453        "delete_observations" => memory::handle_delete_observations(kg, tool_args).map(HandlerResult::Value),
454        "delete_relations" => memory::handle_delete_relations(kg, tool_args).map(HandlerResult::Value),
455        "open_nodes" => memory::handle_open_nodes(kg, tool_args).map(HandlerResult::Value),
456        "get_entity" => memory::handle_get_entity(kg, tool_args).map(HandlerResult::Value),
457        "graph_stats" => memory::handle_graph_stats(kg).map(HandlerResult::Value),
458        "search_relations" => memory::handle_search_relations(kg, tool_args).map(HandlerResult::Value),
459        "find_path" => memory::handle_find_path(kg, tool_args).map(HandlerResult::Value),
460        "compact" => memory::handle_compact(kg).map(HandlerResult::Value),
461        "get_neighbors" => memory::handle_get_neighbors(kg, tool_args).map(HandlerResult::Value),
462        "describe_entity" => memory::handle_describe_entity(kg, tool_args).map(HandlerResult::Value),
463        "list_entity_types" => memory::handle_list_entity_types(kg).map(HandlerResult::Value),
464        "list_relation_types" => memory::handle_list_relation_types(kg).map(HandlerResult::Value),
465        "upsert_entities" => memory::handle_upsert_entities(kg, tool_args).map(HandlerResult::Value),
466        "export_graph" => memory::handle_export_graph(kg, tool_args).map(HandlerResult::Value),
467        "merge_entities" => memory::handle_merge_entities(kg, tool_args).map(HandlerResult::Value),
468        "extract_subgraph" => memory::handle_extract_subgraph(kg, tool_args).map(HandlerResult::Value),
469        "batch_get_entities" => memory::handle_batch_get_entities(kg, tool_args).map(HandlerResult::Value),
470        "find_all_paths" => memory::handle_find_all_paths(kg, tool_args).map(HandlerResult::Value),
471        tool => Err(MCSError::MethodNotFound(tool.to_string())),
472    };
473
474    // Tool execution failures become isError CallToolResults so the model can
475    // read the message and self-correct, instead of an opaque protocol error.
476    Ok(result.unwrap_or_else(|e| {
477        error!("Tool '{tool_name}' error: {e}");
478        HandlerResult::Value(tool_error(&e.to_string()))
479    }))
480}
481
482#[cfg(test)]
483mod tests {
484    use super::*;
485    use crate::config::Durability;
486    use std::sync::atomic::{AtomicU64, Ordering};
487
488    static COUNTER: AtomicU64 = AtomicU64::new(0);
489
490    fn setup_kg() -> (Arc<GraphHandle>, String) {
491        let pid = std::process::id();
492        let seq = COUNTER.fetch_add(1, Ordering::SeqCst);
493        let path = format!("/tmp/mcp_mem_test_{pid}_{seq}.bin");
494        let kg = GraphHandle::new(Path::new(&path), Durability::Async).unwrap();
495        (Arc::new(kg), path)
496    }
497
498    fn cleanup(path: &str) {
499        let _ = std::fs::remove_file(path);
500    }
501
502    #[test]
503    fn test_dispatch_line_valid_request() {
504        let (kg, path) = setup_kg();
505        let line = r#"{"jsonrpc":"2.0","method":"initialize","id":1}"#;
506        let resp = dispatch_line(line, &kg).unwrap();
507        let v: Value = serde_json::from_str(&resp).unwrap();
508        assert_eq!(v["id"], 1);
509        assert_eq!(v["result"]["serverInfo"]["name"], "mcp-memory");
510        cleanup(&path);
511    }
512
513    #[test]
514    fn test_dispatch_line_invalid_json() {
515        let (kg, path) = setup_kg();
516        let resp = dispatch_line("{invalid}", &kg).unwrap();
517        let v: Value = serde_json::from_str(&resp).unwrap();
518        assert_eq!(v["error"]["code"], -32700);
519        assert!(v["id"].is_null());
520        cleanup(&path);
521    }
522
523    #[test]
524    fn test_dispatch_line_empty() {
525        let (kg, path) = setup_kg();
526        let resp = dispatch_line("   \n", &kg).unwrap();
527        let v: Value = serde_json::from_str(&resp).unwrap();
528        assert_eq!(v["error"]["code"], -32700);
529        cleanup(&path);
530    }
531
532    #[test]
533    fn test_notification_has_no_response() {
534        let (kg, path) = setup_kg();
535        let line = r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#;
536        assert!(dispatch_line(line, &kg).is_none());
537        cleanup(&path);
538    }
539
540    #[test]
541    fn test_unknown_method_error() {
542        let (kg, path) = setup_kg();
543        let line = r#"{"jsonrpc":"2.0","method":"does/not/exist","id":7}"#;
544        let v: Value = serde_json::from_str(&dispatch_line(line, &kg).unwrap()).unwrap();
545        assert_eq!(v["id"], 7);
546        assert_eq!(v["error"]["code"], -32601);
547        cleanup(&path);
548    }
549
550    #[test]
551    fn test_tools_call_roundtrip_via_dispatch() {
552        let (kg, path) = setup_kg();
553        let create = r#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"create_entities","arguments":{"entities":[{"name":"Ada","entityType":"person","observations":["math"]}]}}}"#;
554        assert!(dispatch_line(create, &kg).is_some());
555
556        let read = r#"{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"read_graph","arguments":{}}}"#;
557        let v: Value = serde_json::from_str(&dispatch_line(read, &kg).unwrap()).unwrap();
558        let text = v["result"]["content"][0]["text"].as_str().unwrap();
559        assert!(text.contains("Ada"));
560        cleanup(&path);
561    }
562
563    #[test]
564    fn test_http_body_batch_and_notifications() {
565        let (kg, path) = setup_kg();
566        let batch = r#"[
567            {"jsonrpc":"2.0","method":"initialize","id":1},
568            {"jsonrpc":"2.0","method":"notifications/initialized"}
569        ]"#;
570        let out = dispatch_http_body(batch, &kg).unwrap().unwrap();
571        let arr = out.as_array().unwrap();
572        assert_eq!(arr.len(), 1);
573        assert_eq!(arr[0]["id"], 1);
574
575        let notif = r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#;
576        assert!(dispatch_http_body(notif, &kg).unwrap().is_none());
577
578        assert!(dispatch_http_body("{bad", &kg).is_err());
579        cleanup(&path);
580    }
581
582    #[test]
583    fn test_handle_initialize_response() {
584        let (kg, path) = setup_kg();
585        let req = JsonRpcRequest {
586            jsonrpc: "2.0".to_string(),
587            method: "initialize".to_string(),
588            params: None,
589            id: Some(Value::Number(1.into())),
590        };
591        let result = match process_request(&req, &kg).unwrap() {
592            HandlerResult::Value(v) => v,
593            HandlerResult::RawResult(_) => panic!("expected Value"),
594        };
595        assert_eq!(result["protocolVersion"], LATEST_PROTOCOL_VERSION);
596        assert_eq!(result["serverInfo"]["name"], "mcp-memory");
597        assert!(result["instructions"].is_string());
598        cleanup(&path);
599    }
600
601    #[test]
602    fn test_initialize_version_negotiation() {
603        let (kg, path) = setup_kg();
604        let req = JsonRpcRequest {
605            jsonrpc: "2.0".to_string(),
606            method: "initialize".to_string(),
607            params: Some(json!({ "protocolVersion": "2024-11-05" })),
608            id: Some(Value::Number(1.into())),
609        };
610        let result = match process_request(&req, &kg).unwrap() {
611            HandlerResult::Value(v) => v,
612            HandlerResult::RawResult(_) => panic!("expected Value"),
613        };
614        assert_eq!(result["protocolVersion"], "2024-11-05");
615        cleanup(&path);
616    }
617
618    #[test]
619    fn test_tool_error_on_bad_args() {
620        // A tool that fails returns an isError CallToolResult, not a protocol error.
621        let (kg, path) = setup_kg();
622        let line = r#"{"jsonrpc":"2.0","id":9,"method":"tools/call","params":{"name":"get_entity","arguments":{}}}"#;
623        let v: Value = serde_json::from_str(&dispatch_line(line, &kg).unwrap()).unwrap();
624        assert!(v["error"].is_null(), "should not be a protocol error: {v}");
625        assert_eq!(v["result"]["isError"], json!(true));
626        cleanup(&path);
627    }
628
629    #[test]
630    fn test_token_matches() {
631        assert!(token_matches("secret", "secret"));
632        assert!(token_matches("Bearer secret", "secret"));
633        assert!(token_matches("  Bearer secret  ", "secret"));
634        assert!(!token_matches("wrong", "secret"));
635        assert!(!token_matches("", "secret"));
636        // Length-mismatch and prefix cases.
637        assert!(!token_matches("secre", "secret"));
638        assert!(!token_matches("secretx", "secret"));
639        assert!(!token_matches("Bearer ", "secret"));
640    }
641
642    // ── MCP 2025-11-25 compliance ─────────────────────────────────────────
643
644    fn init_result(params: Option<Value>, kg: &GraphHandle) -> Value {
645        let req = JsonRpcRequest {
646            jsonrpc: "2.0".to_string(),
647            method: "initialize".to_string(),
648            params,
649            id: Some(json!(1)),
650        };
651        match process_request(&req, kg).unwrap() {
652            HandlerResult::Value(v) => v,
653            HandlerResult::RawResult(_) => panic!("expected Value"),
654        }
655    }
656
657    #[test]
658    fn test_compliance_negotiation_matrix() {
659        let (kg, path) = setup_kg();
660        for v in SUPPORTED_PROTOCOL_VERSIONS {
661            let r = init_result(Some(json!({ "protocolVersion": v })), &kg);
662            assert_eq!(&r["protocolVersion"], v);
663        }
664        // Unsupported / malformed → latest.
665        assert_eq!(
666            init_result(Some(json!({ "protocolVersion": "1900-01-01" })), &kg)["protocolVersion"],
667            LATEST_PROTOCOL_VERSION
668        );
669        assert_eq!(init_result(None, &kg)["protocolVersion"], "2025-11-25");
670        cleanup(&path);
671    }
672
673    #[test]
674    fn test_compliance_initialize_honest_with_instructions() {
675        let (kg, path) = setup_kg();
676        let r = init_result(None, &kg);
677        assert!(r["capabilities"]["tools"].is_object());
678        for cap in ["resources", "prompts", "logging", "completions"] {
679            assert!(r["capabilities"][cap].is_null(), "must not advertise {cap}");
680        }
681        assert!(r["instructions"].as_str().is_some_and(|s| !s.is_empty()));
682        cleanup(&path);
683    }
684
685    #[test]
686    fn test_compliance_tool_success_is_content_wrapped() {
687        let (kg, path) = setup_kg();
688        let create = r#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"create_entities","arguments":{"entities":[{"name":"Ada","entityType":"person","observations":["math"]}]}}}"#;
689        let v: Value = serde_json::from_str(&dispatch_line(create, &kg).unwrap()).unwrap();
690        let content = v["result"]["content"].as_array().expect("content array");
691        assert!(!content.is_empty());
692        assert_eq!(content[0]["type"], "text");
693        assert!(v["error"].is_null());
694        cleanup(&path);
695    }
696
697    #[test]
698    fn test_compliance_protocol_errors_remain_protocol_errors() {
699        let (kg, path) = setup_kg();
700        // Unknown tool → JSON-RPC error, not an isError result.
701        let line = r#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"no_such_tool","arguments":{}}}"#;
702        let v: Value = serde_json::from_str(&dispatch_line(line, &kg).unwrap()).unwrap();
703        assert_eq!(v["error"]["code"], -32601);
704        assert!(v["result"].is_null());
705        cleanup(&path);
706    }
707
708    #[tokio::test]
709    async fn test_compliance_tcp_auth_handshake() {
710        // First-line token handshake used by the TCP transport.
711        let mut ok = tokio::io::BufReader::new(&b"Bearer s3cr3t\n"[..]);
712        assert!(authenticate_line_conn(&mut ok, "s3cr3t").await.unwrap());
713
714        let mut bad = tokio::io::BufReader::new(&b"nope\n"[..]);
715        assert!(!authenticate_line_conn(&mut bad, "s3cr3t").await.unwrap());
716
717        // EOF / no first line → rejected.
718        let mut empty = tokio::io::BufReader::new(&b""[..]);
719        assert!(!authenticate_line_conn(&mut empty, "s3cr3t").await.unwrap());
720    }
721}