1use crate::scoring::compute_score;
4use crate::CodememEngine;
5use chrono::Utc;
6use codemem_core::{CodememError, MemoryNode, MemoryType, NodeKind, SearchResult};
7use std::collections::{HashMap, HashSet};
8
9#[derive(Debug, Clone)]
11pub struct ExpandedResult {
12 pub result: SearchResult,
13 pub expansion_path: String,
14}
15
16#[derive(Debug, Clone)]
18pub struct NamespaceStats {
19 pub namespace: String,
20 pub count: usize,
21 pub avg_importance: f64,
22 pub avg_confidence: f64,
23 pub type_distribution: HashMap<String, usize>,
24 pub tag_frequency: HashMap<String, usize>,
25 pub oldest: Option<chrono::DateTime<chrono::Utc>>,
26 pub newest: Option<chrono::DateTime<chrono::Utc>>,
27}
28
29#[derive(Debug, Clone)]
31pub struct RecallQuery<'a> {
32 pub query: &'a str,
33 pub k: usize,
34 pub memory_type_filter: Option<MemoryType>,
35 pub namespace_filter: Option<&'a str>,
36 pub exclude_tags: &'a [String],
37 pub min_importance: Option<f64>,
38 pub min_confidence: Option<f64>,
39 pub git_ref_filter: Option<&'a str>,
41}
42
43impl<'a> RecallQuery<'a> {
44 pub fn new(query: &'a str, k: usize) -> Self {
46 Self {
47 query,
48 k,
49 memory_type_filter: None,
50 namespace_filter: None,
51 exclude_tags: &[],
52 min_importance: None,
53 min_confidence: None,
54 git_ref_filter: None,
55 }
56 }
57}
58
59impl CodememEngine {
60 pub fn recall(&self, q: &RecallQuery<'_>) -> Result<Vec<SearchResult>, CodememError> {
67 self.sweep_expired_memories();
69
70 let vector_results: Vec<(String, f32)> = if let Some(emb_guard) = self.lock_embeddings()? {
72 match emb_guard.embed(q.query) {
73 Ok(query_embedding) => {
74 drop(emb_guard);
75 let vec = self.lock_vector()?;
76 vec.search(&query_embedding, q.k * 2) .unwrap_or_default()
78 }
79 Err(e) => {
80 tracing::warn!("Query embedding failed: {e}");
81 vec![]
82 }
83 }
84 } else {
85 vec![]
86 };
87
88 let query_tokens: Vec<String> = crate::bm25::tokenize(q.query);
92 let query_token_refs: Vec<&str> = query_tokens.iter().map(|s| s.as_str()).collect();
93
94 let mut graph = self.lock_graph()?;
98 graph.ensure_betweenness_computed();
101 let bm25 = self.lock_bm25()?;
102 let now = Utc::now();
103
104 let entity_memory_ids = self.resolve_entity_memories(q.query, &**graph, now);
108
109 let mut results: Vec<SearchResult> = Vec::new();
110 let weights = self.scoring_weights()?;
111
112 if vector_results.is_empty() {
113 let type_str = q.memory_type_filter.as_ref().map(|t| t.to_string());
115 let all_memories = self
116 .storage
117 .list_memories_filtered(q.namespace_filter, type_str.as_deref())?;
118
119 for memory in all_memories {
120 if !Self::passes_quality_filters(&memory, q) {
121 continue;
122 }
123
124 let breakdown =
125 compute_score(&memory, &query_token_refs, 0.0, &**graph, &bm25, now);
126 let score = breakdown.total_with_weights(&weights);
127 if score > 0.01 {
128 results.push(SearchResult {
129 memory,
130 score,
131 score_breakdown: breakdown,
132 });
133 }
134 }
135 } else {
136 let mut all_candidate_ids: HashSet<&str> =
138 vector_results.iter().map(|(id, _)| id.as_str()).collect();
139
140 for eid in &entity_memory_ids {
142 all_candidate_ids.insert(eid.as_str());
143 }
144
145 let candidate_id_vec: Vec<&str> = all_candidate_ids.into_iter().collect();
146 let candidate_memories = self.storage.get_memories_batch(&candidate_id_vec)?;
147
148 let sim_map: HashMap<&str, f64> = vector_results
150 .iter()
151 .map(|(id, sim)| (id.as_str(), *sim as f64))
152 .collect();
153
154 for memory in candidate_memories {
155 if let Some(ref filter_type) = q.memory_type_filter {
157 if memory.memory_type != *filter_type {
158 continue;
159 }
160 }
161 if let Some(ns) = q.namespace_filter {
163 if memory.namespace.as_deref() != Some(ns) {
164 continue;
165 }
166 }
167 if !Self::passes_quality_filters(&memory, q) {
168 continue;
169 }
170
171 let similarity = sim_map.get(memory.id.as_str()).copied().unwrap_or(0.0);
172 let breakdown =
173 compute_score(&memory, &query_token_refs, similarity, &**graph, &bm25, now);
174 let score = breakdown.total_with_weights(&weights);
175 if score > 0.01 {
176 results.push(SearchResult {
177 memory,
178 score,
179 score_breakdown: breakdown,
180 });
181 }
182 }
183 }
184
185 results.sort_by(|a, b| {
187 b.score
188 .partial_cmp(&a.score)
189 .unwrap_or(std::cmp::Ordering::Equal)
190 });
191 results.truncate(q.k);
192
193 Ok(results)
194 }
195
196 fn passes_quality_filters(memory: &MemoryNode, q: &RecallQuery<'_>) -> bool {
198 if memory.expires_at.is_some_and(|dt| dt <= Utc::now()) {
200 return false;
201 }
202 if !q.exclude_tags.is_empty() && memory.tags.iter().any(|t| q.exclude_tags.contains(t)) {
203 return false;
204 }
205 if let Some(min) = q.min_importance {
206 if memory.importance < min {
207 return false;
208 }
209 }
210 if let Some(min) = q.min_confidence {
211 if memory.confidence < min {
212 return false;
213 }
214 }
215 if let Some(ref_filter) = q.git_ref_filter {
216 if memory.git_ref.as_deref() != Some(ref_filter) {
217 return false;
218 }
219 }
220 true
221 }
222
223 pub fn recall_with_expansion(
227 &self,
228 query: &str,
229 k: usize,
230 expansion_depth: usize,
231 namespace_filter: Option<&str>,
232 ) -> Result<Vec<ExpandedResult>, CodememError> {
233 self.sweep_expired_memories();
235
236 let query_tokens: Vec<String> = crate::bm25::tokenize(query);
238 let query_token_refs: Vec<&str> = query_tokens.iter().map(|s| s.as_str()).collect();
239
240 let vector_results: Vec<(String, f32)> = if let Some(emb_guard) = self.lock_embeddings()? {
242 match emb_guard.embed(query) {
243 Ok(query_embedding) => {
244 drop(emb_guard);
245 let vec = self.lock_vector()?;
246 vec.search(&query_embedding, k * 2).unwrap_or_default()
247 }
248 Err(e) => {
249 tracing::warn!("Query embedding failed: {e}");
250 vec![]
251 }
252 }
253 } else {
254 vec![]
255 };
256
257 let mut graph = self.lock_graph()?;
258 graph.ensure_betweenness_computed();
260 let bm25 = self.lock_bm25()?;
261 let now = Utc::now();
262
263 struct ScoredMemory {
265 memory: MemoryNode,
266 vector_sim: f64,
267 expansion_path: String,
268 }
269
270 let mut all_memories: Vec<ScoredMemory> = Vec::new();
271 let mut seen_ids: HashSet<String> = HashSet::new();
272
273 if vector_results.is_empty() {
274 let all = self
276 .storage
277 .list_memories_filtered(namespace_filter, None)?;
278 let weights = self.scoring_weights()?;
279
280 for memory in all {
281 if memory.expires_at.is_some_and(|dt| dt <= now) {
282 continue;
283 }
284 let breakdown =
285 compute_score(&memory, &query_token_refs, 0.0, &**graph, &bm25, now);
286 let score = breakdown.total_with_weights(&weights);
287 if score > 0.01 {
288 seen_ids.insert(memory.id.clone());
289 all_memories.push(ScoredMemory {
290 memory,
291 vector_sim: 0.0,
292 expansion_path: "direct".to_string(),
293 });
294 }
295 }
296 } else {
297 let candidate_ids: Vec<&str> =
299 vector_results.iter().map(|(id, _)| id.as_str()).collect();
300 let candidate_memories = self.storage.get_memories_batch(&candidate_ids)?;
301
302 let sim_map: HashMap<&str, f64> = vector_results
303 .iter()
304 .map(|(id, sim)| (id.as_str(), *sim as f64))
305 .collect();
306
307 for memory in candidate_memories {
308 if memory.expires_at.is_some_and(|dt| dt <= now) {
309 continue;
310 }
311 if let Some(ns) = namespace_filter {
312 if memory.namespace.as_deref() != Some(ns) {
313 continue;
314 }
315 }
316 let similarity = sim_map.get(memory.id.as_str()).copied().unwrap_or(0.0);
317 seen_ids.insert(memory.id.clone());
318 all_memories.push(ScoredMemory {
319 memory,
320 vector_sim: similarity,
321 expansion_path: "direct".to_string(),
322 });
323 }
324 }
325
326 let direct_ids: Vec<String> = all_memories.iter().map(|m| m.memory.id.clone()).collect();
331
332 for direct_id in &direct_ids {
333 let direct_edges: Vec<_> = graph
336 .get_edges(direct_id)
337 .unwrap_or_default()
338 .into_iter()
339 .filter(|e| is_edge_active(e, now))
340 .collect();
341
342 if let Ok(expanded_nodes) =
345 graph.bfs_filtered(direct_id, expansion_depth, &[NodeKind::Chunk], None)
346 {
347 for expanded_node in &expanded_nodes {
348 if expanded_node.id == *direct_id {
350 continue;
351 }
352
353 if expanded_node.kind != NodeKind::Memory {
356 continue;
357 }
358
359 let memory_id = expanded_node
361 .memory_id
362 .as_deref()
363 .unwrap_or(&expanded_node.id);
364
365 if seen_ids.contains(memory_id) {
367 continue;
368 }
369
370 if let Ok(Some(memory)) = self.storage.get_memory_no_touch(memory_id) {
372 if memory.expires_at.is_some_and(|dt| dt <= now) {
373 continue;
374 }
375 if let Some(ns) = namespace_filter {
376 if memory.namespace.as_deref() != Some(ns) {
377 continue;
378 }
379 }
380
381 let expansion_path = direct_edges
383 .iter()
384 .find(|e| e.dst == expanded_node.id || e.src == expanded_node.id)
385 .map(|e| format!("via {} from {}", e.relationship, direct_id))
386 .unwrap_or_else(|| format!("via graph from {direct_id}"));
387
388 seen_ids.insert(memory_id.to_string());
389 all_memories.push(ScoredMemory {
390 memory,
391 vector_sim: 0.0,
392 expansion_path,
393 });
394 }
395 }
396 }
397 }
398
399 let weights = self.scoring_weights()?;
401 let mut scored_results: Vec<ExpandedResult> = all_memories
402 .into_iter()
403 .map(|sm| {
404 let breakdown = compute_score(
405 &sm.memory,
406 &query_token_refs,
407 sm.vector_sim,
408 &**graph,
409 &bm25,
410 now,
411 );
412 let score = breakdown.total_with_weights(&weights);
413 ExpandedResult {
414 result: SearchResult {
415 memory: sm.memory,
416 score,
417 score_breakdown: breakdown,
418 },
419 expansion_path: sm.expansion_path,
420 }
421 })
422 .collect();
423
424 scored_results.sort_by(|a, b| {
425 b.result
426 .score
427 .partial_cmp(&a.result.score)
428 .unwrap_or(std::cmp::Ordering::Equal)
429 });
430 scored_results.truncate(k);
431
432 Ok(scored_results)
433 }
434
435 pub(crate) fn resolve_entity_memories(
442 &self,
443 query: &str,
444 graph: &dyn codemem_core::GraphBackend,
445 now: chrono::DateTime<chrono::Utc>,
446 ) -> HashSet<String> {
447 let entity_refs = crate::search::extract_code_references(query);
448 let mut memory_ids: HashSet<String> = HashSet::new();
449
450 for entity_ref in &entity_refs {
451 let candidate_ids = [
453 format!("sym:{entity_ref}"),
454 format!("file:{entity_ref}"),
455 entity_ref.clone(),
456 ];
457
458 for candidate_id in &candidate_ids {
459 if graph.get_node_ref(candidate_id).is_none() {
460 continue;
461 }
462 for edge in graph.get_edges_ref(candidate_id) {
464 if !is_edge_active(edge, now) {
465 continue;
466 }
467 let neighbor_id = if edge.src == *candidate_id {
468 &edge.dst
469 } else {
470 &edge.src
471 };
472 if let Some(node) = graph.get_node_ref(neighbor_id) {
473 if node.kind == NodeKind::Memory {
474 let mem_id = node.memory_id.as_deref().unwrap_or(&node.id);
475 memory_ids.insert(mem_id.to_string());
476 }
477 }
478 }
479 break; }
481 }
482
483 memory_ids
484 }
485
486 pub fn namespace_stats(&self, namespace: &str) -> Result<NamespaceStats, CodememError> {
489 let ids = self.storage.list_memory_ids_for_namespace(namespace)?;
490
491 if ids.is_empty() {
492 return Ok(NamespaceStats {
493 namespace: namespace.to_string(),
494 count: 0,
495 avg_importance: 0.0,
496 avg_confidence: 0.0,
497 type_distribution: HashMap::new(),
498 tag_frequency: HashMap::new(),
499 oldest: None,
500 newest: None,
501 });
502 }
503
504 let mut total_importance = 0.0;
505 let mut total_confidence = 0.0;
506 let mut type_distribution: HashMap<String, usize> = HashMap::new();
507 let mut tag_frequency: HashMap<String, usize> = HashMap::new();
508 let mut oldest: Option<chrono::DateTime<chrono::Utc>> = None;
509 let mut newest: Option<chrono::DateTime<chrono::Utc>> = None;
510 let mut count = 0usize;
511
512 let id_refs: Vec<&str> = ids.iter().map(|s| s.as_str()).collect();
516 let memories = self.storage.get_memories_batch(&id_refs)?;
517
518 for memory in &memories {
519 count += 1;
520 total_importance += memory.importance;
521 total_confidence += memory.confidence;
522
523 *type_distribution
524 .entry(memory.memory_type.to_string())
525 .or_insert(0) += 1;
526
527 for tag in &memory.tags {
528 *tag_frequency.entry(tag.clone()).or_insert(0) += 1;
529 }
530
531 match oldest {
532 None => oldest = Some(memory.created_at),
533 Some(ref o) if memory.created_at < *o => oldest = Some(memory.created_at),
534 _ => {}
535 }
536 match newest {
537 None => newest = Some(memory.created_at),
538 Some(ref n) if memory.created_at > *n => newest = Some(memory.created_at),
539 _ => {}
540 }
541 }
542
543 let avg_importance = if count > 0 {
544 total_importance / count as f64
545 } else {
546 0.0
547 };
548 let avg_confidence = if count > 0 {
549 total_confidence / count as f64
550 } else {
551 0.0
552 };
553
554 Ok(NamespaceStats {
555 namespace: namespace.to_string(),
556 count,
557 avg_importance,
558 avg_confidence,
559 type_distribution,
560 tag_frequency,
561 oldest,
562 newest,
563 })
564 }
565
566 pub fn delete_namespace(&self, namespace: &str) -> Result<usize, CodememError> {
569 let ids = self.storage.list_memory_ids_for_namespace(namespace)?;
570
571 let mut deleted = 0usize;
572 let mut graph = self.lock_graph()?;
573 let mut vector = self.lock_vector()?;
574 let mut bm25 = self.lock_bm25()?;
575
576 for id in &ids {
577 if let Ok(true) = self.storage.delete_memory_cascade(id) {
579 deleted += 1;
580
581 let _ = vector.remove(id);
583 let _ = graph.remove_node(id);
584 bm25.remove_document(id);
585 }
586 }
587
588 drop(graph);
590 drop(vector);
591 drop(bm25);
592
593 self.save_index();
595
596 Ok(deleted)
597 }
598}
599
600pub(crate) fn is_edge_active(
605 edge: &codemem_core::Edge,
606 now: chrono::DateTime<chrono::Utc>,
607) -> bool {
608 if let Some(valid_to) = edge.valid_to {
609 if valid_to < now {
610 return false;
611 }
612 }
613 if let Some(valid_from) = edge.valid_from {
614 if valid_from > now {
615 return false;
616 }
617 }
618 true
619}