1use serde_json::{Value, json};
2
3use crate::errors::{MCSError, Result};
4use crate::kg::{GraphHandle, push_json_str};
5use crate::vector_store::{EntityId, VectorStore, with_scratch};
6
7type HybridResult = Vec<(String, String, f64, f64, f64)>;
8
9use rusqlite::params;
10
11const MAX_EMBEDDING_DIMS: usize = 4096;
12const MAX_TOP_K: usize = 100;
13const DEFAULT_TOP_K: usize = 10;
14const MAX_NAME_BYTES: usize = 1024;
15
16fn validate_name(name: &str) -> Result<()> {
17 if name.is_empty() {
18 return Err(MCSError::InvalidParams("Name must not be empty".into()));
19 }
20 if name.len() > MAX_NAME_BYTES {
21 return Err(MCSError::InvalidParams(format!(
22 "Name too long (max {MAX_NAME_BYTES} bytes)"
23 )));
24 }
25 Ok(())
26}
27
28fn parse_embedding(val: &Value) -> Result<Vec<f64>> {
29 let arr = val
30 .as_array()
31 .ok_or_else(|| MCSError::InvalidParams("'embedding' must be an array of numbers".into()))?;
32 if arr.is_empty() {
33 return Err(MCSError::InvalidParams("Embedding must not be empty".into()));
34 }
35 if arr.len() > MAX_EMBEDDING_DIMS {
36 return Err(MCSError::InvalidParams(format!(
37 "Embedding too large (max {MAX_EMBEDDING_DIMS} dimensions)"
38 )));
39 }
40 let emb: Vec<f64> = arr
41 .iter()
42 .map(|v| {
43 v.as_f64()
44 .ok_or_else(|| MCSError::InvalidParams("Embedding values must be numbers".into()))
45 })
46 .collect::<Result<_>>()?;
47 Ok(emb)
48}
49
50fn opt_usize(params: &Value, key: &str, default: usize) -> Result<usize> {
51 match params.get(key) {
52 None | Some(Value::Null) => Ok(default),
53 Some(v) => v.as_u64().map(|n| n as usize).ok_or_else(|| {
54 MCSError::InvalidParams(format!("'{key}' must be a non-negative integer"))
55 }),
56 }
57}
58
59fn opt_f64(params: &Value, key: &str, default: f64) -> Result<f64> {
60 match params.get(key) {
61 None | Some(Value::Null) => Ok(default),
62 Some(v) => v.as_f64().ok_or_else(|| {
63 MCSError::InvalidParams(format!("'{key}' must be a number"))
64 }),
65 }
66}
67
68fn text_content(text: &str) -> Value {
69 json!({
70 "content": [{
71 "type": "text",
72 "text": text
73 }]
74 })
75}
76
77fn build_content_response(inner_json: &str) -> String {
78 let mut out = String::with_capacity(64 + inner_json.len() + (inner_json.len() / 8));
79 out.push_str(r#"{"content":[{"type":"text","text":"#);
80 push_json_str(&mut out, inner_json);
81 out.push_str(r#"}]}"#);
82 out
83}
84
85pub fn handle_vector_upsert_embedding(
86 vs: &VectorStore,
87 _kg: &GraphHandle,
88 args: Option<&Value>,
89) -> Result<Value> {
90 let params = args.ok_or_else(|| MCSError::InvalidParams("Missing parameters".into()))?;
91
92 let entity_name = params
93 .get("entityName")
94 .and_then(|v| v.as_str())
95 .ok_or_else(|| MCSError::InvalidParams("Missing 'entityName' parameter".into()))?;
96 validate_name(entity_name)?;
97
98 let embedding = parse_embedding(
99 params
100 .get("embedding")
101 .ok_or_else(|| MCSError::InvalidParams("Missing 'embedding' parameter".into()))?,
102 )?;
103
104 let model = params
105 .get("model")
106 .and_then(|v| v.as_str())
107 .unwrap_or("");
108
109 with_scratch(|buf| {
110 buf.reserve(embedding.len());
111 buf.extend(embedding.iter().map(|&v| v as f32));
112 vs.upsert_embedding(entity_name, buf, model)
113 })?;
114
115 let text = serde_json::to_string(&json!({
116 "entityName": entity_name,
117 "dims": vs.dims(),
118 "model": model,
119 }))
120 .map_err(MCSError::JsonError)?;
121
122 Ok(text_content(&text))
123}
124
125pub fn handle_vector_search_entities(
126 vs: &VectorStore,
127 _kg: &GraphHandle,
128 args: Option<&Value>,
129) -> Result<String> {
130 let params = args.ok_or_else(|| MCSError::InvalidParams("Missing parameters".into()))?;
131
132 let embedding = parse_embedding(
133 params
134 .get("embedding")
135 .ok_or_else(|| MCSError::InvalidParams("Missing 'embedding' parameter".into()))?,
136 )?;
137
138 let top_k = opt_usize(params, "topK", DEFAULT_TOP_K)?
139 .clamp(1, MAX_TOP_K);
140
141 let entity_type = params
142 .get("entityType")
143 .and_then(|v| v.as_str())
144 .filter(|s| !s.is_empty());
145
146 let json = with_scratch(|buf| {
147 buf.reserve(embedding.len());
148 buf.extend(embedding.iter().map(|&v| v as f32));
149 vs.search_entities_json(buf, top_k, entity_type)
150 })?;
151
152 Ok(build_content_response(&json))
153}
154
155pub fn handle_vector_delete_embedding(
156 vs: &VectorStore,
157 _kg: &GraphHandle,
158 args: Option<&Value>,
159) -> Result<Value> {
160 let params = args.ok_or_else(|| MCSError::InvalidParams("Missing parameters".into()))?;
161
162 let entity_name = params
163 .get("entityName")
164 .and_then(|v| v.as_str())
165 .ok_or_else(|| MCSError::InvalidParams("Missing 'entityName' parameter".into()))?;
166 validate_name(entity_name)?;
167
168 let deleted = vs.delete_embedding(entity_name)?;
169
170 let text = serde_json::to_string(&json!({
171 "deleted": deleted,
172 "entityName": entity_name,
173 }))
174 .map_err(MCSError::JsonError)?;
175
176 Ok(text_content(&text))
177}
178
179pub fn handle_hybrid_search(
180 vs: &VectorStore,
181 kg: &GraphHandle,
182 args: Option<&Value>,
183) -> Result<String> {
184 let params = args.ok_or_else(|| MCSError::InvalidParams("Missing parameters".into()))?;
185
186 let query_text = params
187 .get("queryText")
188 .and_then(|v| v.as_str())
189 .ok_or_else(|| MCSError::InvalidParams("Missing 'queryText' parameter".into()))?;
190
191 let query_embedding = parse_embedding(
192 params
193 .get("queryEmbedding")
194 .ok_or_else(|| MCSError::InvalidParams("Missing 'queryEmbedding' parameter".into()))?,
195 )?;
196
197 let text_weight = opt_f64(params, "textWeight", 0.5)?;
198 let vec_weight = opt_f64(params, "vecWeight", 0.5)?;
199 let top_k = opt_usize(params, "topK", DEFAULT_TOP_K)?
200 .clamp(1, MAX_TOP_K);
201
202 let results = with_scratch(|buf| {
203 buf.reserve(query_embedding.len());
204 buf.extend(query_embedding.iter().map(|&v| v as f32));
205 perform_hybrid_search(vs, kg, query_text, buf, text_weight, vec_weight, top_k)
206 })?;
207
208 let mut out = String::with_capacity(128 + results.len() * 80);
209 out.push_str(r#"{"results":["#);
210 for (i, (name, etype, score, txt_score, vec_score)) in results.iter().enumerate() {
211 if i > 0 {
212 out.push(',');
213 }
214 out.push_str(r#"{"name":"#);
215 push_json_str(&mut out, name);
216 out.push_str(r#","entityType":"#);
217 push_json_str(&mut out, etype);
218 use std::fmt::Write;
219 write!(
220 out,
221 r#","score":{:.6},"textScore":{:.6},"vecScore":{:.6}}}"#,
222 score, txt_score, vec_score
223 )
224 .unwrap();
225 }
226 out.push_str(r#"],"count":"#);
227 out.push_str(&results.len().to_string());
228 out.push('}');
229
230 Ok(build_content_response(&out))
231}
232
233fn perform_hybrid_search(
234 vs: &VectorStore,
235 kg: &GraphHandle,
236 query_text: &str,
237 query_emb: &[f32],
238 text_weight: f64,
239 vec_weight: f64,
240 top_k: usize,
241) -> Result<HybridResult> {
242 let fetch_k = top_k * 3;
243 let rrf_constant = 60.0;
244
245 let vec_matches = vs.search_embeddings(query_emb, fetch_k)?;
246
247 let kg_results = kg.search_nodes_filtered(query_text, None, 0, fetch_k);
248 let mut text_matches: Vec<EntityIdAndName> = Vec::with_capacity(kg_results.len());
249 for entity in &kg_results {
250 if let Ok(Some(_)) = vs.get_entity_type(
251 vs.name_to_id.get(&entity.name).map(|r| *r.value()).unwrap_or(-1),
252 ) {
253 let id = vs.name_to_id.get(&entity.name).map(|r| *r.value());
254 text_matches.push(EntityIdAndName {
255 id: id.unwrap_or(-1),
256 });
257 } else {
258 let conn = vs.db.lock();
259 let h = crate::kg::name_hash(&entity.name);
260 let id: Option<i64> = conn
261 .query_row(
262 "SELECT id FROM entity WHERE name_hash = ?1 AND name = ?2 AND flags = 0",
263 params![h, entity.name],
264 |row| row.get(0),
265 )
266 .ok();
267 text_matches.push(EntityIdAndName {
268 id: id.unwrap_or(-1),
269 });
270 }
271 }
272
273 let mut score_map: std::collections::HashMap<EntityId, AggScore> =
274 std::collections::HashMap::with_capacity(vec_matches.len() + text_matches.len());
275
276 for (rank, (id, _dist)) in vec_matches.iter().enumerate() {
277 let entry = score_map.entry(*id).or_insert_with(|| AggScore {
278 id: *id,
279 total: 0.0,
280 vec_score: 0.0,
281 text_score: 0.0,
282 });
283 let rrf = vec_weight * (1.0 / (rrf_constant + rank as f64));
284 entry.total += rrf;
285 entry.vec_score += rrf;
286 }
287
288 for (rank, tm) in text_matches.iter().enumerate() {
289 let entry = score_map.entry(tm.id).or_insert_with(|| AggScore {
290 id: tm.id,
291 total: 0.0,
292 vec_score: 0.0,
293 text_score: 0.0,
294 });
295 let rrf = text_weight * (1.0 / (rrf_constant + rank as f64));
296 entry.total += rrf;
297 entry.text_score += rrf;
298 }
299
300 let mut scored: Vec<AggScore> = score_map.into_values().collect();
301 scored.sort_unstable_by(|a, b| b.total.partial_cmp(&a.total).unwrap_or(std::cmp::Ordering::Equal));
302
303 if vs.graph_node_count() > 0 {
304 let g = vs.graph.read();
305 for entry in &mut scored {
306 if let Some(nx) = vs.node_map.get(&entry.id) {
307 let deg = g.neighbors(*nx).count() as f64;
308 if deg > 0.0 {
309 let boost = 0.1 * (deg / (deg + 5.0));
310 entry.total += boost;
311 }
312 }
313 }
314 scored.sort_unstable_by(|a, b| b.total.partial_cmp(&a.total).unwrap_or(std::cmp::Ordering::Equal));
315 }
316
317 let conn = vs.db.lock();
318 let mut results = Vec::with_capacity(top_k.min(scored.len()));
319 for entry in scored.iter().take(top_k) {
320 let name = vs
321 .id_to_name
322 .get(&entry.id)
323 .map(|r| r.value().clone())
324 .or_else(|| {
325 conn.query_row(
326 "SELECT name FROM entity WHERE id = ?1 AND flags = 0",
327 params![entry.id],
328 |row| row.get::<_, String>(0),
329 )
330 .ok()
331 })
332 .unwrap_or_default();
333
334 let etype: String = conn
335 .query_row(
336 "SELECT t.name FROM entity e JOIN type_dict t ON t.id = e.type_id WHERE e.id = ?1 AND e.flags = 0",
337 params![entry.id],
338 |row| row.get(0),
339 )
340 .unwrap_or_default();
341
342 results.push((name, etype, entry.total, entry.text_score, entry.vec_score));
343 }
344
345 Ok(results)
346}
347
348struct EntityIdAndName {
349 id: EntityId,
350}
351
352struct AggScore {
353 id: EntityId,
354 total: f64,
355 vec_score: f64,
356 text_score: f64,
357}
358
359pub fn handle_refresh_graph_cache(
360 vs: &VectorStore,
361 _kg: &GraphHandle,
362 _args: Option<&Value>,
363) -> Result<Value> {
364 vs.rebuild_graph_cache()?;
365 let text = serde_json::to_string(&json!({
366 "nodes": vs.graph_node_count(),
367 "edges": vs.graph_edge_count(),
368 }))
369 .map_err(MCSError::JsonError)?;
370 Ok(text_content(&text))
371}
372
373pub fn handle_vector_store_stats(
374 vs: &VectorStore,
375 _kg: &GraphHandle,
376 _args: Option<&Value>,
377) -> Result<Value> {
378 let text = serde_json::to_string(&json!({
379 "embeddingCount": vs.count(),
380 "dims": vs.dims(),
381 "petgraphNodes": vs.graph_node_count(),
382 "petgraphEdges": vs.graph_edge_count(),
383 }))
384 .map_err(MCSError::JsonError)?;
385 Ok(text_content(&text))
386}