Skip to main content

engram/search/
rerank.rs

1//! Search result reranking
2//!
3//! Provides post-search reranking to improve result quality through:
4//! - Query-document relevance scoring
5//! - Recency boosting
6//! - Importance weighting
7//! - Entity mention boosting
8//! - Context-aware scoring
9//!
10//! Supports pluggable reranking strategies with a default heuristic-based
11//! approach and optional integration with cross-encoder models.
12
13use serde::{Deserialize, Serialize};
14use std::collections::HashSet;
15
16use crate::types::{Memory, MemoryType, SearchResult};
17
18/// Configuration for the reranker
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct RerankConfig {
21    /// Enable reranking
22    pub enabled: bool,
23    /// Reranking strategy to use
24    pub strategy: RerankStrategy,
25    /// Weight for original search score (0.0 - 1.0)
26    pub original_score_weight: f32,
27    /// Weight for rerank score (0.0 - 1.0)
28    pub rerank_score_weight: f32,
29    /// Boost for recent memories (per day, decays exponentially)
30    pub recency_boost: f32,
31    /// Half-life for recency boost in days
32    pub recency_half_life_days: f32,
33    /// Boost per importance point
34    pub importance_boost: f32,
35    /// Boost for memories with matching entities
36    pub entity_match_boost: f32,
37    /// Boost for exact phrase matches
38    pub exact_match_boost: f32,
39    /// Minimum number of results to consider for reranking
40    pub min_results: usize,
41    /// Maximum number of results to rerank (for performance)
42    pub max_rerank_candidates: usize,
43}
44
45impl Default for RerankConfig {
46    fn default() -> Self {
47        Self {
48            enabled: true,
49            strategy: RerankStrategy::Heuristic,
50            original_score_weight: 0.6,
51            rerank_score_weight: 0.4,
52            recency_boost: 0.05,
53            recency_half_life_days: 30.0,
54            importance_boost: 0.1,
55            entity_match_boost: 0.15,
56            exact_match_boost: 0.2,
57            min_results: 3,
58            max_rerank_candidates: 100,
59        }
60    }
61}
62
63/// Reranking strategy
64#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
65#[serde(rename_all = "snake_case")]
66pub enum RerankStrategy {
67    /// No reranking, keep original order
68    None,
69    /// Heuristic-based reranking using query features
70    #[default]
71    Heuristic,
72    /// Cross-encoder model (requires external API or local model)
73    CrossEncoder,
74    /// Reciprocal Rank Fusion with multiple signals
75    MultiSignal,
76}
77
78/// Reranking result with explanation
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct RerankResult {
81    /// Original search result
82    pub result: SearchResult,
83    /// Original rank (1-indexed)
84    pub original_rank: usize,
85    /// New rank after reranking (1-indexed)
86    pub new_rank: usize,
87    /// Rerank score details
88    pub rerank_info: RerankInfo,
89}
90
91/// Detailed reranking information for explainability
92#[derive(Debug, Clone, Serialize, Deserialize, Default)]
93pub struct RerankInfo {
94    /// Original search score
95    pub original_score: f32,
96    /// Final combined score
97    pub final_score: f32,
98    /// Rerank score before combination
99    pub rerank_score: f32,
100    /// Individual score components
101    pub components: RerankComponents,
102}
103
104/// Individual components of the rerank score
105#[derive(Debug, Clone, Serialize, Deserialize, Default)]
106pub struct RerankComponents {
107    /// Score from term overlap
108    pub term_overlap: f32,
109    /// Score from recency
110    pub recency: f32,
111    /// Score from memory importance
112    pub importance: f32,
113    /// Score from entity matches
114    pub entity_match: f32,
115    /// Score from exact phrase match
116    pub exact_match: f32,
117    /// Score from memory type relevance
118    pub type_relevance: f32,
119    /// Score from tag matches
120    pub tag_match: f32,
121}
122
123/// Reranker for search results
124pub struct Reranker {
125    config: RerankConfig,
126}
127
128impl Reranker {
129    /// Create a new reranker with default config
130    pub fn new() -> Self {
131        Self {
132            config: RerankConfig::default(),
133        }
134    }
135
136    /// Create a new reranker with custom config
137    pub fn with_config(config: RerankConfig) -> Self {
138        Self { config }
139    }
140
141    /// Rerank search results
142    pub fn rerank(
143        &self,
144        results: Vec<SearchResult>,
145        query: &str,
146        query_entities: Option<&[String]>,
147    ) -> Vec<RerankResult> {
148        if !self.config.enabled || results.len() < self.config.min_results {
149            // Return results unchanged but with rerank info
150            return results
151                .into_iter()
152                .enumerate()
153                .map(|(i, r)| RerankResult {
154                    rerank_info: RerankInfo {
155                        original_score: r.score,
156                        final_score: r.score,
157                        rerank_score: 0.0,
158                        components: RerankComponents::default(),
159                    },
160                    result: r,
161                    original_rank: i + 1,
162                    new_rank: i + 1,
163                })
164                .collect();
165        }
166
167        match self.config.strategy {
168            RerankStrategy::None => self.no_rerank(results),
169            RerankStrategy::Heuristic => self.heuristic_rerank(results, query, query_entities),
170            RerankStrategy::CrossEncoder => {
171                // Cross-encoder requires external model, fallback to heuristic
172                self.heuristic_rerank(results, query, query_entities)
173            }
174            RerankStrategy::MultiSignal => self.multi_signal_rerank(results, query, query_entities),
175        }
176    }
177
178    /// No reranking - just wrap results
179    fn no_rerank(&self, results: Vec<SearchResult>) -> Vec<RerankResult> {
180        results
181            .into_iter()
182            .enumerate()
183            .map(|(i, r)| RerankResult {
184                rerank_info: RerankInfo {
185                    original_score: r.score,
186                    final_score: r.score,
187                    rerank_score: 0.0,
188                    components: RerankComponents::default(),
189                },
190                result: r,
191                original_rank: i + 1,
192                new_rank: i + 1,
193            })
194            .collect()
195    }
196
197    /// Heuristic-based reranking
198    fn heuristic_rerank(
199        &self,
200        results: Vec<SearchResult>,
201        query: &str,
202        query_entities: Option<&[String]>,
203    ) -> Vec<RerankResult> {
204        let query_terms = extract_terms(query);
205        let query_lower = query.to_lowercase();
206
207        let mut rerank_results: Vec<RerankResult> = results
208            .into_iter()
209            .enumerate()
210            .take(self.config.max_rerank_candidates)
211            .map(|(i, r)| {
212                let components = self.compute_rerank_components(
213                    &r.memory,
214                    &query_terms,
215                    &query_lower,
216                    query_entities,
217                );
218
219                let rerank_score = self.combine_components(&components);
220                let final_score = self.config.original_score_weight * r.score
221                    + self.config.rerank_score_weight * rerank_score;
222
223                RerankResult {
224                    rerank_info: RerankInfo {
225                        original_score: r.score,
226                        final_score,
227                        rerank_score,
228                        components,
229                    },
230                    result: r,
231                    original_rank: i + 1,
232                    new_rank: 0, // Will be set after sorting
233                }
234            })
235            .collect();
236
237        // Sort by final score
238        rerank_results.sort_by(|a, b| {
239            b.rerank_info
240                .final_score
241                .partial_cmp(&a.rerank_info.final_score)
242                .unwrap_or(std::cmp::Ordering::Equal)
243        });
244
245        // Update new ranks
246        for (i, result) in rerank_results.iter_mut().enumerate() {
247            result.new_rank = i + 1;
248        }
249
250        rerank_results
251    }
252
253    /// Multi-signal reranking using RRF across multiple signals
254    fn multi_signal_rerank(
255        &self,
256        results: Vec<SearchResult>,
257        query: &str,
258        query_entities: Option<&[String]>,
259    ) -> Vec<RerankResult> {
260        let query_terms = extract_terms(query);
261        let query_lower = query.to_lowercase();
262
263        // Compute multiple rankings
264        let mut original_ranks: Vec<(usize, f32)> = results
265            .iter()
266            .enumerate()
267            .map(|(i, r)| (i, r.score))
268            .collect();
269        original_ranks.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
270
271        let mut recency_ranks: Vec<(usize, f32)> = results
272            .iter()
273            .enumerate()
274            .map(|(i, r)| (i, self.compute_recency_score(&r.memory)))
275            .collect();
276        recency_ranks.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
277
278        let mut term_ranks: Vec<(usize, f32)> = results
279            .iter()
280            .enumerate()
281            .map(|(i, r)| (i, compute_term_overlap(&r.memory.content, &query_terms)))
282            .collect();
283        term_ranks.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
284
285        // RRF fusion
286        let k = 60.0;
287        let mut rrf_scores: Vec<(usize, f32)> = vec![];
288
289        for i in 0..results.len() {
290            let orig_rank = original_ranks
291                .iter()
292                .position(|(idx, _)| *idx == i)
293                .unwrap()
294                + 1;
295            let rec_rank = recency_ranks.iter().position(|(idx, _)| *idx == i).unwrap() + 1;
296            let term_rank = term_ranks.iter().position(|(idx, _)| *idx == i).unwrap() + 1;
297
298            let rrf_score = 1.0 / (k + orig_rank as f32)
299                + 0.5 / (k + rec_rank as f32)
300                + 0.5 / (k + term_rank as f32);
301
302            rrf_scores.push((i, rrf_score));
303        }
304
305        rrf_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
306
307        // Build results
308        let mut rerank_results: Vec<RerankResult> = results
309            .into_iter()
310            .enumerate()
311            .map(|(i, r)| {
312                let components = self.compute_rerank_components(
313                    &r.memory,
314                    &query_terms,
315                    &query_lower,
316                    query_entities,
317                );
318                let rrf_score = rrf_scores
319                    .iter()
320                    .find(|(idx, _)| *idx == i)
321                    .map(|(_, s)| *s)
322                    .unwrap_or(0.0);
323                let new_rank = rrf_scores
324                    .iter()
325                    .position(|(idx, _)| *idx == i)
326                    .unwrap_or(i)
327                    + 1;
328
329                RerankResult {
330                    rerank_info: RerankInfo {
331                        original_score: r.score,
332                        final_score: rrf_score,
333                        rerank_score: rrf_score,
334                        components,
335                    },
336                    result: r,
337                    original_rank: i + 1,
338                    new_rank,
339                }
340            })
341            .collect();
342
343        rerank_results.sort_by_key(|r| r.new_rank);
344        rerank_results
345    }
346
347    /// Compute reranking score components for a memory
348    fn compute_rerank_components(
349        &self,
350        memory: &Memory,
351        query_terms: &HashSet<String>,
352        query_lower: &str,
353        query_entities: Option<&[String]>,
354    ) -> RerankComponents {
355        let content_lower = memory.content.to_lowercase();
356
357        RerankComponents {
358            term_overlap: compute_term_overlap(&memory.content, query_terms),
359            recency: self.compute_recency_score(memory),
360            importance: memory.importance * self.config.importance_boost,
361            entity_match: self.compute_entity_match_score(memory, query_entities),
362            exact_match: if content_lower.contains(query_lower) {
363                self.config.exact_match_boost
364            } else {
365                0.0
366            },
367            type_relevance: self.compute_type_relevance(memory),
368            tag_match: self.compute_tag_match_score(memory, query_terms),
369        }
370    }
371
372    /// Combine component scores into a single rerank score
373    fn combine_components(&self, components: &RerankComponents) -> f32 {
374        // Weighted combination of components
375        components.term_overlap * 0.25
376            + components.recency * 0.15
377            + components.importance * 0.15
378            + components.entity_match * 0.15
379            + components.exact_match * 0.15
380            + components.type_relevance * 0.05
381            + components.tag_match * 0.10
382    }
383
384    /// Compute recency score with exponential decay
385    fn compute_recency_score(&self, memory: &Memory) -> f32 {
386        let now = chrono::Utc::now();
387        let age_days = (now - memory.created_at).num_days() as f32;
388
389        // Exponential decay: score = boost * 0.5^(age/half_life)
390        let decay = 0.5_f32.powf(age_days / self.config.recency_half_life_days);
391        self.config.recency_boost * decay
392    }
393
394    /// Compute entity match score
395    fn compute_entity_match_score(
396        &self,
397        memory: &Memory,
398        query_entities: Option<&[String]>,
399    ) -> f32 {
400        let Some(entities) = query_entities else {
401            return 0.0;
402        };
403
404        if entities.is_empty() {
405            return 0.0;
406        }
407
408        let content_lower = memory.content.to_lowercase();
409        let matches = entities
410            .iter()
411            .filter(|e| content_lower.contains(&e.to_lowercase()))
412            .count();
413
414        if matches > 0 {
415            self.config.entity_match_boost * (matches as f32 / entities.len() as f32)
416        } else {
417            0.0
418        }
419    }
420
421    /// Compute type relevance (some types are generally more relevant)
422    fn compute_type_relevance(&self, memory: &Memory) -> f32 {
423        match memory.memory_type {
424            MemoryType::Decision => 0.1,
425            MemoryType::Preference => 0.08,
426            MemoryType::Learning => 0.06,
427            MemoryType::Context => 0.05,
428            MemoryType::Note => 0.04,
429            MemoryType::Todo => 0.03,
430            MemoryType::Issue => 0.03,
431            MemoryType::Credential => 0.02,
432            MemoryType::Custom => 0.04,
433            MemoryType::TranscriptChunk => 0.02, // Lower relevance for transcript chunks
434            MemoryType::Episodic => 0.07,
435            MemoryType::Procedural => 0.06,
436            MemoryType::Summary => 0.05,
437            MemoryType::Checkpoint => 0.04,
438            MemoryType::Image | MemoryType::Audio | MemoryType::Video => 0.05,
439        }
440    }
441
442    /// Compute tag match score
443    fn compute_tag_match_score(&self, memory: &Memory, query_terms: &HashSet<String>) -> f32 {
444        if memory.tags.is_empty() || query_terms.is_empty() {
445            return 0.0;
446        }
447
448        let tag_set: HashSet<String> = memory.tags.iter().map(|t| t.to_lowercase()).collect();
449        let matches = query_terms.intersection(&tag_set).count();
450
451        if matches > 0 {
452            0.1 * (matches as f32 / query_terms.len().min(memory.tags.len()) as f32)
453        } else {
454            0.0
455        }
456    }
457}
458
459impl Default for Reranker {
460    fn default() -> Self {
461        Self::new()
462    }
463}
464
465/// Extract normalized terms from text
466fn extract_terms(text: &str) -> HashSet<String> {
467    text.to_lowercase()
468        .split(|c: char| !c.is_alphanumeric())
469        .filter(|s| s.len() > 2)
470        .map(|s| s.to_string())
471        .collect()
472}
473
474/// Compute term overlap score between content and query terms
475fn compute_term_overlap(content: &str, query_terms: &HashSet<String>) -> f32 {
476    if query_terms.is_empty() {
477        return 0.0;
478    }
479
480    let content_terms = extract_terms(content);
481    let matches = query_terms.intersection(&content_terms).count();
482
483    matches as f32 / query_terms.len() as f32
484}
485
486#[cfg(test)]
487mod tests {
488    use super::*;
489    use crate::types::{MatchInfo, MemoryScope, SearchStrategy, Visibility};
490    use chrono::Utc;
491    use std::collections::HashMap;
492
493    fn create_test_memory(content: &str, importance: f32) -> Memory {
494        Memory {
495            id: 1,
496            content: content.to_string(),
497            memory_type: MemoryType::Note,
498            importance,
499            tags: vec![],
500            access_count: 0,
501            created_at: Utc::now(),
502            updated_at: Utc::now(),
503            last_accessed_at: None,
504            owner_id: None,
505            visibility: Visibility::Private,
506            version: 1,
507            has_embedding: false,
508            metadata: HashMap::new(),
509            scope: MemoryScope::Global,
510            workspace: "default".to_string(),
511            tier: crate::types::MemoryTier::Permanent,
512            expires_at: None,
513            content_hash: None,
514            event_time: None,
515            event_duration_seconds: None,
516            trigger_pattern: None,
517            procedure_success_count: 0,
518            procedure_failure_count: 0,
519            summary_of_id: None,
520            lifecycle_state: crate::types::LifecycleState::Active,
521            media_url: None,
522        }
523    }
524
525    fn create_test_result(memory: Memory, score: f32) -> SearchResult {
526        SearchResult {
527            memory,
528            score,
529            match_info: MatchInfo {
530                strategy: SearchStrategy::Hybrid,
531                matched_terms: vec![],
532                highlights: vec![],
533                semantic_score: None,
534                keyword_score: Some(score),
535            },
536        }
537    }
538
539    #[test]
540    fn test_reranker_preserves_order_when_disabled() {
541        let config = RerankConfig {
542            enabled: false,
543            ..Default::default()
544        };
545        let reranker = Reranker::with_config(config);
546
547        let results = vec![
548            create_test_result(create_test_memory("First result", 0.5), 0.9),
549            create_test_result(create_test_memory("Second result", 0.5), 0.8),
550            create_test_result(create_test_memory("Third result", 0.5), 0.7),
551        ];
552
553        let reranked = reranker.rerank(results, "test query", None);
554
555        assert_eq!(reranked[0].new_rank, 1);
556        assert_eq!(reranked[1].new_rank, 2);
557        assert_eq!(reranked[2].new_rank, 3);
558    }
559
560    #[test]
561    fn test_exact_match_boost() {
562        let reranker = Reranker::new();
563
564        let results = vec![
565            create_test_result(create_test_memory("Some unrelated content", 0.5), 0.9),
566            create_test_result(
567                create_test_memory("This contains test query exactly", 0.5),
568                0.7,
569            ),
570            create_test_result(create_test_memory("Another unrelated text", 0.5), 0.8),
571        ];
572
573        let reranked = reranker.rerank(results, "test query", None);
574
575        // The result with exact match should be boosted
576        let exact_match_result = reranked
577            .iter()
578            .find(|r| r.result.memory.content.contains("test query"))
579            .unwrap();
580        assert!(exact_match_result.rerank_info.components.exact_match > 0.0);
581    }
582
583    #[test]
584    fn test_importance_boost() {
585        let config = RerankConfig {
586            min_results: 2, // Allow testing with 2 results
587            ..Default::default()
588        };
589        let reranker = Reranker::with_config(config);
590
591        let mut low_importance = create_test_memory("Test content low", 0.2);
592        let mut high_importance = create_test_memory("Test content high", 0.9);
593
594        low_importance.id = 1;
595        high_importance.id = 2;
596
597        let results = vec![
598            create_test_result(low_importance, 0.8),
599            create_test_result(high_importance, 0.75),
600        ];
601
602        let reranked = reranker.rerank(results, "test", None);
603
604        // High importance memory should have higher importance component
605        let high_result = reranked.iter().find(|r| r.result.memory.id == 2).unwrap();
606        let low_result = reranked.iter().find(|r| r.result.memory.id == 1).unwrap();
607
608        assert!(
609            high_result.rerank_info.components.importance
610                > low_result.rerank_info.components.importance
611        );
612    }
613
614    #[test]
615    fn test_entity_match_boost() {
616        let config = RerankConfig {
617            min_results: 2, // Allow testing with 2 results
618            ..Default::default()
619        };
620        let reranker = Reranker::with_config(config);
621
622        let results = vec![
623            create_test_result(
624                create_test_memory("Content about Python programming", 0.5),
625                0.8,
626            ),
627            create_test_result(
628                create_test_memory("Content about Rust and systems", 0.5),
629                0.75,
630            ),
631        ];
632
633        let entities = vec!["Rust".to_string(), "systems".to_string()];
634        let reranked = reranker.rerank(results, "programming language", Some(&entities));
635
636        // Result mentioning entities should have entity_match boost
637        let rust_result = reranked
638            .iter()
639            .find(|r| r.result.memory.content.contains("Rust"))
640            .unwrap();
641        assert!(rust_result.rerank_info.components.entity_match > 0.0);
642    }
643
644    #[test]
645    fn test_term_overlap() {
646        let terms: HashSet<String> = ["rust", "programming", "memory"]
647            .iter()
648            .map(|s| s.to_string())
649            .collect();
650
651        let high_overlap = compute_term_overlap("Rust programming with memory management", &terms);
652        let low_overlap = compute_term_overlap("Python web development", &terms);
653
654        assert!(high_overlap > low_overlap);
655        assert!(high_overlap > 0.5); // At least 2 of 3 terms match
656    }
657
658    #[test]
659    fn test_multi_signal_rerank() {
660        let config = RerankConfig {
661            strategy: RerankStrategy::MultiSignal,
662            ..Default::default()
663        };
664        let reranker = Reranker::with_config(config);
665
666        let results = vec![
667            create_test_result(create_test_memory("First memory", 0.5), 0.9),
668            create_test_result(create_test_memory("Second memory", 0.5), 0.8),
669            create_test_result(
670                create_test_memory("Third memory with exact query", 0.5),
671                0.7,
672            ),
673        ];
674
675        let reranked = reranker.rerank(results, "exact query", None);
676
677        // Results should be reranked
678        assert_eq!(reranked.len(), 3);
679        // All should have rerank info
680        for r in &reranked {
681            assert!(r.rerank_info.final_score > 0.0);
682        }
683    }
684}