1use crate::scoring::compute_score;
4use crate::CodememEngine;
5use chrono::Utc;
6use codemem_core::{
7 CodememError, GraphBackend, MemoryNode, MemoryType, NodeKind, SearchResult, VectorBackend,
8};
9use std::collections::{HashMap, HashSet};
10
11#[derive(Debug, Clone)]
13pub struct ExpandedResult {
14 pub result: SearchResult,
15 pub expansion_path: String,
16}
17
18#[derive(Debug, Clone)]
20pub struct NamespaceStats {
21 pub namespace: String,
22 pub count: usize,
23 pub avg_importance: f64,
24 pub avg_confidence: f64,
25 pub type_distribution: HashMap<String, usize>,
26 pub tag_frequency: HashMap<String, usize>,
27 pub oldest: Option<chrono::DateTime<chrono::Utc>>,
28 pub newest: Option<chrono::DateTime<chrono::Utc>>,
29}
30
31#[derive(Debug, Clone)]
33pub struct RecallQuery<'a> {
34 pub query: &'a str,
35 pub k: usize,
36 pub memory_type_filter: Option<MemoryType>,
37 pub namespace_filter: Option<&'a str>,
38 pub exclude_tags: &'a [String],
39 pub min_importance: Option<f64>,
40 pub min_confidence: Option<f64>,
41}
42
43impl<'a> RecallQuery<'a> {
44 pub fn new(query: &'a str, k: usize) -> Self {
46 Self {
47 query,
48 k,
49 memory_type_filter: None,
50 namespace_filter: None,
51 exclude_tags: &[],
52 min_importance: None,
53 min_confidence: None,
54 }
55 }
56}
57
58impl CodememEngine {
59 pub fn recall(&self, q: &RecallQuery<'_>) -> Result<Vec<SearchResult>, CodememError> {
66 let vector_results: Vec<(String, f32)> = if let Some(emb_guard) = self.lock_embeddings()? {
68 match emb_guard.embed(q.query) {
69 Ok(query_embedding) => {
70 drop(emb_guard);
71 let vec = self.lock_vector()?;
72 vec.search(&query_embedding, q.k * 2) .unwrap_or_default()
74 }
75 Err(e) => {
76 tracing::warn!("Query embedding failed: {e}");
77 vec![]
78 }
79 }
80 } else {
81 vec![]
82 };
83
84 let query_tokens: Vec<String> = crate::bm25::tokenize(q.query);
88 let query_token_refs: Vec<&str> = query_tokens.iter().map(|s| s.as_str()).collect();
89
90 let mut graph = self.lock_graph()?;
94 graph.ensure_betweenness_computed();
97 let bm25 = self.lock_bm25()?;
98 let now = Utc::now();
99
100 let entity_memory_ids = self.resolve_entity_memories(q.query, &graph, now);
104
105 let mut results: Vec<SearchResult> = Vec::new();
106 let weights = self.scoring_weights()?;
107
108 if vector_results.is_empty() {
109 let type_str = q.memory_type_filter.as_ref().map(|t| t.to_string());
111 let all_memories = self
112 .storage
113 .list_memories_filtered(q.namespace_filter, type_str.as_deref())?;
114
115 for memory in all_memories {
116 if !Self::passes_quality_filters(&memory, q) {
117 continue;
118 }
119
120 let breakdown = compute_score(&memory, &query_token_refs, 0.0, &graph, &bm25, now);
121 let score = breakdown.total_with_weights(&weights);
122 if score > 0.01 {
123 results.push(SearchResult {
124 memory,
125 score,
126 score_breakdown: breakdown,
127 });
128 }
129 }
130 } else {
131 let mut all_candidate_ids: HashSet<&str> =
133 vector_results.iter().map(|(id, _)| id.as_str()).collect();
134
135 for eid in &entity_memory_ids {
137 all_candidate_ids.insert(eid.as_str());
138 }
139
140 let candidate_id_vec: Vec<&str> = all_candidate_ids.into_iter().collect();
141 let candidate_memories = self.storage.get_memories_batch(&candidate_id_vec)?;
142
143 let sim_map: HashMap<&str, f64> = vector_results
145 .iter()
146 .map(|(id, sim)| (id.as_str(), *sim as f64))
147 .collect();
148
149 for memory in candidate_memories {
150 if let Some(ref filter_type) = q.memory_type_filter {
152 if memory.memory_type != *filter_type {
153 continue;
154 }
155 }
156 if let Some(ns) = q.namespace_filter {
158 if memory.namespace.as_deref() != Some(ns) {
159 continue;
160 }
161 }
162 if !Self::passes_quality_filters(&memory, q) {
163 continue;
164 }
165
166 let similarity = sim_map.get(memory.id.as_str()).copied().unwrap_or(0.0);
167 let breakdown =
168 compute_score(&memory, &query_token_refs, similarity, &graph, &bm25, now);
169 let score = breakdown.total_with_weights(&weights);
170 if score > 0.01 {
171 results.push(SearchResult {
172 memory,
173 score,
174 score_breakdown: breakdown,
175 });
176 }
177 }
178 }
179
180 results.sort_by(|a, b| {
182 b.score
183 .partial_cmp(&a.score)
184 .unwrap_or(std::cmp::Ordering::Equal)
185 });
186 results.truncate(q.k);
187
188 Ok(results)
189 }
190
191 fn passes_quality_filters(memory: &MemoryNode, q: &RecallQuery<'_>) -> bool {
193 if !q.exclude_tags.is_empty() && memory.tags.iter().any(|t| q.exclude_tags.contains(t)) {
194 return false;
195 }
196 if let Some(min) = q.min_importance {
197 if memory.importance < min {
198 return false;
199 }
200 }
201 if let Some(min) = q.min_confidence {
202 if memory.confidence < min {
203 return false;
204 }
205 }
206 true
207 }
208
209 pub fn recall_with_expansion(
213 &self,
214 query: &str,
215 k: usize,
216 expansion_depth: usize,
217 namespace_filter: Option<&str>,
218 ) -> Result<Vec<ExpandedResult>, CodememError> {
219 let query_tokens: Vec<String> = crate::bm25::tokenize(query);
221 let query_token_refs: Vec<&str> = query_tokens.iter().map(|s| s.as_str()).collect();
222
223 let vector_results: Vec<(String, f32)> = if let Some(emb_guard) = self.lock_embeddings()? {
225 match emb_guard.embed(query) {
226 Ok(query_embedding) => {
227 drop(emb_guard);
228 let vec = self.lock_vector()?;
229 vec.search(&query_embedding, k * 2).unwrap_or_default()
230 }
231 Err(e) => {
232 tracing::warn!("Query embedding failed: {e}");
233 vec![]
234 }
235 }
236 } else {
237 vec![]
238 };
239
240 let mut graph = self.lock_graph()?;
241 graph.ensure_betweenness_computed();
243 let bm25 = self.lock_bm25()?;
244 let now = Utc::now();
245
246 struct ScoredMemory {
248 memory: MemoryNode,
249 vector_sim: f64,
250 expansion_path: String,
251 }
252
253 let mut all_memories: Vec<ScoredMemory> = Vec::new();
254 let mut seen_ids: HashSet<String> = HashSet::new();
255
256 if vector_results.is_empty() {
257 let all = self
259 .storage
260 .list_memories_filtered(namespace_filter, None)?;
261 let weights = self.scoring_weights()?;
262
263 for memory in all {
264 let breakdown = compute_score(&memory, &query_token_refs, 0.0, &graph, &bm25, now);
265 let score = breakdown.total_with_weights(&weights);
266 if score > 0.01 {
267 seen_ids.insert(memory.id.clone());
268 all_memories.push(ScoredMemory {
269 memory,
270 vector_sim: 0.0,
271 expansion_path: "direct".to_string(),
272 });
273 }
274 }
275 } else {
276 let candidate_ids: Vec<&str> =
278 vector_results.iter().map(|(id, _)| id.as_str()).collect();
279 let candidate_memories = self.storage.get_memories_batch(&candidate_ids)?;
280
281 let sim_map: HashMap<&str, f64> = vector_results
282 .iter()
283 .map(|(id, sim)| (id.as_str(), *sim as f64))
284 .collect();
285
286 for memory in candidate_memories {
287 if let Some(ns) = namespace_filter {
288 if memory.namespace.as_deref() != Some(ns) {
289 continue;
290 }
291 }
292 let similarity = sim_map.get(memory.id.as_str()).copied().unwrap_or(0.0);
293 seen_ids.insert(memory.id.clone());
294 all_memories.push(ScoredMemory {
295 memory,
296 vector_sim: similarity,
297 expansion_path: "direct".to_string(),
298 });
299 }
300 }
301
302 let direct_ids: Vec<String> = all_memories.iter().map(|m| m.memory.id.clone()).collect();
307
308 for direct_id in &direct_ids {
309 let direct_edges: Vec<_> = graph
312 .get_edges(direct_id)
313 .unwrap_or_default()
314 .into_iter()
315 .filter(|e| is_edge_active(e, now))
316 .collect();
317
318 if let Ok(expanded_nodes) =
321 graph.bfs_filtered(direct_id, expansion_depth, &[NodeKind::Chunk], None)
322 {
323 for expanded_node in &expanded_nodes {
324 if expanded_node.id == *direct_id {
326 continue;
327 }
328
329 if expanded_node.kind != NodeKind::Memory {
332 continue;
333 }
334
335 let memory_id = expanded_node
337 .memory_id
338 .as_deref()
339 .unwrap_or(&expanded_node.id);
340
341 if seen_ids.contains(memory_id) {
343 continue;
344 }
345
346 if let Ok(Some(memory)) = self.storage.get_memory_no_touch(memory_id) {
348 if let Some(ns) = namespace_filter {
349 if memory.namespace.as_deref() != Some(ns) {
350 continue;
351 }
352 }
353
354 let expansion_path = direct_edges
356 .iter()
357 .find(|e| e.dst == expanded_node.id || e.src == expanded_node.id)
358 .map(|e| format!("via {} from {}", e.relationship, direct_id))
359 .unwrap_or_else(|| format!("via graph from {direct_id}"));
360
361 seen_ids.insert(memory_id.to_string());
362 all_memories.push(ScoredMemory {
363 memory,
364 vector_sim: 0.0,
365 expansion_path,
366 });
367 }
368 }
369 }
370 }
371
372 let weights = self.scoring_weights()?;
374 let mut scored_results: Vec<ExpandedResult> = all_memories
375 .into_iter()
376 .map(|sm| {
377 let breakdown = compute_score(
378 &sm.memory,
379 &query_token_refs,
380 sm.vector_sim,
381 &graph,
382 &bm25,
383 now,
384 );
385 let score = breakdown.total_with_weights(&weights);
386 ExpandedResult {
387 result: SearchResult {
388 memory: sm.memory,
389 score,
390 score_breakdown: breakdown,
391 },
392 expansion_path: sm.expansion_path,
393 }
394 })
395 .collect();
396
397 scored_results.sort_by(|a, b| {
398 b.result
399 .score
400 .partial_cmp(&a.result.score)
401 .unwrap_or(std::cmp::Ordering::Equal)
402 });
403 scored_results.truncate(k);
404
405 Ok(scored_results)
406 }
407
408 pub(crate) fn resolve_entity_memories(
415 &self,
416 query: &str,
417 graph: &codemem_storage::graph::GraphEngine,
418 now: chrono::DateTime<chrono::Utc>,
419 ) -> HashSet<String> {
420 let entity_refs = crate::search::extract_code_references(query);
421 let mut memory_ids: HashSet<String> = HashSet::new();
422
423 for entity_ref in &entity_refs {
424 let candidate_ids = [
426 format!("sym:{entity_ref}"),
427 format!("file:{entity_ref}"),
428 entity_ref.clone(),
429 ];
430
431 for candidate_id in &candidate_ids {
432 if graph.get_node_ref(candidate_id).is_none() {
433 continue;
434 }
435 for edge in graph.get_edges_ref(candidate_id) {
437 if !is_edge_active(edge, now) {
438 continue;
439 }
440 let neighbor_id = if edge.src == *candidate_id {
441 &edge.dst
442 } else {
443 &edge.src
444 };
445 if let Some(node) = graph.get_node_ref(neighbor_id) {
446 if node.kind == NodeKind::Memory {
447 let mem_id = node.memory_id.as_deref().unwrap_or(&node.id);
448 memory_ids.insert(mem_id.to_string());
449 }
450 }
451 }
452 break; }
454 }
455
456 memory_ids
457 }
458
459 pub fn namespace_stats(&self, namespace: &str) -> Result<NamespaceStats, CodememError> {
462 let ids = self.storage.list_memory_ids_for_namespace(namespace)?;
463
464 if ids.is_empty() {
465 return Ok(NamespaceStats {
466 namespace: namespace.to_string(),
467 count: 0,
468 avg_importance: 0.0,
469 avg_confidence: 0.0,
470 type_distribution: HashMap::new(),
471 tag_frequency: HashMap::new(),
472 oldest: None,
473 newest: None,
474 });
475 }
476
477 let mut total_importance = 0.0;
478 let mut total_confidence = 0.0;
479 let mut type_distribution: HashMap<String, usize> = HashMap::new();
480 let mut tag_frequency: HashMap<String, usize> = HashMap::new();
481 let mut oldest: Option<chrono::DateTime<chrono::Utc>> = None;
482 let mut newest: Option<chrono::DateTime<chrono::Utc>> = None;
483 let mut count = 0usize;
484
485 let id_refs: Vec<&str> = ids.iter().map(|s| s.as_str()).collect();
489 let memories = self.storage.get_memories_batch(&id_refs)?;
490
491 for memory in &memories {
492 count += 1;
493 total_importance += memory.importance;
494 total_confidence += memory.confidence;
495
496 *type_distribution
497 .entry(memory.memory_type.to_string())
498 .or_insert(0) += 1;
499
500 for tag in &memory.tags {
501 *tag_frequency.entry(tag.clone()).or_insert(0) += 1;
502 }
503
504 match oldest {
505 None => oldest = Some(memory.created_at),
506 Some(ref o) if memory.created_at < *o => oldest = Some(memory.created_at),
507 _ => {}
508 }
509 match newest {
510 None => newest = Some(memory.created_at),
511 Some(ref n) if memory.created_at > *n => newest = Some(memory.created_at),
512 _ => {}
513 }
514 }
515
516 let avg_importance = if count > 0 {
517 total_importance / count as f64
518 } else {
519 0.0
520 };
521 let avg_confidence = if count > 0 {
522 total_confidence / count as f64
523 } else {
524 0.0
525 };
526
527 Ok(NamespaceStats {
528 namespace: namespace.to_string(),
529 count,
530 avg_importance,
531 avg_confidence,
532 type_distribution,
533 tag_frequency,
534 oldest,
535 newest,
536 })
537 }
538
539 pub fn delete_namespace(&self, namespace: &str) -> Result<usize, CodememError> {
542 let ids = self.storage.list_memory_ids_for_namespace(namespace)?;
543
544 let mut deleted = 0usize;
545 let mut graph = self.lock_graph()?;
546 let mut vector = self.lock_vector()?;
547 let mut bm25 = self.lock_bm25()?;
548
549 for id in &ids {
550 if let Ok(true) = self.storage.delete_memory_cascade(id) {
552 deleted += 1;
553
554 let _ = vector.remove(id);
556 let _ = graph.remove_node(id);
557 bm25.remove_document(id);
558 }
559 }
560
561 drop(graph);
563 drop(vector);
564 drop(bm25);
565
566 self.save_index();
568
569 Ok(deleted)
570 }
571}
572
573pub(crate) fn is_edge_active(
578 edge: &codemem_core::Edge,
579 now: chrono::DateTime<chrono::Utc>,
580) -> bool {
581 if let Some(valid_to) = edge.valid_to {
582 if valid_to < now {
583 return false;
584 }
585 }
586 if let Some(valid_from) = edge.valid_from {
587 if valid_from > now {
588 return false;
589 }
590 }
591 true
592}