Skip to main content

codemem_engine/
recall.rs

1//! Memory recall with hybrid scoring.
2
3use crate::scoring::compute_score;
4use crate::CodememEngine;
5use chrono::Utc;
6use codemem_core::{CodememError, MemoryNode, MemoryType, NodeKind, SearchResult};
7use std::collections::{HashMap, HashSet};
8
9/// A recall result that includes the expansion path taken to reach the memory.
10#[derive(Debug, Clone)]
11pub struct ExpandedResult {
12    pub result: SearchResult,
13    pub expansion_path: String,
14}
15
16/// Aggregated stats for a single namespace.
17#[derive(Debug, Clone)]
18pub struct NamespaceStats {
19    pub namespace: String,
20    pub count: usize,
21    pub avg_importance: f64,
22    pub avg_confidence: f64,
23    pub type_distribution: HashMap<String, usize>,
24    pub tag_frequency: HashMap<String, usize>,
25    pub oldest: Option<chrono::DateTime<chrono::Utc>>,
26    pub newest: Option<chrono::DateTime<chrono::Utc>>,
27}
28
29/// Parameters for the recall query.
30#[derive(Debug, Clone)]
31pub struct RecallQuery<'a> {
32    pub query: &'a str,
33    pub k: usize,
34    pub memory_type_filter: Option<MemoryType>,
35    pub namespace_filter: Option<&'a str>,
36    pub exclude_tags: &'a [String],
37    pub min_importance: Option<f64>,
38    pub min_confidence: Option<f64>,
39    /// Filter results to memories with this git ref (branch/tag).
40    pub git_ref_filter: Option<&'a str>,
41}
42
43impl<'a> RecallQuery<'a> {
44    /// Create a minimal recall query with just the search text and result limit.
45    pub fn new(query: &'a str, k: usize) -> Self {
46        Self {
47            query,
48            k,
49            memory_type_filter: None,
50            namespace_filter: None,
51            exclude_tags: &[],
52            min_importance: None,
53            min_confidence: None,
54            git_ref_filter: None,
55        }
56    }
57}
58
59impl CodememEngine {
60    /// Core recall logic: search storage with hybrid scoring and return ranked results.
61    ///
62    /// Combines vector search (if embeddings available), BM25, graph strength,
63    /// temporal, tag matching, importance, confidence, and recency into a
64    /// 9-component hybrid score. Supports filtering by memory type, namespace,
65    /// tag exclusion, and minimum importance/confidence thresholds.
66    pub fn recall(&self, q: &RecallQuery<'_>) -> Result<Vec<SearchResult>, CodememError> {
67        // Opportunistic cleanup of expired memories (rate-limited to once per 60s)
68        self.sweep_expired_memories();
69
70        // Try vector search first (if embeddings available)
71        let vector_results: Vec<(String, f32)> = if let Some(emb_guard) = self.lock_embeddings()? {
72            match emb_guard.embed(q.query) {
73                Ok(query_embedding) => {
74                    drop(emb_guard);
75                    let vec = self.lock_vector()?;
76                    vec.search(&query_embedding, q.k * 2) // over-fetch for re-ranking
77                        .unwrap_or_default()
78                }
79                Err(e) => {
80                    tracing::warn!("Query embedding failed: {e}");
81                    vec![]
82                }
83            }
84        } else {
85            vec![]
86        };
87
88        // H1: Use code-aware tokenizer for query tokens so that compound identifiers
89        // like "parseFunction" are split into ["parse", "function"] — matching the
90        // tokenization used when documents were added to the BM25 index.
91        let query_tokens: Vec<String> = crate::bm25::tokenize(q.query);
92        let query_token_refs: Vec<&str> = query_tokens.iter().map(|s| s.as_str()).collect();
93
94        // Graph and BM25 intentionally load different data: graph stores structural relationships
95        // (nodes/edges), while BM25 indexes memory content for text search. This is by design,
96        // not duplication.
97        let mut graph = self.lock_graph()?;
98        // C1: Lazily compute betweenness centrality before scoring so the
99        // betweenness component (30% of graph_strength) is not permanently zero.
100        graph.ensure_betweenness_computed();
101        let bm25 = self.lock_bm25()?;
102        let now = Utc::now();
103
104        // Entity expansion: find memories connected to code entities mentioned in the query.
105        // This ensures that structurally related memories are candidates even when they are
106        // semantically distant from the query text.
107        let entity_memory_ids = self.resolve_entity_memories(q.query, &**graph, now);
108
109        let mut results: Vec<SearchResult> = Vec::new();
110        let weights = self.scoring_weights()?;
111
112        if !vector_results.is_empty() {
113            // Vector search path: batch-fetch all candidate memories + entity-connected memories
114            let mut all_candidate_ids: HashSet<&str> =
115                vector_results.iter().map(|(id, _)| id.as_str()).collect();
116
117            // Merge entity-connected memory IDs into the candidate pool
118            for eid in &entity_memory_ids {
119                all_candidate_ids.insert(eid.as_str());
120            }
121
122            let candidate_id_vec: Vec<&str> = all_candidate_ids.into_iter().collect();
123            let candidate_memories = self.storage.get_memories_batch(&candidate_id_vec)?;
124
125            // Build similarity lookup (entity memories will get 0.0 similarity)
126            let sim_map: HashMap<&str, f64> = vector_results
127                .iter()
128                .map(|(id, sim)| (id.as_str(), *sim as f64))
129                .collect();
130
131            for memory in candidate_memories {
132                // Apply memory_type filter
133                if let Some(ref filter_type) = q.memory_type_filter {
134                    if memory.memory_type != *filter_type {
135                        continue;
136                    }
137                }
138                // Apply namespace filter
139                if let Some(ns) = q.namespace_filter {
140                    if memory.namespace.as_deref() != Some(ns) {
141                        continue;
142                    }
143                }
144                if !Self::passes_quality_filters(&memory, q) {
145                    continue;
146                }
147
148                let similarity = sim_map.get(memory.id.as_str()).copied().unwrap_or(0.0);
149                let breakdown =
150                    compute_score(&memory, &query_token_refs, similarity, &**graph, &bm25, now);
151                let score = breakdown.total_with_weights(&weights);
152                if score > 0.01 {
153                    results.push(SearchResult {
154                        memory,
155                        score,
156                        score_breakdown: breakdown,
157                    });
158                }
159            }
160        }
161
162        // If vector search produced no results (either unavailable or all
163        // candidates filtered out by namespace/type), fall back to BM25
164        // full-scan so we still return matches.
165        if results.is_empty() {
166            let type_str = q.memory_type_filter.as_ref().map(|t| t.to_string());
167            let all_memories = self
168                .storage
169                .list_memories_filtered(q.namespace_filter, type_str.as_deref())?;
170
171            for memory in all_memories {
172                if !Self::passes_quality_filters(&memory, q) {
173                    continue;
174                }
175
176                let breakdown =
177                    compute_score(&memory, &query_token_refs, 0.0, &**graph, &bm25, now);
178                let score = breakdown.total_with_weights(&weights);
179                if score > 0.01 {
180                    results.push(SearchResult {
181                        memory,
182                        score,
183                        score_breakdown: breakdown,
184                    });
185                }
186            }
187        }
188
189        // Sort by score descending, take top k
190        results.sort_by(|a, b| {
191            b.score
192                .partial_cmp(&a.score)
193                .unwrap_or(std::cmp::Ordering::Equal)
194        });
195        results.truncate(q.k);
196
197        Ok(results)
198    }
199
200    /// Check expiry, exclude_tags, min_importance, min_confidence, and git_ref filters.
201    fn passes_quality_filters(memory: &MemoryNode, q: &RecallQuery<'_>) -> bool {
202        // Skip expired memories (their embeddings may linger in HNSW until next sweep)
203        if memory.expires_at.is_some_and(|dt| dt <= Utc::now()) {
204            return false;
205        }
206        if !q.exclude_tags.is_empty() && memory.tags.iter().any(|t| q.exclude_tags.contains(t)) {
207            return false;
208        }
209        if let Some(min) = q.min_importance {
210            if memory.importance < min {
211                return false;
212            }
213        }
214        if let Some(min) = q.min_confidence {
215            if memory.confidence < min {
216                return false;
217            }
218        }
219        if let Some(ref_filter) = q.git_ref_filter {
220            if memory.git_ref.as_deref() != Some(ref_filter) {
221                return false;
222            }
223        }
224        true
225    }
226
227    /// Recall with graph expansion: vector search (or BM25 fallback) for seed
228    /// memories, then BFS expansion from each seed through the graph, scoring
229    /// all candidates with the 9-component hybrid scorer.
230    pub fn recall_with_expansion(
231        &self,
232        query: &str,
233        k: usize,
234        expansion_depth: usize,
235        namespace_filter: Option<&str>,
236    ) -> Result<Vec<ExpandedResult>, CodememError> {
237        // Opportunistic cleanup of expired memories (rate-limited to once per 60s)
238        self.sweep_expired_memories();
239
240        // H1: Code-aware tokenization for consistent BM25 scoring
241        let query_tokens: Vec<String> = crate::bm25::tokenize(query);
242        let query_token_refs: Vec<&str> = query_tokens.iter().map(|s| s.as_str()).collect();
243
244        // Step 1: Run normal vector search (or text fallback)
245        let vector_results: Vec<(String, f32)> = if let Some(emb_guard) = self.lock_embeddings()? {
246            match emb_guard.embed(query) {
247                Ok(query_embedding) => {
248                    drop(emb_guard);
249                    let vec = self.lock_vector()?;
250                    vec.search(&query_embedding, k * 2).unwrap_or_default()
251                }
252                Err(e) => {
253                    tracing::warn!("Query embedding failed: {e}");
254                    vec![]
255                }
256            }
257        } else {
258            vec![]
259        };
260
261        let mut graph = self.lock_graph()?;
262        // C1: Lazily compute betweenness centrality before scoring
263        graph.ensure_betweenness_computed();
264        let bm25 = self.lock_bm25()?;
265        let now = Utc::now();
266
267        // Collect initial seed memories with their vector similarity
268        struct ScoredMemory {
269            memory: MemoryNode,
270            vector_sim: f64,
271            expansion_path: String,
272        }
273
274        let mut all_memories: Vec<ScoredMemory> = Vec::new();
275        let mut seen_ids: HashSet<String> = HashSet::new();
276
277        if vector_results.is_empty() {
278            // Fallback: batch-load all memories matching namespace in one query
279            let all = self
280                .storage
281                .list_memories_filtered(namespace_filter, None)?;
282            let weights = self.scoring_weights()?;
283
284            for memory in all {
285                if memory.expires_at.is_some_and(|dt| dt <= now) {
286                    continue;
287                }
288                let breakdown =
289                    compute_score(&memory, &query_token_refs, 0.0, &**graph, &bm25, now);
290                let score = breakdown.total_with_weights(&weights);
291                if score > 0.01 {
292                    seen_ids.insert(memory.id.clone());
293                    all_memories.push(ScoredMemory {
294                        memory,
295                        vector_sim: 0.0,
296                        expansion_path: "direct".to_string(),
297                    });
298                }
299            }
300        } else {
301            // Vector search path: batch-fetch all candidate memories
302            let candidate_ids: Vec<&str> =
303                vector_results.iter().map(|(id, _)| id.as_str()).collect();
304            let candidate_memories = self.storage.get_memories_batch(&candidate_ids)?;
305
306            let sim_map: HashMap<&str, f64> = vector_results
307                .iter()
308                .map(|(id, sim)| (id.as_str(), *sim as f64))
309                .collect();
310
311            for memory in candidate_memories {
312                if memory.expires_at.is_some_and(|dt| dt <= now) {
313                    continue;
314                }
315                if let Some(ns) = namespace_filter {
316                    if memory.namespace.as_deref() != Some(ns) {
317                        continue;
318                    }
319                }
320                let similarity = sim_map.get(memory.id.as_str()).copied().unwrap_or(0.0);
321                seen_ids.insert(memory.id.clone());
322                all_memories.push(ScoredMemory {
323                    memory,
324                    vector_sim: similarity,
325                    expansion_path: "direct".to_string(),
326                });
327            }
328        }
329
330        // Step 2-4: Graph expansion from each direct result
331        // A7: BFS traverses through ALL node kinds (including code nodes like
332        // File, Function, etc.) as intermediaries, but only COLLECTS Memory nodes.
333        // A6: Apply temporal edge filtering — skip edges whose valid_to < now.
334        let direct_ids: Vec<String> = all_memories.iter().map(|m| m.memory.id.clone()).collect();
335
336        for direct_id in &direct_ids {
337            // Cache edges for this direct node outside the inner loop,
338            // filtering out expired temporal edges (A6)
339            let direct_edges: Vec<_> = graph
340                .get_edges(direct_id)
341                .unwrap_or_default()
342                .into_iter()
343                .filter(|e| is_edge_active(e, now))
344                .collect();
345
346            // A7: Only exclude Chunk from BFS traversal (noisy), but allow
347            // File, Function, Class, etc. as intermediaries to reach more Memory nodes
348            if let Ok(expanded_nodes) =
349                graph.bfs_filtered(direct_id, expansion_depth, &[NodeKind::Chunk], None)
350            {
351                for expanded_node in &expanded_nodes {
352                    // Skip the start node itself (already in results)
353                    if expanded_node.id == *direct_id {
354                        continue;
355                    }
356
357                    // A7: Only COLLECT Memory nodes in results, but we
358                    // traversed through all other node kinds to reach them
359                    if expanded_node.kind != NodeKind::Memory {
360                        continue;
361                    }
362
363                    // Get the memory_id from the graph node
364                    let memory_id = expanded_node
365                        .memory_id
366                        .as_deref()
367                        .unwrap_or(&expanded_node.id);
368
369                    // Skip if already seen
370                    if seen_ids.contains(memory_id) {
371                        continue;
372                    }
373
374                    // Fetch the memory (no-touch to avoid inflating access_count)
375                    if let Ok(Some(memory)) = self.storage.get_memory_no_touch(memory_id) {
376                        if memory.expires_at.is_some_and(|dt| dt <= now) {
377                            continue;
378                        }
379                        if let Some(ns) = namespace_filter {
380                            if memory.namespace.as_deref() != Some(ns) {
381                                continue;
382                            }
383                        }
384
385                        // Build expansion path description using cached edges
386                        let expansion_path = direct_edges
387                            .iter()
388                            .find(|e| e.dst == expanded_node.id || e.src == expanded_node.id)
389                            .map(|e| format!("via {} from {}", e.relationship, direct_id))
390                            .unwrap_or_else(|| format!("via graph from {direct_id}"));
391
392                        seen_ids.insert(memory_id.to_string());
393                        all_memories.push(ScoredMemory {
394                            memory,
395                            vector_sim: 0.0,
396                            expansion_path,
397                        });
398                    }
399                }
400            }
401        }
402
403        // Step 5-6: Score all memories and sort
404        let weights = self.scoring_weights()?;
405        let mut scored_results: Vec<ExpandedResult> = all_memories
406            .into_iter()
407            .map(|sm| {
408                let breakdown = compute_score(
409                    &sm.memory,
410                    &query_token_refs,
411                    sm.vector_sim,
412                    &**graph,
413                    &bm25,
414                    now,
415                );
416                let score = breakdown.total_with_weights(&weights);
417                ExpandedResult {
418                    result: SearchResult {
419                        memory: sm.memory,
420                        score,
421                        score_breakdown: breakdown,
422                    },
423                    expansion_path: sm.expansion_path,
424                }
425            })
426            .collect();
427
428        scored_results.sort_by(|a, b| {
429            b.result
430                .score
431                .partial_cmp(&a.result.score)
432                .unwrap_or(std::cmp::Ordering::Equal)
433        });
434        scored_results.truncate(k);
435
436        Ok(scored_results)
437    }
438
439    /// Resolve entity references from a query to memory IDs connected to those entities.
440    ///
441    /// Extracts code references (CamelCase identifiers, qualified paths, file paths,
442    /// backtick-wrapped code) from the query, matches them to graph nodes, and returns
443    /// the IDs of Memory nodes within one hop of each matched entity. This ensures
444    /// structurally related memories are recall candidates even when semantically distant.
445    pub(crate) fn resolve_entity_memories(
446        &self,
447        query: &str,
448        graph: &dyn codemem_core::GraphBackend,
449        now: chrono::DateTime<chrono::Utc>,
450    ) -> HashSet<String> {
451        let entity_refs = crate::search::extract_code_references(query);
452        let mut memory_ids: HashSet<String> = HashSet::new();
453
454        for entity_ref in &entity_refs {
455            // Try common ID patterns: sym:Name, file:path, or direct ID match
456            let candidate_ids = [
457                format!("sym:{entity_ref}"),
458                format!("file:{entity_ref}"),
459                entity_ref.clone(),
460            ];
461
462            for candidate_id in &candidate_ids {
463                if graph.get_node_ref(candidate_id).is_none() {
464                    continue;
465                }
466                // Found a matching node — collect one-hop Memory neighbors
467                for edge in graph.get_edges_ref(candidate_id) {
468                    if !is_edge_active(edge, now) {
469                        continue;
470                    }
471                    let neighbor_id = if edge.src == *candidate_id {
472                        &edge.dst
473                    } else {
474                        &edge.src
475                    };
476                    if let Some(node) = graph.get_node_ref(neighbor_id) {
477                        if node.kind == NodeKind::Memory {
478                            let mem_id = node.memory_id.as_deref().unwrap_or(&node.id);
479                            memory_ids.insert(mem_id.to_string());
480                        }
481                    }
482                }
483                break; // Found the node, no need to try other ID patterns
484            }
485        }
486
487        memory_ids
488    }
489
490    /// Compute detailed stats for a single namespace: count, averages,
491    /// type distribution, tag frequency, and date range.
492    pub fn namespace_stats(&self, namespace: &str) -> Result<NamespaceStats, CodememError> {
493        let ids = self.storage.list_memory_ids_for_namespace(namespace)?;
494
495        if ids.is_empty() {
496            return Ok(NamespaceStats {
497                namespace: namespace.to_string(),
498                count: 0,
499                avg_importance: 0.0,
500                avg_confidence: 0.0,
501                type_distribution: HashMap::new(),
502                tag_frequency: HashMap::new(),
503                oldest: None,
504                newest: None,
505            });
506        }
507
508        let mut total_importance = 0.0;
509        let mut total_confidence = 0.0;
510        let mut type_distribution: HashMap<String, usize> = HashMap::new();
511        let mut tag_frequency: HashMap<String, usize> = HashMap::new();
512        let mut oldest: Option<chrono::DateTime<chrono::Utc>> = None;
513        let mut newest: Option<chrono::DateTime<chrono::Utc>> = None;
514        let mut count = 0usize;
515
516        // M2: Batch-fetch all memories in one query instead of per-ID get_memory_no_touch.
517        // get_memories_batch does not increment access_count (pure SELECT), so it is
518        // equivalent to get_memory_no_touch for stats purposes.
519        let id_refs: Vec<&str> = ids.iter().map(|s| s.as_str()).collect();
520        let memories = self.storage.get_memories_batch(&id_refs)?;
521
522        for memory in &memories {
523            count += 1;
524            total_importance += memory.importance;
525            total_confidence += memory.confidence;
526
527            *type_distribution
528                .entry(memory.memory_type.to_string())
529                .or_insert(0) += 1;
530
531            for tag in &memory.tags {
532                *tag_frequency.entry(tag.clone()).or_insert(0) += 1;
533            }
534
535            match oldest {
536                None => oldest = Some(memory.created_at),
537                Some(ref o) if memory.created_at < *o => oldest = Some(memory.created_at),
538                _ => {}
539            }
540            match newest {
541                None => newest = Some(memory.created_at),
542                Some(ref n) if memory.created_at > *n => newest = Some(memory.created_at),
543                _ => {}
544            }
545        }
546
547        let avg_importance = if count > 0 {
548            total_importance / count as f64
549        } else {
550            0.0
551        };
552        let avg_confidence = if count > 0 {
553            total_confidence / count as f64
554        } else {
555            0.0
556        };
557
558        Ok(NamespaceStats {
559            namespace: namespace.to_string(),
560            count,
561            avg_importance,
562            avg_confidence,
563            type_distribution,
564            tag_frequency,
565            oldest,
566            newest,
567        })
568    }
569
570    /// Delete all memories in a namespace from all subsystems (storage, vector,
571    /// graph, BM25). Returns the number of memories deleted.
572    pub fn delete_namespace(&self, namespace: &str) -> Result<usize, CodememError> {
573        let ids = self.storage.list_memory_ids_for_namespace(namespace)?;
574
575        let mut deleted = 0usize;
576        let mut graph = self.lock_graph()?;
577        let mut vector = self.lock_vector()?;
578        let mut bm25 = self.lock_bm25()?;
579
580        for id in &ids {
581            // Use cascade delete: atomic transaction deleting memory + graph + embedding from SQLite
582            if let Ok(true) = self.storage.delete_memory_cascade(id) {
583                deleted += 1;
584
585                // Remove from in-memory indexes
586                let _ = vector.remove(id);
587                let _ = graph.remove_node(id);
588                bm25.remove_document(id);
589            }
590        }
591
592        // Drop locks before calling save_index (which acquires vector lock)
593        drop(graph);
594        drop(vector);
595        drop(bm25);
596
597        // Persist vector index to disk
598        self.save_index();
599
600        Ok(deleted)
601    }
602}
603
604/// Check if an edge is currently active based on its temporal bounds.
605/// An edge is active if:
606/// - `valid_from` is None or <= `now`
607/// - `valid_to` is None or > `now`
608pub(crate) fn is_edge_active(
609    edge: &codemem_core::Edge,
610    now: chrono::DateTime<chrono::Utc>,
611) -> bool {
612    if let Some(valid_to) = edge.valid_to {
613        if valid_to < now {
614            return false;
615        }
616    }
617    if let Some(valid_from) = edge.valid_from {
618        if valid_from > now {
619            return false;
620        }
621    }
622    true
623}