1use crate::scoring::compute_score;
4use crate::CodememEngine;
5use chrono::Utc;
6use codemem_core::{CodememError, MemoryNode, MemoryType, NodeKind, SearchResult};
7use std::collections::{HashMap, HashSet};
8
9#[derive(Debug, Clone)]
11pub struct ExpandedResult {
12 pub result: SearchResult,
13 pub expansion_path: String,
14}
15
16#[derive(Debug, Clone)]
18pub struct NamespaceStats {
19 pub namespace: String,
20 pub count: usize,
21 pub avg_importance: f64,
22 pub avg_confidence: f64,
23 pub type_distribution: HashMap<String, usize>,
24 pub tag_frequency: HashMap<String, usize>,
25 pub oldest: Option<chrono::DateTime<chrono::Utc>>,
26 pub newest: Option<chrono::DateTime<chrono::Utc>>,
27}
28
29#[derive(Debug, Clone)]
31pub struct RecallQuery<'a> {
32 pub query: &'a str,
33 pub k: usize,
34 pub memory_type_filter: Option<MemoryType>,
35 pub namespace_filter: Option<&'a str>,
36 pub exclude_tags: &'a [String],
37 pub min_importance: Option<f64>,
38 pub min_confidence: Option<f64>,
39 pub git_ref_filter: Option<&'a str>,
41}
42
43impl<'a> RecallQuery<'a> {
44 pub fn new(query: &'a str, k: usize) -> Self {
46 Self {
47 query,
48 k,
49 memory_type_filter: None,
50 namespace_filter: None,
51 exclude_tags: &[],
52 min_importance: None,
53 min_confidence: None,
54 git_ref_filter: None,
55 }
56 }
57}
58
59impl CodememEngine {
60 pub fn recall(&self, q: &RecallQuery<'_>) -> Result<Vec<SearchResult>, CodememError> {
67 self.sweep_expired_memories();
69
70 let vector_results: Vec<(String, f32)> = if let Some(emb_guard) = self.lock_embeddings()? {
72 match emb_guard.embed(q.query) {
73 Ok(query_embedding) => {
74 drop(emb_guard);
75 let vec = self.lock_vector()?;
76 vec.search(&query_embedding, q.k * 2) .unwrap_or_default()
78 }
79 Err(e) => {
80 tracing::warn!("Query embedding failed: {e}");
81 vec![]
82 }
83 }
84 } else {
85 vec![]
86 };
87
88 let query_tokens: Vec<String> = crate::bm25::tokenize(q.query);
92 let query_token_refs: Vec<&str> = query_tokens.iter().map(|s| s.as_str()).collect();
93
94 let mut graph = self.lock_graph()?;
98 graph.ensure_betweenness_computed();
101 let bm25 = self.lock_bm25()?;
102 let now = Utc::now();
103
104 let entity_memory_ids = self.resolve_entity_memories(q.query, &**graph, now);
108
109 let mut results: Vec<SearchResult> = Vec::new();
110 let weights = self.scoring_weights()?;
111
112 if !vector_results.is_empty() {
113 let mut all_candidate_ids: HashSet<&str> =
115 vector_results.iter().map(|(id, _)| id.as_str()).collect();
116
117 for eid in &entity_memory_ids {
119 all_candidate_ids.insert(eid.as_str());
120 }
121
122 let candidate_id_vec: Vec<&str> = all_candidate_ids.into_iter().collect();
123 let candidate_memories = self.storage.get_memories_batch(&candidate_id_vec)?;
124
125 let sim_map: HashMap<&str, f64> = vector_results
127 .iter()
128 .map(|(id, sim)| (id.as_str(), *sim as f64))
129 .collect();
130
131 for memory in candidate_memories {
132 if let Some(ref filter_type) = q.memory_type_filter {
134 if memory.memory_type != *filter_type {
135 continue;
136 }
137 }
138 if let Some(ns) = q.namespace_filter {
140 if memory.namespace.as_deref() != Some(ns) {
141 continue;
142 }
143 }
144 if !Self::passes_quality_filters(&memory, q) {
145 continue;
146 }
147
148 let similarity = sim_map.get(memory.id.as_str()).copied().unwrap_or(0.0);
149 let breakdown =
150 compute_score(&memory, &query_token_refs, similarity, &**graph, &bm25, now);
151 let score = breakdown.total_with_weights(&weights);
152 if score > 0.01 {
153 results.push(SearchResult {
154 memory,
155 score,
156 score_breakdown: breakdown,
157 });
158 }
159 }
160 }
161
162 if results.is_empty() {
166 let type_str = q.memory_type_filter.as_ref().map(|t| t.to_string());
167 let all_memories = self
168 .storage
169 .list_memories_filtered(q.namespace_filter, type_str.as_deref())?;
170
171 for memory in all_memories {
172 if !Self::passes_quality_filters(&memory, q) {
173 continue;
174 }
175
176 let breakdown =
177 compute_score(&memory, &query_token_refs, 0.0, &**graph, &bm25, now);
178 let score = breakdown.total_with_weights(&weights);
179 if score > 0.01 {
180 results.push(SearchResult {
181 memory,
182 score,
183 score_breakdown: breakdown,
184 });
185 }
186 }
187 }
188
189 results.sort_by(|a, b| {
191 b.score
192 .partial_cmp(&a.score)
193 .unwrap_or(std::cmp::Ordering::Equal)
194 });
195 results.truncate(q.k);
196
197 Ok(results)
198 }
199
200 fn passes_quality_filters(memory: &MemoryNode, q: &RecallQuery<'_>) -> bool {
202 if memory.expires_at.is_some_and(|dt| dt <= Utc::now()) {
204 return false;
205 }
206 if !q.exclude_tags.is_empty() && memory.tags.iter().any(|t| q.exclude_tags.contains(t)) {
207 return false;
208 }
209 if let Some(min) = q.min_importance {
210 if memory.importance < min {
211 return false;
212 }
213 }
214 if let Some(min) = q.min_confidence {
215 if memory.confidence < min {
216 return false;
217 }
218 }
219 if let Some(ref_filter) = q.git_ref_filter {
220 if memory.git_ref.as_deref() != Some(ref_filter) {
221 return false;
222 }
223 }
224 true
225 }
226
227 pub fn recall_with_expansion(
231 &self,
232 query: &str,
233 k: usize,
234 expansion_depth: usize,
235 namespace_filter: Option<&str>,
236 ) -> Result<Vec<ExpandedResult>, CodememError> {
237 self.sweep_expired_memories();
239
240 let query_tokens: Vec<String> = crate::bm25::tokenize(query);
242 let query_token_refs: Vec<&str> = query_tokens.iter().map(|s| s.as_str()).collect();
243
244 let vector_results: Vec<(String, f32)> = if let Some(emb_guard) = self.lock_embeddings()? {
246 match emb_guard.embed(query) {
247 Ok(query_embedding) => {
248 drop(emb_guard);
249 let vec = self.lock_vector()?;
250 vec.search(&query_embedding, k * 2).unwrap_or_default()
251 }
252 Err(e) => {
253 tracing::warn!("Query embedding failed: {e}");
254 vec![]
255 }
256 }
257 } else {
258 vec![]
259 };
260
261 let mut graph = self.lock_graph()?;
262 graph.ensure_betweenness_computed();
264 let bm25 = self.lock_bm25()?;
265 let now = Utc::now();
266
267 struct ScoredMemory {
269 memory: MemoryNode,
270 vector_sim: f64,
271 expansion_path: String,
272 }
273
274 let mut all_memories: Vec<ScoredMemory> = Vec::new();
275 let mut seen_ids: HashSet<String> = HashSet::new();
276
277 if vector_results.is_empty() {
278 let all = self
280 .storage
281 .list_memories_filtered(namespace_filter, None)?;
282 let weights = self.scoring_weights()?;
283
284 for memory in all {
285 if memory.expires_at.is_some_and(|dt| dt <= now) {
286 continue;
287 }
288 let breakdown =
289 compute_score(&memory, &query_token_refs, 0.0, &**graph, &bm25, now);
290 let score = breakdown.total_with_weights(&weights);
291 if score > 0.01 {
292 seen_ids.insert(memory.id.clone());
293 all_memories.push(ScoredMemory {
294 memory,
295 vector_sim: 0.0,
296 expansion_path: "direct".to_string(),
297 });
298 }
299 }
300 } else {
301 let candidate_ids: Vec<&str> =
303 vector_results.iter().map(|(id, _)| id.as_str()).collect();
304 let candidate_memories = self.storage.get_memories_batch(&candidate_ids)?;
305
306 let sim_map: HashMap<&str, f64> = vector_results
307 .iter()
308 .map(|(id, sim)| (id.as_str(), *sim as f64))
309 .collect();
310
311 for memory in candidate_memories {
312 if memory.expires_at.is_some_and(|dt| dt <= now) {
313 continue;
314 }
315 if let Some(ns) = namespace_filter {
316 if memory.namespace.as_deref() != Some(ns) {
317 continue;
318 }
319 }
320 let similarity = sim_map.get(memory.id.as_str()).copied().unwrap_or(0.0);
321 seen_ids.insert(memory.id.clone());
322 all_memories.push(ScoredMemory {
323 memory,
324 vector_sim: similarity,
325 expansion_path: "direct".to_string(),
326 });
327 }
328 }
329
330 let direct_ids: Vec<String> = all_memories.iter().map(|m| m.memory.id.clone()).collect();
335
336 for direct_id in &direct_ids {
337 let direct_edges: Vec<_> = graph
340 .get_edges(direct_id)
341 .unwrap_or_default()
342 .into_iter()
343 .filter(|e| is_edge_active(e, now))
344 .collect();
345
346 if let Ok(expanded_nodes) =
349 graph.bfs_filtered(direct_id, expansion_depth, &[NodeKind::Chunk], None)
350 {
351 for expanded_node in &expanded_nodes {
352 if expanded_node.id == *direct_id {
354 continue;
355 }
356
357 if expanded_node.kind != NodeKind::Memory {
360 continue;
361 }
362
363 let memory_id = expanded_node
365 .memory_id
366 .as_deref()
367 .unwrap_or(&expanded_node.id);
368
369 if seen_ids.contains(memory_id) {
371 continue;
372 }
373
374 if let Ok(Some(memory)) = self.storage.get_memory_no_touch(memory_id) {
376 if memory.expires_at.is_some_and(|dt| dt <= now) {
377 continue;
378 }
379 if let Some(ns) = namespace_filter {
380 if memory.namespace.as_deref() != Some(ns) {
381 continue;
382 }
383 }
384
385 let expansion_path = direct_edges
387 .iter()
388 .find(|e| e.dst == expanded_node.id || e.src == expanded_node.id)
389 .map(|e| format!("via {} from {}", e.relationship, direct_id))
390 .unwrap_or_else(|| format!("via graph from {direct_id}"));
391
392 seen_ids.insert(memory_id.to_string());
393 all_memories.push(ScoredMemory {
394 memory,
395 vector_sim: 0.0,
396 expansion_path,
397 });
398 }
399 }
400 }
401 }
402
403 let weights = self.scoring_weights()?;
405 let mut scored_results: Vec<ExpandedResult> = all_memories
406 .into_iter()
407 .map(|sm| {
408 let breakdown = compute_score(
409 &sm.memory,
410 &query_token_refs,
411 sm.vector_sim,
412 &**graph,
413 &bm25,
414 now,
415 );
416 let score = breakdown.total_with_weights(&weights);
417 ExpandedResult {
418 result: SearchResult {
419 memory: sm.memory,
420 score,
421 score_breakdown: breakdown,
422 },
423 expansion_path: sm.expansion_path,
424 }
425 })
426 .collect();
427
428 scored_results.sort_by(|a, b| {
429 b.result
430 .score
431 .partial_cmp(&a.result.score)
432 .unwrap_or(std::cmp::Ordering::Equal)
433 });
434 scored_results.truncate(k);
435
436 Ok(scored_results)
437 }
438
439 pub(crate) fn resolve_entity_memories(
446 &self,
447 query: &str,
448 graph: &dyn codemem_core::GraphBackend,
449 now: chrono::DateTime<chrono::Utc>,
450 ) -> HashSet<String> {
451 let entity_refs = crate::search::extract_code_references(query);
452 let mut memory_ids: HashSet<String> = HashSet::new();
453
454 for entity_ref in &entity_refs {
455 let candidate_ids = [
457 format!("sym:{entity_ref}"),
458 format!("file:{entity_ref}"),
459 entity_ref.clone(),
460 ];
461
462 for candidate_id in &candidate_ids {
463 if graph.get_node_ref(candidate_id).is_none() {
464 continue;
465 }
466 for edge in graph.get_edges_ref(candidate_id) {
468 if !is_edge_active(edge, now) {
469 continue;
470 }
471 let neighbor_id = if edge.src == *candidate_id {
472 &edge.dst
473 } else {
474 &edge.src
475 };
476 if let Some(node) = graph.get_node_ref(neighbor_id) {
477 if node.kind == NodeKind::Memory {
478 let mem_id = node.memory_id.as_deref().unwrap_or(&node.id);
479 memory_ids.insert(mem_id.to_string());
480 }
481 }
482 }
483 break; }
485 }
486
487 memory_ids
488 }
489
490 pub fn namespace_stats(&self, namespace: &str) -> Result<NamespaceStats, CodememError> {
493 let ids = self.storage.list_memory_ids_for_namespace(namespace)?;
494
495 if ids.is_empty() {
496 return Ok(NamespaceStats {
497 namespace: namespace.to_string(),
498 count: 0,
499 avg_importance: 0.0,
500 avg_confidence: 0.0,
501 type_distribution: HashMap::new(),
502 tag_frequency: HashMap::new(),
503 oldest: None,
504 newest: None,
505 });
506 }
507
508 let mut total_importance = 0.0;
509 let mut total_confidence = 0.0;
510 let mut type_distribution: HashMap<String, usize> = HashMap::new();
511 let mut tag_frequency: HashMap<String, usize> = HashMap::new();
512 let mut oldest: Option<chrono::DateTime<chrono::Utc>> = None;
513 let mut newest: Option<chrono::DateTime<chrono::Utc>> = None;
514 let mut count = 0usize;
515
516 let id_refs: Vec<&str> = ids.iter().map(|s| s.as_str()).collect();
520 let memories = self.storage.get_memories_batch(&id_refs)?;
521
522 for memory in &memories {
523 count += 1;
524 total_importance += memory.importance;
525 total_confidence += memory.confidence;
526
527 *type_distribution
528 .entry(memory.memory_type.to_string())
529 .or_insert(0) += 1;
530
531 for tag in &memory.tags {
532 *tag_frequency.entry(tag.clone()).or_insert(0) += 1;
533 }
534
535 match oldest {
536 None => oldest = Some(memory.created_at),
537 Some(ref o) if memory.created_at < *o => oldest = Some(memory.created_at),
538 _ => {}
539 }
540 match newest {
541 None => newest = Some(memory.created_at),
542 Some(ref n) if memory.created_at > *n => newest = Some(memory.created_at),
543 _ => {}
544 }
545 }
546
547 let avg_importance = if count > 0 {
548 total_importance / count as f64
549 } else {
550 0.0
551 };
552 let avg_confidence = if count > 0 {
553 total_confidence / count as f64
554 } else {
555 0.0
556 };
557
558 Ok(NamespaceStats {
559 namespace: namespace.to_string(),
560 count,
561 avg_importance,
562 avg_confidence,
563 type_distribution,
564 tag_frequency,
565 oldest,
566 newest,
567 })
568 }
569
570 pub fn delete_namespace(&self, namespace: &str) -> Result<usize, CodememError> {
573 let ids = self.storage.list_memory_ids_for_namespace(namespace)?;
574
575 let mut deleted = 0usize;
576 let mut graph = self.lock_graph()?;
577 let mut vector = self.lock_vector()?;
578 let mut bm25 = self.lock_bm25()?;
579
580 for id in &ids {
581 if let Ok(true) = self.storage.delete_memory_cascade(id) {
583 deleted += 1;
584
585 let _ = vector.remove(id);
587 let _ = graph.remove_node(id);
588 bm25.remove_document(id);
589 }
590 }
591
592 drop(graph);
594 drop(vector);
595 drop(bm25);
596
597 self.save_index();
599
600 Ok(deleted)
601 }
602}
603
604pub(crate) fn is_edge_active(
609 edge: &codemem_core::Edge,
610 now: chrono::DateTime<chrono::Utc>,
611) -> bool {
612 if let Some(valid_to) = edge.valid_to {
613 if valid_to < now {
614 return false;
615 }
616 }
617 if let Some(valid_from) = edge.valid_from {
618 if valid_from > now {
619 return false;
620 }
621 }
622 true
623}