Skip to main content

engram/search/
rerank.rs

1//! Search result reranking
2//!
3//! Provides post-search reranking to improve result quality through:
4//! - Query-document relevance scoring
5//! - Recency boosting
6//! - Importance weighting
7//! - Entity mention boosting
8//! - Context-aware scoring
9//!
10//! Supports pluggable reranking strategies with a default heuristic-based
11//! approach and optional integration with cross-encoder models.
12
13use serde::{Deserialize, Serialize};
14use std::collections::HashSet;
15
16use crate::types::{Memory, MemoryType, SearchResult};
17
18/// Configuration for the reranker
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct RerankConfig {
21    /// Enable reranking
22    pub enabled: bool,
23    /// Reranking strategy to use
24    pub strategy: RerankStrategy,
25    /// Weight for original search score (0.0 - 1.0)
26    pub original_score_weight: f32,
27    /// Weight for rerank score (0.0 - 1.0)
28    pub rerank_score_weight: f32,
29    /// Boost for recent memories (per day, decays exponentially)
30    pub recency_boost: f32,
31    /// Half-life for recency boost in days
32    pub recency_half_life_days: f32,
33    /// Boost per importance point
34    pub importance_boost: f32,
35    /// Boost for memories with matching entities
36    pub entity_match_boost: f32,
37    /// Boost for exact phrase matches
38    pub exact_match_boost: f32,
39    /// Minimum number of results to consider for reranking
40    pub min_results: usize,
41    /// Maximum number of results to rerank (for performance)
42    pub max_rerank_candidates: usize,
43}
44
45impl Default for RerankConfig {
46    fn default() -> Self {
47        Self {
48            enabled: true,
49            strategy: RerankStrategy::Heuristic,
50            original_score_weight: 0.6,
51            rerank_score_weight: 0.4,
52            recency_boost: 0.05,
53            recency_half_life_days: 30.0,
54            importance_boost: 0.1,
55            entity_match_boost: 0.15,
56            exact_match_boost: 0.2,
57            min_results: 3,
58            max_rerank_candidates: 100,
59        }
60    }
61}
62
63/// Reranking strategy
64#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
65#[serde(rename_all = "snake_case")]
66pub enum RerankStrategy {
67    /// No reranking, keep original order
68    None,
69    /// Heuristic-based reranking using query features
70    #[default]
71    Heuristic,
72    /// Cross-encoder model (requires external API or local model)
73    CrossEncoder,
74    /// Reciprocal Rank Fusion with multiple signals
75    MultiSignal,
76}
77
78/// Reranking result with explanation
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct RerankResult {
81    /// Original search result
82    pub result: SearchResult,
83    /// Original rank (1-indexed)
84    pub original_rank: usize,
85    /// New rank after reranking (1-indexed)
86    pub new_rank: usize,
87    /// Rerank score details
88    pub rerank_info: RerankInfo,
89}
90
91/// Detailed reranking information for explainability
92#[derive(Debug, Clone, Serialize, Deserialize, Default)]
93pub struct RerankInfo {
94    /// Original search score
95    pub original_score: f32,
96    /// Final combined score
97    pub final_score: f32,
98    /// Rerank score before combination
99    pub rerank_score: f32,
100    /// Individual score components
101    pub components: RerankComponents,
102}
103
104/// Individual components of the rerank score
105#[derive(Debug, Clone, Serialize, Deserialize, Default)]
106pub struct RerankComponents {
107    /// Score from term overlap
108    pub term_overlap: f32,
109    /// Score from recency
110    pub recency: f32,
111    /// Score from memory importance
112    pub importance: f32,
113    /// Score from entity matches
114    pub entity_match: f32,
115    /// Score from exact phrase match
116    pub exact_match: f32,
117    /// Score from memory type relevance
118    pub type_relevance: f32,
119    /// Score from tag matches
120    pub tag_match: f32,
121}
122
123/// Reranker for search results
124pub struct Reranker {
125    config: RerankConfig,
126}
127
128impl Reranker {
129    /// Create a new reranker with default config
130    pub fn new() -> Self {
131        Self {
132            config: RerankConfig::default(),
133        }
134    }
135
136    /// Create a new reranker with custom config
137    pub fn with_config(config: RerankConfig) -> Self {
138        Self { config }
139    }
140
141    /// Rerank search results
142    pub fn rerank(
143        &self,
144        results: Vec<SearchResult>,
145        query: &str,
146        query_entities: Option<&[String]>,
147    ) -> Vec<RerankResult> {
148        if !self.config.enabled || results.len() < self.config.min_results {
149            // Return results unchanged but with rerank info
150            return results
151                .into_iter()
152                .enumerate()
153                .map(|(i, r)| RerankResult {
154                    rerank_info: RerankInfo {
155                        original_score: r.score,
156                        final_score: r.score,
157                        rerank_score: 0.0,
158                        components: RerankComponents::default(),
159                    },
160                    result: r,
161                    original_rank: i + 1,
162                    new_rank: i + 1,
163                })
164                .collect();
165        }
166
167        match self.config.strategy {
168            RerankStrategy::None => self.no_rerank(results),
169            RerankStrategy::Heuristic => self.heuristic_rerank(results, query, query_entities),
170            RerankStrategy::CrossEncoder => {
171                // Cross-encoder requires external model, fallback to heuristic
172                self.heuristic_rerank(results, query, query_entities)
173            }
174            RerankStrategy::MultiSignal => self.multi_signal_rerank(results, query, query_entities),
175        }
176    }
177
178    /// No reranking - just wrap results
179    fn no_rerank(&self, results: Vec<SearchResult>) -> Vec<RerankResult> {
180        results
181            .into_iter()
182            .enumerate()
183            .map(|(i, r)| RerankResult {
184                rerank_info: RerankInfo {
185                    original_score: r.score,
186                    final_score: r.score,
187                    rerank_score: 0.0,
188                    components: RerankComponents::default(),
189                },
190                result: r,
191                original_rank: i + 1,
192                new_rank: i + 1,
193            })
194            .collect()
195    }
196
197    /// Heuristic-based reranking
198    fn heuristic_rerank(
199        &self,
200        results: Vec<SearchResult>,
201        query: &str,
202        query_entities: Option<&[String]>,
203    ) -> Vec<RerankResult> {
204        let query_terms = extract_terms(query);
205        let query_lower = query.to_lowercase();
206
207        let mut rerank_results: Vec<RerankResult> = results
208            .into_iter()
209            .enumerate()
210            .take(self.config.max_rerank_candidates)
211            .map(|(i, r)| {
212                let components = self.compute_rerank_components(
213                    &r.memory,
214                    &query_terms,
215                    &query_lower,
216                    query_entities,
217                );
218
219                let rerank_score = self.combine_components(&components);
220                let final_score = self.config.original_score_weight * r.score
221                    + self.config.rerank_score_weight * rerank_score;
222
223                RerankResult {
224                    rerank_info: RerankInfo {
225                        original_score: r.score,
226                        final_score,
227                        rerank_score,
228                        components,
229                    },
230                    result: r,
231                    original_rank: i + 1,
232                    new_rank: 0, // Will be set after sorting
233                }
234            })
235            .collect();
236
237        // Sort by final score
238        rerank_results.sort_by(|a, b| {
239            b.rerank_info
240                .final_score
241                .partial_cmp(&a.rerank_info.final_score)
242                .unwrap_or(std::cmp::Ordering::Equal)
243        });
244
245        // Update new ranks
246        for (i, result) in rerank_results.iter_mut().enumerate() {
247            result.new_rank = i + 1;
248        }
249
250        rerank_results
251    }
252
253    /// Multi-signal reranking using RRF across multiple signals
254    fn multi_signal_rerank(
255        &self,
256        results: Vec<SearchResult>,
257        query: &str,
258        query_entities: Option<&[String]>,
259    ) -> Vec<RerankResult> {
260        let query_terms = extract_terms(query);
261        let query_lower = query.to_lowercase();
262
263        // Compute multiple rankings
264        let mut original_ranks: Vec<(usize, f32)> = results
265            .iter()
266            .enumerate()
267            .map(|(i, r)| (i, r.score))
268            .collect();
269        original_ranks.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
270
271        let mut recency_ranks: Vec<(usize, f32)> = results
272            .iter()
273            .enumerate()
274            .map(|(i, r)| (i, self.compute_recency_score(&r.memory)))
275            .collect();
276        recency_ranks.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
277
278        let mut term_ranks: Vec<(usize, f32)> = results
279            .iter()
280            .enumerate()
281            .map(|(i, r)| (i, compute_term_overlap(&r.memory.content, &query_terms)))
282            .collect();
283        term_ranks.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
284
285        // RRF fusion
286        let k = 60.0;
287        let mut rrf_scores: Vec<(usize, f32)> = vec![];
288
289        for i in 0..results.len() {
290            let orig_rank = original_ranks
291                .iter()
292                .position(|(idx, _)| *idx == i)
293                .unwrap()
294                + 1;
295            let rec_rank = recency_ranks.iter().position(|(idx, _)| *idx == i).unwrap() + 1;
296            let term_rank = term_ranks.iter().position(|(idx, _)| *idx == i).unwrap() + 1;
297
298            let rrf_score = 1.0 / (k + orig_rank as f32)
299                + 0.5 / (k + rec_rank as f32)
300                + 0.5 / (k + term_rank as f32);
301
302            rrf_scores.push((i, rrf_score));
303        }
304
305        rrf_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
306
307        // Build results
308        let mut rerank_results: Vec<RerankResult> = results
309            .into_iter()
310            .enumerate()
311            .map(|(i, r)| {
312                let components = self.compute_rerank_components(
313                    &r.memory,
314                    &query_terms,
315                    &query_lower,
316                    query_entities,
317                );
318                let rrf_score = rrf_scores
319                    .iter()
320                    .find(|(idx, _)| *idx == i)
321                    .map(|(_, s)| *s)
322                    .unwrap_or(0.0);
323                let new_rank = rrf_scores
324                    .iter()
325                    .position(|(idx, _)| *idx == i)
326                    .unwrap_or(i)
327                    + 1;
328
329                RerankResult {
330                    rerank_info: RerankInfo {
331                        original_score: r.score,
332                        final_score: rrf_score,
333                        rerank_score: rrf_score,
334                        components,
335                    },
336                    result: r,
337                    original_rank: i + 1,
338                    new_rank,
339                }
340            })
341            .collect();
342
343        rerank_results.sort_by_key(|r| r.new_rank);
344        rerank_results
345    }
346
347    /// Compute reranking score components for a memory
348    fn compute_rerank_components(
349        &self,
350        memory: &Memory,
351        query_terms: &HashSet<String>,
352        query_lower: &str,
353        query_entities: Option<&[String]>,
354    ) -> RerankComponents {
355        let content_lower = memory.content.to_lowercase();
356
357        RerankComponents {
358            term_overlap: compute_term_overlap(&memory.content, query_terms),
359            recency: self.compute_recency_score(memory),
360            importance: memory.importance * self.config.importance_boost,
361            entity_match: self.compute_entity_match_score(memory, query_entities),
362            exact_match: if content_lower.contains(query_lower) {
363                self.config.exact_match_boost
364            } else {
365                0.0
366            },
367            type_relevance: self.compute_type_relevance(memory),
368            tag_match: self.compute_tag_match_score(memory, query_terms),
369        }
370    }
371
372    /// Combine component scores into a single rerank score
373    fn combine_components(&self, components: &RerankComponents) -> f32 {
374        // Weighted combination of components
375        components.term_overlap * 0.25
376            + components.recency * 0.15
377            + components.importance * 0.15
378            + components.entity_match * 0.15
379            + components.exact_match * 0.15
380            + components.type_relevance * 0.05
381            + components.tag_match * 0.10
382    }
383
384    /// Compute recency score with exponential decay
385    fn compute_recency_score(&self, memory: &Memory) -> f32 {
386        let now = chrono::Utc::now();
387        let age_days = (now - memory.created_at).num_days() as f32;
388
389        // Exponential decay: score = boost * 0.5^(age/half_life)
390        let decay = 0.5_f32.powf(age_days / self.config.recency_half_life_days);
391        self.config.recency_boost * decay
392    }
393
394    /// Compute entity match score
395    fn compute_entity_match_score(
396        &self,
397        memory: &Memory,
398        query_entities: Option<&[String]>,
399    ) -> f32 {
400        let Some(entities) = query_entities else {
401            return 0.0;
402        };
403
404        if entities.is_empty() {
405            return 0.0;
406        }
407
408        let content_lower = memory.content.to_lowercase();
409        let matches = entities
410            .iter()
411            .filter(|e| content_lower.contains(&e.to_lowercase()))
412            .count();
413
414        if matches > 0 {
415            self.config.entity_match_boost * (matches as f32 / entities.len() as f32)
416        } else {
417            0.0
418        }
419    }
420
421    /// Compute type relevance (some types are generally more relevant)
422    fn compute_type_relevance(&self, memory: &Memory) -> f32 {
423        match memory.memory_type {
424            MemoryType::Decision => 0.1,
425            MemoryType::Preference => 0.08,
426            MemoryType::Learning => 0.06,
427            MemoryType::Context => 0.05,
428            MemoryType::Note => 0.04,
429            MemoryType::Todo => 0.03,
430            MemoryType::Issue => 0.03,
431            MemoryType::Credential => 0.02,
432            MemoryType::Custom => 0.04,
433            MemoryType::TranscriptChunk => 0.02, // Lower relevance for transcript chunks
434            MemoryType::Episodic => 0.07,
435            MemoryType::Procedural => 0.06,
436            MemoryType::Summary => 0.05,
437            MemoryType::Checkpoint => 0.04,
438        }
439    }
440
441    /// Compute tag match score
442    fn compute_tag_match_score(&self, memory: &Memory, query_terms: &HashSet<String>) -> f32 {
443        if memory.tags.is_empty() || query_terms.is_empty() {
444            return 0.0;
445        }
446
447        let tag_set: HashSet<String> = memory.tags.iter().map(|t| t.to_lowercase()).collect();
448        let matches = query_terms.intersection(&tag_set).count();
449
450        if matches > 0 {
451            0.1 * (matches as f32 / query_terms.len().min(memory.tags.len()) as f32)
452        } else {
453            0.0
454        }
455    }
456}
457
458impl Default for Reranker {
459    fn default() -> Self {
460        Self::new()
461    }
462}
463
464/// Extract normalized terms from text
465fn extract_terms(text: &str) -> HashSet<String> {
466    text.to_lowercase()
467        .split(|c: char| !c.is_alphanumeric())
468        .filter(|s| s.len() > 2)
469        .map(|s| s.to_string())
470        .collect()
471}
472
473/// Compute term overlap score between content and query terms
474fn compute_term_overlap(content: &str, query_terms: &HashSet<String>) -> f32 {
475    if query_terms.is_empty() {
476        return 0.0;
477    }
478
479    let content_terms = extract_terms(content);
480    let matches = query_terms.intersection(&content_terms).count();
481
482    matches as f32 / query_terms.len() as f32
483}
484
485#[cfg(test)]
486mod tests {
487    use super::*;
488    use crate::types::{MatchInfo, MemoryScope, SearchStrategy, Visibility};
489    use chrono::Utc;
490    use std::collections::HashMap;
491
492    fn create_test_memory(content: &str, importance: f32) -> Memory {
493        Memory {
494            id: 1,
495            content: content.to_string(),
496            memory_type: MemoryType::Note,
497            importance,
498            tags: vec![],
499            access_count: 0,
500            created_at: Utc::now(),
501            updated_at: Utc::now(),
502            last_accessed_at: None,
503            owner_id: None,
504            visibility: Visibility::Private,
505            version: 1,
506            has_embedding: false,
507            metadata: HashMap::new(),
508            scope: MemoryScope::Global,
509            workspace: "default".to_string(),
510            tier: crate::types::MemoryTier::Permanent,
511            expires_at: None,
512            content_hash: None,
513            event_time: None,
514            event_duration_seconds: None,
515            trigger_pattern: None,
516            procedure_success_count: 0,
517            procedure_failure_count: 0,
518            summary_of_id: None,
519            lifecycle_state: crate::types::LifecycleState::Active,
520        }
521    }
522
523    fn create_test_result(memory: Memory, score: f32) -> SearchResult {
524        SearchResult {
525            memory,
526            score,
527            match_info: MatchInfo {
528                strategy: SearchStrategy::Hybrid,
529                matched_terms: vec![],
530                highlights: vec![],
531                semantic_score: None,
532                keyword_score: Some(score),
533            },
534        }
535    }
536
537    #[test]
538    fn test_reranker_preserves_order_when_disabled() {
539        let config = RerankConfig {
540            enabled: false,
541            ..Default::default()
542        };
543        let reranker = Reranker::with_config(config);
544
545        let results = vec![
546            create_test_result(create_test_memory("First result", 0.5), 0.9),
547            create_test_result(create_test_memory("Second result", 0.5), 0.8),
548            create_test_result(create_test_memory("Third result", 0.5), 0.7),
549        ];
550
551        let reranked = reranker.rerank(results, "test query", None);
552
553        assert_eq!(reranked[0].new_rank, 1);
554        assert_eq!(reranked[1].new_rank, 2);
555        assert_eq!(reranked[2].new_rank, 3);
556    }
557
558    #[test]
559    fn test_exact_match_boost() {
560        let reranker = Reranker::new();
561
562        let results = vec![
563            create_test_result(create_test_memory("Some unrelated content", 0.5), 0.9),
564            create_test_result(
565                create_test_memory("This contains test query exactly", 0.5),
566                0.7,
567            ),
568            create_test_result(create_test_memory("Another unrelated text", 0.5), 0.8),
569        ];
570
571        let reranked = reranker.rerank(results, "test query", None);
572
573        // The result with exact match should be boosted
574        let exact_match_result = reranked
575            .iter()
576            .find(|r| r.result.memory.content.contains("test query"))
577            .unwrap();
578        assert!(exact_match_result.rerank_info.components.exact_match > 0.0);
579    }
580
581    #[test]
582    fn test_importance_boost() {
583        let config = RerankConfig {
584            min_results: 2, // Allow testing with 2 results
585            ..Default::default()
586        };
587        let reranker = Reranker::with_config(config);
588
589        let mut low_importance = create_test_memory("Test content low", 0.2);
590        let mut high_importance = create_test_memory("Test content high", 0.9);
591
592        low_importance.id = 1;
593        high_importance.id = 2;
594
595        let results = vec![
596            create_test_result(low_importance, 0.8),
597            create_test_result(high_importance, 0.75),
598        ];
599
600        let reranked = reranker.rerank(results, "test", None);
601
602        // High importance memory should have higher importance component
603        let high_result = reranked.iter().find(|r| r.result.memory.id == 2).unwrap();
604        let low_result = reranked.iter().find(|r| r.result.memory.id == 1).unwrap();
605
606        assert!(
607            high_result.rerank_info.components.importance
608                > low_result.rerank_info.components.importance
609        );
610    }
611
612    #[test]
613    fn test_entity_match_boost() {
614        let config = RerankConfig {
615            min_results: 2, // Allow testing with 2 results
616            ..Default::default()
617        };
618        let reranker = Reranker::with_config(config);
619
620        let results = vec![
621            create_test_result(
622                create_test_memory("Content about Python programming", 0.5),
623                0.8,
624            ),
625            create_test_result(
626                create_test_memory("Content about Rust and systems", 0.5),
627                0.75,
628            ),
629        ];
630
631        let entities = vec!["Rust".to_string(), "systems".to_string()];
632        let reranked = reranker.rerank(results, "programming language", Some(&entities));
633
634        // Result mentioning entities should have entity_match boost
635        let rust_result = reranked
636            .iter()
637            .find(|r| r.result.memory.content.contains("Rust"))
638            .unwrap();
639        assert!(rust_result.rerank_info.components.entity_match > 0.0);
640    }
641
642    #[test]
643    fn test_term_overlap() {
644        let terms: HashSet<String> = ["rust", "programming", "memory"]
645            .iter()
646            .map(|s| s.to_string())
647            .collect();
648
649        let high_overlap = compute_term_overlap("Rust programming with memory management", &terms);
650        let low_overlap = compute_term_overlap("Python web development", &terms);
651
652        assert!(high_overlap > low_overlap);
653        assert!(high_overlap > 0.5); // At least 2 of 3 terms match
654    }
655
656    #[test]
657    fn test_multi_signal_rerank() {
658        let config = RerankConfig {
659            strategy: RerankStrategy::MultiSignal,
660            ..Default::default()
661        };
662        let reranker = Reranker::with_config(config);
663
664        let results = vec![
665            create_test_result(create_test_memory("First memory", 0.5), 0.9),
666            create_test_result(create_test_memory("Second memory", 0.5), 0.8),
667            create_test_result(
668                create_test_memory("Third memory with exact query", 0.5),
669                0.7,
670            ),
671        ];
672
673        let reranked = reranker.rerank(results, "exact query", None);
674
675        // Results should be reranked
676        assert_eq!(reranked.len(), 3);
677        // All should have rerank info
678        for r in &reranked {
679            assert!(r.rerank_info.final_score > 0.0);
680        }
681    }
682}