Skip to main content

mcp_memory/
vector_actions.rs

1use serde_json::{Value, json};
2
3use crate::errors::{MCSError, Result};
4use crate::kg::{GraphHandle, push_json_str};
5use crate::vector_store::{EntityId, VectorStore, with_scratch};
6use rustc_hash::FxHashMap;
7
8type HybridResult = Vec<(String, String, f64, f64, f64)>;
9
10use rusqlite::params;
11
12const MAX_EMBEDDING_DIMS: usize = 4096;
13const MAX_TOP_K: usize = 100;
14const DEFAULT_TOP_K: usize = 10;
15const MAX_NAME_BYTES: usize = 1024;
16/// Cap on items in a single `vector_batch_upsert` call.
17const MAX_BATCH_ITEMS: usize = 1024;
18
19fn validate_name(name: &str) -> Result<()> {
20    if name.is_empty() {
21        return Err(MCSError::InvalidParams("Name must not be empty".into()));
22    }
23    if name.len() > MAX_NAME_BYTES {
24        return Err(MCSError::InvalidParams(format!(
25            "Name too long (max {MAX_NAME_BYTES} bytes)"
26        )));
27    }
28    Ok(())
29}
30
31fn parse_embedding(val: &Value) -> Result<Vec<f64>> {
32    let arr = val
33        .as_array()
34        .ok_or_else(|| MCSError::InvalidParams("'embedding' must be an array of numbers".into()))?;
35    if arr.is_empty() {
36        return Err(MCSError::InvalidParams("Embedding must not be empty".into()));
37    }
38    if arr.len() > MAX_EMBEDDING_DIMS {
39        return Err(MCSError::InvalidParams(format!(
40            "Embedding too large (max {MAX_EMBEDDING_DIMS} dimensions)"
41        )));
42    }
43    let emb: Vec<f64> = arr
44        .iter()
45        .map(|v| {
46            v.as_f64()
47                .ok_or_else(|| MCSError::InvalidParams("Embedding values must be numbers".into()))
48        })
49        .collect::<Result<_>>()?;
50    Ok(emb)
51}
52
53fn opt_usize(params: &Value, key: &str, default: usize) -> Result<usize> {
54    match params.get(key) {
55        None | Some(Value::Null) => Ok(default),
56        Some(v) => v.as_u64().map(|n| n as usize).ok_or_else(|| {
57            MCSError::InvalidParams(format!("'{key}' must be a non-negative integer"))
58        }),
59    }
60}
61
62fn opt_f64(params: &Value, key: &str, default: f64) -> Result<f64> {
63    match params.get(key) {
64        None | Some(Value::Null) => Ok(default),
65        Some(v) => v.as_f64().ok_or_else(|| {
66            MCSError::InvalidParams(format!("'{key}' must be a number"))
67        }),
68    }
69}
70
71fn text_content(text: &str) -> Value {
72    json!({
73        "content": [{
74            "type": "text",
75            "text": text
76        }]
77    })
78}
79
80fn build_content_response(inner_json: &str) -> String {
81    let mut out = String::with_capacity(64 + inner_json.len() + (inner_json.len() / 8));
82    out.push_str(r#"{"content":[{"type":"text","text":"#);
83    push_json_str(&mut out, inner_json);
84    out.push_str(r#"}]}"#);
85    out
86}
87
88pub fn handle_vector_upsert_embedding(
89    vs: &VectorStore,
90    _kg: &GraphHandle,
91    args: Option<&Value>,
92) -> Result<Value> {
93    let params = args.ok_or_else(|| MCSError::InvalidParams("Missing parameters".into()))?;
94
95    let entity_name = params
96        .get("entityName")
97        .and_then(|v| v.as_str())
98        .ok_or_else(|| MCSError::InvalidParams("Missing 'entityName' parameter".into()))?;
99    validate_name(entity_name)?;
100
101    let embedding = parse_embedding(
102        params
103            .get("embedding")
104            .ok_or_else(|| MCSError::InvalidParams("Missing 'embedding' parameter".into()))?,
105    )?;
106
107    let model = params
108        .get("model")
109        .and_then(|v| v.as_str())
110        .unwrap_or("");
111
112    with_scratch(|buf| {
113        buf.reserve(embedding.len());
114        buf.extend(embedding.iter().map(|&v| v as f32));
115        vs.upsert_embedding(entity_name, buf, model)
116    })?;
117
118    let text = serde_json::to_string(&json!({
119        "entityName": entity_name,
120        "dims": vs.dims(),
121        "model": model,
122    }))
123    .map_err(MCSError::JsonError)?;
124
125    Ok(text_content(&text))
126}
127
128pub fn handle_vector_search_entities(
129    vs: &VectorStore,
130    _kg: &GraphHandle,
131    args: Option<&Value>,
132) -> Result<String> {
133    let params = args.ok_or_else(|| MCSError::InvalidParams("Missing parameters".into()))?;
134
135    let embedding = parse_embedding(
136        params
137            .get("embedding")
138            .ok_or_else(|| MCSError::InvalidParams("Missing 'embedding' parameter".into()))?,
139    )?;
140
141    let top_k = opt_usize(params, "topK", DEFAULT_TOP_K)?
142        .clamp(1, MAX_TOP_K);
143
144    let entity_type = params
145        .get("entityType")
146        .and_then(|v| v.as_str())
147        .filter(|s| !s.is_empty());
148
149    let json = with_scratch(|buf| {
150        buf.reserve(embedding.len());
151        buf.extend(embedding.iter().map(|&v| v as f32));
152        vs.search_entities_json(buf, top_k, entity_type)
153    })?;
154
155    Ok(build_content_response(&json))
156}
157
158pub fn handle_vector_delete_embedding(
159    vs: &VectorStore,
160    _kg: &GraphHandle,
161    args: Option<&Value>,
162) -> Result<Value> {
163    let params = args.ok_or_else(|| MCSError::InvalidParams("Missing parameters".into()))?;
164
165    let entity_name = params
166        .get("entityName")
167        .and_then(|v| v.as_str())
168        .ok_or_else(|| MCSError::InvalidParams("Missing 'entityName' parameter".into()))?;
169    validate_name(entity_name)?;
170
171    let deleted = vs.delete_embedding(entity_name)?;
172
173    let text = serde_json::to_string(&json!({
174        "deleted": deleted,
175        "entityName": entity_name,
176    }))
177    .map_err(MCSError::JsonError)?;
178
179    Ok(text_content(&text))
180}
181
182pub fn handle_hybrid_search(
183    vs: &VectorStore,
184    kg: &GraphHandle,
185    args: Option<&Value>,
186) -> Result<String> {
187    let params = args.ok_or_else(|| MCSError::InvalidParams("Missing parameters".into()))?;
188
189    let query_text = params
190        .get("queryText")
191        .and_then(|v| v.as_str())
192        .ok_or_else(|| MCSError::InvalidParams("Missing 'queryText' parameter".into()))?;
193
194    let query_embedding = parse_embedding(
195        params
196            .get("queryEmbedding")
197            .ok_or_else(|| MCSError::InvalidParams("Missing 'queryEmbedding' parameter".into()))?,
198    )?;
199
200    let text_weight = opt_f64(params, "textWeight", 0.5)?;
201    let vec_weight = opt_f64(params, "vecWeight", 0.5)?;
202    let top_k = opt_usize(params, "topK", DEFAULT_TOP_K)?
203        .clamp(1, MAX_TOP_K);
204
205    let results = with_scratch(|buf| {
206        buf.reserve(query_embedding.len());
207        buf.extend(query_embedding.iter().map(|&v| v as f32));
208        perform_hybrid_search(vs, kg, query_text, buf, text_weight, vec_weight, top_k)
209    })?;
210
211    let mut out = String::with_capacity(128 + results.len() * 80);
212    out.push_str(r#"{"results":["#);
213    for (i, (name, etype, score, txt_score, vec_score)) in results.iter().enumerate() {
214        if i > 0 {
215            out.push(',');
216        }
217        out.push_str(r#"{"name":"#);
218        push_json_str(&mut out, name);
219        out.push_str(r#","entityType":"#);
220        push_json_str(&mut out, etype);
221        use std::fmt::Write;
222        write!(
223            out,
224            r#","score":{:.6},"textScore":{:.6},"vecScore":{:.6}}}"#,
225            score, txt_score, vec_score
226        )
227        .unwrap();
228    }
229    out.push_str(r#"],"count":"#);
230    out.push_str(&results.len().to_string());
231    out.push('}');
232
233    Ok(build_content_response(&out))
234}
235
236fn perform_hybrid_search(
237    vs: &VectorStore,
238    kg: &GraphHandle,
239    query_text: &str,
240    query_emb: &[f32],
241    text_weight: f64,
242    vec_weight: f64,
243    top_k: usize,
244) -> Result<HybridResult> {
245    let fetch_k = top_k * 3;
246    let rrf_constant = 60.0;
247
248    let vec_matches = vs.search_embeddings(query_emb, fetch_k)?;
249
250    let kg_results = kg.search_nodes_filtered(query_text, None, 0, fetch_k);
251    let mut text_matches: Vec<EntityIdAndName> = Vec::with_capacity(kg_results.len());
252    for entity in &kg_results {
253        if let Ok(Some(_)) = vs.get_entity_type(
254            vs.name_to_id.get(&entity.name).map(|r| *r.value()).unwrap_or(-1),
255        ) {
256            let id = vs.name_to_id.get(&entity.name).map(|r| *r.value());
257            text_matches.push(EntityIdAndName {
258                id: id.unwrap_or(-1),
259            });
260        } else {
261            let conn = vs.db.lock();
262            let h = crate::kg::name_hash(&entity.name);
263            let id: Option<i64> = conn
264                .query_row(
265                    "SELECT id FROM entity WHERE name_hash = ?1 AND name = ?2 AND flags = 0",
266                    params![h, entity.name],
267                    |row| row.get(0),
268                )
269                .ok();
270            text_matches.push(EntityIdAndName {
271                id: id.unwrap_or(-1),
272            });
273        }
274    }
275
276    let mut score_map: FxHashMap<EntityId, AggScore> = FxHashMap::with_capacity_and_hasher(
277        vec_matches.len() + text_matches.len(),
278        rustc_hash::FxBuildHasher,
279    );
280
281    for (rank, (id, _dist)) in vec_matches.iter().enumerate() {
282        let entry = score_map.entry(*id).or_insert_with(|| AggScore {
283            id: *id,
284            total: 0.0,
285            vec_score: 0.0,
286            text_score: 0.0,
287        });
288        let rrf = vec_weight * (1.0 / (rrf_constant + rank as f64));
289        entry.total += rrf;
290        entry.vec_score += rrf;
291    }
292
293    for (rank, tm) in text_matches.iter().enumerate() {
294        let entry = score_map.entry(tm.id).or_insert_with(|| AggScore {
295            id: tm.id,
296            total: 0.0,
297            vec_score: 0.0,
298            text_score: 0.0,
299        });
300        let rrf = text_weight * (1.0 / (rrf_constant + rank as f64));
301        entry.total += rrf;
302        entry.text_score += rrf;
303    }
304
305    let mut scored: Vec<AggScore> = score_map.into_values().collect();
306    scored.sort_unstable_by(|a, b| b.total.partial_cmp(&a.total).unwrap_or(std::cmp::Ordering::Equal));
307
308    if vs.graph_node_count() > 0 {
309        let g = vs.graph.read();
310        for entry in &mut scored {
311            if let Some(nx) = vs.node_map.get(&entry.id) {
312                let deg = g.neighbors(*nx).count() as f64;
313                if deg > 0.0 {
314                    let boost = 0.1 * (deg / (deg + 5.0));
315                    entry.total += boost;
316                }
317            }
318        }
319        scored.sort_unstable_by(|a, b| b.total.partial_cmp(&a.total).unwrap_or(std::cmp::Ordering::Equal));
320    }
321
322    let conn = vs.db.lock();
323    let mut results = Vec::with_capacity(top_k.min(scored.len()));
324    for entry in scored.iter().take(top_k) {
325        let name = vs
326            .id_to_name
327            .get(&entry.id)
328            .map(|r| r.value().clone())
329            .or_else(|| {
330                conn.query_row(
331                    "SELECT name FROM entity WHERE id = ?1 AND flags = 0",
332                    params![entry.id],
333                    |row| row.get::<_, String>(0),
334                )
335                .ok()
336            })
337            .unwrap_or_default();
338
339        let etype: String = conn
340            .query_row(
341                "SELECT t.name FROM entity e JOIN type_dict t ON t.id = e.type_id WHERE e.id = ?1 AND e.flags = 0",
342                params![entry.id],
343                |row| row.get(0),
344            )
345            .unwrap_or_default();
346
347        results.push((name, etype, entry.total, entry.text_score, entry.vec_score));
348    }
349
350    Ok(results)
351}
352
353struct EntityIdAndName {
354    id: EntityId,
355}
356
357struct AggScore {
358    id: EntityId,
359    total: f64,
360    vec_score: f64,
361    text_score: f64,
362}
363
364pub fn handle_refresh_graph_cache(
365    vs: &VectorStore,
366    _kg: &GraphHandle,
367    _args: Option<&Value>,
368) -> Result<Value> {
369    vs.rebuild_graph_cache()?;
370    let text = serde_json::to_string(&json!({
371        "nodes": vs.graph_node_count(),
372        "edges": vs.graph_edge_count(),
373    }))
374    .map_err(MCSError::JsonError)?;
375    Ok(text_content(&text))
376}
377
378pub fn handle_vector_store_stats(
379    vs: &VectorStore,
380    _kg: &GraphHandle,
381    _args: Option<&Value>,
382) -> Result<Value> {
383    let (graph_bytes, vectors_bytes) = vs.index_memory_breakdown();
384    let index_kind = match vs.index_kind() {
385        crate::vector_store::IndexKind::Hnsw => "hnsw",
386        crate::vector_store::IndexKind::Ivf => "ivf",
387    };
388    let text = serde_json::to_string(&json!({
389        "embeddingCount": vs.count(),
390        "dims": vs.dims(),
391        "indexKind": index_kind,
392        "petgraphNodes": vs.graph_node_count(),
393        "petgraphEdges": vs.graph_edge_count(),
394        "indexCapacity": vs.index_capacity(),
395        "indexMemoryBytes": vs.index_memory_bytes(),
396        "indexGraphBytes": graph_bytes,
397        "indexVectorsBytes": vectors_bytes,
398    }))
399    .map_err(MCSError::JsonError)?;
400    Ok(text_content(&text))
401}
402
403/// Convert parsed `f64` numbers into the `f32` scratch buffer.
404fn to_f32(emb: &[f64]) -> Vec<f32> {
405    emb.iter().map(|&v| v as f32).collect()
406}
407
408#[inline]
409fn cosine_sim(a: &[f32], b: &[f32]) -> f64 {
410    let mut dot = 0.0f64;
411    let mut na = 0.0f64;
412    let mut nb = 0.0f64;
413    for (&x, &y) in a.iter().zip(b) {
414        dot += f64::from(x) * f64::from(y);
415        na += f64::from(x) * f64::from(x);
416        nb += f64::from(y) * f64::from(y);
417    }
418    let denom = na.sqrt() * nb.sqrt();
419    if denom == 0.0 { 0.0 } else { dot / denom }
420}
421
422/// Render resolved `(name, entityType, score)` rows as the standard results JSON.
423fn build_named_results(rows: &[(String, String, f64)]) -> String {
424    use std::fmt::Write;
425    let mut out = String::with_capacity(64 + rows.len() * 64);
426    out.push_str(r#"{"results":["#);
427    for (i, (name, etype, score)) in rows.iter().enumerate() {
428        if i > 0 {
429            out.push(',');
430        }
431        out.push_str(r#"{"name":"#);
432        push_json_str(&mut out, name);
433        out.push_str(r#","entityType":"#);
434        push_json_str(&mut out, etype);
435        write!(out, r#","score":{score:.6}}}"#).unwrap();
436    }
437    out.push_str(r#"],"count":"#);
438    out.push_str(&rows.len().to_string());
439    out.push('}');
440    out
441}
442
443/// Bulk-ingest embeddings: `{ items: [{entityName, embedding, model?}, ...] }`.
444/// Each item is upserted independently; per-item failures are reported rather
445/// than aborting the batch — the shape RAG ingestion pipelines expect.
446pub fn handle_vector_batch_upsert(
447    vs: &VectorStore,
448    _kg: &GraphHandle,
449    args: Option<&Value>,
450) -> Result<Value> {
451    let params = args.ok_or_else(|| MCSError::InvalidParams("Missing parameters".into()))?;
452    let items = params
453        .get("items")
454        .and_then(|v| v.as_array())
455        .ok_or_else(|| MCSError::InvalidParams("'items' must be an array".into()))?;
456    if items.len() > MAX_BATCH_ITEMS {
457        return Err(MCSError::InvalidParams(format!(
458            "Too many items (max {MAX_BATCH_ITEMS})"
459        )));
460    }
461
462    let mut upserted = 0usize;
463    let mut errors: Vec<Value> = Vec::new();
464    for item in items {
465        let name = match item.get("entityName").and_then(|v| v.as_str()) {
466            Some(n) if !n.is_empty() && n.len() <= MAX_NAME_BYTES => n,
467            _ => {
468                errors.push(json!({"entityName": item.get("entityName"), "error": "invalid entityName"}));
469                continue;
470            }
471        };
472        let emb = match item.get("embedding").map(parse_embedding) {
473            Some(Ok(e)) => e,
474            Some(Err(e)) => {
475                errors.push(json!({"entityName": name, "error": e.to_string()}));
476                continue;
477            }
478            None => {
479                errors.push(json!({"entityName": name, "error": "missing embedding"}));
480                continue;
481            }
482        };
483        let model = item.get("model").and_then(|v| v.as_str()).unwrap_or("");
484        let buf = to_f32(&emb);
485        match vs.upsert_embedding(name, &buf, model) {
486            Ok(()) => upserted += 1,
487            Err(e) => errors.push(json!({"entityName": name, "error": e.to_string()})),
488        }
489    }
490
491    let text = serde_json::to_string(&json!({
492        "upserted": upserted,
493        "failed": errors.len(),
494        "errors": errors,
495    }))
496    .map_err(MCSError::JsonError)?;
497    Ok(text_content(&text))
498}
499
500/// Fetch the stored embedding for an entity: `{ entityName }`.
501pub fn handle_vector_get_embedding(
502    vs: &VectorStore,
503    _kg: &GraphHandle,
504    args: Option<&Value>,
505) -> Result<Value> {
506    let params = args.ok_or_else(|| MCSError::InvalidParams("Missing parameters".into()))?;
507    let name = params
508        .get("entityName")
509        .and_then(|v| v.as_str())
510        .ok_or_else(|| MCSError::InvalidParams("Missing 'entityName' parameter".into()))?;
511    validate_name(name)?;
512
513    match vs.get_embedding_by_name(name)? {
514        Some((_id, emb, model)) => {
515            let text = serde_json::to_string(&json!({
516                "entityName": name,
517                "dims": emb.len(),
518                "model": model,
519                "embedding": emb,
520            }))
521            .map_err(MCSError::JsonError)?;
522            Ok(text_content(&text))
523        }
524        None => {
525            let text = serde_json::to_string(&json!({
526                "entityName": name,
527                "embedding": Value::Null,
528                "found": false,
529            }))
530            .map_err(MCSError::JsonError)?;
531            Ok(text_content(&text))
532        }
533    }
534}
535
536/// "More like this": find entities nearest to a given entity's own embedding.
537/// `{ entityName, topK?, entityType?, excludeSelf? }`.
538pub fn handle_vector_search_by_entity(
539    vs: &VectorStore,
540    _kg: &GraphHandle,
541    args: Option<&Value>,
542) -> Result<String> {
543    let params = args.ok_or_else(|| MCSError::InvalidParams("Missing parameters".into()))?;
544    let name = params
545        .get("entityName")
546        .and_then(|v| v.as_str())
547        .ok_or_else(|| MCSError::InvalidParams("Missing 'entityName' parameter".into()))?;
548    validate_name(name)?;
549    let top_k = opt_usize(params, "topK", DEFAULT_TOP_K)?.clamp(1, MAX_TOP_K);
550    let entity_type = params
551        .get("entityType")
552        .and_then(|v| v.as_str())
553        .filter(|s| !s.is_empty());
554    let exclude_self = params
555        .get("excludeSelf")
556        .and_then(|v| v.as_bool())
557        .unwrap_or(true);
558
559    let (id, emb, _model) = vs.get_embedding_by_name(name)?.ok_or_else(|| {
560        MCSError::InvalidParams(format!("Entity '{name}' has no embedding"))
561    })?;
562
563    let mut exclude = std::collections::HashSet::new();
564    if exclude_self {
565        exclude.insert(id);
566    }
567    let rows = vs.search_resolved(&emb, top_k, entity_type, &exclude)?;
568    let named: Vec<(String, String, f64)> = rows
569        .into_iter()
570        .map(|(_, n, t, d)| (n, t, f64::from(d)))
571        .collect();
572    Ok(build_content_response(&build_named_results(&named)))
573}
574
575/// Example-based recommendation: build a query from positive (and optional
576/// negative) example entities and search. `{ positive: [names], negative?:
577/// [names], topK?, entityType? }`. The example entities are excluded from results.
578pub fn handle_vector_recommend(
579    vs: &VectorStore,
580    _kg: &GraphHandle,
581    args: Option<&Value>,
582) -> Result<String> {
583    let params = args.ok_or_else(|| MCSError::InvalidParams("Missing parameters".into()))?;
584    let top_k = opt_usize(params, "topK", DEFAULT_TOP_K)?.clamp(1, MAX_TOP_K);
585    let entity_type = params
586        .get("entityType")
587        .and_then(|v| v.as_str())
588        .filter(|s| !s.is_empty());
589
590    let positive = collect_names(params, "positive")?;
591    if positive.is_empty() {
592        return Err(MCSError::InvalidParams(
593            "'positive' must contain at least one entity name".into(),
594        ));
595    }
596    let negative = collect_names(params, "negative").unwrap_or_default();
597
598    let dims = vs.dims() as usize;
599    let mut query = vec![0.0f64; dims];
600    let mut exclude = std::collections::HashSet::new();
601
602    let mut pos_count = 0usize;
603    for n in &positive {
604        if let Some((id, emb, _)) = vs.get_embedding_by_name(n)? {
605            if emb.len() != dims {
606                continue;
607            }
608            for (q, &e) in query.iter_mut().zip(&emb) {
609                *q += f64::from(e);
610            }
611            exclude.insert(id);
612            pos_count += 1;
613        }
614    }
615    if pos_count == 0 {
616        return Err(MCSError::InvalidParams(
617            "None of the 'positive' entities have embeddings".into(),
618        ));
619    }
620    for q in query.iter_mut() {
621        *q /= pos_count as f64;
622    }
623
624    let mut neg_count = 0usize;
625    let mut neg = vec![0.0f64; dims];
626    for n in &negative {
627        if let Some((id, emb, _)) = vs.get_embedding_by_name(n)? {
628            if emb.len() != dims {
629                continue;
630            }
631            for (q, &e) in neg.iter_mut().zip(&emb) {
632                *q += f64::from(e);
633            }
634            exclude.insert(id);
635            neg_count += 1;
636        }
637    }
638    if neg_count > 0 {
639        for (q, n) in query.iter_mut().zip(&neg) {
640            *q -= n / neg_count as f64;
641        }
642    }
643
644    let qf = to_f32(&query);
645    let rows = vs.search_resolved(&qf, top_k, entity_type, &exclude)?;
646    let named: Vec<(String, String, f64)> = rows
647        .into_iter()
648        .map(|(_, n, t, d)| (n, t, f64::from(d)))
649        .collect();
650    Ok(build_content_response(&build_named_results(&named)))
651}
652
653/// Maximal Marginal Relevance search: diversified semantic retrieval.
654/// `{ embedding, topK?, fetchK?, lambda?, entityType? }`. `lambda` in `[0,1]`
655/// trades relevance (1.0) against diversity (0.0). Reduces near-duplicate hits —
656/// a common RAG context-selection step. The reported `score` is the MMR score.
657pub fn handle_vector_mmr_search(
658    vs: &VectorStore,
659    _kg: &GraphHandle,
660    args: Option<&Value>,
661) -> Result<String> {
662    let params = args.ok_or_else(|| MCSError::InvalidParams("Missing parameters".into()))?;
663    let embedding = parse_embedding(
664        params
665            .get("embedding")
666            .ok_or_else(|| MCSError::InvalidParams("Missing 'embedding' parameter".into()))?,
667    )?;
668    let top_k = opt_usize(params, "topK", DEFAULT_TOP_K)?.clamp(1, MAX_TOP_K);
669    let fetch_k = opt_usize(params, "fetchK", (top_k * 4).max(20))?.clamp(top_k, MAX_TOP_K);
670    let lambda = opt_f64(params, "lambda", 0.5)?.clamp(0.0, 1.0);
671    let entity_type = params
672        .get("entityType")
673        .and_then(|v| v.as_str())
674        .filter(|s| !s.is_empty());
675
676    let query = to_f32(&embedding);
677
678    // Fetch a candidate pool, then greedily select for MMR.
679    let pool = vs.search_embeddings(&query, fetch_k)?;
680    let mut cands: Vec<MmrCand> = Vec::with_capacity(pool.len());
681    for (id, _dist) in pool {
682        let (name, etype) = vs.resolve_name_type(id);
683        if name.is_empty() {
684            continue;
685        }
686        if let Some(ft) = entity_type
687            && etype != ft
688        {
689            continue;
690        }
691        if let Some(emb) = vs.get_embedding_by_id(id)? {
692            let rel = cosine_sim(&query, &emb);
693            cands.push(MmrCand { name, etype, emb, rel });
694        }
695    }
696
697    let mut selected: Vec<MmrCand> = Vec::with_capacity(top_k.min(cands.len()));
698    let mut scores: Vec<f64> = Vec::with_capacity(top_k.min(cands.len()));
699    while selected.len() < top_k && !cands.is_empty() {
700        let mut best_idx = 0usize;
701        let mut best_mmr = f64::NEG_INFINITY;
702        for (i, c) in cands.iter().enumerate() {
703            let max_sim = selected
704                .iter()
705                .map(|s| cosine_sim(&c.emb, &s.emb))
706                .fold(0.0f64, f64::max);
707            let mmr = lambda * c.rel - (1.0 - lambda) * max_sim;
708            if mmr > best_mmr {
709                best_mmr = mmr;
710                best_idx = i;
711            }
712        }
713        let chosen = cands.swap_remove(best_idx);
714        selected.push(chosen);
715        scores.push(best_mmr);
716    }
717
718    let named: Vec<(String, String, f64)> = selected
719        .into_iter()
720        .zip(scores)
721        .map(|(c, s)| (c.name, c.etype, s))
722        .collect();
723    Ok(build_content_response(&build_named_results(&named)))
724}
725
726struct MmrCand {
727    name: String,
728    etype: String,
729    emb: Vec<f32>,
730    rel: f64,
731}
732
733/// Rebuild/retrain the ANN index (IVF k-means; HNSW is a no-op). `{}`.
734pub fn handle_vector_reindex(
735    vs: &VectorStore,
736    _kg: &GraphHandle,
737    _args: Option<&Value>,
738) -> Result<Value> {
739    vs.reindex()?;
740    let kind = match vs.index_kind() {
741        crate::vector_store::IndexKind::Hnsw => "hnsw",
742        crate::vector_store::IndexKind::Ivf => "ivf",
743    };
744    let text = serde_json::to_string(&json!({
745        "reindexed": true,
746        "indexKind": kind,
747        "embeddingCount": vs.count(),
748    }))
749    .map_err(MCSError::JsonError)?;
750    Ok(text_content(&text))
751}
752
753fn collect_names(params: &Value, key: &str) -> Result<Vec<String>> {
754    match params.get(key) {
755        None | Some(Value::Null) => Ok(Vec::new()),
756        Some(Value::Array(arr)) => {
757            let mut out = Vec::with_capacity(arr.len());
758            for v in arr {
759                let s = v.as_str().ok_or_else(|| {
760                    MCSError::InvalidParams(format!("'{key}' must be an array of strings"))
761                })?;
762                out.push(s.to_string());
763            }
764            Ok(out)
765        }
766        Some(_) => Err(MCSError::InvalidParams(format!(
767            "'{key}' must be an array of strings"
768        ))),
769    }
770}