Skip to main content

mcp_memory/
vector_server.rs

1use std::convert::Infallible;
2use std::path::Path;
3use std::sync::Arc;
4use std::time::Duration;
5
6use serde_json::{Value, json};
7use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
8use tokio::net::TcpListener;
9use tokio::sync::Semaphore;
10use tracing::{error, info};
11
12use crate::config::Config;
13use crate::errors::{MCSError, Result};
14use crate::kg::GraphHandle;
15use crate::protocol::{JsonRpcRequest, JsonRpcResponse};
16use crate::tools;
17use crate::vector_actions;
18use crate::vector_store::{VectorConfig, VectorStore};
19
20enum HandlerResult {
21    Value(Value),
22    RawResult(String),
23}
24
25const BUFFER_CAPACITY: usize = 65536;
26const NEWLINE: &[u8] = b"\n";
27const MAX_REQUEST_BYTES: usize = 16 * 1024 * 1024;
28const MAX_TCP_CONNECTIONS: usize = 128;
29
30#[derive(Clone, Copy, PartialEq, Eq)]
31enum LineRead {
32    Line,
33    Eof,
34    TooLong,
35}
36
37async fn read_line_capped<R>(
38    reader: &mut R,
39    out: &mut String,
40    max: usize,
41) -> std::io::Result<LineRead>
42where
43    R: AsyncBufReadExt + Unpin,
44{
45    out.clear();
46    let mut buf: Vec<u8> = Vec::new();
47    loop {
48        let available = reader.fill_buf().await?;
49        if available.is_empty() {
50            if buf.is_empty() {
51                return Ok(LineRead::Eof);
52            }
53            *out = String::from_utf8(buf).map_err(|_| {
54                std::io::Error::new(std::io::ErrorKind::InvalidData, "Non-UTF-8 input")
55            })?;
56            return Ok(LineRead::Line);
57        }
58        match available.iter().position(|&b| b == b'\n') {
59            Some(i) => {
60                if buf.len() + i + 1 > max {
61                    reader.consume(i + 1);
62                    return Ok(LineRead::TooLong);
63                }
64                buf.extend_from_slice(&available[..=i]);
65                reader.consume(i + 1);
66                *out = String::from_utf8(buf).map_err(|_| {
67                    std::io::Error::new(std::io::ErrorKind::InvalidData, "Non-UTF-8 input")
68                })?;
69                return Ok(LineRead::Line);
70            }
71            None => {
72                let take = available.len();
73                if buf.len() + take > max {
74                    reader.consume(take);
75                    return Ok(LineRead::TooLong);
76                }
77                buf.extend_from_slice(available);
78                reader.consume(take);
79            }
80        }
81    }
82}
83
84fn parse_error(msg: String) -> JsonRpcResponse {
85    let mcp_error = MCSError::ParseError(msg);
86    JsonRpcResponse::error(None, mcp_error.error_code(), mcp_error.to_string())
87}
88
89const SUPPORTED_PROTOCOL_VERSIONS: &[&str] =
90    &["2025-11-25", "2025-06-18", "2025-03-26", "2024-11-05"];
91const LATEST_PROTOCOL_VERSION: &str = "2025-11-25";
92
93const VECTOR_SERVER_INSTRUCTIONS: &str = "Knowledge-graph memory MCP server with vector search. \
94Entity names are unique and case-sensitive. Use `create_entities`/`create_relations` to build the \
95graph, and `vector_upsert_embedding` to add vector embeddings. Search semantically with \
96`vector_search_entities` or combine text + vector with `hybrid_search`. Tool failures are \
97returned with `isError: true` rather than as protocol errors.";
98
99fn handle_initialize(req: &JsonRpcRequest) -> Value {
100    let protocol_version = req
101        .params
102        .as_ref()
103        .and_then(|p| p.get("protocolVersion"))
104        .and_then(Value::as_str)
105        .filter(|v| SUPPORTED_PROTOCOL_VERSIONS.contains(v))
106        .unwrap_or(LATEST_PROTOCOL_VERSION);
107
108    json!({
109        "protocolVersion": protocol_version,
110        "capabilities": {
111            "tools": { "listChanged": false }
112        },
113        "serverInfo": {
114            "name": "mcp-memory-vec",
115            "version": env!("CARGO_PKG_VERSION")
116        },
117        "instructions": VECTOR_SERVER_INSTRUCTIONS
118    })
119}
120
121static VECTOR_TOOLS_LIST: std::sync::OnceLock<Value> = std::sync::OnceLock::new();
122
123fn handle_tools_list() -> Value {
124    if let Some(cached) = VECTOR_TOOLS_LIST.get() {
125        return cached.clone();
126    }
127    let base_tools: Vec<Value> = serde_json::from_str(include_str!("../tools.json"))
128        .expect("tools.json is valid JSON");
129    let vec_tools: Vec<Value> = serde_json::from_str(include_str!("../vector_tools.json"))
130        .expect("vector_tools.json is valid JSON");
131    let mut all = base_tools;
132    all.extend(vec_tools);
133    let result = json!({ "tools": all });
134    let _ = VECTOR_TOOLS_LIST.set(result.clone());
135    result
136}
137
138#[inline]
139fn tool_error(message: &str) -> Value {
140    json!({
141        "content": [{ "type": "text", "text": message }],
142        "isError": true
143    })
144}
145
146fn handle_tools_call(req: &JsonRpcRequest, kg: &GraphHandle, vs: &VectorStore) -> Result<HandlerResult> {
147    let tool_name = req
148        .params
149        .as_ref()
150        .and_then(|p| p.get("name").and_then(|v| v.as_str()))
151        .ok_or_else(|| MCSError::InvalidParams("Missing 'name' parameter".into()))?;
152
153    let tool_args = req.params.as_ref().and_then(|p| p.get("arguments"));
154
155    if !tools::tool_exists(tool_name) && !is_vector_tool_name(tool_name) {
156        return Err(MCSError::MethodNotFound(tool_name.to_string()));
157    }
158
159    let result = match tool_name {
160        // Vector tools
161        "vector_upsert_embedding" => {
162            vector_actions::handle_vector_upsert_embedding(vs, kg, tool_args)
163                .map(HandlerResult::Value)
164        }
165        "vector_search_entities" => {
166            vector_actions::handle_vector_search_entities(vs, kg, tool_args)
167                .map(HandlerResult::RawResult)
168        }
169        "vector_delete_embedding" => {
170            vector_actions::handle_vector_delete_embedding(vs, kg, tool_args)
171                .map(HandlerResult::Value)
172        }
173        "hybrid_search" => {
174            vector_actions::handle_hybrid_search(vs, kg, tool_args)
175                .map(HandlerResult::RawResult)
176        }
177        "vector_refresh_graph_cache" => {
178            vector_actions::handle_refresh_graph_cache(vs, kg, tool_args)
179                .map(HandlerResult::Value)
180        }
181        "vector_store_stats" => {
182            vector_actions::handle_vector_store_stats(vs, kg, tool_args)
183                .map(HandlerResult::Value)
184        }
185        // KG tools — delegate to existing handlers
186        "read_graph" | "search_nodes" => {
187            let kg_only = crate::server::dispatch_line(
188                &serialize_request(req),
189                kg,
190            );
191            match kg_only {
192                Some(resp) => {
193                    let v: Value = serde_json::from_str(&resp)
194                        .map_err(MCSError::JsonError)?;
195                    if let Some(result_val) = v.get("result") {
196                        Ok(HandlerResult::Value(result_val.clone()))
197                    } else {
198                        Err(MCSError::MemoryError("KG dispatch failed".into()))
199                    }
200                }
201                None => Ok(HandlerResult::Value(Value::Null)),
202            }
203        }
204        _ => {
205            // Delegate to existing KG handlers by calling dispatch_line
206            let kg_only = crate::server::dispatch_line(
207                &serialize_request(req),
208                kg,
209            );
210            match kg_only {
211                Some(resp) => {
212                    let v: Value = serde_json::from_str(&resp)
213                        .map_err(MCSError::JsonError)?;
214                    if let Some(result_val) = v.get("result") {
215                        Ok(HandlerResult::Value(result_val.clone()))
216                    } else {
217                        Err(MCSError::MemoryError("KG dispatch failed".into()))
218                    }
219                }
220                None => Ok(HandlerResult::Value(Value::Null)),
221            }
222        }
223    };
224
225    Ok(result.unwrap_or_else(|e| {
226        error!("Tool '{tool_name}' error: {e}");
227        HandlerResult::Value(tool_error(&e.to_string()))
228    }))
229}
230
231fn is_vector_tool_name(name: &str) -> bool {
232    matches!(
233        name,
234        "vector_upsert_embedding"
235            | "vector_search_entities"
236            | "vector_delete_embedding"
237            | "hybrid_search"
238            | "vector_refresh_graph_cache"
239            | "vector_store_stats"
240    )
241}
242
243fn serialize_request(req: &JsonRpcRequest) -> String {
244    let params = req.params.as_ref().map(|p| {
245        let name = p.get("name").cloned().unwrap_or(Value::Null);
246        let args = p.get("arguments").cloned();
247        json!({
248            "name": name,
249            "arguments": args
250        })
251    });
252    let wrapped = JsonRpcRequest {
253        jsonrpc: req.jsonrpc.clone(),
254        id: req.id.clone(),
255        method: req.method.clone(),
256        params,
257    };
258    serde_json::to_string(&wrapped).unwrap_or_default()
259}
260
261fn process_request_value(value: Value, kg: &GraphHandle, vs: &VectorStore) -> Option<Value> {
262    let req: JsonRpcRequest = match serde_json::from_value(value) {
263        Ok(r) => r,
264        Err(e) => return Some(to_value(parse_error(e.to_string()))),
265    };
266    req.id.as_ref()?;
267
268    match process_request(&req, kg, vs) {
269        Ok(HandlerResult::Value(result)) => {
270            Some(to_value(JsonRpcResponse::success(req.id, result)))
271        }
272        Ok(HandlerResult::RawResult(result_json)) => {
273            let result_val: Value = serde_json::from_str(&result_json).unwrap_or(Value::Null);
274            Some(to_value(JsonRpcResponse::success(req.id, result_val)))
275        }
276        Err(e) => Some(to_value(JsonRpcResponse::error(
277            req.id,
278            e.error_code(),
279            e.to_string(),
280        ))),
281    }
282}
283
284#[inline]
285fn to_value(resp: JsonRpcResponse) -> Value {
286    serde_json::to_value(resp).expect("JsonRpcResponse always serializes")
287}
288
289fn process_request(req: &JsonRpcRequest, kg: &GraphHandle, vs: &VectorStore) -> Result<HandlerResult> {
290    match req.method.as_str() {
291        "initialize" => Ok(HandlerResult::Value(handle_initialize(req))),
292        "tools/list" => Ok(HandlerResult::Value(handle_tools_list())),
293        "tools/call" => handle_tools_call(req, kg, vs),
294        "ping" => Ok(HandlerResult::Value(Value::Null)),
295        method if method.starts_with("notifications/") => {
296            tracing::trace!("Received notification: {method}");
297            Ok(HandlerResult::Value(Value::Null))
298        }
299        _ => Err(MCSError::MethodNotFound(req.method.clone())),
300    }
301}
302
303pub fn dispatch_line(line: &str, kg: &GraphHandle, vs: &VectorStore) -> Option<String> {
304    let trimmed = line.trim();
305    if trimmed.is_empty() {
306        return Some(serde_json::to_string(&parse_error("Empty request".into())).unwrap());
307    }
308    let raw: Value = match serde_json::from_str(trimmed) {
309        Ok(v) => v,
310        Err(e) => return Some(serde_json::to_string(&parse_error(e.to_string())).unwrap()),
311    };
312    let req: JsonRpcRequest = match serde_json::from_value(raw) {
313        Ok(r) => r,
314        Err(e) => return Some(serde_json::to_string(&parse_error(e.to_string())).unwrap()),
315    };
316    req.id.as_ref()?;
317
318    match process_request(&req, kg, vs) {
319        Ok(HandlerResult::Value(result)) => {
320            let resp = JsonRpcResponse::success(req.id, result);
321            Some(serde_json::to_string(&resp).unwrap())
322        }
323        Ok(HandlerResult::RawResult(result_json)) => {
324            let id_json = serde_json::to_string(&req.id).unwrap();
325            let mut out = String::with_capacity(64 + id_json.len() + result_json.len());
326            out.push_str(r#"{"jsonrpc":"2.0","id":"#);
327            out.push_str(&id_json);
328            out.push_str(",\"result\":");
329            out.push_str(&result_json);
330            out.push('}');
331            Some(out)
332        }
333        Err(e) => {
334            let resp = JsonRpcResponse::error(req.id, e.error_code(), e.to_string());
335            Some(serde_json::to_string(&resp).unwrap())
336        }
337    }
338}
339
340const AUTH_REQUIRED_LINE: &str = "{\"jsonrpc\":\"2.0\",\"error\":{\"code\":-32001,\
341\"message\":\"Authentication required: send the bearer token as the first line\"},\"id\":null}\n";
342
343async fn authenticate_line_conn<R>(reader: &mut R, expected: &str) -> Result<bool>
344where
345    R: AsyncBufReadExt + Unpin,
346{
347    let mut line = String::new();
348    match read_line_capped(reader, &mut line, MAX_REQUEST_BYTES)
349        .await
350        .map_err(MCSError::IoError)?
351    {
352        LineRead::Line => Ok(crate::server::token_matches(&line, expected)),
353        _ => Ok(false),
354    }
355}
356
357async fn serve_line_conn<R, W>(
358    reader: &mut R,
359    writer: &mut W,
360    kg: Arc<GraphHandle>,
361    vs: Arc<VectorStore>,
362) -> Result<()>
363where
364    R: AsyncBufReadExt + Unpin,
365    W: AsyncWriteExt + Unpin,
366{
367    let mut line = String::with_capacity(1024);
368    let mut out = Vec::with_capacity(BUFFER_CAPACITY);
369
370    loop {
371        match read_line_capped(reader, &mut line, MAX_REQUEST_BYTES).await {
372            Ok(LineRead::Eof) => break,
373            Ok(LineRead::Line) => {
374                let line_copy = line.clone();
375                let kg_clone = Arc::clone(&kg);
376                let vs_clone = Arc::clone(&vs);
377                let resp = tokio::task::spawn_blocking(move || {
378                    dispatch_line(&line_copy, &kg_clone, &vs_clone)
379                })
380                .await
381                .map_err(|join_err| {
382                    error!("dispatch task panicked: {join_err}");
383                    MCSError::IoError(std::io::Error::other("dispatch task panicked"))
384                })?;
385
386                if let Some(resp) = resp {
387                    out.clear();
388                    out.extend_from_slice(resp.as_bytes());
389                    out.extend_from_slice(NEWLINE);
390                    writer.write_all(&out).await.map_err(MCSError::IoError)?;
391                    writer.flush().await.map_err(MCSError::IoError)?;
392                }
393            }
394            Ok(LineRead::TooLong) => {
395                let err = MCSError::InvalidParams("Request exceeds maximum size of 16MB".into());
396                let response = JsonRpcResponse::error(None, err.error_code(), err.to_string());
397                out.clear();
398                serde_json::to_writer(&mut out, &response).map_err(MCSError::JsonError)?;
399                out.extend_from_slice(NEWLINE);
400                writer.write_all(&out).await.map_err(MCSError::IoError)?;
401                writer.flush().await.map_err(MCSError::IoError)?;
402                break;
403            }
404            Err(e) => {
405                error!("IO error: {}", e);
406                break;
407            }
408        }
409    }
410    Ok(())
411}
412
413fn spawn_maintenance(kg: Arc<GraphHandle>) {
414    tokio::spawn(async move {
415        let mut interval = tokio::time::interval(Duration::from_secs(300));
416        interval.tick().await;
417        loop {
418            interval.tick().await;
419            let kg = kg.clone();
420            tokio::task::spawn_blocking(move || {
421                if let Err(e) = kg.run_maintenance() {
422                    tracing::warn!("Maintenance error: {e}");
423                }
424            })
425            .await
426            .ok();
427        }
428    });
429}
430
431pub fn dispatch_http_body(
432    body: &str,
433    kg: &GraphHandle,
434    vs: &VectorStore,
435) -> std::result::Result<Option<Value>, String> {
436    let value: Value = serde_json::from_str(body.trim()).map_err(|e| e.to_string())?;
437    match value {
438        Value::Array(items) => {
439            let responses: Vec<Value> = items
440                .into_iter()
441                .filter_map(|v| process_request_value(v, kg, vs))
442                .collect();
443            Ok((!responses.is_empty()).then_some(Value::Array(responses)))
444        }
445        other => Ok(process_request_value(other, kg, vs)),
446    }
447}
448
449pub struct VectorServer {
450    config: Arc<Config>,
451    kg: Arc<GraphHandle>,
452    vs: Arc<VectorStore>,
453}
454
455impl VectorServer {
456    pub fn new(config: Config, vec_config: VectorConfig) -> Result<Self> {
457        let path = Path::new(&config.memory_file_path);
458        let lru_cache = std::num::NonZeroUsize::new(config.lru_cache_size).unwrap_or_else(|| {
459            std::num::NonZeroUsize::new(10000).expect("10000 > 0")
460        });
461        let kg = GraphHandle::new(
462            path,
463            config.durability,
464            config.mmap_size,
465            lru_cache,
466            config.read_pool_size,
467        )?;
468        let vs = VectorStore::with_config(path, &vec_config)?;
469
470        Ok(Self {
471            config: Arc::new(config),
472            kg: Arc::new(kg),
473            vs: Arc::new(vs),
474        })
475    }
476
477    pub fn graph(&self) -> Arc<GraphHandle> {
478        Arc::clone(&self.kg)
479    }
480
481    pub fn vector_store(&self) -> Arc<VectorStore> {
482        Arc::clone(&self.vs)
483    }
484
485    pub async fn run_stdio(&self) -> Result<()> {
486        spawn_maintenance(self.kg.clone());
487        let stdin = tokio::io::stdin();
488        let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, stdin);
489        let mut stdout = tokio::io::stdout();
490        serve_line_conn(&mut reader, &mut stdout, Arc::clone(&self.kg), Arc::clone(&self.vs)).await
491    }
492
493    pub async fn run_tcp(&self, addr: &str) -> Result<()> {
494        spawn_maintenance(self.kg.clone());
495        let listener = TcpListener::bind(addr).await.map_err(MCSError::IoError)?;
496        let semaphore = Arc::new(Semaphore::new(MAX_TCP_CONNECTIONS));
497        let auth_token = self.config.auth_token.clone();
498        info!(
499            "Listening for TCP MCP connections on {addr} (max {MAX_TCP_CONNECTIONS}, auth {})",
500            if auth_token.is_some() { "on" } else { "off" }
501        );
502        loop {
503            let permit = Arc::clone(&semaphore).acquire_owned().await;
504            let (socket, peer) = listener.accept().await.map_err(MCSError::IoError)?;
505            let kg = Arc::clone(&self.kg);
506            let vs = Arc::clone(&self.vs);
507            let auth_token = auth_token.clone();
508            tokio::spawn(async move {
509                let _permit = permit;
510                let (read_half, mut write_half) = socket.into_split();
511                let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, read_half);
512                if let Some(ref expected) = auth_token {
513                    match authenticate_line_conn(&mut reader, expected).await {
514                        Ok(true) => {}
515                        Ok(false) => {
516                            let _ = write_half.write_all(AUTH_REQUIRED_LINE.as_bytes()).await;
517                            let _ = write_half.flush().await;
518                            return;
519                        }
520                        Err(e) => {
521                            error!("TCP auth error for {peer}: {e}");
522                            return;
523                        }
524                    }
525                }
526                if let Err(e) = serve_line_conn(&mut reader, &mut write_half, kg, vs).await {
527                    error!("TCP connection {peer} error: {e}");
528                }
529            });
530        }
531    }
532
533    pub async fn run_http(&self, addr: &str) -> Result<()> {
534        spawn_maintenance(self.kg.clone());
535        self.run_http_inner(addr).await
536    }
537
538    async fn run_http_inner(&self, addr: &str) -> Result<()> {
539        use axum::routing::{get, post};
540        use axum::Router;
541
542        let kg = Arc::clone(&self.kg);
543        let vs = Arc::clone(&self.vs);
544        let auth_token = self.config.auth_token.clone();
545
546        let app = Router::new()
547            .route("/mcp", post(handle_http_post))
548            .route("/mcp", get(handle_http_get))
549            .with_state(HttpState { kg, vs, auth_token });
550
551        let listener = tokio::net::TcpListener::bind(addr)
552            .await
553            .map_err(MCSError::IoError)?;
554        info!("MCP Streamable HTTP listening on {addr}");
555
556        if let (Some(cert), Some(key)) = (
557            self.config.tls_cert.clone(),
558            self.config.tls_key.clone(),
559        ) {
560            let tls_config = crate::tls::server_config(&cert, &key)
561                .await
562                .map_err(MCSError::IoError)?;
563            axum_server::bind_rustls(listener.local_addr().unwrap(), tls_config)
564                .serve(app.into_make_service())
565                .await
566                .map_err(|e| MCSError::IoError(std::io::Error::other(e)))?;
567        } else {
568            axum::serve(listener, app)
569                .await
570                .map_err(|e| MCSError::IoError(std::io::Error::other(e)))?;
571        }
572        Ok(())
573    }
574}
575
576#[derive(Clone)]
577struct HttpState {
578    kg: Arc<GraphHandle>,
579    vs: Arc<VectorStore>,
580    auth_token: Option<Arc<str>>,
581}
582
583/// `true` when the request is allowed: either no token is configured, or the
584/// `Authorization` header carries the expected bearer token.
585fn http_authorized(state: &HttpState, headers: &axum::http::HeaderMap) -> bool {
586    match state.auth_token {
587        None => true,
588        Some(ref expected) => headers
589            .get(axum::http::header::AUTHORIZATION)
590            .and_then(|v| v.to_str().ok())
591            .is_some_and(|presented| crate::server::token_matches(presented, expected)),
592    }
593}
594
595async fn handle_http_post(
596    axum::extract::State(state): axum::extract::State<HttpState>,
597    headers: axum::http::HeaderMap,
598    body: String,
599) -> axum::response::Response {
600    use axum::response::sse::Event;
601    use axum::response::{IntoResponse, Json};
602    use axum::http::StatusCode;
603
604    if !http_authorized(&state, &headers) {
605        return (StatusCode::UNAUTHORIZED, "Unauthorized").into_response();
606    }
607
608    let result = tokio::task::spawn_blocking(move || {
609        dispatch_http_body(&body, &state.kg, &state.vs)
610    })
611    .await;
612
613    let outcome = match result {
614        Ok(inner) => inner,
615        Err(join_err) => {
616            error!("dispatch task panicked: {join_err}");
617            return (StatusCode::INTERNAL_SERVER_ERROR, "internal error").into_response();
618        }
619    };
620
621    match outcome {
622        Ok(None) => StatusCode::ACCEPTED.into_response(),
623        Ok(Some(value)) => {
624            let wants_sse = headers
625                .get(axum::http::header::ACCEPT)
626                .and_then(|v| v.to_str().ok())
627                .is_some_and(|a| a.contains("text/event-stream"));
628            if wants_sse {
629                let json = serde_json::to_string(&value).unwrap();
630                let stream = futures::stream::once(async move {
631                    Ok::<Event, Infallible>(Event::default().data(json))
632                });
633                axum::response::sse::Sse::new(stream).into_response()
634            } else {
635                Json(value).into_response()
636            }
637        }
638        Err(e) => {
639            let resp = json!({
640                "jsonrpc": "2.0",
641                "error": { "code": -32700, "message": format!("Parse error: {e}") },
642                "id": null
643            });
644            (StatusCode::BAD_REQUEST, Json(resp)).into_response()
645        }
646    }
647}
648
649async fn handle_http_get(
650    axum::extract::State(state): axum::extract::State<HttpState>,
651    headers: axum::http::HeaderMap,
652) -> axum::response::Response {
653    use axum::response::sse::{Event, KeepAlive, Sse};
654    use axum::response::IntoResponse;
655    use axum::http::StatusCode;
656
657    if !http_authorized(&state, &headers) {
658        return (StatusCode::UNAUTHORIZED, "Unauthorized").into_response();
659    }
660
661    let stream = futures::stream::pending::<std::result::Result<Event, Infallible>>();
662    Sse::new(stream)
663        .keep_alive(KeepAlive::default())
664        .into_response()
665}