1use std::convert::Infallible;
2use std::path::Path;
3use std::sync::Arc;
4use std::time::Duration;
5
6use serde_json::{Value, json};
7use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
8use tokio::net::TcpListener;
9use tokio::sync::Semaphore;
10use tracing::{error, info};
11
12use crate::config::Config;
13use crate::errors::{MCSError, Result};
14use crate::kg::GraphHandle;
15use crate::protocol::{JsonRpcRequest, JsonRpcResponse};
16use crate::tools;
17use crate::vector_actions;
18use crate::vector_store::{VectorConfig, VectorStore};
19
20enum HandlerResult {
21 Value(Value),
22 RawResult(String),
23}
24
25const BUFFER_CAPACITY: usize = 65536;
26const NEWLINE: &[u8] = b"\n";
27const MAX_REQUEST_BYTES: usize = 16 * 1024 * 1024;
28const MAX_TCP_CONNECTIONS: usize = 128;
29
30#[derive(Clone, Copy, PartialEq, Eq)]
31enum LineRead {
32 Line,
33 Eof,
34 TooLong,
35}
36
37async fn read_line_capped<R>(
38 reader: &mut R,
39 out: &mut String,
40 max: usize,
41) -> std::io::Result<LineRead>
42where
43 R: AsyncBufReadExt + Unpin,
44{
45 out.clear();
46 let mut buf: Vec<u8> = Vec::new();
47 loop {
48 let available = reader.fill_buf().await?;
49 if available.is_empty() {
50 if buf.is_empty() {
51 return Ok(LineRead::Eof);
52 }
53 *out = String::from_utf8(buf).map_err(|_| {
54 std::io::Error::new(std::io::ErrorKind::InvalidData, "Non-UTF-8 input")
55 })?;
56 return Ok(LineRead::Line);
57 }
58 match available.iter().position(|&b| b == b'\n') {
59 Some(i) => {
60 if buf.len() + i + 1 > max {
61 reader.consume(i + 1);
62 return Ok(LineRead::TooLong);
63 }
64 buf.extend_from_slice(&available[..=i]);
65 reader.consume(i + 1);
66 *out = String::from_utf8(buf).map_err(|_| {
67 std::io::Error::new(std::io::ErrorKind::InvalidData, "Non-UTF-8 input")
68 })?;
69 return Ok(LineRead::Line);
70 }
71 None => {
72 let take = available.len();
73 if buf.len() + take > max {
74 reader.consume(take);
75 return Ok(LineRead::TooLong);
76 }
77 buf.extend_from_slice(available);
78 reader.consume(take);
79 }
80 }
81 }
82}
83
84fn parse_error(msg: String) -> JsonRpcResponse {
85 let mcp_error = MCSError::ParseError(msg);
86 JsonRpcResponse::error(None, mcp_error.error_code(), mcp_error.to_string())
87}
88
89const SUPPORTED_PROTOCOL_VERSIONS: &[&str] =
90 &["2025-11-25", "2025-06-18", "2025-03-26", "2024-11-05"];
91const LATEST_PROTOCOL_VERSION: &str = "2025-11-25";
92
93const VECTOR_SERVER_INSTRUCTIONS: &str = "Knowledge-graph memory MCP server with vector search. \
94Entity names are unique and case-sensitive. Use `create_entities`/`create_relations` to build the \
95graph, and `vector_upsert_embedding` to add vector embeddings. Search semantically with \
96`vector_search_entities` or combine text + vector with `hybrid_search`. Tool failures are \
97returned with `isError: true` rather than as protocol errors.";
98
99fn handle_initialize(req: &JsonRpcRequest) -> Value {
100 let protocol_version = req
101 .params
102 .as_ref()
103 .and_then(|p| p.get("protocolVersion"))
104 .and_then(Value::as_str)
105 .filter(|v| SUPPORTED_PROTOCOL_VERSIONS.contains(v))
106 .unwrap_or(LATEST_PROTOCOL_VERSION);
107
108 json!({
109 "protocolVersion": protocol_version,
110 "capabilities": {
111 "tools": { "listChanged": false }
112 },
113 "serverInfo": {
114 "name": "mcp-memory-vec",
115 "version": env!("CARGO_PKG_VERSION")
116 },
117 "instructions": VECTOR_SERVER_INSTRUCTIONS
118 })
119}
120
121static VECTOR_TOOLS_LIST: std::sync::OnceLock<Value> = std::sync::OnceLock::new();
122
123fn handle_tools_list() -> Value {
124 if let Some(cached) = VECTOR_TOOLS_LIST.get() {
125 return cached.clone();
126 }
127 let base_tools: Vec<Value> = serde_json::from_str(include_str!("../tools.json"))
128 .expect("tools.json is valid JSON");
129 let vec_tools: Vec<Value> = serde_json::from_str(include_str!("../vector_tools.json"))
130 .expect("vector_tools.json is valid JSON");
131 let mut all = base_tools;
132 all.extend(vec_tools);
133 let result = json!({ "tools": all });
134 let _ = VECTOR_TOOLS_LIST.set(result.clone());
135 result
136}
137
138#[inline]
139fn tool_error(message: &str) -> Value {
140 json!({
141 "content": [{ "type": "text", "text": message }],
142 "isError": true
143 })
144}
145
146fn handle_tools_call(req: &JsonRpcRequest, kg: &GraphHandle, vs: &VectorStore) -> Result<HandlerResult> {
147 let tool_name = req
148 .params
149 .as_ref()
150 .and_then(|p| p.get("name").and_then(|v| v.as_str()))
151 .ok_or_else(|| MCSError::InvalidParams("Missing 'name' parameter".into()))?;
152
153 let tool_args = req.params.as_ref().and_then(|p| p.get("arguments"));
154
155 if !tools::tool_exists(tool_name) && !is_vector_tool_name(tool_name) {
156 return Err(MCSError::MethodNotFound(tool_name.to_string()));
157 }
158
159 let result = match tool_name {
160 "vector_upsert_embedding" => {
162 vector_actions::handle_vector_upsert_embedding(vs, kg, tool_args)
163 .map(HandlerResult::Value)
164 }
165 "vector_search_entities" => {
166 vector_actions::handle_vector_search_entities(vs, kg, tool_args)
167 .map(HandlerResult::RawResult)
168 }
169 "vector_delete_embedding" => {
170 vector_actions::handle_vector_delete_embedding(vs, kg, tool_args)
171 .map(HandlerResult::Value)
172 }
173 "hybrid_search" => {
174 vector_actions::handle_hybrid_search(vs, kg, tool_args)
175 .map(HandlerResult::RawResult)
176 }
177 "vector_refresh_graph_cache" => {
178 vector_actions::handle_refresh_graph_cache(vs, kg, tool_args)
179 .map(HandlerResult::Value)
180 }
181 "vector_store_stats" => {
182 vector_actions::handle_vector_store_stats(vs, kg, tool_args)
183 .map(HandlerResult::Value)
184 }
185 "read_graph" | "search_nodes" => {
187 let kg_only = crate::server::dispatch_line(
188 &serialize_request(req),
189 kg,
190 );
191 match kg_only {
192 Some(resp) => {
193 let v: Value = serde_json::from_str(&resp)
194 .map_err(MCSError::JsonError)?;
195 if let Some(result_val) = v.get("result") {
196 Ok(HandlerResult::Value(result_val.clone()))
197 } else {
198 Err(MCSError::MemoryError("KG dispatch failed".into()))
199 }
200 }
201 None => Ok(HandlerResult::Value(Value::Null)),
202 }
203 }
204 _ => {
205 let kg_only = crate::server::dispatch_line(
207 &serialize_request(req),
208 kg,
209 );
210 match kg_only {
211 Some(resp) => {
212 let v: Value = serde_json::from_str(&resp)
213 .map_err(MCSError::JsonError)?;
214 if let Some(result_val) = v.get("result") {
215 Ok(HandlerResult::Value(result_val.clone()))
216 } else {
217 Err(MCSError::MemoryError("KG dispatch failed".into()))
218 }
219 }
220 None => Ok(HandlerResult::Value(Value::Null)),
221 }
222 }
223 };
224
225 Ok(result.unwrap_or_else(|e| {
226 error!("Tool '{tool_name}' error: {e}");
227 HandlerResult::Value(tool_error(&e.to_string()))
228 }))
229}
230
231fn is_vector_tool_name(name: &str) -> bool {
232 matches!(
233 name,
234 "vector_upsert_embedding"
235 | "vector_search_entities"
236 | "vector_delete_embedding"
237 | "hybrid_search"
238 | "vector_refresh_graph_cache"
239 | "vector_store_stats"
240 )
241}
242
243fn serialize_request(req: &JsonRpcRequest) -> String {
244 let params = req.params.as_ref().map(|p| {
245 let name = p.get("name").cloned().unwrap_or(Value::Null);
246 let args = p.get("arguments").cloned();
247 json!({
248 "name": name,
249 "arguments": args
250 })
251 });
252 let wrapped = JsonRpcRequest {
253 jsonrpc: req.jsonrpc.clone(),
254 id: req.id.clone(),
255 method: req.method.clone(),
256 params,
257 };
258 serde_json::to_string(&wrapped).unwrap_or_default()
259}
260
261fn process_request_value(value: Value, kg: &GraphHandle, vs: &VectorStore) -> Option<Value> {
262 let req: JsonRpcRequest = match serde_json::from_value(value) {
263 Ok(r) => r,
264 Err(e) => return Some(to_value(parse_error(e.to_string()))),
265 };
266 req.id.as_ref()?;
267
268 match process_request(&req, kg, vs) {
269 Ok(HandlerResult::Value(result)) => {
270 Some(to_value(JsonRpcResponse::success(req.id, result)))
271 }
272 Ok(HandlerResult::RawResult(result_json)) => {
273 let result_val: Value = serde_json::from_str(&result_json).unwrap_or(Value::Null);
274 Some(to_value(JsonRpcResponse::success(req.id, result_val)))
275 }
276 Err(e) => Some(to_value(JsonRpcResponse::error(
277 req.id,
278 e.error_code(),
279 e.to_string(),
280 ))),
281 }
282}
283
284#[inline]
285fn to_value(resp: JsonRpcResponse) -> Value {
286 serde_json::to_value(resp).expect("JsonRpcResponse always serializes")
287}
288
289fn process_request(req: &JsonRpcRequest, kg: &GraphHandle, vs: &VectorStore) -> Result<HandlerResult> {
290 match req.method.as_str() {
291 "initialize" => Ok(HandlerResult::Value(handle_initialize(req))),
292 "tools/list" => Ok(HandlerResult::Value(handle_tools_list())),
293 "tools/call" => handle_tools_call(req, kg, vs),
294 "ping" => Ok(HandlerResult::Value(Value::Null)),
295 method if method.starts_with("notifications/") => {
296 tracing::trace!("Received notification: {method}");
297 Ok(HandlerResult::Value(Value::Null))
298 }
299 _ => Err(MCSError::MethodNotFound(req.method.clone())),
300 }
301}
302
303pub fn dispatch_line(line: &str, kg: &GraphHandle, vs: &VectorStore) -> Option<String> {
304 let trimmed = line.trim();
305 if trimmed.is_empty() {
306 return Some(serde_json::to_string(&parse_error("Empty request".into())).unwrap());
307 }
308 let raw: Value = match serde_json::from_str(trimmed) {
309 Ok(v) => v,
310 Err(e) => return Some(serde_json::to_string(&parse_error(e.to_string())).unwrap()),
311 };
312 let req: JsonRpcRequest = match serde_json::from_value(raw) {
313 Ok(r) => r,
314 Err(e) => return Some(serde_json::to_string(&parse_error(e.to_string())).unwrap()),
315 };
316 req.id.as_ref()?;
317
318 match process_request(&req, kg, vs) {
319 Ok(HandlerResult::Value(result)) => {
320 let resp = JsonRpcResponse::success(req.id, result);
321 Some(serde_json::to_string(&resp).unwrap())
322 }
323 Ok(HandlerResult::RawResult(result_json)) => {
324 let id_json = serde_json::to_string(&req.id).unwrap();
325 let mut out = String::with_capacity(64 + id_json.len() + result_json.len());
326 out.push_str(r#"{"jsonrpc":"2.0","id":"#);
327 out.push_str(&id_json);
328 out.push_str(",\"result\":");
329 out.push_str(&result_json);
330 out.push('}');
331 Some(out)
332 }
333 Err(e) => {
334 let resp = JsonRpcResponse::error(req.id, e.error_code(), e.to_string());
335 Some(serde_json::to_string(&resp).unwrap())
336 }
337 }
338}
339
340const AUTH_REQUIRED_LINE: &str = "{\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32001,\
341\"message\":\"Authentication required: send the bearer token as the first line\"},\"id\":null}\n";
342
343async fn authenticate_line_conn<R>(reader: &mut R, expected: &str) -> Result<bool>
344where
345 R: AsyncBufReadExt + Unpin,
346{
347 let mut line = String::new();
348 match read_line_capped(reader, &mut line, MAX_REQUEST_BYTES)
349 .await
350 .map_err(MCSError::IoError)?
351 {
352 LineRead::Line => Ok(crate::server::token_matches(&line, expected)),
353 _ => Ok(false),
354 }
355}
356
357async fn serve_line_conn<R, W>(
358 reader: &mut R,
359 writer: &mut W,
360 kg: Arc<GraphHandle>,
361 vs: Arc<VectorStore>,
362) -> Result<()>
363where
364 R: AsyncBufReadExt + Unpin,
365 W: AsyncWriteExt + Unpin,
366{
367 let mut line = String::with_capacity(1024);
368 let mut out = Vec::with_capacity(BUFFER_CAPACITY);
369
370 loop {
371 match read_line_capped(reader, &mut line, MAX_REQUEST_BYTES).await {
372 Ok(LineRead::Eof) => break,
373 Ok(LineRead::Line) => {
374 let line_copy = line.clone();
375 let kg_clone = Arc::clone(&kg);
376 let vs_clone = Arc::clone(&vs);
377 let resp = tokio::task::spawn_blocking(move || {
378 dispatch_line(&line_copy, &kg_clone, &vs_clone)
379 })
380 .await
381 .map_err(|join_err| {
382 error!("dispatch task panicked: {join_err}");
383 MCSError::IoError(std::io::Error::other("dispatch task panicked"))
384 })?;
385
386 if let Some(resp) = resp {
387 out.clear();
388 out.extend_from_slice(resp.as_bytes());
389 out.extend_from_slice(NEWLINE);
390 writer.write_all(&out).await.map_err(MCSError::IoError)?;
391 writer.flush().await.map_err(MCSError::IoError)?;
392 }
393 }
394 Ok(LineRead::TooLong) => {
395 let err = MCSError::InvalidParams("Request exceeds maximum size of 16MB".into());
396 let response = JsonRpcResponse::error(None, err.error_code(), err.to_string());
397 out.clear();
398 serde_json::to_writer(&mut out, &response).map_err(MCSError::JsonError)?;
399 out.extend_from_slice(NEWLINE);
400 writer.write_all(&out).await.map_err(MCSError::IoError)?;
401 writer.flush().await.map_err(MCSError::IoError)?;
402 break;
403 }
404 Err(e) => {
405 error!("IO error: {}", e);
406 break;
407 }
408 }
409 }
410 Ok(())
411}
412
413fn spawn_maintenance(kg: Arc<GraphHandle>) {
414 tokio::spawn(async move {
415 let mut interval = tokio::time::interval(Duration::from_secs(300));
416 interval.tick().await;
417 loop {
418 interval.tick().await;
419 let kg = kg.clone();
420 tokio::task::spawn_blocking(move || {
421 if let Err(e) = kg.run_maintenance() {
422 tracing::warn!("Maintenance error: {e}");
423 }
424 })
425 .await
426 .ok();
427 }
428 });
429}
430
431pub fn dispatch_http_body(
432 body: &str,
433 kg: &GraphHandle,
434 vs: &VectorStore,
435) -> std::result::Result<Option<Value>, String> {
436 let value: Value = serde_json::from_str(body.trim()).map_err(|e| e.to_string())?;
437 match value {
438 Value::Array(items) => {
439 let responses: Vec<Value> = items
440 .into_iter()
441 .filter_map(|v| process_request_value(v, kg, vs))
442 .collect();
443 Ok((!responses.is_empty()).then_some(Value::Array(responses)))
444 }
445 other => Ok(process_request_value(other, kg, vs)),
446 }
447}
448
449pub struct VectorServer {
450 config: Arc<Config>,
451 kg: Arc<GraphHandle>,
452 vs: Arc<VectorStore>,
453}
454
455impl VectorServer {
456 pub fn new(config: Config, vec_config: VectorConfig) -> Result<Self> {
457 let path = Path::new(&config.memory_file_path);
458 let lru_cache = std::num::NonZeroUsize::new(config.lru_cache_size).unwrap_or_else(|| {
459 std::num::NonZeroUsize::new(10000).expect("10000 > 0")
460 });
461 let kg = GraphHandle::new(
462 path,
463 config.durability,
464 config.mmap_size,
465 lru_cache,
466 config.read_pool_size,
467 )?;
468 let vs = VectorStore::with_config(path, &vec_config)?;
469
470 Ok(Self {
471 config: Arc::new(config),
472 kg: Arc::new(kg),
473 vs: Arc::new(vs),
474 })
475 }
476
477 pub fn graph(&self) -> Arc<GraphHandle> {
478 Arc::clone(&self.kg)
479 }
480
481 pub fn vector_store(&self) -> Arc<VectorStore> {
482 Arc::clone(&self.vs)
483 }
484
485 pub async fn run_stdio(&self) -> Result<()> {
486 spawn_maintenance(self.kg.clone());
487 let stdin = tokio::io::stdin();
488 let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, stdin);
489 let mut stdout = tokio::io::stdout();
490 serve_line_conn(&mut reader, &mut stdout, Arc::clone(&self.kg), Arc::clone(&self.vs)).await
491 }
492
493 pub async fn run_tcp(&self, addr: &str) -> Result<()> {
494 spawn_maintenance(self.kg.clone());
495 let listener = TcpListener::bind(addr).await.map_err(MCSError::IoError)?;
496 let semaphore = Arc::new(Semaphore::new(MAX_TCP_CONNECTIONS));
497 let auth_token = self.config.auth_token.clone();
498 info!(
499 "Listening for TCP MCP connections on {addr} (max {MAX_TCP_CONNECTIONS}, auth {})",
500 if auth_token.is_some() { "on" } else { "off" }
501 );
502 loop {
503 let permit = Arc::clone(&semaphore).acquire_owned().await;
504 let (socket, peer) = listener.accept().await.map_err(MCSError::IoError)?;
505 let kg = Arc::clone(&self.kg);
506 let vs = Arc::clone(&self.vs);
507 let auth_token = auth_token.clone();
508 tokio::spawn(async move {
509 let _permit = permit;
510 let (read_half, mut write_half) = socket.into_split();
511 let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, read_half);
512 if let Some(ref expected) = auth_token {
513 match authenticate_line_conn(&mut reader, expected).await {
514 Ok(true) => {}
515 Ok(false) => {
516 let _ = write_half.write_all(AUTH_REQUIRED_LINE.as_bytes()).await;
517 let _ = write_half.flush().await;
518 return;
519 }
520 Err(e) => {
521 error!("TCP auth error for {peer}: {e}");
522 return;
523 }
524 }
525 }
526 if let Err(e) = serve_line_conn(&mut reader, &mut write_half, kg, vs).await {
527 error!("TCP connection {peer} error: {e}");
528 }
529 });
530 }
531 }
532
533 pub async fn run_http(&self, addr: &str) -> Result<()> {
534 spawn_maintenance(self.kg.clone());
535 self.run_http_inner(addr).await
536 }
537
538 async fn run_http_inner(&self, addr: &str) -> Result<()> {
539 use axum::routing::{get, post};
540 use axum::Router;
541
542 let kg = Arc::clone(&self.kg);
543 let vs = Arc::clone(&self.vs);
544 let auth_token = self.config.auth_token.clone();
545
546 let app = Router::new()
547 .route("/mcp", post(handle_http_post))
548 .route("/mcp", get(handle_http_get))
549 .with_state(HttpState { kg, vs, auth_token });
550
551 let listener = tokio::net::TcpListener::bind(addr)
552 .await
553 .map_err(MCSError::IoError)?;
554 info!("MCP Streamable HTTP listening on {addr}");
555
556 if let (Some(cert), Some(key)) = (
557 self.config.tls_cert.clone(),
558 self.config.tls_key.clone(),
559 ) {
560 let tls_config = crate::tls::server_config(&cert, &key)
561 .await
562 .map_err(MCSError::IoError)?;
563 axum_server::bind_rustls(listener.local_addr().unwrap(), tls_config)
564 .serve(app.into_make_service())
565 .await
566 .map_err(|e| MCSError::IoError(std::io::Error::other(e)))?;
567 } else {
568 axum::serve(listener, app)
569 .await
570 .map_err(|e| MCSError::IoError(std::io::Error::other(e)))?;
571 }
572 Ok(())
573 }
574}
575
576#[derive(Clone)]
577struct HttpState {
578 kg: Arc<GraphHandle>,
579 vs: Arc<VectorStore>,
580 auth_token: Option<Arc<str>>,
581}
582
583fn http_authorized(state: &HttpState, headers: &axum::http::HeaderMap) -> bool {
586 match state.auth_token {
587 None => true,
588 Some(ref expected) => headers
589 .get(axum::http::header::AUTHORIZATION)
590 .and_then(|v| v.to_str().ok())
591 .is_some_and(|presented| crate::server::token_matches(presented, expected)),
592 }
593}
594
595async fn handle_http_post(
596 axum::extract::State(state): axum::extract::State<HttpState>,
597 headers: axum::http::HeaderMap,
598 body: String,
599) -> axum::response::Response {
600 use axum::response::sse::Event;
601 use axum::response::{IntoResponse, Json};
602 use axum::http::StatusCode;
603
604 if !http_authorized(&state, &headers) {
605 return (StatusCode::UNAUTHORIZED, "Unauthorized").into_response();
606 }
607
608 let result = tokio::task::spawn_blocking(move || {
609 dispatch_http_body(&body, &state.kg, &state.vs)
610 })
611 .await;
612
613 let outcome = match result {
614 Ok(inner) => inner,
615 Err(join_err) => {
616 error!("dispatch task panicked: {join_err}");
617 return (StatusCode::INTERNAL_SERVER_ERROR, "internal error").into_response();
618 }
619 };
620
621 match outcome {
622 Ok(None) => StatusCode::ACCEPTED.into_response(),
623 Ok(Some(value)) => {
624 let wants_sse = headers
625 .get(axum::http::header::ACCEPT)
626 .and_then(|v| v.to_str().ok())
627 .is_some_and(|a| a.contains("text/event-stream"));
628 if wants_sse {
629 let json = serde_json::to_string(&value).unwrap();
630 let stream = futures::stream::once(async move {
631 Ok::<Event, Infallible>(Event::default().data(json))
632 });
633 axum::response::sse::Sse::new(stream).into_response()
634 } else {
635 Json(value).into_response()
636 }
637 }
638 Err(e) => {
639 let resp = json!({
640 "jsonrpc": "2.0",
641 "error": { "code": -32700, "message": format!("Parse error: {e}") },
642 "id": null
643 });
644 (StatusCode::BAD_REQUEST, Json(resp)).into_response()
645 }
646 }
647}
648
649async fn handle_http_get(
650 axum::extract::State(state): axum::extract::State<HttpState>,
651 headers: axum::http::HeaderMap,
652) -> axum::response::Response {
653 use axum::response::sse::{Event, KeepAlive, Sse};
654 use axum::response::IntoResponse;
655 use axum::http::StatusCode;
656
657 if !http_authorized(&state, &headers) {
658 return (StatusCode::UNAUTHORIZED, "Unauthorized").into_response();
659 }
660
661 let stream = futures::stream::pending::<std::result::Result<Event, Infallible>>();
662 Sse::new(stream)
663 .keep_alive(KeepAlive::default())
664 .into_response()
665}