Skip to main content

codemem_engine/
recall.rs

1//! Memory recall with hybrid scoring.
2
3use crate::scoring::compute_score;
4use crate::CodememEngine;
5use chrono::Utc;
6use codemem_core::{
7    CodememError, GraphBackend, MemoryNode, MemoryType, NodeKind, SearchResult, VectorBackend,
8};
9use std::collections::{HashMap, HashSet};
10
11/// A recall result that includes the expansion path taken to reach the memory.
12#[derive(Debug, Clone)]
13pub struct ExpandedResult {
14    pub result: SearchResult,
15    pub expansion_path: String,
16}
17
18/// Aggregated stats for a single namespace.
19#[derive(Debug, Clone)]
20pub struct NamespaceStats {
21    pub namespace: String,
22    pub count: usize,
23    pub avg_importance: f64,
24    pub avg_confidence: f64,
25    pub type_distribution: HashMap<String, usize>,
26    pub tag_frequency: HashMap<String, usize>,
27    pub oldest: Option<chrono::DateTime<chrono::Utc>>,
28    pub newest: Option<chrono::DateTime<chrono::Utc>>,
29}
30
31/// Parameters for the recall query.
32#[derive(Debug, Clone)]
33pub struct RecallQuery<'a> {
34    pub query: &'a str,
35    pub k: usize,
36    pub memory_type_filter: Option<MemoryType>,
37    pub namespace_filter: Option<&'a str>,
38    pub exclude_tags: &'a [String],
39    pub min_importance: Option<f64>,
40    pub min_confidence: Option<f64>,
41}
42
43impl<'a> RecallQuery<'a> {
44    /// Create a minimal recall query with just the search text and result limit.
45    pub fn new(query: &'a str, k: usize) -> Self {
46        Self {
47            query,
48            k,
49            memory_type_filter: None,
50            namespace_filter: None,
51            exclude_tags: &[],
52            min_importance: None,
53            min_confidence: None,
54        }
55    }
56}
57
58impl CodememEngine {
59    /// Core recall logic: search storage with hybrid scoring and return ranked results.
60    ///
61    /// Combines vector search (if embeddings available), BM25, graph strength,
62    /// temporal, tag matching, importance, confidence, and recency into a
63    /// 9-component hybrid score. Supports filtering by memory type, namespace,
64    /// tag exclusion, and minimum importance/confidence thresholds.
65    pub fn recall(&self, q: &RecallQuery<'_>) -> Result<Vec<SearchResult>, CodememError> {
66        // Try vector search first (if embeddings available)
67        let vector_results: Vec<(String, f32)> = if let Some(emb_guard) = self.lock_embeddings()? {
68            match emb_guard.embed(q.query) {
69                Ok(query_embedding) => {
70                    drop(emb_guard);
71                    let vec = self.lock_vector()?;
72                    vec.search(&query_embedding, q.k * 2) // over-fetch for re-ranking
73                        .unwrap_or_default()
74                }
75                Err(e) => {
76                    tracing::warn!("Query embedding failed: {e}");
77                    vec![]
78                }
79            }
80        } else {
81            vec![]
82        };
83
84        // H1: Use code-aware tokenizer for query tokens so that compound identifiers
85        // like "parseFunction" are split into ["parse", "function"] — matching the
86        // tokenization used when documents were added to the BM25 index.
87        let query_tokens: Vec<String> = crate::bm25::tokenize(q.query);
88        let query_token_refs: Vec<&str> = query_tokens.iter().map(|s| s.as_str()).collect();
89
90        // Graph and BM25 intentionally load different data: graph stores structural relationships
91        // (nodes/edges), while BM25 indexes memory content for text search. This is by design,
92        // not duplication.
93        let mut graph = self.lock_graph()?;
94        // C1: Lazily compute betweenness centrality before scoring so the
95        // betweenness component (30% of graph_strength) is not permanently zero.
96        graph.ensure_betweenness_computed();
97        let bm25 = self.lock_bm25()?;
98        let now = Utc::now();
99
100        // Entity expansion: find memories connected to code entities mentioned in the query.
101        // This ensures that structurally related memories are candidates even when they are
102        // semantically distant from the query text.
103        let entity_memory_ids = self.resolve_entity_memories(q.query, &graph, now);
104
105        let mut results: Vec<SearchResult> = Vec::new();
106        let weights = self.scoring_weights()?;
107
108        if vector_results.is_empty() {
109            // Fallback: batch-load all memories matching filters in one query
110            let type_str = q.memory_type_filter.as_ref().map(|t| t.to_string());
111            let all_memories = self
112                .storage
113                .list_memories_filtered(q.namespace_filter, type_str.as_deref())?;
114
115            for memory in all_memories {
116                if !Self::passes_quality_filters(&memory, q) {
117                    continue;
118                }
119
120                let breakdown = compute_score(&memory, &query_token_refs, 0.0, &graph, &bm25, now);
121                let score = breakdown.total_with_weights(&weights);
122                if score > 0.01 {
123                    results.push(SearchResult {
124                        memory,
125                        score,
126                        score_breakdown: breakdown,
127                    });
128                }
129            }
130        } else {
131            // Vector search path: batch-fetch all candidate memories + entity-connected memories
132            let mut all_candidate_ids: HashSet<&str> =
133                vector_results.iter().map(|(id, _)| id.as_str()).collect();
134
135            // Merge entity-connected memory IDs into the candidate pool
136            for eid in &entity_memory_ids {
137                all_candidate_ids.insert(eid.as_str());
138            }
139
140            let candidate_id_vec: Vec<&str> = all_candidate_ids.into_iter().collect();
141            let candidate_memories = self.storage.get_memories_batch(&candidate_id_vec)?;
142
143            // Build similarity lookup (entity memories will get 0.0 similarity)
144            let sim_map: HashMap<&str, f64> = vector_results
145                .iter()
146                .map(|(id, sim)| (id.as_str(), *sim as f64))
147                .collect();
148
149            for memory in candidate_memories {
150                // Apply memory_type filter
151                if let Some(ref filter_type) = q.memory_type_filter {
152                    if memory.memory_type != *filter_type {
153                        continue;
154                    }
155                }
156                // Apply namespace filter
157                if let Some(ns) = q.namespace_filter {
158                    if memory.namespace.as_deref() != Some(ns) {
159                        continue;
160                    }
161                }
162                if !Self::passes_quality_filters(&memory, q) {
163                    continue;
164                }
165
166                let similarity = sim_map.get(memory.id.as_str()).copied().unwrap_or(0.0);
167                let breakdown =
168                    compute_score(&memory, &query_token_refs, similarity, &graph, &bm25, now);
169                let score = breakdown.total_with_weights(&weights);
170                if score > 0.01 {
171                    results.push(SearchResult {
172                        memory,
173                        score,
174                        score_breakdown: breakdown,
175                    });
176                }
177            }
178        }
179
180        // Sort by score descending, take top k
181        results.sort_by(|a, b| {
182            b.score
183                .partial_cmp(&a.score)
184                .unwrap_or(std::cmp::Ordering::Equal)
185        });
186        results.truncate(q.k);
187
188        Ok(results)
189    }
190
191    /// Check exclude_tags, min_importance, and min_confidence filters.
192    fn passes_quality_filters(memory: &MemoryNode, q: &RecallQuery<'_>) -> bool {
193        if !q.exclude_tags.is_empty() && memory.tags.iter().any(|t| q.exclude_tags.contains(t)) {
194            return false;
195        }
196        if let Some(min) = q.min_importance {
197            if memory.importance < min {
198                return false;
199            }
200        }
201        if let Some(min) = q.min_confidence {
202            if memory.confidence < min {
203                return false;
204            }
205        }
206        true
207    }
208
209    /// Recall with graph expansion: vector search (or BM25 fallback) for seed
210    /// memories, then BFS expansion from each seed through the graph, scoring
211    /// all candidates with the 9-component hybrid scorer.
212    pub fn recall_with_expansion(
213        &self,
214        query: &str,
215        k: usize,
216        expansion_depth: usize,
217        namespace_filter: Option<&str>,
218    ) -> Result<Vec<ExpandedResult>, CodememError> {
219        // H1: Code-aware tokenization for consistent BM25 scoring
220        let query_tokens: Vec<String> = crate::bm25::tokenize(query);
221        let query_token_refs: Vec<&str> = query_tokens.iter().map(|s| s.as_str()).collect();
222
223        // Step 1: Run normal vector search (or text fallback)
224        let vector_results: Vec<(String, f32)> = if let Some(emb_guard) = self.lock_embeddings()? {
225            match emb_guard.embed(query) {
226                Ok(query_embedding) => {
227                    drop(emb_guard);
228                    let vec = self.lock_vector()?;
229                    vec.search(&query_embedding, k * 2).unwrap_or_default()
230                }
231                Err(e) => {
232                    tracing::warn!("Query embedding failed: {e}");
233                    vec![]
234                }
235            }
236        } else {
237            vec![]
238        };
239
240        let mut graph = self.lock_graph()?;
241        // C1: Lazily compute betweenness centrality before scoring
242        graph.ensure_betweenness_computed();
243        let bm25 = self.lock_bm25()?;
244        let now = Utc::now();
245
246        // Collect initial seed memories with their vector similarity
247        struct ScoredMemory {
248            memory: MemoryNode,
249            vector_sim: f64,
250            expansion_path: String,
251        }
252
253        let mut all_memories: Vec<ScoredMemory> = Vec::new();
254        let mut seen_ids: HashSet<String> = HashSet::new();
255
256        if vector_results.is_empty() {
257            // Fallback: batch-load all memories matching namespace in one query
258            let all = self
259                .storage
260                .list_memories_filtered(namespace_filter, None)?;
261            let weights = self.scoring_weights()?;
262
263            for memory in all {
264                let breakdown = compute_score(&memory, &query_token_refs, 0.0, &graph, &bm25, now);
265                let score = breakdown.total_with_weights(&weights);
266                if score > 0.01 {
267                    seen_ids.insert(memory.id.clone());
268                    all_memories.push(ScoredMemory {
269                        memory,
270                        vector_sim: 0.0,
271                        expansion_path: "direct".to_string(),
272                    });
273                }
274            }
275        } else {
276            // Vector search path: batch-fetch all candidate memories
277            let candidate_ids: Vec<&str> =
278                vector_results.iter().map(|(id, _)| id.as_str()).collect();
279            let candidate_memories = self.storage.get_memories_batch(&candidate_ids)?;
280
281            let sim_map: HashMap<&str, f64> = vector_results
282                .iter()
283                .map(|(id, sim)| (id.as_str(), *sim as f64))
284                .collect();
285
286            for memory in candidate_memories {
287                if let Some(ns) = namespace_filter {
288                    if memory.namespace.as_deref() != Some(ns) {
289                        continue;
290                    }
291                }
292                let similarity = sim_map.get(memory.id.as_str()).copied().unwrap_or(0.0);
293                seen_ids.insert(memory.id.clone());
294                all_memories.push(ScoredMemory {
295                    memory,
296                    vector_sim: similarity,
297                    expansion_path: "direct".to_string(),
298                });
299            }
300        }
301
302        // Step 2-4: Graph expansion from each direct result
303        // A7: BFS traverses through ALL node kinds (including code nodes like
304        // File, Function, etc.) as intermediaries, but only COLLECTS Memory nodes.
305        // A6: Apply temporal edge filtering — skip edges whose valid_to < now.
306        let direct_ids: Vec<String> = all_memories.iter().map(|m| m.memory.id.clone()).collect();
307
308        for direct_id in &direct_ids {
309            // Cache edges for this direct node outside the inner loop,
310            // filtering out expired temporal edges (A6)
311            let direct_edges: Vec<_> = graph
312                .get_edges(direct_id)
313                .unwrap_or_default()
314                .into_iter()
315                .filter(|e| is_edge_active(e, now))
316                .collect();
317
318            // A7: Only exclude Chunk from BFS traversal (noisy), but allow
319            // File, Function, Class, etc. as intermediaries to reach more Memory nodes
320            if let Ok(expanded_nodes) =
321                graph.bfs_filtered(direct_id, expansion_depth, &[NodeKind::Chunk], None)
322            {
323                for expanded_node in &expanded_nodes {
324                    // Skip the start node itself (already in results)
325                    if expanded_node.id == *direct_id {
326                        continue;
327                    }
328
329                    // A7: Only COLLECT Memory nodes in results, but we
330                    // traversed through all other node kinds to reach them
331                    if expanded_node.kind != NodeKind::Memory {
332                        continue;
333                    }
334
335                    // Get the memory_id from the graph node
336                    let memory_id = expanded_node
337                        .memory_id
338                        .as_deref()
339                        .unwrap_or(&expanded_node.id);
340
341                    // Skip if already seen
342                    if seen_ids.contains(memory_id) {
343                        continue;
344                    }
345
346                    // Fetch the memory (no-touch to avoid inflating access_count)
347                    if let Ok(Some(memory)) = self.storage.get_memory_no_touch(memory_id) {
348                        if let Some(ns) = namespace_filter {
349                            if memory.namespace.as_deref() != Some(ns) {
350                                continue;
351                            }
352                        }
353
354                        // Build expansion path description using cached edges
355                        let expansion_path = direct_edges
356                            .iter()
357                            .find(|e| e.dst == expanded_node.id || e.src == expanded_node.id)
358                            .map(|e| format!("via {} from {}", e.relationship, direct_id))
359                            .unwrap_or_else(|| format!("via graph from {direct_id}"));
360
361                        seen_ids.insert(memory_id.to_string());
362                        all_memories.push(ScoredMemory {
363                            memory,
364                            vector_sim: 0.0,
365                            expansion_path,
366                        });
367                    }
368                }
369            }
370        }
371
372        // Step 5-6: Score all memories and sort
373        let weights = self.scoring_weights()?;
374        let mut scored_results: Vec<ExpandedResult> = all_memories
375            .into_iter()
376            .map(|sm| {
377                let breakdown = compute_score(
378                    &sm.memory,
379                    &query_token_refs,
380                    sm.vector_sim,
381                    &graph,
382                    &bm25,
383                    now,
384                );
385                let score = breakdown.total_with_weights(&weights);
386                ExpandedResult {
387                    result: SearchResult {
388                        memory: sm.memory,
389                        score,
390                        score_breakdown: breakdown,
391                    },
392                    expansion_path: sm.expansion_path,
393                }
394            })
395            .collect();
396
397        scored_results.sort_by(|a, b| {
398            b.result
399                .score
400                .partial_cmp(&a.result.score)
401                .unwrap_or(std::cmp::Ordering::Equal)
402        });
403        scored_results.truncate(k);
404
405        Ok(scored_results)
406    }
407
408    /// Resolve entity references from a query to memory IDs connected to those entities.
409    ///
410    /// Extracts code references (CamelCase identifiers, qualified paths, file paths,
411    /// backtick-wrapped code) from the query, matches them to graph nodes, and returns
412    /// the IDs of Memory nodes within one hop of each matched entity. This ensures
413    /// structurally related memories are recall candidates even when semantically distant.
414    pub(crate) fn resolve_entity_memories(
415        &self,
416        query: &str,
417        graph: &codemem_storage::graph::GraphEngine,
418        now: chrono::DateTime<chrono::Utc>,
419    ) -> HashSet<String> {
420        let entity_refs = crate::search::extract_code_references(query);
421        let mut memory_ids: HashSet<String> = HashSet::new();
422
423        for entity_ref in &entity_refs {
424            // Try common ID patterns: sym:Name, file:path, or direct ID match
425            let candidate_ids = [
426                format!("sym:{entity_ref}"),
427                format!("file:{entity_ref}"),
428                entity_ref.clone(),
429            ];
430
431            for candidate_id in &candidate_ids {
432                if graph.get_node_ref(candidate_id).is_none() {
433                    continue;
434                }
435                // Found a matching node — collect one-hop Memory neighbors
436                for edge in graph.get_edges_ref(candidate_id) {
437                    if !is_edge_active(edge, now) {
438                        continue;
439                    }
440                    let neighbor_id = if edge.src == *candidate_id {
441                        &edge.dst
442                    } else {
443                        &edge.src
444                    };
445                    if let Some(node) = graph.get_node_ref(neighbor_id) {
446                        if node.kind == NodeKind::Memory {
447                            let mem_id = node.memory_id.as_deref().unwrap_or(&node.id);
448                            memory_ids.insert(mem_id.to_string());
449                        }
450                    }
451                }
452                break; // Found the node, no need to try other ID patterns
453            }
454        }
455
456        memory_ids
457    }
458
459    /// Compute detailed stats for a single namespace: count, averages,
460    /// type distribution, tag frequency, and date range.
461    pub fn namespace_stats(&self, namespace: &str) -> Result<NamespaceStats, CodememError> {
462        let ids = self.storage.list_memory_ids_for_namespace(namespace)?;
463
464        if ids.is_empty() {
465            return Ok(NamespaceStats {
466                namespace: namespace.to_string(),
467                count: 0,
468                avg_importance: 0.0,
469                avg_confidence: 0.0,
470                type_distribution: HashMap::new(),
471                tag_frequency: HashMap::new(),
472                oldest: None,
473                newest: None,
474            });
475        }
476
477        let mut total_importance = 0.0;
478        let mut total_confidence = 0.0;
479        let mut type_distribution: HashMap<String, usize> = HashMap::new();
480        let mut tag_frequency: HashMap<String, usize> = HashMap::new();
481        let mut oldest: Option<chrono::DateTime<chrono::Utc>> = None;
482        let mut newest: Option<chrono::DateTime<chrono::Utc>> = None;
483        let mut count = 0usize;
484
485        // M2: Batch-fetch all memories in one query instead of per-ID get_memory_no_touch.
486        // get_memories_batch does not increment access_count (pure SELECT), so it is
487        // equivalent to get_memory_no_touch for stats purposes.
488        let id_refs: Vec<&str> = ids.iter().map(|s| s.as_str()).collect();
489        let memories = self.storage.get_memories_batch(&id_refs)?;
490
491        for memory in &memories {
492            count += 1;
493            total_importance += memory.importance;
494            total_confidence += memory.confidence;
495
496            *type_distribution
497                .entry(memory.memory_type.to_string())
498                .or_insert(0) += 1;
499
500            for tag in &memory.tags {
501                *tag_frequency.entry(tag.clone()).or_insert(0) += 1;
502            }
503
504            match oldest {
505                None => oldest = Some(memory.created_at),
506                Some(ref o) if memory.created_at < *o => oldest = Some(memory.created_at),
507                _ => {}
508            }
509            match newest {
510                None => newest = Some(memory.created_at),
511                Some(ref n) if memory.created_at > *n => newest = Some(memory.created_at),
512                _ => {}
513            }
514        }
515
516        let avg_importance = if count > 0 {
517            total_importance / count as f64
518        } else {
519            0.0
520        };
521        let avg_confidence = if count > 0 {
522            total_confidence / count as f64
523        } else {
524            0.0
525        };
526
527        Ok(NamespaceStats {
528            namespace: namespace.to_string(),
529            count,
530            avg_importance,
531            avg_confidence,
532            type_distribution,
533            tag_frequency,
534            oldest,
535            newest,
536        })
537    }
538
539    /// Delete all memories in a namespace from all subsystems (storage, vector,
540    /// graph, BM25). Returns the number of memories deleted.
541    pub fn delete_namespace(&self, namespace: &str) -> Result<usize, CodememError> {
542        let ids = self.storage.list_memory_ids_for_namespace(namespace)?;
543
544        let mut deleted = 0usize;
545        let mut graph = self.lock_graph()?;
546        let mut vector = self.lock_vector()?;
547        let mut bm25 = self.lock_bm25()?;
548
549        for id in &ids {
550            // Use cascade delete: atomic transaction deleting memory + graph + embedding from SQLite
551            if let Ok(true) = self.storage.delete_memory_cascade(id) {
552                deleted += 1;
553
554                // Remove from in-memory indexes
555                let _ = vector.remove(id);
556                let _ = graph.remove_node(id);
557                bm25.remove_document(id);
558            }
559        }
560
561        // Drop locks before calling save_index (which acquires vector lock)
562        drop(graph);
563        drop(vector);
564        drop(bm25);
565
566        // Persist vector index to disk
567        self.save_index();
568
569        Ok(deleted)
570    }
571}
572
573/// Check if an edge is currently active based on its temporal bounds.
574/// An edge is active if:
575/// - `valid_from` is None or <= `now`
576/// - `valid_to` is None or > `now`
577pub(crate) fn is_edge_active(
578    edge: &codemem_core::Edge,
579    now: chrono::DateTime<chrono::Utc>,
580) -> bool {
581    if let Some(valid_to) = edge.valid_to {
582        if valid_to < now {
583            return false;
584        }
585    }
586    if let Some(valid_from) = edge.valid_from {
587        if valid_from > now {
588            return false;
589        }
590    }
591    true
592}