oxirs_vec/reranking/
diversity.rs

1//! Diversity-aware re-ranking strategies
2//!
3//! Diversity re-ranking ensures that search results cover different aspects
4//! of the query rather than redundant similar documents.
5//!
6//! ## Strategies
7//! - **MMR (Maximal Marginal Relevance)**: Balances relevance and diversity
8//! - **Cluster-based**: Groups similar documents and selects from each cluster
9//! - **Topic-based**: Ensures topical diversity across results
10
11use crate::reranking::types::{RerankingResult, ScoredCandidate};
12use serde::{Deserialize, Serialize};
13use std::collections::HashSet;
14
15/// Diversity strategy for re-ranking
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
17pub enum DiversityStrategy {
18    /// Maximal Marginal Relevance (MMR)
19    MaximalMarginalRelevance,
20    /// Cluster-based diversity
21    ClusterBased,
22    /// Topic-based diversity
23    TopicBased,
24    /// No diversity (baseline)
25    None,
26}
27
28/// Diversity re-ranker
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct DiversityReranker {
31    /// Diversity weight (0.0 = relevance only, 1.0 = diversity only)
32    weight: f32,
33    /// Strategy to use
34    strategy: DiversityStrategy,
35    /// Similarity threshold for considering documents as similar
36    similarity_threshold: f32,
37}
38
39impl DiversityReranker {
40    /// Create new diversity re-ranker with default strategy
41    pub fn new(weight: f32) -> Self {
42        Self {
43            weight: weight.clamp(0.0, 1.0),
44            strategy: DiversityStrategy::MaximalMarginalRelevance,
45            similarity_threshold: 0.85,
46        }
47    }
48
49    /// Create with specific strategy
50    pub fn with_strategy(weight: f32, strategy: DiversityStrategy) -> Self {
51        Self {
52            weight: weight.clamp(0.0, 1.0),
53            strategy,
54            similarity_threshold: 0.85,
55        }
56    }
57
58    /// Set similarity threshold
59    pub fn set_similarity_threshold(mut self, threshold: f32) -> Self {
60        self.similarity_threshold = threshold.clamp(0.0, 1.0);
61        self
62    }
63
64    /// Apply diversity re-ranking to candidates
65    pub fn apply_diversity(
66        &self,
67        candidates: &[ScoredCandidate],
68    ) -> RerankingResult<Vec<ScoredCandidate>> {
69        if candidates.is_empty() || self.weight == 0.0 {
70            return Ok(candidates.to_vec());
71        }
72
73        match self.strategy {
74            DiversityStrategy::MaximalMarginalRelevance => self.mmr_rerank(candidates),
75            DiversityStrategy::ClusterBased => self.cluster_based_rerank(candidates),
76            DiversityStrategy::TopicBased => self.topic_based_rerank(candidates),
77            DiversityStrategy::None => Ok(candidates.to_vec()),
78        }
79    }
80
81    /// MMR (Maximal Marginal Relevance) re-ranking
82    ///
83    /// Selects documents that maximize:
84    /// MMR = λ * Relevance(d) - (1-λ) * max Similarity(d, already_selected)
85    fn mmr_rerank(&self, candidates: &[ScoredCandidate]) -> RerankingResult<Vec<ScoredCandidate>> {
86        let lambda = 1.0 - self.weight; // Convert weight to λ parameter
87        let mut selected = Vec::new();
88        let mut remaining: Vec<_> = candidates.to_vec();
89
90        // Select first candidate (highest relevance)
91        if let Some(first) = remaining.first().cloned() {
92            selected.push(first);
93            remaining.remove(0);
94        }
95
96        // Iteratively select documents maximizing MMR
97        while !remaining.is_empty() && selected.len() < candidates.len() {
98            let mut best_idx = 0;
99            let mut best_mmr = f32::NEG_INFINITY;
100
101            for (idx, candidate) in remaining.iter().enumerate() {
102                // Relevance component
103                let relevance = candidate.effective_score();
104
105                // Diversity component: max similarity to already selected
106                let max_similarity = selected
107                    .iter()
108                    .map(|sel| self.compute_similarity(candidate, sel))
109                    .fold(0.0f32, f32::max);
110
111                // MMR score
112                let mmr = lambda * relevance - (1.0 - lambda) * max_similarity;
113
114                if mmr > best_mmr {
115                    best_mmr = mmr;
116                    best_idx = idx;
117                }
118            }
119
120            // Select best MMR candidate
121            if best_idx < remaining.len() {
122                selected.push(remaining.remove(best_idx));
123            } else {
124                break;
125            }
126        }
127
128        Ok(selected)
129    }
130
131    /// Cluster-based diversity re-ranking
132    ///
133    /// Groups similar documents into clusters and selects
134    /// representatives from each cluster to ensure diversity.
135    fn cluster_based_rerank(
136        &self,
137        candidates: &[ScoredCandidate],
138    ) -> RerankingResult<Vec<ScoredCandidate>> {
139        if candidates.len() <= 2 {
140            return Ok(candidates.to_vec());
141        }
142
143        // Simple greedy clustering
144        let mut clusters: Vec<Vec<ScoredCandidate>> = Vec::new();
145        let mut assigned = HashSet::new();
146
147        for (idx, candidate) in candidates.iter().enumerate() {
148            if assigned.contains(&idx) {
149                continue;
150            }
151
152            // Start new cluster
153            let mut cluster = vec![candidate.clone()];
154            assigned.insert(idx);
155
156            // Find similar candidates
157            for (other_idx, other) in candidates.iter().enumerate() {
158                if assigned.contains(&other_idx) {
159                    continue;
160                }
161
162                let similarity = self.compute_similarity(candidate, other);
163                if similarity > self.similarity_threshold {
164                    cluster.push(other.clone());
165                    assigned.insert(other_idx);
166                }
167            }
168
169            clusters.push(cluster);
170        }
171
172        // Select best candidate from each cluster
173        let mut result = Vec::new();
174        let num_per_cluster = (candidates.len() / clusters.len().max(1)).max(1);
175
176        for cluster in clusters {
177            // Sort cluster by score
178            let mut sorted_cluster = cluster;
179            sorted_cluster.sort_by(|a, b| {
180                b.effective_score()
181                    .partial_cmp(&a.effective_score())
182                    .unwrap_or(std::cmp::Ordering::Equal)
183            });
184
185            // Take top candidates from this cluster
186            result.extend(sorted_cluster.into_iter().take(num_per_cluster));
187        }
188
189        // Sort final result by score
190        result.sort_by(|a, b| {
191            b.effective_score()
192                .partial_cmp(&a.effective_score())
193                .unwrap_or(std::cmp::Ordering::Equal)
194        });
195
196        Ok(result)
197    }
198
199    /// Topic-based diversity re-ranking
200    ///
201    /// Ensures results cover different topics by analyzing
202    /// keyword distributions and selecting diverse documents.
203    fn topic_based_rerank(
204        &self,
205        candidates: &[ScoredCandidate],
206    ) -> RerankingResult<Vec<ScoredCandidate>> {
207        // Extract topics (keywords) from each candidate
208        let mut doc_topics: Vec<HashSet<String>> = Vec::new();
209
210        for candidate in candidates {
211            let content = candidate.content.as_deref().unwrap_or("");
212            let topics = self.extract_topics(content);
213            doc_topics.push(topics);
214        }
215
216        // Select documents to maximize topic coverage
217        let mut selected = Vec::new();
218        let mut covered_topics = HashSet::new();
219        let mut remaining_indices: Vec<usize> = (0..candidates.len()).collect();
220
221        while !remaining_indices.is_empty() && selected.len() < candidates.len() {
222            let mut best_idx = 0;
223            let mut best_score = f32::NEG_INFINITY;
224
225            for (list_idx, &doc_idx) in remaining_indices.iter().enumerate() {
226                let candidate = &candidates[doc_idx];
227                let topics = &doc_topics[doc_idx];
228
229                // Relevance component
230                let relevance = candidate.effective_score();
231
232                // Diversity component: number of new topics
233                let new_topics = topics.difference(&covered_topics).count() as f32;
234                let total_topics = topics.len().max(1) as f32;
235                let topic_novelty = new_topics / total_topics;
236
237                // Combined score
238                let score = (1.0 - self.weight) * relevance + self.weight * topic_novelty;
239
240                if score > best_score {
241                    best_score = score;
242                    best_idx = list_idx;
243                }
244            }
245
246            // Select best candidate
247            if best_idx < remaining_indices.len() {
248                let doc_idx = remaining_indices.remove(best_idx);
249                selected.push(candidates[doc_idx].clone());
250
251                // Update covered topics
252                for topic in &doc_topics[doc_idx] {
253                    covered_topics.insert(topic.clone());
254                }
255            } else {
256                break;
257            }
258        }
259
260        Ok(selected)
261    }
262
263    /// Compute similarity between two candidates
264    fn compute_similarity(&self, a: &ScoredCandidate, b: &ScoredCandidate) -> f32 {
265        // Extract keywords from both documents
266        let a_content = a.content.as_deref().unwrap_or("");
267        let b_content = b.content.as_deref().unwrap_or("");
268
269        let a_words: HashSet<String> = a_content
270            .to_lowercase()
271            .split_whitespace()
272            .filter(|w| w.len() > 3) // Filter short words
273            .map(|w| w.to_string())
274            .collect();
275
276        let b_words: HashSet<String> = b_content
277            .to_lowercase()
278            .split_whitespace()
279            .filter(|w| w.len() > 3)
280            .map(|w| w.to_string())
281            .collect();
282
283        if a_words.is_empty() || b_words.is_empty() {
284            return 0.0;
285        }
286
287        // Jaccard similarity
288        let intersection = a_words.intersection(&b_words).count() as f32;
289        let union = a_words.union(&b_words).count() as f32;
290
291        if union == 0.0 {
292            0.0
293        } else {
294            intersection / union
295        }
296    }
297
298    /// Extract topics (keywords) from document
299    fn extract_topics(&self, document: &str) -> HashSet<String> {
300        // Simple keyword extraction
301        // In production: use TF-IDF, NER, or topic modeling
302        document
303            .to_lowercase()
304            .split_whitespace()
305            .filter(|w| w.len() > 4) // Only meaningful words
306            .map(|w| w.to_string())
307            .collect()
308    }
309}
310
311impl Default for DiversityReranker {
312    fn default() -> Self {
313        Self::new(0.3)
314    }
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320
321    fn create_test_candidates() -> Vec<ScoredCandidate> {
322        vec![
323            ScoredCandidate::new("doc1", 0.9, 0)
324                .with_content("machine learning deep neural networks")
325                .with_reranking_score(0.85),
326            ScoredCandidate::new("doc2", 0.85, 1)
327                .with_content("machine learning algorithms classification")
328                .with_reranking_score(0.8),
329            ScoredCandidate::new("doc3", 0.7, 2)
330                .with_content("database management systems SQL queries")
331                .with_reranking_score(0.75),
332            ScoredCandidate::new("doc4", 0.65, 3)
333                .with_content("web development JavaScript frameworks")
334                .with_reranking_score(0.7),
335        ]
336    }
337
338    #[test]
339    fn test_mmr_rerank() {
340        let reranker = DiversityReranker::new(0.5);
341        let candidates = create_test_candidates();
342
343        let result = reranker.mmr_rerank(&candidates).unwrap();
344
345        // Should have all candidates
346        assert_eq!(result.len(), candidates.len());
347
348        // First should be highest scoring
349        assert_eq!(result[0].id, "doc1");
350
351        // Should not have all machine learning docs together
352        // (diversity should spread them out)
353        let first_three_ids: Vec<_> = result.iter().take(3).map(|c| c.id.as_str()).collect();
354        let all_ml = first_three_ids
355            .iter()
356            .all(|id| id.starts_with("doc1") || id.starts_with("doc2"));
357        assert!(!all_ml, "MMR should diversify results");
358    }
359
360    #[test]
361    fn test_cluster_based_rerank() {
362        let reranker = DiversityReranker::with_strategy(0.5, DiversityStrategy::ClusterBased);
363        let candidates = create_test_candidates();
364
365        let result = reranker.cluster_based_rerank(&candidates).unwrap();
366
367        assert!(!result.is_empty());
368        assert!(result.len() <= candidates.len());
369    }
370
371    #[test]
372    fn test_topic_based_rerank() {
373        let reranker = DiversityReranker::with_strategy(0.6, DiversityStrategy::TopicBased);
374        let candidates = create_test_candidates();
375
376        let result = reranker.topic_based_rerank(&candidates).unwrap();
377
378        assert_eq!(result.len(), candidates.len());
379
380        // Verify diversity: first few results should cover different topics
381        let first_two = &result[0..2.min(result.len())];
382        let similarity = reranker.compute_similarity(&first_two[0], &first_two[1]);
383
384        // Should have lower similarity due to diversity
385        assert!(
386            similarity < 0.8,
387            "Topic-based reranking should increase diversity"
388        );
389    }
390
391    #[test]
392    fn test_no_diversity() {
393        let reranker = DiversityReranker::new(0.0); // No diversity
394        let candidates = create_test_candidates();
395
396        let result = reranker.apply_diversity(&candidates).unwrap();
397
398        // Should return unchanged
399        assert_eq!(result.len(), candidates.len());
400        for (orig, res) in candidates.iter().zip(result.iter()) {
401            assert_eq!(orig.id, res.id);
402        }
403    }
404
405    #[test]
406    fn test_similarity_computation() {
407        let reranker = DiversityReranker::new(0.3);
408
409        let a = ScoredCandidate::new("a", 0.8, 0).with_content("machine learning neural networks");
410
411        let b = ScoredCandidate::new("b", 0.7, 1).with_content("machine learning algorithms");
412
413        let c = ScoredCandidate::new("c", 0.6, 2).with_content("database systems SQL");
414
415        let sim_ab = reranker.compute_similarity(&a, &b);
416        let sim_ac = reranker.compute_similarity(&a, &c);
417
418        // a and b should be more similar than a and c
419        assert!(sim_ab > sim_ac);
420    }
421
422    #[test]
423    fn test_topic_extraction() {
424        let reranker = DiversityReranker::new(0.3);
425        let doc = "machine learning and deep neural networks for classification";
426
427        let topics = reranker.extract_topics(doc);
428
429        assert!(topics.contains("machine"));
430        assert!(topics.contains("learning"));
431        assert!(topics.contains("neural"));
432        assert!(topics.contains("networks"));
433        assert!(topics.contains("classification"));
434
435        // Short words should be filtered
436        assert!(!topics.contains("and"));
437        assert!(!topics.contains("for"));
438    }
439
440    #[test]
441    fn test_empty_candidates() {
442        let reranker = DiversityReranker::new(0.5);
443        let candidates = vec![];
444
445        let result = reranker.apply_diversity(&candidates).unwrap();
446        assert!(result.is_empty());
447    }
448
449    #[test]
450    fn test_single_candidate() {
451        let reranker = DiversityReranker::new(0.5);
452        let candidates = vec![ScoredCandidate::new("doc1", 0.8, 0)
453            .with_content("test")
454            .with_reranking_score(0.85)];
455
456        let result = reranker.apply_diversity(&candidates).unwrap();
457        assert_eq!(result.len(), 1);
458        assert_eq!(result[0].id, "doc1");
459    }
460}