ipfrs_semantic/
reranking.rs

1//! Query result re-ranking
2//!
3//! This module provides re-ranking capabilities for search results based on
4//! multiple criteria including semantic similarity, metadata scores, recency,
5//! and custom scoring functions.
6
7use crate::hnsw::SearchResult;
8use crate::metadata::{Metadata, MetadataValue};
9use ipfrs_core::Cid;
10use std::collections::HashMap;
11
12/// Re-ranking strategy for search results
13#[derive(Debug, Clone)]
14pub enum ReRankingStrategy {
15    /// Weighted combination of multiple scores
16    WeightedCombination(Vec<(ScoreComponent, f32)>),
17    /// Reciprocal Rank Fusion (RRF)
18    ReciprocalRankFusion { k: f32 },
19    /// Learning to Rank (placeholder for custom models)
20    LearnToRank { model_name: String },
21    /// Custom scoring function
22    Custom,
23}
24
25/// Components that contribute to the final score
26#[derive(Debug, Clone)]
27pub enum ScoreComponent {
28    /// Original vector similarity score
29    VectorSimilarity,
30    /// Metadata-based score (requires metadata lookup)
31    MetadataScore { field: String },
32    /// Recency score (requires timestamp metadata)
33    Recency { decay_factor: f32 },
34    /// Popularity score (requires popularity metadata)
35    Popularity,
36    /// Diversity score (penalize similar results)
37    Diversity { threshold: f32 },
38    /// Custom score (requires external scoring function)
39    Custom { name: String },
40}
41
42/// Re-ranking configuration
43#[derive(Debug, Clone)]
44pub struct ReRankingConfig {
45    /// Strategy to use for re-ranking
46    pub strategy: ReRankingStrategy,
47    /// Whether to normalize scores before combining
48    pub normalize_scores: bool,
49    /// Maximum number of results to re-rank (re-rank top-k only)
50    pub top_k: Option<usize>,
51}
52
53impl Default for ReRankingConfig {
54    fn default() -> Self {
55        Self {
56            strategy: ReRankingStrategy::WeightedCombination(vec![(
57                ScoreComponent::VectorSimilarity,
58                1.0,
59            )]),
60            normalize_scores: true,
61            top_k: Some(100), // Re-rank top 100 only for efficiency
62        }
63    }
64}
65
66/// Result with multiple score components
67#[derive(Debug, Clone)]
68pub struct ScoredResult {
69    /// The search result
70    pub result: SearchResult,
71    /// Individual score components
72    pub score_components: HashMap<String, f32>,
73    /// Final combined score
74    pub final_score: f32,
75}
76
77/// Re-ranker for search results
78pub struct ReRanker {
79    config: ReRankingConfig,
80    metadata_cache: HashMap<Cid, Metadata>,
81}
82
83impl ReRanker {
84    /// Create a new re-ranker with the given configuration
85    pub fn new(config: ReRankingConfig) -> Self {
86        Self {
87            config,
88            metadata_cache: HashMap::new(),
89        }
90    }
91
92    /// Create a re-ranker with default configuration
93    pub fn with_defaults() -> Self {
94        Self::new(ReRankingConfig::default())
95    }
96
97    /// Add metadata for a CID (for metadata-based scoring)
98    pub fn add_metadata(&mut self, cid: Cid, metadata: Metadata) {
99        self.metadata_cache.insert(cid, metadata);
100    }
101
102    /// Re-rank search results
103    pub fn rerank(&self, results: Vec<SearchResult>) -> Vec<ScoredResult> {
104        let limit = self
105            .config
106            .top_k
107            .unwrap_or(results.len())
108            .min(results.len());
109        let mut to_rerank: Vec<SearchResult> = results.into_iter().take(limit).collect();
110
111        match &self.config.strategy {
112            ReRankingStrategy::WeightedCombination(weights) => {
113                self.rerank_weighted(&mut to_rerank, weights)
114            }
115            ReRankingStrategy::ReciprocalRankFusion { k } => self.rerank_rrf(&mut to_rerank, *k),
116            ReRankingStrategy::LearnToRank { model_name: _ } => {
117                // Placeholder - would integrate with external model
118                self.rerank_placeholder(&mut to_rerank)
119            }
120            ReRankingStrategy::Custom => self.rerank_placeholder(&mut to_rerank),
121        }
122    }
123
124    /// Re-rank using weighted combination of scores
125    fn rerank_weighted(
126        &self,
127        results: &mut [SearchResult],
128        weights: &[(ScoreComponent, f32)],
129    ) -> Vec<ScoredResult> {
130        let mut scored_results: Vec<ScoredResult> = results
131            .iter()
132            .map(|r| {
133                let mut score_components = HashMap::new();
134                let mut final_score = 0.0;
135
136                for (component, weight) in weights {
137                    let component_score = self.compute_component_score(r, component);
138                    let component_name = self.component_name(component);
139                    score_components.insert(component_name, component_score);
140                    final_score += component_score * weight;
141                }
142
143                ScoredResult {
144                    result: r.clone(),
145                    score_components,
146                    final_score,
147                }
148            })
149            .collect();
150
151        // Normalize if requested
152        if self.config.normalize_scores {
153            self.normalize_scores(&mut scored_results);
154        }
155
156        // Sort by final score (descending)
157        scored_results.sort_by(|a, b| {
158            b.final_score
159                .partial_cmp(&a.final_score)
160                .unwrap_or(std::cmp::Ordering::Equal)
161        });
162
163        scored_results
164    }
165
166    /// Re-rank using Reciprocal Rank Fusion
167    fn rerank_rrf(&self, results: &mut [SearchResult], k: f32) -> Vec<ScoredResult> {
168        let scored_results: Vec<ScoredResult> = results
169            .iter()
170            .enumerate()
171            .map(|(rank, r)| {
172                let rrf_score = 1.0 / (k + rank as f32 + 1.0);
173                let mut score_components = HashMap::new();
174                score_components.insert("vector_similarity".to_string(), r.score);
175                score_components.insert("rrf_score".to_string(), rrf_score);
176
177                ScoredResult {
178                    result: r.clone(),
179                    score_components,
180                    final_score: rrf_score,
181                }
182            })
183            .collect();
184
185        scored_results
186    }
187
188    /// Placeholder re-ranking (just return as-is)
189    fn rerank_placeholder(&self, results: &mut [SearchResult]) -> Vec<ScoredResult> {
190        results
191            .iter()
192            .map(|r| {
193                let mut score_components = HashMap::new();
194                score_components.insert("vector_similarity".to_string(), r.score);
195
196                ScoredResult {
197                    result: r.clone(),
198                    score_components,
199                    final_score: r.score,
200                }
201            })
202            .collect()
203    }
204
205    /// Compute score for a single component
206    fn compute_component_score(&self, result: &SearchResult, component: &ScoreComponent) -> f32 {
207        match component {
208            ScoreComponent::VectorSimilarity => result.score,
209            ScoreComponent::MetadataScore { field } => {
210                // Get metadata score from cached metadata
211                if let Some(metadata) = self.metadata_cache.get(&result.cid) {
212                    if let Some(value) = metadata.get(field) {
213                        return self.metadata_value_to_score(value);
214                    }
215                }
216                0.0
217            }
218            ScoreComponent::Recency { decay_factor } => {
219                // Compute recency score from timestamp
220                if let Some(metadata) = self.metadata_cache.get(&result.cid) {
221                    if let Some(MetadataValue::Integer(timestamp)) = metadata.get("timestamp") {
222                        // Simple exponential decay
223                        let age = Self::current_timestamp() - timestamp;
224                        return (-(age as f32) * decay_factor).exp();
225                    }
226                }
227                0.0
228            }
229            ScoreComponent::Popularity => {
230                // Get popularity from metadata
231                if let Some(metadata) = self.metadata_cache.get(&result.cid) {
232                    if let Some(value) = metadata.get("popularity") {
233                        return self.metadata_value_to_score(value);
234                    }
235                }
236                0.0
237            }
238            ScoreComponent::Diversity { threshold: _ } => {
239                // Diversity scoring requires comparing with other results
240                // Placeholder for now
241                0.0
242            }
243            ScoreComponent::Custom { name: _ } => {
244                // Custom scoring would call external function
245                0.0
246            }
247        }
248    }
249
250    /// Convert metadata value to a score
251    fn metadata_value_to_score(&self, value: &MetadataValue) -> f32 {
252        match value {
253            MetadataValue::Integer(i) => *i as f32,
254            MetadataValue::Float(f) => *f as f32,
255            MetadataValue::Boolean(b) => {
256                if *b {
257                    1.0
258                } else {
259                    0.0
260                }
261            }
262            MetadataValue::Timestamp(t) => *t as f32,
263            MetadataValue::String(_) | MetadataValue::StringArray(_) | MetadataValue::Null => 0.0,
264        }
265    }
266
267    /// Get component name for display
268    fn component_name(&self, component: &ScoreComponent) -> String {
269        match component {
270            ScoreComponent::VectorSimilarity => "vector_similarity".to_string(),
271            ScoreComponent::MetadataScore { field } => format!("metadata_{}", field),
272            ScoreComponent::Recency { .. } => "recency".to_string(),
273            ScoreComponent::Popularity => "popularity".to_string(),
274            ScoreComponent::Diversity { .. } => "diversity".to_string(),
275            ScoreComponent::Custom { name } => format!("custom_{}", name),
276        }
277    }
278
279    /// Normalize scores to [0, 1] range
280    fn normalize_scores(&self, results: &mut [ScoredResult]) {
281        if results.is_empty() {
282            return;
283        }
284
285        // Find min and max scores
286        let min_score = results
287            .iter()
288            .map(|r| r.final_score)
289            .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
290            .unwrap_or(0.0);
291
292        let max_score = results
293            .iter()
294            .map(|r| r.final_score)
295            .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
296            .unwrap_or(1.0);
297
298        let range = max_score - min_score;
299
300        if range > 0.0 {
301            for result in results.iter_mut() {
302                result.final_score = (result.final_score - min_score) / range;
303            }
304        }
305    }
306
307    /// Get current timestamp (seconds since epoch)
308    fn current_timestamp() -> i64 {
309        use std::time::{SystemTime, UNIX_EPOCH};
310        SystemTime::now()
311            .duration_since(UNIX_EPOCH)
312            .unwrap()
313            .as_secs() as i64
314    }
315
316    /// Create a weighted combination strategy
317    pub fn weighted(components: Vec<(ScoreComponent, f32)>) -> ReRankingConfig {
318        ReRankingConfig {
319            strategy: ReRankingStrategy::WeightedCombination(components),
320            normalize_scores: true,
321            top_k: Some(100),
322        }
323    }
324
325    /// Create a reciprocal rank fusion strategy
326    pub fn reciprocal_rank_fusion(k: f32) -> ReRankingConfig {
327        ReRankingConfig {
328            strategy: ReRankingStrategy::ReciprocalRankFusion { k },
329            normalize_scores: false,
330            top_k: Some(100),
331        }
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338
339    #[test]
340    fn test_reranker_creation() {
341        let reranker = ReRanker::with_defaults();
342        assert!(matches!(
343            reranker.config.strategy,
344            ReRankingStrategy::WeightedCombination(_)
345        ));
346    }
347
348    #[test]
349    fn test_weighted_reranking() {
350        let config = ReRanker::weighted(vec![
351            (ScoreComponent::VectorSimilarity, 0.7),
352            (ScoreComponent::Popularity, 0.3),
353        ]);
354
355        let mut reranker = ReRanker::new(config);
356
357        // Create test results
358        let cid1 = "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"
359            .parse::<Cid>()
360            .unwrap();
361        let cid2 = "bafybeihpjhkeuiq3k6nqa3fkgeigeri7iebtrsuyuey5y6vy36n345xmbi"
362            .parse::<Cid>()
363            .unwrap();
364
365        // Add metadata
366        let mut metadata1 = Metadata::new();
367        metadata1.set("popularity", MetadataValue::Float(0.5));
368        reranker.add_metadata(cid1, metadata1);
369
370        let mut metadata2 = Metadata::new();
371        metadata2.set("popularity", MetadataValue::Float(0.9));
372        reranker.add_metadata(cid2, metadata2);
373
374        let results = vec![
375            SearchResult {
376                cid: cid1,
377                score: 0.9,
378            },
379            SearchResult {
380                cid: cid2,
381                score: 0.7,
382            },
383        ];
384
385        let reranked = reranker.rerank(results);
386        assert_eq!(reranked.len(), 2);
387
388        // First result should still be cid1 (0.9*0.7 + 0.5*0.3 = 0.78)
389        // vs cid2 (0.7*0.7 + 0.9*0.3 = 0.76)
390        assert_eq!(reranked[0].result.cid, cid1);
391    }
392
393    #[test]
394    fn test_rrf_reranking() {
395        let config = ReRanker::reciprocal_rank_fusion(60.0);
396        let reranker = ReRanker::new(config);
397
398        let cid1 = "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"
399            .parse::<Cid>()
400            .unwrap();
401        let cid2 = "bafybeihpjhkeuiq3k6nqa3fkgeigeri7iebtrsuyuey5y6vy36n345xmbi"
402            .parse::<Cid>()
403            .unwrap();
404
405        let results = vec![
406            SearchResult {
407                cid: cid1,
408                score: 0.9,
409            },
410            SearchResult {
411                cid: cid2,
412                score: 0.7,
413            },
414        ];
415
416        let reranked = reranker.rerank(results);
417        assert_eq!(reranked.len(), 2);
418
419        // Check RRF scores
420        assert!(reranked[0].final_score > reranked[1].final_score);
421    }
422
423    #[test]
424    fn test_recency_scoring() {
425        let config = ReRanker::weighted(vec![
426            (ScoreComponent::VectorSimilarity, 0.5),
427            (
428                ScoreComponent::Recency {
429                    decay_factor: 0.0001,
430                },
431                0.5,
432            ),
433        ]);
434
435        let mut reranker = ReRanker::new(config);
436
437        let cid1 = "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"
438            .parse::<Cid>()
439            .unwrap();
440
441        let current_time = ReRanker::current_timestamp();
442
443        let mut metadata = Metadata::new();
444        metadata.set("timestamp", MetadataValue::Integer(current_time - 100));
445        reranker.add_metadata(cid1, metadata);
446
447        let results = vec![SearchResult {
448            cid: cid1,
449            score: 0.8,
450        }];
451
452        let reranked = reranker.rerank(results);
453        assert_eq!(reranked.len(), 1);
454        assert!(reranked[0].score_components.contains_key("recency"));
455    }
456
457    #[test]
458    fn test_normalize_scores() {
459        let config = ReRankingConfig {
460            strategy: ReRankingStrategy::WeightedCombination(vec![(
461                ScoreComponent::VectorSimilarity,
462                1.0,
463            )]),
464            normalize_scores: true,
465            top_k: None,
466        };
467
468        let reranker = ReRanker::new(config);
469
470        let cid1 = "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"
471            .parse::<Cid>()
472            .unwrap();
473        let cid2 = "bafybeihpjhkeuiq3k6nqa3fkgeigeri7iebtrsuyuey5y6vy36n345xmbi"
474            .parse::<Cid>()
475            .unwrap();
476
477        let results = vec![
478            SearchResult {
479                cid: cid1,
480                score: 0.9,
481            },
482            SearchResult {
483                cid: cid2,
484                score: 0.5,
485            },
486        ];
487
488        let reranked = reranker.rerank(results);
489
490        // Normalized scores should be in [0, 1]
491        assert!(reranked[0].final_score >= 0.0 && reranked[0].final_score <= 1.0);
492        assert!(reranked[1].final_score >= 0.0 && reranked[1].final_score <= 1.0);
493    }
494
495    #[test]
496    fn test_top_k_reranking() {
497        let config = ReRankingConfig {
498            strategy: ReRankingStrategy::WeightedCombination(vec![(
499                ScoreComponent::VectorSimilarity,
500                1.0,
501            )]),
502            normalize_scores: false,
503            top_k: Some(2), // Only rerank top 2
504        };
505
506        let reranker = ReRanker::new(config);
507
508        let cid1 = "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"
509            .parse::<Cid>()
510            .unwrap();
511        let cid2 = "bafybeihpjhkeuiq3k6nqa3fkgeigeri7iebtrsuyuey5y6vy36n345xmbi"
512            .parse::<Cid>()
513            .unwrap();
514        let cid3 = "bafybeif2pall7dybz7vecqka3zo24irdwabwdi4wc55jznaq75q7eaavvu"
515            .parse::<Cid>()
516            .unwrap();
517
518        let results = vec![
519            SearchResult {
520                cid: cid1,
521                score: 0.9,
522            },
523            SearchResult {
524                cid: cid2,
525                score: 0.7,
526            },
527            SearchResult {
528                cid: cid3,
529                score: 0.5,
530            },
531        ];
532
533        let reranked = reranker.rerank(results);
534
535        // Should only return top 2
536        assert_eq!(reranked.len(), 2);
537    }
538}