Skip to main content

codemem_engine/
recall.rs

1//! Memory recall with hybrid scoring.
2
3use crate::scoring::compute_score;
4use crate::CodememEngine;
5use chrono::Utc;
6use codemem_core::{
7    CodememError, GraphBackend, MemoryNode, MemoryType, NodeKind, SearchResult, VectorBackend,
8};
9use std::collections::{HashMap, HashSet};
10
11/// A recall result that includes the expansion path taken to reach the memory.
12#[derive(Debug, Clone)]
13pub struct ExpandedResult {
14    pub result: SearchResult,
15    pub expansion_path: String,
16}
17
18/// Aggregated stats for a single namespace.
19#[derive(Debug, Clone)]
20pub struct NamespaceStats {
21    pub namespace: String,
22    pub count: usize,
23    pub avg_importance: f64,
24    pub avg_confidence: f64,
25    pub type_distribution: HashMap<String, usize>,
26    pub tag_frequency: HashMap<String, usize>,
27    pub oldest: Option<chrono::DateTime<chrono::Utc>>,
28    pub newest: Option<chrono::DateTime<chrono::Utc>>,
29}
30
31impl CodememEngine {
32    /// Core recall logic: search storage with hybrid scoring and return ranked results.
33    ///
34    /// Combines vector search (if embeddings available), BM25, graph strength,
35    /// temporal, tag matching, importance, confidence, and recency into a
36    /// 9-component hybrid score. Supports filtering by memory type, namespace,
37    /// tag exclusion, and minimum importance/confidence thresholds.
38    #[allow(clippy::too_many_arguments)]
39    pub fn recall(
40        &self,
41        query: &str,
42        k: usize,
43        memory_type_filter: Option<MemoryType>,
44        namespace_filter: Option<&str>,
45        exclude_tags: &[String],
46        min_importance: Option<f64>,
47        min_confidence: Option<f64>,
48    ) -> Result<Vec<SearchResult>, CodememError> {
49        // Try vector search first (if embeddings available)
50        let vector_results: Vec<(String, f32)> = if let Some(emb_guard) = self.lock_embeddings()? {
51            match emb_guard.embed(query) {
52                Ok(query_embedding) => {
53                    drop(emb_guard);
54                    let vec = self.lock_vector()?;
55                    vec.search(&query_embedding, k * 2) // over-fetch for re-ranking
56                        .unwrap_or_default()
57                }
58                Err(e) => {
59                    tracing::warn!("Query embedding failed: {e}");
60                    vec![]
61                }
62            }
63        } else {
64            vec![]
65        };
66
67        // H1: Use code-aware tokenizer for query tokens so that compound identifiers
68        // like "parseFunction" are split into ["parse", "function"] — matching the
69        // tokenization used when documents were added to the BM25 index.
70        let query_tokens: Vec<String> = crate::bm25::tokenize(query);
71        let query_token_refs: Vec<&str> = query_tokens.iter().map(|s| s.as_str()).collect();
72
73        // Graph and BM25 intentionally load different data: graph stores structural relationships
74        // (nodes/edges), while BM25 indexes memory content for text search. This is by design,
75        // not duplication.
76        let mut graph = self.lock_graph()?;
77        // C1: Lazily compute betweenness centrality before scoring so the
78        // betweenness component (30% of graph_strength) is not permanently zero.
79        graph.ensure_betweenness_computed();
80        let bm25 = self.lock_bm25()?;
81        let now = Utc::now();
82
83        // Entity expansion: find memories connected to code entities mentioned in the query.
84        // This ensures that structurally related memories are candidates even when they are
85        // semantically distant from the query text.
86        let entity_memory_ids = self.resolve_entity_memories(query, &graph, now);
87
88        let mut results: Vec<SearchResult> = Vec::new();
89        let weights = self.scoring_weights()?;
90
91        if vector_results.is_empty() {
92            // Fallback: batch-load all memories matching filters in one query
93            let type_str = memory_type_filter.as_ref().map(|t| t.to_string());
94            let all_memories = self
95                .storage
96                .list_memories_filtered(namespace_filter, type_str.as_deref())?;
97
98            for memory in all_memories {
99                // Apply quality filters
100                if !exclude_tags.is_empty() && memory.tags.iter().any(|t| exclude_tags.contains(t))
101                {
102                    continue;
103                }
104                if let Some(min) = min_importance {
105                    if memory.importance < min {
106                        continue;
107                    }
108                }
109                if let Some(min) = min_confidence {
110                    if memory.confidence < min {
111                        continue;
112                    }
113                }
114
115                let breakdown = compute_score(&memory, &query_token_refs, 0.0, &graph, &bm25, now);
116                let score = breakdown.total_with_weights(&weights);
117                if score > 0.01 {
118                    results.push(SearchResult {
119                        memory,
120                        score,
121                        score_breakdown: breakdown,
122                    });
123                }
124            }
125        } else {
126            // Vector search path: batch-fetch all candidate memories + entity-connected memories
127            let mut all_candidate_ids: HashSet<&str> =
128                vector_results.iter().map(|(id, _)| id.as_str()).collect();
129
130            // Merge entity-connected memory IDs into the candidate pool
131            for eid in &entity_memory_ids {
132                all_candidate_ids.insert(eid.as_str());
133            }
134
135            let candidate_id_vec: Vec<&str> = all_candidate_ids.into_iter().collect();
136            let candidate_memories = self.storage.get_memories_batch(&candidate_id_vec)?;
137
138            // Build similarity lookup (entity memories will get 0.0 similarity)
139            let sim_map: HashMap<&str, f64> = vector_results
140                .iter()
141                .map(|(id, sim)| (id.as_str(), *sim as f64))
142                .collect();
143
144            for memory in candidate_memories {
145                // Apply memory_type filter
146                if let Some(ref filter_type) = memory_type_filter {
147                    if memory.memory_type != *filter_type {
148                        continue;
149                    }
150                }
151                // Apply namespace filter
152                if let Some(ns) = namespace_filter {
153                    if memory.namespace.as_deref() != Some(ns) {
154                        continue;
155                    }
156                }
157                // Apply quality filters
158                if !exclude_tags.is_empty() && memory.tags.iter().any(|t| exclude_tags.contains(t))
159                {
160                    continue;
161                }
162                if let Some(min) = min_importance {
163                    if memory.importance < min {
164                        continue;
165                    }
166                }
167                if let Some(min) = min_confidence {
168                    if memory.confidence < min {
169                        continue;
170                    }
171                }
172
173                let similarity = sim_map.get(memory.id.as_str()).copied().unwrap_or(0.0);
174                let breakdown =
175                    compute_score(&memory, &query_token_refs, similarity, &graph, &bm25, now);
176                let score = breakdown.total_with_weights(&weights);
177                if score > 0.01 {
178                    results.push(SearchResult {
179                        memory,
180                        score,
181                        score_breakdown: breakdown,
182                    });
183                }
184            }
185        }
186
187        // Sort by score descending, take top k
188        results.sort_by(|a, b| {
189            b.score
190                .partial_cmp(&a.score)
191                .unwrap_or(std::cmp::Ordering::Equal)
192        });
193        results.truncate(k);
194
195        Ok(results)
196    }
197
198    /// Recall with graph expansion: vector search (or BM25 fallback) for seed
199    /// memories, then BFS expansion from each seed through the graph, scoring
200    /// all candidates with the 9-component hybrid scorer.
201    pub fn recall_with_expansion(
202        &self,
203        query: &str,
204        k: usize,
205        expansion_depth: usize,
206        namespace_filter: Option<&str>,
207    ) -> Result<Vec<ExpandedResult>, CodememError> {
208        // H1: Code-aware tokenization for consistent BM25 scoring
209        let query_tokens: Vec<String> = crate::bm25::tokenize(query);
210        let query_token_refs: Vec<&str> = query_tokens.iter().map(|s| s.as_str()).collect();
211
212        // Step 1: Run normal vector search (or text fallback)
213        let vector_results: Vec<(String, f32)> = if let Some(emb_guard) = self.lock_embeddings()? {
214            match emb_guard.embed(query) {
215                Ok(query_embedding) => {
216                    drop(emb_guard);
217                    let vec = self.lock_vector()?;
218                    vec.search(&query_embedding, k * 2).unwrap_or_default()
219                }
220                Err(e) => {
221                    tracing::warn!("Query embedding failed: {e}");
222                    vec![]
223                }
224            }
225        } else {
226            vec![]
227        };
228
229        let mut graph = self.lock_graph()?;
230        // C1: Lazily compute betweenness centrality before scoring
231        graph.ensure_betweenness_computed();
232        let bm25 = self.lock_bm25()?;
233        let now = Utc::now();
234
235        // Collect initial seed memories with their vector similarity
236        struct ScoredMemory {
237            memory: MemoryNode,
238            vector_sim: f64,
239            expansion_path: String,
240        }
241
242        let mut all_memories: Vec<ScoredMemory> = Vec::new();
243        let mut seen_ids: HashSet<String> = HashSet::new();
244
245        if vector_results.is_empty() {
246            // Fallback: batch-load all memories matching namespace in one query
247            let all = self
248                .storage
249                .list_memories_filtered(namespace_filter, None)?;
250            let weights = self.scoring_weights()?;
251
252            for memory in all {
253                let breakdown = compute_score(&memory, &query_token_refs, 0.0, &graph, &bm25, now);
254                let score = breakdown.total_with_weights(&weights);
255                if score > 0.01 {
256                    seen_ids.insert(memory.id.clone());
257                    all_memories.push(ScoredMemory {
258                        memory,
259                        vector_sim: 0.0,
260                        expansion_path: "direct".to_string(),
261                    });
262                }
263            }
264        } else {
265            // Vector search path: batch-fetch all candidate memories
266            let candidate_ids: Vec<&str> =
267                vector_results.iter().map(|(id, _)| id.as_str()).collect();
268            let candidate_memories = self.storage.get_memories_batch(&candidate_ids)?;
269
270            let sim_map: HashMap<&str, f64> = vector_results
271                .iter()
272                .map(|(id, sim)| (id.as_str(), *sim as f64))
273                .collect();
274
275            for memory in candidate_memories {
276                if let Some(ns) = namespace_filter {
277                    if memory.namespace.as_deref() != Some(ns) {
278                        continue;
279                    }
280                }
281                let similarity = sim_map.get(memory.id.as_str()).copied().unwrap_or(0.0);
282                seen_ids.insert(memory.id.clone());
283                all_memories.push(ScoredMemory {
284                    memory,
285                    vector_sim: similarity,
286                    expansion_path: "direct".to_string(),
287                });
288            }
289        }
290
291        // Step 2-4: Graph expansion from each direct result
292        // A7: BFS traverses through ALL node kinds (including code nodes like
293        // File, Function, etc.) as intermediaries, but only COLLECTS Memory nodes.
294        // A6: Apply temporal edge filtering — skip edges whose valid_to < now.
295        let direct_ids: Vec<String> = all_memories.iter().map(|m| m.memory.id.clone()).collect();
296
297        for direct_id in &direct_ids {
298            // Cache edges for this direct node outside the inner loop,
299            // filtering out expired temporal edges (A6)
300            let direct_edges: Vec<_> = graph
301                .get_edges(direct_id)
302                .unwrap_or_default()
303                .into_iter()
304                .filter(|e| is_edge_active(e, now))
305                .collect();
306
307            // A7: Only exclude Chunk from BFS traversal (noisy), but allow
308            // File, Function, Class, etc. as intermediaries to reach more Memory nodes
309            if let Ok(expanded_nodes) =
310                graph.bfs_filtered(direct_id, expansion_depth, &[NodeKind::Chunk], None)
311            {
312                for expanded_node in &expanded_nodes {
313                    // Skip the start node itself (already in results)
314                    if expanded_node.id == *direct_id {
315                        continue;
316                    }
317
318                    // A7: Only COLLECT Memory nodes in results, but we
319                    // traversed through all other node kinds to reach them
320                    if expanded_node.kind != NodeKind::Memory {
321                        continue;
322                    }
323
324                    // Get the memory_id from the graph node
325                    let memory_id = expanded_node
326                        .memory_id
327                        .as_deref()
328                        .unwrap_or(&expanded_node.id);
329
330                    // Skip if already seen
331                    if seen_ids.contains(memory_id) {
332                        continue;
333                    }
334
335                    // Fetch the memory (no-touch to avoid inflating access_count)
336                    if let Ok(Some(memory)) = self.storage.get_memory_no_touch(memory_id) {
337                        if let Some(ns) = namespace_filter {
338                            if memory.namespace.as_deref() != Some(ns) {
339                                continue;
340                            }
341                        }
342
343                        // Build expansion path description using cached edges
344                        let expansion_path = direct_edges
345                            .iter()
346                            .find(|e| e.dst == expanded_node.id || e.src == expanded_node.id)
347                            .map(|e| format!("via {} from {}", e.relationship, direct_id))
348                            .unwrap_or_else(|| format!("via graph from {direct_id}"));
349
350                        seen_ids.insert(memory_id.to_string());
351                        all_memories.push(ScoredMemory {
352                            memory,
353                            vector_sim: 0.0,
354                            expansion_path,
355                        });
356                    }
357                }
358            }
359        }
360
361        // Step 5-6: Score all memories and sort
362        let weights = self.scoring_weights()?;
363        let mut scored_results: Vec<ExpandedResult> = all_memories
364            .into_iter()
365            .map(|sm| {
366                let breakdown = compute_score(
367                    &sm.memory,
368                    &query_token_refs,
369                    sm.vector_sim,
370                    &graph,
371                    &bm25,
372                    now,
373                );
374                let score = breakdown.total_with_weights(&weights);
375                ExpandedResult {
376                    result: SearchResult {
377                        memory: sm.memory,
378                        score,
379                        score_breakdown: breakdown,
380                    },
381                    expansion_path: sm.expansion_path,
382                }
383            })
384            .collect();
385
386        scored_results.sort_by(|a, b| {
387            b.result
388                .score
389                .partial_cmp(&a.result.score)
390                .unwrap_or(std::cmp::Ordering::Equal)
391        });
392        scored_results.truncate(k);
393
394        Ok(scored_results)
395    }
396
397    /// Resolve entity references from a query to memory IDs connected to those entities.
398    ///
399    /// Extracts code references (CamelCase identifiers, qualified paths, file paths,
400    /// backtick-wrapped code) from the query, matches them to graph nodes, and returns
401    /// the IDs of Memory nodes within one hop of each matched entity. This ensures
402    /// structurally related memories are recall candidates even when semantically distant.
403    pub(crate) fn resolve_entity_memories(
404        &self,
405        query: &str,
406        graph: &codemem_storage::graph::GraphEngine,
407        now: chrono::DateTime<chrono::Utc>,
408    ) -> HashSet<String> {
409        let entity_refs = crate::search::extract_code_references(query);
410        let mut memory_ids: HashSet<String> = HashSet::new();
411
412        for entity_ref in &entity_refs {
413            // Try common ID patterns: sym:Name, file:path, or direct ID match
414            let candidate_ids = [
415                format!("sym:{entity_ref}"),
416                format!("file:{entity_ref}"),
417                entity_ref.clone(),
418            ];
419
420            for candidate_id in &candidate_ids {
421                if graph.get_node_ref(candidate_id).is_none() {
422                    continue;
423                }
424                // Found a matching node — collect one-hop Memory neighbors
425                for edge in graph.get_edges_ref(candidate_id) {
426                    if !is_edge_active(edge, now) {
427                        continue;
428                    }
429                    let neighbor_id = if edge.src == *candidate_id {
430                        &edge.dst
431                    } else {
432                        &edge.src
433                    };
434                    if let Some(node) = graph.get_node_ref(neighbor_id) {
435                        if node.kind == NodeKind::Memory {
436                            let mem_id = node.memory_id.as_deref().unwrap_or(&node.id);
437                            memory_ids.insert(mem_id.to_string());
438                        }
439                    }
440                }
441                break; // Found the node, no need to try other ID patterns
442            }
443        }
444
445        memory_ids
446    }
447
448    /// Compute detailed stats for a single namespace: count, averages,
449    /// type distribution, tag frequency, and date range.
450    pub fn namespace_stats(&self, namespace: &str) -> Result<NamespaceStats, CodememError> {
451        let ids = self.storage.list_memory_ids_for_namespace(namespace)?;
452
453        if ids.is_empty() {
454            return Ok(NamespaceStats {
455                namespace: namespace.to_string(),
456                count: 0,
457                avg_importance: 0.0,
458                avg_confidence: 0.0,
459                type_distribution: HashMap::new(),
460                tag_frequency: HashMap::new(),
461                oldest: None,
462                newest: None,
463            });
464        }
465
466        let mut total_importance = 0.0;
467        let mut total_confidence = 0.0;
468        let mut type_distribution: HashMap<String, usize> = HashMap::new();
469        let mut tag_frequency: HashMap<String, usize> = HashMap::new();
470        let mut oldest: Option<chrono::DateTime<chrono::Utc>> = None;
471        let mut newest: Option<chrono::DateTime<chrono::Utc>> = None;
472        let mut count = 0usize;
473
474        // M2: Batch-fetch all memories in one query instead of per-ID get_memory_no_touch.
475        // get_memories_batch does not increment access_count (pure SELECT), so it is
476        // equivalent to get_memory_no_touch for stats purposes.
477        let id_refs: Vec<&str> = ids.iter().map(|s| s.as_str()).collect();
478        let memories = self.storage.get_memories_batch(&id_refs)?;
479
480        for memory in &memories {
481            count += 1;
482            total_importance += memory.importance;
483            total_confidence += memory.confidence;
484
485            *type_distribution
486                .entry(memory.memory_type.to_string())
487                .or_insert(0) += 1;
488
489            for tag in &memory.tags {
490                *tag_frequency.entry(tag.clone()).or_insert(0) += 1;
491            }
492
493            match oldest {
494                None => oldest = Some(memory.created_at),
495                Some(ref o) if memory.created_at < *o => oldest = Some(memory.created_at),
496                _ => {}
497            }
498            match newest {
499                None => newest = Some(memory.created_at),
500                Some(ref n) if memory.created_at > *n => newest = Some(memory.created_at),
501                _ => {}
502            }
503        }
504
505        let avg_importance = if count > 0 {
506            total_importance / count as f64
507        } else {
508            0.0
509        };
510        let avg_confidence = if count > 0 {
511            total_confidence / count as f64
512        } else {
513            0.0
514        };
515
516        Ok(NamespaceStats {
517            namespace: namespace.to_string(),
518            count,
519            avg_importance,
520            avg_confidence,
521            type_distribution,
522            tag_frequency,
523            oldest,
524            newest,
525        })
526    }
527
528    /// Delete all memories in a namespace from all subsystems (storage, vector,
529    /// graph, BM25). Returns the number of memories deleted.
530    pub fn delete_namespace(&self, namespace: &str) -> Result<usize, CodememError> {
531        let ids = self.storage.list_memory_ids_for_namespace(namespace)?;
532
533        let mut deleted = 0usize;
534        let mut graph = self.lock_graph()?;
535        let mut vector = self.lock_vector()?;
536        let mut bm25 = self.lock_bm25()?;
537
538        for id in &ids {
539            // Delete memory from storage
540            if let Ok(true) = self.storage.delete_memory(id) {
541                deleted += 1;
542
543                // Remove from vector index
544                let _ = vector.remove(id);
545
546                // Remove from in-memory graph
547                let _ = graph.remove_node(id);
548
549                // Remove graph node and edges from SQLite
550                let _ = self.storage.delete_graph_edges_for_node(id);
551                let _ = self.storage.delete_graph_node(id);
552
553                // Remove embedding from SQLite
554                let _ = self.storage.delete_embedding(id);
555
556                // Remove from BM25 index
557                bm25.remove_document(id);
558            }
559        }
560
561        // Drop locks before calling save_index (which acquires vector lock)
562        drop(graph);
563        drop(vector);
564        drop(bm25);
565
566        // Persist vector index to disk
567        self.save_index();
568
569        Ok(deleted)
570    }
571}
572
573/// Check if an edge is currently active based on its temporal bounds.
574/// An edge is active if:
575/// - `valid_from` is None or <= `now`
576/// - `valid_to` is None or > `now`
577pub(crate) fn is_edge_active(
578    edge: &codemem_core::Edge,
579    now: chrono::DateTime<chrono::Utc>,
580) -> bool {
581    if let Some(valid_to) = edge.valid_to {
582        if valid_to < now {
583            return false;
584        }
585    }
586    if let Some(valid_from) = edge.valid_from {
587        if valid_from > now {
588            return false;
589        }
590    }
591    true
592}