Skip to main content

codemem_engine/
recall.rs

1//! Memory recall with hybrid scoring.
2
3use crate::scoring::compute_score;
4use crate::CodememEngine;
5use chrono::Utc;
6use codemem_core::{CodememError, MemoryNode, MemoryType, NodeKind, SearchResult};
7use std::collections::{HashMap, HashSet};
8
9/// A recall result that includes the expansion path taken to reach the memory.
10#[derive(Debug, Clone)]
11pub struct ExpandedResult {
12    pub result: SearchResult,
13    pub expansion_path: String,
14}
15
16/// Aggregated stats for a single namespace.
17#[derive(Debug, Clone)]
18pub struct NamespaceStats {
19    pub namespace: String,
20    pub count: usize,
21    pub avg_importance: f64,
22    pub avg_confidence: f64,
23    pub type_distribution: HashMap<String, usize>,
24    pub tag_frequency: HashMap<String, usize>,
25    pub oldest: Option<chrono::DateTime<chrono::Utc>>,
26    pub newest: Option<chrono::DateTime<chrono::Utc>>,
27}
28
29/// Parameters for the recall query.
30#[derive(Debug, Clone)]
31pub struct RecallQuery<'a> {
32    pub query: &'a str,
33    pub k: usize,
34    pub memory_type_filter: Option<MemoryType>,
35    pub namespace_filter: Option<&'a str>,
36    pub exclude_tags: &'a [String],
37    pub min_importance: Option<f64>,
38    pub min_confidence: Option<f64>,
39    /// Filter results to memories with this git ref (branch/tag).
40    pub git_ref_filter: Option<&'a str>,
41}
42
43impl<'a> RecallQuery<'a> {
44    /// Create a minimal recall query with just the search text and result limit.
45    pub fn new(query: &'a str, k: usize) -> Self {
46        Self {
47            query,
48            k,
49            memory_type_filter: None,
50            namespace_filter: None,
51            exclude_tags: &[],
52            min_importance: None,
53            min_confidence: None,
54            git_ref_filter: None,
55        }
56    }
57}
58
59impl CodememEngine {
60    /// Core recall logic: search storage with hybrid scoring and return ranked results.
61    ///
62    /// Combines vector search (if embeddings available), BM25, graph strength,
63    /// temporal, tag matching, importance, confidence, and recency into a
64    /// 9-component hybrid score. Supports filtering by memory type, namespace,
65    /// tag exclusion, and minimum importance/confidence thresholds.
66    pub fn recall(&self, q: &RecallQuery<'_>) -> Result<Vec<SearchResult>, CodememError> {
67        // Opportunistic cleanup of expired memories (rate-limited to once per 60s)
68        self.sweep_expired_memories();
69
70        // Try vector search first (if embeddings available)
71        let vector_results: Vec<(String, f32)> = if let Some(emb_guard) = self.lock_embeddings()? {
72            match emb_guard.embed(q.query) {
73                Ok(query_embedding) => {
74                    drop(emb_guard);
75                    let vec = self.lock_vector()?;
76                    vec.search(&query_embedding, q.k * 2) // over-fetch for re-ranking
77                        .unwrap_or_default()
78                }
79                Err(e) => {
80                    tracing::warn!("Query embedding failed: {e}");
81                    vec![]
82                }
83            }
84        } else {
85            vec![]
86        };
87
88        // H1: Use code-aware tokenizer for query tokens so that compound identifiers
89        // like "parseFunction" are split into ["parse", "function"] — matching the
90        // tokenization used when documents were added to the BM25 index.
91        let query_tokens: Vec<String> = crate::bm25::tokenize(q.query);
92        let query_token_refs: Vec<&str> = query_tokens.iter().map(|s| s.as_str()).collect();
93
94        // Graph and BM25 intentionally load different data: graph stores structural relationships
95        // (nodes/edges), while BM25 indexes memory content for text search. This is by design,
96        // not duplication.
97        let mut graph = self.lock_graph()?;
98        // C1: Lazily compute betweenness centrality before scoring so the
99        // betweenness component (30% of graph_strength) is not permanently zero.
100        graph.ensure_betweenness_computed();
101        let bm25 = self.lock_bm25()?;
102        let now = Utc::now();
103
104        // Entity expansion: find memories connected to code entities mentioned in the query.
105        // This ensures that structurally related memories are candidates even when they are
106        // semantically distant from the query text.
107        let entity_memory_ids = self.resolve_entity_memories(q.query, &**graph, now);
108
109        let mut results: Vec<SearchResult> = Vec::new();
110        let weights = self.scoring_weights()?;
111
112        if vector_results.is_empty() {
113            // Fallback: batch-load all memories matching filters in one query
114            let type_str = q.memory_type_filter.as_ref().map(|t| t.to_string());
115            let all_memories = self
116                .storage
117                .list_memories_filtered(q.namespace_filter, type_str.as_deref())?;
118
119            for memory in all_memories {
120                if !Self::passes_quality_filters(&memory, q) {
121                    continue;
122                }
123
124                let breakdown =
125                    compute_score(&memory, &query_token_refs, 0.0, &**graph, &bm25, now);
126                let score = breakdown.total_with_weights(&weights);
127                if score > 0.01 {
128                    results.push(SearchResult {
129                        memory,
130                        score,
131                        score_breakdown: breakdown,
132                    });
133                }
134            }
135        } else {
136            // Vector search path: batch-fetch all candidate memories + entity-connected memories
137            let mut all_candidate_ids: HashSet<&str> =
138                vector_results.iter().map(|(id, _)| id.as_str()).collect();
139
140            // Merge entity-connected memory IDs into the candidate pool
141            for eid in &entity_memory_ids {
142                all_candidate_ids.insert(eid.as_str());
143            }
144
145            let candidate_id_vec: Vec<&str> = all_candidate_ids.into_iter().collect();
146            let candidate_memories = self.storage.get_memories_batch(&candidate_id_vec)?;
147
148            // Build similarity lookup (entity memories will get 0.0 similarity)
149            let sim_map: HashMap<&str, f64> = vector_results
150                .iter()
151                .map(|(id, sim)| (id.as_str(), *sim as f64))
152                .collect();
153
154            for memory in candidate_memories {
155                // Apply memory_type filter
156                if let Some(ref filter_type) = q.memory_type_filter {
157                    if memory.memory_type != *filter_type {
158                        continue;
159                    }
160                }
161                // Apply namespace filter
162                if let Some(ns) = q.namespace_filter {
163                    if memory.namespace.as_deref() != Some(ns) {
164                        continue;
165                    }
166                }
167                if !Self::passes_quality_filters(&memory, q) {
168                    continue;
169                }
170
171                let similarity = sim_map.get(memory.id.as_str()).copied().unwrap_or(0.0);
172                let breakdown =
173                    compute_score(&memory, &query_token_refs, similarity, &**graph, &bm25, now);
174                let score = breakdown.total_with_weights(&weights);
175                if score > 0.01 {
176                    results.push(SearchResult {
177                        memory,
178                        score,
179                        score_breakdown: breakdown,
180                    });
181                }
182            }
183        }
184
185        // Sort by score descending, take top k
186        results.sort_by(|a, b| {
187            b.score
188                .partial_cmp(&a.score)
189                .unwrap_or(std::cmp::Ordering::Equal)
190        });
191        results.truncate(q.k);
192
193        Ok(results)
194    }
195
196    /// Check expiry, exclude_tags, min_importance, min_confidence, and git_ref filters.
197    fn passes_quality_filters(memory: &MemoryNode, q: &RecallQuery<'_>) -> bool {
198        // Skip expired memories (their embeddings may linger in HNSW until next sweep)
199        if memory.expires_at.is_some_and(|dt| dt <= Utc::now()) {
200            return false;
201        }
202        if !q.exclude_tags.is_empty() && memory.tags.iter().any(|t| q.exclude_tags.contains(t)) {
203            return false;
204        }
205        if let Some(min) = q.min_importance {
206            if memory.importance < min {
207                return false;
208            }
209        }
210        if let Some(min) = q.min_confidence {
211            if memory.confidence < min {
212                return false;
213            }
214        }
215        if let Some(ref_filter) = q.git_ref_filter {
216            if memory.git_ref.as_deref() != Some(ref_filter) {
217                return false;
218            }
219        }
220        true
221    }
222
223    /// Recall with graph expansion: vector search (or BM25 fallback) for seed
224    /// memories, then BFS expansion from each seed through the graph, scoring
225    /// all candidates with the 9-component hybrid scorer.
226    pub fn recall_with_expansion(
227        &self,
228        query: &str,
229        k: usize,
230        expansion_depth: usize,
231        namespace_filter: Option<&str>,
232    ) -> Result<Vec<ExpandedResult>, CodememError> {
233        // Opportunistic cleanup of expired memories (rate-limited to once per 60s)
234        self.sweep_expired_memories();
235
236        // H1: Code-aware tokenization for consistent BM25 scoring
237        let query_tokens: Vec<String> = crate::bm25::tokenize(query);
238        let query_token_refs: Vec<&str> = query_tokens.iter().map(|s| s.as_str()).collect();
239
240        // Step 1: Run normal vector search (or text fallback)
241        let vector_results: Vec<(String, f32)> = if let Some(emb_guard) = self.lock_embeddings()? {
242            match emb_guard.embed(query) {
243                Ok(query_embedding) => {
244                    drop(emb_guard);
245                    let vec = self.lock_vector()?;
246                    vec.search(&query_embedding, k * 2).unwrap_or_default()
247                }
248                Err(e) => {
249                    tracing::warn!("Query embedding failed: {e}");
250                    vec![]
251                }
252            }
253        } else {
254            vec![]
255        };
256
257        let mut graph = self.lock_graph()?;
258        // C1: Lazily compute betweenness centrality before scoring
259        graph.ensure_betweenness_computed();
260        let bm25 = self.lock_bm25()?;
261        let now = Utc::now();
262
263        // Collect initial seed memories with their vector similarity
264        struct ScoredMemory {
265            memory: MemoryNode,
266            vector_sim: f64,
267            expansion_path: String,
268        }
269
270        let mut all_memories: Vec<ScoredMemory> = Vec::new();
271        let mut seen_ids: HashSet<String> = HashSet::new();
272
273        if vector_results.is_empty() {
274            // Fallback: batch-load all memories matching namespace in one query
275            let all = self
276                .storage
277                .list_memories_filtered(namespace_filter, None)?;
278            let weights = self.scoring_weights()?;
279
280            for memory in all {
281                if memory.expires_at.is_some_and(|dt| dt <= now) {
282                    continue;
283                }
284                let breakdown =
285                    compute_score(&memory, &query_token_refs, 0.0, &**graph, &bm25, now);
286                let score = breakdown.total_with_weights(&weights);
287                if score > 0.01 {
288                    seen_ids.insert(memory.id.clone());
289                    all_memories.push(ScoredMemory {
290                        memory,
291                        vector_sim: 0.0,
292                        expansion_path: "direct".to_string(),
293                    });
294                }
295            }
296        } else {
297            // Vector search path: batch-fetch all candidate memories
298            let candidate_ids: Vec<&str> =
299                vector_results.iter().map(|(id, _)| id.as_str()).collect();
300            let candidate_memories = self.storage.get_memories_batch(&candidate_ids)?;
301
302            let sim_map: HashMap<&str, f64> = vector_results
303                .iter()
304                .map(|(id, sim)| (id.as_str(), *sim as f64))
305                .collect();
306
307            for memory in candidate_memories {
308                if memory.expires_at.is_some_and(|dt| dt <= now) {
309                    continue;
310                }
311                if let Some(ns) = namespace_filter {
312                    if memory.namespace.as_deref() != Some(ns) {
313                        continue;
314                    }
315                }
316                let similarity = sim_map.get(memory.id.as_str()).copied().unwrap_or(0.0);
317                seen_ids.insert(memory.id.clone());
318                all_memories.push(ScoredMemory {
319                    memory,
320                    vector_sim: similarity,
321                    expansion_path: "direct".to_string(),
322                });
323            }
324        }
325
326        // Step 2-4: Graph expansion from each direct result
327        // A7: BFS traverses through ALL node kinds (including code nodes like
328        // File, Function, etc.) as intermediaries, but only COLLECTS Memory nodes.
329        // A6: Apply temporal edge filtering — skip edges whose valid_to < now.
330        let direct_ids: Vec<String> = all_memories.iter().map(|m| m.memory.id.clone()).collect();
331
332        for direct_id in &direct_ids {
333            // Cache edges for this direct node outside the inner loop,
334            // filtering out expired temporal edges (A6)
335            let direct_edges: Vec<_> = graph
336                .get_edges(direct_id)
337                .unwrap_or_default()
338                .into_iter()
339                .filter(|e| is_edge_active(e, now))
340                .collect();
341
342            // A7: Only exclude Chunk from BFS traversal (noisy), but allow
343            // File, Function, Class, etc. as intermediaries to reach more Memory nodes
344            if let Ok(expanded_nodes) =
345                graph.bfs_filtered(direct_id, expansion_depth, &[NodeKind::Chunk], None)
346            {
347                for expanded_node in &expanded_nodes {
348                    // Skip the start node itself (already in results)
349                    if expanded_node.id == *direct_id {
350                        continue;
351                    }
352
353                    // A7: Only COLLECT Memory nodes in results, but we
354                    // traversed through all other node kinds to reach them
355                    if expanded_node.kind != NodeKind::Memory {
356                        continue;
357                    }
358
359                    // Get the memory_id from the graph node
360                    let memory_id = expanded_node
361                        .memory_id
362                        .as_deref()
363                        .unwrap_or(&expanded_node.id);
364
365                    // Skip if already seen
366                    if seen_ids.contains(memory_id) {
367                        continue;
368                    }
369
370                    // Fetch the memory (no-touch to avoid inflating access_count)
371                    if let Ok(Some(memory)) = self.storage.get_memory_no_touch(memory_id) {
372                        if memory.expires_at.is_some_and(|dt| dt <= now) {
373                            continue;
374                        }
375                        if let Some(ns) = namespace_filter {
376                            if memory.namespace.as_deref() != Some(ns) {
377                                continue;
378                            }
379                        }
380
381                        // Build expansion path description using cached edges
382                        let expansion_path = direct_edges
383                            .iter()
384                            .find(|e| e.dst == expanded_node.id || e.src == expanded_node.id)
385                            .map(|e| format!("via {} from {}", e.relationship, direct_id))
386                            .unwrap_or_else(|| format!("via graph from {direct_id}"));
387
388                        seen_ids.insert(memory_id.to_string());
389                        all_memories.push(ScoredMemory {
390                            memory,
391                            vector_sim: 0.0,
392                            expansion_path,
393                        });
394                    }
395                }
396            }
397        }
398
399        // Step 5-6: Score all memories and sort
400        let weights = self.scoring_weights()?;
401        let mut scored_results: Vec<ExpandedResult> = all_memories
402            .into_iter()
403            .map(|sm| {
404                let breakdown = compute_score(
405                    &sm.memory,
406                    &query_token_refs,
407                    sm.vector_sim,
408                    &**graph,
409                    &bm25,
410                    now,
411                );
412                let score = breakdown.total_with_weights(&weights);
413                ExpandedResult {
414                    result: SearchResult {
415                        memory: sm.memory,
416                        score,
417                        score_breakdown: breakdown,
418                    },
419                    expansion_path: sm.expansion_path,
420                }
421            })
422            .collect();
423
424        scored_results.sort_by(|a, b| {
425            b.result
426                .score
427                .partial_cmp(&a.result.score)
428                .unwrap_or(std::cmp::Ordering::Equal)
429        });
430        scored_results.truncate(k);
431
432        Ok(scored_results)
433    }
434
435    /// Resolve entity references from a query to memory IDs connected to those entities.
436    ///
437    /// Extracts code references (CamelCase identifiers, qualified paths, file paths,
438    /// backtick-wrapped code) from the query, matches them to graph nodes, and returns
439    /// the IDs of Memory nodes within one hop of each matched entity. This ensures
440    /// structurally related memories are recall candidates even when semantically distant.
441    pub(crate) fn resolve_entity_memories(
442        &self,
443        query: &str,
444        graph: &dyn codemem_core::GraphBackend,
445        now: chrono::DateTime<chrono::Utc>,
446    ) -> HashSet<String> {
447        let entity_refs = crate::search::extract_code_references(query);
448        let mut memory_ids: HashSet<String> = HashSet::new();
449
450        for entity_ref in &entity_refs {
451            // Try common ID patterns: sym:Name, file:path, or direct ID match
452            let candidate_ids = [
453                format!("sym:{entity_ref}"),
454                format!("file:{entity_ref}"),
455                entity_ref.clone(),
456            ];
457
458            for candidate_id in &candidate_ids {
459                if graph.get_node_ref(candidate_id).is_none() {
460                    continue;
461                }
462                // Found a matching node — collect one-hop Memory neighbors
463                for edge in graph.get_edges_ref(candidate_id) {
464                    if !is_edge_active(edge, now) {
465                        continue;
466                    }
467                    let neighbor_id = if edge.src == *candidate_id {
468                        &edge.dst
469                    } else {
470                        &edge.src
471                    };
472                    if let Some(node) = graph.get_node_ref(neighbor_id) {
473                        if node.kind == NodeKind::Memory {
474                            let mem_id = node.memory_id.as_deref().unwrap_or(&node.id);
475                            memory_ids.insert(mem_id.to_string());
476                        }
477                    }
478                }
479                break; // Found the node, no need to try other ID patterns
480            }
481        }
482
483        memory_ids
484    }
485
486    /// Compute detailed stats for a single namespace: count, averages,
487    /// type distribution, tag frequency, and date range.
488    pub fn namespace_stats(&self, namespace: &str) -> Result<NamespaceStats, CodememError> {
489        let ids = self.storage.list_memory_ids_for_namespace(namespace)?;
490
491        if ids.is_empty() {
492            return Ok(NamespaceStats {
493                namespace: namespace.to_string(),
494                count: 0,
495                avg_importance: 0.0,
496                avg_confidence: 0.0,
497                type_distribution: HashMap::new(),
498                tag_frequency: HashMap::new(),
499                oldest: None,
500                newest: None,
501            });
502        }
503
504        let mut total_importance = 0.0;
505        let mut total_confidence = 0.0;
506        let mut type_distribution: HashMap<String, usize> = HashMap::new();
507        let mut tag_frequency: HashMap<String, usize> = HashMap::new();
508        let mut oldest: Option<chrono::DateTime<chrono::Utc>> = None;
509        let mut newest: Option<chrono::DateTime<chrono::Utc>> = None;
510        let mut count = 0usize;
511
512        // M2: Batch-fetch all memories in one query instead of per-ID get_memory_no_touch.
513        // get_memories_batch does not increment access_count (pure SELECT), so it is
514        // equivalent to get_memory_no_touch for stats purposes.
515        let id_refs: Vec<&str> = ids.iter().map(|s| s.as_str()).collect();
516        let memories = self.storage.get_memories_batch(&id_refs)?;
517
518        for memory in &memories {
519            count += 1;
520            total_importance += memory.importance;
521            total_confidence += memory.confidence;
522
523            *type_distribution
524                .entry(memory.memory_type.to_string())
525                .or_insert(0) += 1;
526
527            for tag in &memory.tags {
528                *tag_frequency.entry(tag.clone()).or_insert(0) += 1;
529            }
530
531            match oldest {
532                None => oldest = Some(memory.created_at),
533                Some(ref o) if memory.created_at < *o => oldest = Some(memory.created_at),
534                _ => {}
535            }
536            match newest {
537                None => newest = Some(memory.created_at),
538                Some(ref n) if memory.created_at > *n => newest = Some(memory.created_at),
539                _ => {}
540            }
541        }
542
543        let avg_importance = if count > 0 {
544            total_importance / count as f64
545        } else {
546            0.0
547        };
548        let avg_confidence = if count > 0 {
549            total_confidence / count as f64
550        } else {
551            0.0
552        };
553
554        Ok(NamespaceStats {
555            namespace: namespace.to_string(),
556            count,
557            avg_importance,
558            avg_confidence,
559            type_distribution,
560            tag_frequency,
561            oldest,
562            newest,
563        })
564    }
565
566    /// Delete all memories in a namespace from all subsystems (storage, vector,
567    /// graph, BM25). Returns the number of memories deleted.
568    pub fn delete_namespace(&self, namespace: &str) -> Result<usize, CodememError> {
569        let ids = self.storage.list_memory_ids_for_namespace(namespace)?;
570
571        let mut deleted = 0usize;
572        let mut graph = self.lock_graph()?;
573        let mut vector = self.lock_vector()?;
574        let mut bm25 = self.lock_bm25()?;
575
576        for id in &ids {
577            // Use cascade delete: atomic transaction deleting memory + graph + embedding from SQLite
578            if let Ok(true) = self.storage.delete_memory_cascade(id) {
579                deleted += 1;
580
581                // Remove from in-memory indexes
582                let _ = vector.remove(id);
583                let _ = graph.remove_node(id);
584                bm25.remove_document(id);
585            }
586        }
587
588        // Drop locks before calling save_index (which acquires vector lock)
589        drop(graph);
590        drop(vector);
591        drop(bm25);
592
593        // Persist vector index to disk
594        self.save_index();
595
596        Ok(deleted)
597    }
598}
599
600/// Check if an edge is currently active based on its temporal bounds.
601/// An edge is active if:
602/// - `valid_from` is None or <= `now`
603/// - `valid_to` is None or > `now`
604pub(crate) fn is_edge_active(
605    edge: &codemem_core::Edge,
606    now: chrono::DateTime<chrono::Utc>,
607) -> bool {
608    if let Some(valid_to) = edge.valid_to {
609        if valid_to < now {
610            return false;
611        }
612    }
613    if let Some(valid_from) = edge.valid_from {
614        if valid_from > now {
615            return false;
616        }
617    }
618    true
619}