Skip to main content

mcp_memory/
vector_actions.rs

1use serde_json::{Value, json};
2
3use crate::errors::{MCSError, Result};
4use crate::kg::{GraphHandle, push_json_str};
5use crate::vector_store::{EntityId, VectorStore, with_scratch};
6
7type HybridResult = Vec<(String, String, f64, f64, f64)>;
8
9use rusqlite::params;
10
11const MAX_EMBEDDING_DIMS: usize = 4096;
12const MAX_TOP_K: usize = 100;
13const DEFAULT_TOP_K: usize = 10;
14const MAX_NAME_BYTES: usize = 1024;
15
16fn validate_name(name: &str) -> Result<()> {
17    if name.is_empty() {
18        return Err(MCSError::InvalidParams("Name must not be empty".into()));
19    }
20    if name.len() > MAX_NAME_BYTES {
21        return Err(MCSError::InvalidParams(format!(
22            "Name too long (max {MAX_NAME_BYTES} bytes)"
23        )));
24    }
25    Ok(())
26}
27
28fn parse_embedding(val: &Value) -> Result<Vec<f64>> {
29    let arr = val
30        .as_array()
31        .ok_or_else(|| MCSError::InvalidParams("'embedding' must be an array of numbers".into()))?;
32    if arr.is_empty() {
33        return Err(MCSError::InvalidParams("Embedding must not be empty".into()));
34    }
35    if arr.len() > MAX_EMBEDDING_DIMS {
36        return Err(MCSError::InvalidParams(format!(
37            "Embedding too large (max {MAX_EMBEDDING_DIMS} dimensions)"
38        )));
39    }
40    let emb: Vec<f64> = arr
41        .iter()
42        .map(|v| {
43            v.as_f64()
44                .ok_or_else(|| MCSError::InvalidParams("Embedding values must be numbers".into()))
45        })
46        .collect::<Result<_>>()?;
47    Ok(emb)
48}
49
50fn opt_usize(params: &Value, key: &str, default: usize) -> Result<usize> {
51    match params.get(key) {
52        None | Some(Value::Null) => Ok(default),
53        Some(v) => v.as_u64().map(|n| n as usize).ok_or_else(|| {
54            MCSError::InvalidParams(format!("'{key}' must be a non-negative integer"))
55        }),
56    }
57}
58
59fn opt_f64(params: &Value, key: &str, default: f64) -> Result<f64> {
60    match params.get(key) {
61        None | Some(Value::Null) => Ok(default),
62        Some(v) => v.as_f64().ok_or_else(|| {
63            MCSError::InvalidParams(format!("'{key}' must be a number"))
64        }),
65    }
66}
67
68fn text_content(text: &str) -> Value {
69    json!({
70        "content": [{
71            "type": "text",
72            "text": text
73        }]
74    })
75}
76
77fn build_content_response(inner_json: &str) -> String {
78    let mut out = String::with_capacity(64 + inner_json.len() + (inner_json.len() / 8));
79    out.push_str(r#"{"content":[{"type":"text","text":"#);
80    push_json_str(&mut out, inner_json);
81    out.push_str(r#"}]}"#);
82    out
83}
84
85pub fn handle_vector_upsert_embedding(
86    vs: &VectorStore,
87    _kg: &GraphHandle,
88    args: Option<&Value>,
89) -> Result<Value> {
90    let params = args.ok_or_else(|| MCSError::InvalidParams("Missing parameters".into()))?;
91
92    let entity_name = params
93        .get("entityName")
94        .and_then(|v| v.as_str())
95        .ok_or_else(|| MCSError::InvalidParams("Missing 'entityName' parameter".into()))?;
96    validate_name(entity_name)?;
97
98    let embedding = parse_embedding(
99        params
100            .get("embedding")
101            .ok_or_else(|| MCSError::InvalidParams("Missing 'embedding' parameter".into()))?,
102    )?;
103
104    let model = params
105        .get("model")
106        .and_then(|v| v.as_str())
107        .unwrap_or("");
108
109    with_scratch(|buf| {
110        buf.reserve(embedding.len());
111        buf.extend(embedding.iter().map(|&v| v as f32));
112        vs.upsert_embedding(entity_name, buf, model)
113    })?;
114
115    let text = serde_json::to_string(&json!({
116        "entityName": entity_name,
117        "dims": vs.dims(),
118        "model": model,
119    }))
120    .map_err(MCSError::JsonError)?;
121
122    Ok(text_content(&text))
123}
124
125pub fn handle_vector_search_entities(
126    vs: &VectorStore,
127    _kg: &GraphHandle,
128    args: Option<&Value>,
129) -> Result<String> {
130    let params = args.ok_or_else(|| MCSError::InvalidParams("Missing parameters".into()))?;
131
132    let embedding = parse_embedding(
133        params
134            .get("embedding")
135            .ok_or_else(|| MCSError::InvalidParams("Missing 'embedding' parameter".into()))?,
136    )?;
137
138    let top_k = opt_usize(params, "topK", DEFAULT_TOP_K)?
139        .clamp(1, MAX_TOP_K);
140
141    let entity_type = params
142        .get("entityType")
143        .and_then(|v| v.as_str())
144        .filter(|s| !s.is_empty());
145
146    let json = with_scratch(|buf| {
147        buf.reserve(embedding.len());
148        buf.extend(embedding.iter().map(|&v| v as f32));
149        vs.search_entities_json(buf, top_k, entity_type)
150    })?;
151
152    Ok(build_content_response(&json))
153}
154
155pub fn handle_vector_delete_embedding(
156    vs: &VectorStore,
157    _kg: &GraphHandle,
158    args: Option<&Value>,
159) -> Result<Value> {
160    let params = args.ok_or_else(|| MCSError::InvalidParams("Missing parameters".into()))?;
161
162    let entity_name = params
163        .get("entityName")
164        .and_then(|v| v.as_str())
165        .ok_or_else(|| MCSError::InvalidParams("Missing 'entityName' parameter".into()))?;
166    validate_name(entity_name)?;
167
168    let deleted = vs.delete_embedding(entity_name)?;
169
170    let text = serde_json::to_string(&json!({
171        "deleted": deleted,
172        "entityName": entity_name,
173    }))
174    .map_err(MCSError::JsonError)?;
175
176    Ok(text_content(&text))
177}
178
179pub fn handle_hybrid_search(
180    vs: &VectorStore,
181    kg: &GraphHandle,
182    args: Option<&Value>,
183) -> Result<String> {
184    let params = args.ok_or_else(|| MCSError::InvalidParams("Missing parameters".into()))?;
185
186    let query_text = params
187        .get("queryText")
188        .and_then(|v| v.as_str())
189        .ok_or_else(|| MCSError::InvalidParams("Missing 'queryText' parameter".into()))?;
190
191    let query_embedding = parse_embedding(
192        params
193            .get("queryEmbedding")
194            .ok_or_else(|| MCSError::InvalidParams("Missing 'queryEmbedding' parameter".into()))?,
195    )?;
196
197    let text_weight = opt_f64(params, "textWeight", 0.5)?;
198    let vec_weight = opt_f64(params, "vecWeight", 0.5)?;
199    let top_k = opt_usize(params, "topK", DEFAULT_TOP_K)?
200        .clamp(1, MAX_TOP_K);
201
202    let results = with_scratch(|buf| {
203        buf.reserve(query_embedding.len());
204        buf.extend(query_embedding.iter().map(|&v| v as f32));
205        perform_hybrid_search(vs, kg, query_text, buf, text_weight, vec_weight, top_k)
206    })?;
207
208    let mut out = String::with_capacity(128 + results.len() * 80);
209    out.push_str(r#"{"results":["#);
210    for (i, (name, etype, score, txt_score, vec_score)) in results.iter().enumerate() {
211        if i > 0 {
212            out.push(',');
213        }
214        out.push_str(r#"{"name":"#);
215        push_json_str(&mut out, name);
216        out.push_str(r#","entityType":"#);
217        push_json_str(&mut out, etype);
218        use std::fmt::Write;
219        write!(
220            out,
221            r#","score":{:.6},"textScore":{:.6},"vecScore":{:.6}}}"#,
222            score, txt_score, vec_score
223        )
224        .unwrap();
225    }
226    out.push_str(r#"],"count":"#);
227    out.push_str(&results.len().to_string());
228    out.push('}');
229
230    Ok(build_content_response(&out))
231}
232
233fn perform_hybrid_search(
234    vs: &VectorStore,
235    kg: &GraphHandle,
236    query_text: &str,
237    query_emb: &[f32],
238    text_weight: f64,
239    vec_weight: f64,
240    top_k: usize,
241) -> Result<HybridResult> {
242    let fetch_k = top_k * 3;
243    let rrf_constant = 60.0;
244
245    let vec_matches = vs.search_embeddings(query_emb, fetch_k)?;
246
247    let kg_results = kg.search_nodes_filtered(query_text, None, 0, fetch_k);
248    let mut text_matches: Vec<EntityIdAndName> = Vec::with_capacity(kg_results.len());
249    for entity in &kg_results {
250        if let Ok(Some(_)) = vs.get_entity_type(
251            vs.name_to_id.get(&entity.name).map(|r| *r.value()).unwrap_or(-1),
252        ) {
253            let id = vs.name_to_id.get(&entity.name).map(|r| *r.value());
254            text_matches.push(EntityIdAndName {
255                id: id.unwrap_or(-1),
256            });
257        } else {
258            let conn = vs.db.lock();
259            let h = crate::kg::name_hash(&entity.name);
260            let id: Option<i64> = conn
261                .query_row(
262                    "SELECT id FROM entity WHERE name_hash = ?1 AND name = ?2 AND flags = 0",
263                    params![h, entity.name],
264                    |row| row.get(0),
265                )
266                .ok();
267            text_matches.push(EntityIdAndName {
268                id: id.unwrap_or(-1),
269            });
270        }
271    }
272
273    let mut score_map: std::collections::HashMap<EntityId, AggScore> =
274        std::collections::HashMap::with_capacity(vec_matches.len() + text_matches.len());
275
276    for (rank, (id, _dist)) in vec_matches.iter().enumerate() {
277        let entry = score_map.entry(*id).or_insert_with(|| AggScore {
278            id: *id,
279            total: 0.0,
280            vec_score: 0.0,
281            text_score: 0.0,
282        });
283        let rrf = vec_weight * (1.0 / (rrf_constant + rank as f64));
284        entry.total += rrf;
285        entry.vec_score += rrf;
286    }
287
288    for (rank, tm) in text_matches.iter().enumerate() {
289        let entry = score_map.entry(tm.id).or_insert_with(|| AggScore {
290            id: tm.id,
291            total: 0.0,
292            vec_score: 0.0,
293            text_score: 0.0,
294        });
295        let rrf = text_weight * (1.0 / (rrf_constant + rank as f64));
296        entry.total += rrf;
297        entry.text_score += rrf;
298    }
299
300    let mut scored: Vec<AggScore> = score_map.into_values().collect();
301    scored.sort_unstable_by(|a, b| b.total.partial_cmp(&a.total).unwrap_or(std::cmp::Ordering::Equal));
302
303    if vs.graph_node_count() > 0 {
304        let g = vs.graph.read();
305        for entry in &mut scored {
306            if let Some(nx) = vs.node_map.get(&entry.id) {
307                let deg = g.neighbors(*nx).count() as f64;
308                if deg > 0.0 {
309                    let boost = 0.1 * (deg / (deg + 5.0));
310                    entry.total += boost;
311                }
312            }
313        }
314        scored.sort_unstable_by(|a, b| b.total.partial_cmp(&a.total).unwrap_or(std::cmp::Ordering::Equal));
315    }
316
317    let conn = vs.db.lock();
318    let mut results = Vec::with_capacity(top_k.min(scored.len()));
319    for entry in scored.iter().take(top_k) {
320        let name = vs
321            .id_to_name
322            .get(&entry.id)
323            .map(|r| r.value().clone())
324            .or_else(|| {
325                conn.query_row(
326                    "SELECT name FROM entity WHERE id = ?1 AND flags = 0",
327                    params![entry.id],
328                    |row| row.get::<_, String>(0),
329                )
330                .ok()
331            })
332            .unwrap_or_default();
333
334        let etype: String = conn
335            .query_row(
336                "SELECT t.name FROM entity e JOIN type_dict t ON t.id = e.type_id WHERE e.id = ?1 AND e.flags = 0",
337                params![entry.id],
338                |row| row.get(0),
339            )
340            .unwrap_or_default();
341
342        results.push((name, etype, entry.total, entry.text_score, entry.vec_score));
343    }
344
345    Ok(results)
346}
347
348struct EntityIdAndName {
349    id: EntityId,
350}
351
352struct AggScore {
353    id: EntityId,
354    total: f64,
355    vec_score: f64,
356    text_score: f64,
357}
358
359pub fn handle_refresh_graph_cache(
360    vs: &VectorStore,
361    _kg: &GraphHandle,
362    _args: Option<&Value>,
363) -> Result<Value> {
364    vs.rebuild_graph_cache()?;
365    let text = serde_json::to_string(&json!({
366        "nodes": vs.graph_node_count(),
367        "edges": vs.graph_edge_count(),
368    }))
369    .map_err(MCSError::JsonError)?;
370    Ok(text_content(&text))
371}
372
373pub fn handle_vector_store_stats(
374    vs: &VectorStore,
375    _kg: &GraphHandle,
376    _args: Option<&Value>,
377) -> Result<Value> {
378    let text = serde_json::to_string(&json!({
379        "embeddingCount": vs.count(),
380        "dims": vs.dims(),
381        "petgraphNodes": vs.graph_node_count(),
382        "petgraphEdges": vs.graph_edge_count(),
383    }))
384    .map_err(MCSError::JsonError)?;
385    Ok(text_content(&text))
386}