1use crate::scoring::compute_score;
4use crate::CodememEngine;
5use chrono::Utc;
6use codemem_core::{
7 CodememError, GraphBackend, MemoryNode, MemoryType, NodeKind, SearchResult, VectorBackend,
8};
9use std::collections::{HashMap, HashSet};
10
11#[derive(Debug, Clone)]
13pub struct ExpandedResult {
14 pub result: SearchResult,
15 pub expansion_path: String,
16}
17
18#[derive(Debug, Clone)]
20pub struct NamespaceStats {
21 pub namespace: String,
22 pub count: usize,
23 pub avg_importance: f64,
24 pub avg_confidence: f64,
25 pub type_distribution: HashMap<String, usize>,
26 pub tag_frequency: HashMap<String, usize>,
27 pub oldest: Option<chrono::DateTime<chrono::Utc>>,
28 pub newest: Option<chrono::DateTime<chrono::Utc>>,
29}
30
31impl CodememEngine {
32 #[allow(clippy::too_many_arguments)]
39 pub fn recall(
40 &self,
41 query: &str,
42 k: usize,
43 memory_type_filter: Option<MemoryType>,
44 namespace_filter: Option<&str>,
45 exclude_tags: &[String],
46 min_importance: Option<f64>,
47 min_confidence: Option<f64>,
48 ) -> Result<Vec<SearchResult>, CodememError> {
49 let vector_results: Vec<(String, f32)> = if let Some(emb_guard) = self.lock_embeddings()? {
51 match emb_guard.embed(query) {
52 Ok(query_embedding) => {
53 drop(emb_guard);
54 let vec = self.lock_vector()?;
55 vec.search(&query_embedding, k * 2) .unwrap_or_default()
57 }
58 Err(e) => {
59 tracing::warn!("Query embedding failed: {e}");
60 vec![]
61 }
62 }
63 } else {
64 vec![]
65 };
66
67 let query_tokens: Vec<String> = crate::bm25::tokenize(query);
71 let query_token_refs: Vec<&str> = query_tokens.iter().map(|s| s.as_str()).collect();
72
73 let mut graph = self.lock_graph()?;
77 graph.ensure_betweenness_computed();
80 let bm25 = self.lock_bm25()?;
81 let now = Utc::now();
82
83 let entity_memory_ids = self.resolve_entity_memories(query, &graph, now);
87
88 let mut results: Vec<SearchResult> = Vec::new();
89 let weights = self.scoring_weights()?;
90
91 if vector_results.is_empty() {
92 let type_str = memory_type_filter.as_ref().map(|t| t.to_string());
94 let all_memories = self
95 .storage
96 .list_memories_filtered(namespace_filter, type_str.as_deref())?;
97
98 for memory in all_memories {
99 if !exclude_tags.is_empty() && memory.tags.iter().any(|t| exclude_tags.contains(t))
101 {
102 continue;
103 }
104 if let Some(min) = min_importance {
105 if memory.importance < min {
106 continue;
107 }
108 }
109 if let Some(min) = min_confidence {
110 if memory.confidence < min {
111 continue;
112 }
113 }
114
115 let breakdown = compute_score(&memory, &query_token_refs, 0.0, &graph, &bm25, now);
116 let score = breakdown.total_with_weights(&weights);
117 if score > 0.01 {
118 results.push(SearchResult {
119 memory,
120 score,
121 score_breakdown: breakdown,
122 });
123 }
124 }
125 } else {
126 let mut all_candidate_ids: HashSet<&str> =
128 vector_results.iter().map(|(id, _)| id.as_str()).collect();
129
130 for eid in &entity_memory_ids {
132 all_candidate_ids.insert(eid.as_str());
133 }
134
135 let candidate_id_vec: Vec<&str> = all_candidate_ids.into_iter().collect();
136 let candidate_memories = self.storage.get_memories_batch(&candidate_id_vec)?;
137
138 let sim_map: HashMap<&str, f64> = vector_results
140 .iter()
141 .map(|(id, sim)| (id.as_str(), *sim as f64))
142 .collect();
143
144 for memory in candidate_memories {
145 if let Some(ref filter_type) = memory_type_filter {
147 if memory.memory_type != *filter_type {
148 continue;
149 }
150 }
151 if let Some(ns) = namespace_filter {
153 if memory.namespace.as_deref() != Some(ns) {
154 continue;
155 }
156 }
157 if !exclude_tags.is_empty() && memory.tags.iter().any(|t| exclude_tags.contains(t))
159 {
160 continue;
161 }
162 if let Some(min) = min_importance {
163 if memory.importance < min {
164 continue;
165 }
166 }
167 if let Some(min) = min_confidence {
168 if memory.confidence < min {
169 continue;
170 }
171 }
172
173 let similarity = sim_map.get(memory.id.as_str()).copied().unwrap_or(0.0);
174 let breakdown =
175 compute_score(&memory, &query_token_refs, similarity, &graph, &bm25, now);
176 let score = breakdown.total_with_weights(&weights);
177 if score > 0.01 {
178 results.push(SearchResult {
179 memory,
180 score,
181 score_breakdown: breakdown,
182 });
183 }
184 }
185 }
186
187 results.sort_by(|a, b| {
189 b.score
190 .partial_cmp(&a.score)
191 .unwrap_or(std::cmp::Ordering::Equal)
192 });
193 results.truncate(k);
194
195 Ok(results)
196 }
197
198 pub fn recall_with_expansion(
202 &self,
203 query: &str,
204 k: usize,
205 expansion_depth: usize,
206 namespace_filter: Option<&str>,
207 ) -> Result<Vec<ExpandedResult>, CodememError> {
208 let query_tokens: Vec<String> = crate::bm25::tokenize(query);
210 let query_token_refs: Vec<&str> = query_tokens.iter().map(|s| s.as_str()).collect();
211
212 let vector_results: Vec<(String, f32)> = if let Some(emb_guard) = self.lock_embeddings()? {
214 match emb_guard.embed(query) {
215 Ok(query_embedding) => {
216 drop(emb_guard);
217 let vec = self.lock_vector()?;
218 vec.search(&query_embedding, k * 2).unwrap_or_default()
219 }
220 Err(e) => {
221 tracing::warn!("Query embedding failed: {e}");
222 vec![]
223 }
224 }
225 } else {
226 vec![]
227 };
228
229 let mut graph = self.lock_graph()?;
230 graph.ensure_betweenness_computed();
232 let bm25 = self.lock_bm25()?;
233 let now = Utc::now();
234
235 struct ScoredMemory {
237 memory: MemoryNode,
238 vector_sim: f64,
239 expansion_path: String,
240 }
241
242 let mut all_memories: Vec<ScoredMemory> = Vec::new();
243 let mut seen_ids: HashSet<String> = HashSet::new();
244
245 if vector_results.is_empty() {
246 let all = self
248 .storage
249 .list_memories_filtered(namespace_filter, None)?;
250 let weights = self.scoring_weights()?;
251
252 for memory in all {
253 let breakdown = compute_score(&memory, &query_token_refs, 0.0, &graph, &bm25, now);
254 let score = breakdown.total_with_weights(&weights);
255 if score > 0.01 {
256 seen_ids.insert(memory.id.clone());
257 all_memories.push(ScoredMemory {
258 memory,
259 vector_sim: 0.0,
260 expansion_path: "direct".to_string(),
261 });
262 }
263 }
264 } else {
265 let candidate_ids: Vec<&str> =
267 vector_results.iter().map(|(id, _)| id.as_str()).collect();
268 let candidate_memories = self.storage.get_memories_batch(&candidate_ids)?;
269
270 let sim_map: HashMap<&str, f64> = vector_results
271 .iter()
272 .map(|(id, sim)| (id.as_str(), *sim as f64))
273 .collect();
274
275 for memory in candidate_memories {
276 if let Some(ns) = namespace_filter {
277 if memory.namespace.as_deref() != Some(ns) {
278 continue;
279 }
280 }
281 let similarity = sim_map.get(memory.id.as_str()).copied().unwrap_or(0.0);
282 seen_ids.insert(memory.id.clone());
283 all_memories.push(ScoredMemory {
284 memory,
285 vector_sim: similarity,
286 expansion_path: "direct".to_string(),
287 });
288 }
289 }
290
291 let direct_ids: Vec<String> = all_memories.iter().map(|m| m.memory.id.clone()).collect();
296
297 for direct_id in &direct_ids {
298 let direct_edges: Vec<_> = graph
301 .get_edges(direct_id)
302 .unwrap_or_default()
303 .into_iter()
304 .filter(|e| is_edge_active(e, now))
305 .collect();
306
307 if let Ok(expanded_nodes) =
310 graph.bfs_filtered(direct_id, expansion_depth, &[NodeKind::Chunk], None)
311 {
312 for expanded_node in &expanded_nodes {
313 if expanded_node.id == *direct_id {
315 continue;
316 }
317
318 if expanded_node.kind != NodeKind::Memory {
321 continue;
322 }
323
324 let memory_id = expanded_node
326 .memory_id
327 .as_deref()
328 .unwrap_or(&expanded_node.id);
329
330 if seen_ids.contains(memory_id) {
332 continue;
333 }
334
335 if let Ok(Some(memory)) = self.storage.get_memory_no_touch(memory_id) {
337 if let Some(ns) = namespace_filter {
338 if memory.namespace.as_deref() != Some(ns) {
339 continue;
340 }
341 }
342
343 let expansion_path = direct_edges
345 .iter()
346 .find(|e| e.dst == expanded_node.id || e.src == expanded_node.id)
347 .map(|e| format!("via {} from {}", e.relationship, direct_id))
348 .unwrap_or_else(|| format!("via graph from {direct_id}"));
349
350 seen_ids.insert(memory_id.to_string());
351 all_memories.push(ScoredMemory {
352 memory,
353 vector_sim: 0.0,
354 expansion_path,
355 });
356 }
357 }
358 }
359 }
360
361 let weights = self.scoring_weights()?;
363 let mut scored_results: Vec<ExpandedResult> = all_memories
364 .into_iter()
365 .map(|sm| {
366 let breakdown = compute_score(
367 &sm.memory,
368 &query_token_refs,
369 sm.vector_sim,
370 &graph,
371 &bm25,
372 now,
373 );
374 let score = breakdown.total_with_weights(&weights);
375 ExpandedResult {
376 result: SearchResult {
377 memory: sm.memory,
378 score,
379 score_breakdown: breakdown,
380 },
381 expansion_path: sm.expansion_path,
382 }
383 })
384 .collect();
385
386 scored_results.sort_by(|a, b| {
387 b.result
388 .score
389 .partial_cmp(&a.result.score)
390 .unwrap_or(std::cmp::Ordering::Equal)
391 });
392 scored_results.truncate(k);
393
394 Ok(scored_results)
395 }
396
397 pub(crate) fn resolve_entity_memories(
404 &self,
405 query: &str,
406 graph: &codemem_storage::graph::GraphEngine,
407 now: chrono::DateTime<chrono::Utc>,
408 ) -> HashSet<String> {
409 let entity_refs = crate::search::extract_code_references(query);
410 let mut memory_ids: HashSet<String> = HashSet::new();
411
412 for entity_ref in &entity_refs {
413 let candidate_ids = [
415 format!("sym:{entity_ref}"),
416 format!("file:{entity_ref}"),
417 entity_ref.clone(),
418 ];
419
420 for candidate_id in &candidate_ids {
421 if graph.get_node_ref(candidate_id).is_none() {
422 continue;
423 }
424 for edge in graph.get_edges_ref(candidate_id) {
426 if !is_edge_active(edge, now) {
427 continue;
428 }
429 let neighbor_id = if edge.src == *candidate_id {
430 &edge.dst
431 } else {
432 &edge.src
433 };
434 if let Some(node) = graph.get_node_ref(neighbor_id) {
435 if node.kind == NodeKind::Memory {
436 let mem_id = node.memory_id.as_deref().unwrap_or(&node.id);
437 memory_ids.insert(mem_id.to_string());
438 }
439 }
440 }
441 break; }
443 }
444
445 memory_ids
446 }
447
448 pub fn namespace_stats(&self, namespace: &str) -> Result<NamespaceStats, CodememError> {
451 let ids = self.storage.list_memory_ids_for_namespace(namespace)?;
452
453 if ids.is_empty() {
454 return Ok(NamespaceStats {
455 namespace: namespace.to_string(),
456 count: 0,
457 avg_importance: 0.0,
458 avg_confidence: 0.0,
459 type_distribution: HashMap::new(),
460 tag_frequency: HashMap::new(),
461 oldest: None,
462 newest: None,
463 });
464 }
465
466 let mut total_importance = 0.0;
467 let mut total_confidence = 0.0;
468 let mut type_distribution: HashMap<String, usize> = HashMap::new();
469 let mut tag_frequency: HashMap<String, usize> = HashMap::new();
470 let mut oldest: Option<chrono::DateTime<chrono::Utc>> = None;
471 let mut newest: Option<chrono::DateTime<chrono::Utc>> = None;
472 let mut count = 0usize;
473
474 let id_refs: Vec<&str> = ids.iter().map(|s| s.as_str()).collect();
478 let memories = self.storage.get_memories_batch(&id_refs)?;
479
480 for memory in &memories {
481 count += 1;
482 total_importance += memory.importance;
483 total_confidence += memory.confidence;
484
485 *type_distribution
486 .entry(memory.memory_type.to_string())
487 .or_insert(0) += 1;
488
489 for tag in &memory.tags {
490 *tag_frequency.entry(tag.clone()).or_insert(0) += 1;
491 }
492
493 match oldest {
494 None => oldest = Some(memory.created_at),
495 Some(ref o) if memory.created_at < *o => oldest = Some(memory.created_at),
496 _ => {}
497 }
498 match newest {
499 None => newest = Some(memory.created_at),
500 Some(ref n) if memory.created_at > *n => newest = Some(memory.created_at),
501 _ => {}
502 }
503 }
504
505 let avg_importance = if count > 0 {
506 total_importance / count as f64
507 } else {
508 0.0
509 };
510 let avg_confidence = if count > 0 {
511 total_confidence / count as f64
512 } else {
513 0.0
514 };
515
516 Ok(NamespaceStats {
517 namespace: namespace.to_string(),
518 count,
519 avg_importance,
520 avg_confidence,
521 type_distribution,
522 tag_frequency,
523 oldest,
524 newest,
525 })
526 }
527
528 pub fn delete_namespace(&self, namespace: &str) -> Result<usize, CodememError> {
531 let ids = self.storage.list_memory_ids_for_namespace(namespace)?;
532
533 let mut deleted = 0usize;
534 let mut graph = self.lock_graph()?;
535 let mut vector = self.lock_vector()?;
536 let mut bm25 = self.lock_bm25()?;
537
538 for id in &ids {
539 if let Ok(true) = self.storage.delete_memory(id) {
541 deleted += 1;
542
543 let _ = vector.remove(id);
545
546 let _ = graph.remove_node(id);
548
549 let _ = self.storage.delete_graph_edges_for_node(id);
551 let _ = self.storage.delete_graph_node(id);
552
553 let _ = self.storage.delete_embedding(id);
555
556 bm25.remove_document(id);
558 }
559 }
560
561 drop(graph);
563 drop(vector);
564 drop(bm25);
565
566 self.save_index();
568
569 Ok(deleted)
570 }
571}
572
573pub(crate) fn is_edge_active(
578 edge: &codemem_core::Edge,
579 now: chrono::DateTime<chrono::Utc>,
580) -> bool {
581 if let Some(valid_to) = edge.valid_to {
582 if valid_to < now {
583 return false;
584 }
585 }
586 if let Some(valid_from) = edge.valid_from {
587 if valid_from > now {
588 return false;
589 }
590 }
591 true
592}