1use crate::reranking::types::{RerankingResult, ScoredCandidate};
12use serde::{Deserialize, Serialize};
13use std::collections::HashSet;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
17pub enum DiversityStrategy {
18 MaximalMarginalRelevance,
20 ClusterBased,
22 TopicBased,
24 None,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct DiversityReranker {
31 weight: f32,
33 strategy: DiversityStrategy,
35 similarity_threshold: f32,
37}
38
39impl DiversityReranker {
40 pub fn new(weight: f32) -> Self {
42 Self {
43 weight: weight.clamp(0.0, 1.0),
44 strategy: DiversityStrategy::MaximalMarginalRelevance,
45 similarity_threshold: 0.85,
46 }
47 }
48
49 pub fn with_strategy(weight: f32, strategy: DiversityStrategy) -> Self {
51 Self {
52 weight: weight.clamp(0.0, 1.0),
53 strategy,
54 similarity_threshold: 0.85,
55 }
56 }
57
58 pub fn set_similarity_threshold(mut self, threshold: f32) -> Self {
60 self.similarity_threshold = threshold.clamp(0.0, 1.0);
61 self
62 }
63
64 pub fn apply_diversity(
66 &self,
67 candidates: &[ScoredCandidate],
68 ) -> RerankingResult<Vec<ScoredCandidate>> {
69 if candidates.is_empty() || self.weight == 0.0 {
70 return Ok(candidates.to_vec());
71 }
72
73 match self.strategy {
74 DiversityStrategy::MaximalMarginalRelevance => self.mmr_rerank(candidates),
75 DiversityStrategy::ClusterBased => self.cluster_based_rerank(candidates),
76 DiversityStrategy::TopicBased => self.topic_based_rerank(candidates),
77 DiversityStrategy::None => Ok(candidates.to_vec()),
78 }
79 }
80
81 fn mmr_rerank(&self, candidates: &[ScoredCandidate]) -> RerankingResult<Vec<ScoredCandidate>> {
86 let lambda = 1.0 - self.weight; let mut selected = Vec::new();
88 let mut remaining: Vec<_> = candidates.to_vec();
89
90 if let Some(first) = remaining.first().cloned() {
92 selected.push(first);
93 remaining.remove(0);
94 }
95
96 while !remaining.is_empty() && selected.len() < candidates.len() {
98 let mut best_idx = 0;
99 let mut best_mmr = f32::NEG_INFINITY;
100
101 for (idx, candidate) in remaining.iter().enumerate() {
102 let relevance = candidate.effective_score();
104
105 let max_similarity = selected
107 .iter()
108 .map(|sel| self.compute_similarity(candidate, sel))
109 .fold(0.0f32, f32::max);
110
111 let mmr = lambda * relevance - (1.0 - lambda) * max_similarity;
113
114 if mmr > best_mmr {
115 best_mmr = mmr;
116 best_idx = idx;
117 }
118 }
119
120 if best_idx < remaining.len() {
122 selected.push(remaining.remove(best_idx));
123 } else {
124 break;
125 }
126 }
127
128 Ok(selected)
129 }
130
131 fn cluster_based_rerank(
136 &self,
137 candidates: &[ScoredCandidate],
138 ) -> RerankingResult<Vec<ScoredCandidate>> {
139 if candidates.len() <= 2 {
140 return Ok(candidates.to_vec());
141 }
142
143 let mut clusters: Vec<Vec<ScoredCandidate>> = Vec::new();
145 let mut assigned = HashSet::new();
146
147 for (idx, candidate) in candidates.iter().enumerate() {
148 if assigned.contains(&idx) {
149 continue;
150 }
151
152 let mut cluster = vec![candidate.clone()];
154 assigned.insert(idx);
155
156 for (other_idx, other) in candidates.iter().enumerate() {
158 if assigned.contains(&other_idx) {
159 continue;
160 }
161
162 let similarity = self.compute_similarity(candidate, other);
163 if similarity > self.similarity_threshold {
164 cluster.push(other.clone());
165 assigned.insert(other_idx);
166 }
167 }
168
169 clusters.push(cluster);
170 }
171
172 let mut result = Vec::new();
174 let num_per_cluster = (candidates.len() / clusters.len().max(1)).max(1);
175
176 for cluster in clusters {
177 let mut sorted_cluster = cluster;
179 sorted_cluster.sort_by(|a, b| {
180 b.effective_score()
181 .partial_cmp(&a.effective_score())
182 .unwrap_or(std::cmp::Ordering::Equal)
183 });
184
185 result.extend(sorted_cluster.into_iter().take(num_per_cluster));
187 }
188
189 result.sort_by(|a, b| {
191 b.effective_score()
192 .partial_cmp(&a.effective_score())
193 .unwrap_or(std::cmp::Ordering::Equal)
194 });
195
196 Ok(result)
197 }
198
199 fn topic_based_rerank(
204 &self,
205 candidates: &[ScoredCandidate],
206 ) -> RerankingResult<Vec<ScoredCandidate>> {
207 let mut doc_topics: Vec<HashSet<String>> = Vec::new();
209
210 for candidate in candidates {
211 let content = candidate.content.as_deref().unwrap_or("");
212 let topics = self.extract_topics(content);
213 doc_topics.push(topics);
214 }
215
216 let mut selected = Vec::new();
218 let mut covered_topics = HashSet::new();
219 let mut remaining_indices: Vec<usize> = (0..candidates.len()).collect();
220
221 while !remaining_indices.is_empty() && selected.len() < candidates.len() {
222 let mut best_idx = 0;
223 let mut best_score = f32::NEG_INFINITY;
224
225 for (list_idx, &doc_idx) in remaining_indices.iter().enumerate() {
226 let candidate = &candidates[doc_idx];
227 let topics = &doc_topics[doc_idx];
228
229 let relevance = candidate.effective_score();
231
232 let new_topics = topics.difference(&covered_topics).count() as f32;
234 let total_topics = topics.len().max(1) as f32;
235 let topic_novelty = new_topics / total_topics;
236
237 let score = (1.0 - self.weight) * relevance + self.weight * topic_novelty;
239
240 if score > best_score {
241 best_score = score;
242 best_idx = list_idx;
243 }
244 }
245
246 if best_idx < remaining_indices.len() {
248 let doc_idx = remaining_indices.remove(best_idx);
249 selected.push(candidates[doc_idx].clone());
250
251 for topic in &doc_topics[doc_idx] {
253 covered_topics.insert(topic.clone());
254 }
255 } else {
256 break;
257 }
258 }
259
260 Ok(selected)
261 }
262
263 fn compute_similarity(&self, a: &ScoredCandidate, b: &ScoredCandidate) -> f32 {
265 let a_content = a.content.as_deref().unwrap_or("");
267 let b_content = b.content.as_deref().unwrap_or("");
268
269 let a_words: HashSet<String> = a_content
270 .to_lowercase()
271 .split_whitespace()
272 .filter(|w| w.len() > 3) .map(|w| w.to_string())
274 .collect();
275
276 let b_words: HashSet<String> = b_content
277 .to_lowercase()
278 .split_whitespace()
279 .filter(|w| w.len() > 3)
280 .map(|w| w.to_string())
281 .collect();
282
283 if a_words.is_empty() || b_words.is_empty() {
284 return 0.0;
285 }
286
287 let intersection = a_words.intersection(&b_words).count() as f32;
289 let union = a_words.union(&b_words).count() as f32;
290
291 if union == 0.0 {
292 0.0
293 } else {
294 intersection / union
295 }
296 }
297
298 fn extract_topics(&self, document: &str) -> HashSet<String> {
300 document
303 .to_lowercase()
304 .split_whitespace()
305 .filter(|w| w.len() > 4) .map(|w| w.to_string())
307 .collect()
308 }
309}
310
311impl Default for DiversityReranker {
312 fn default() -> Self {
313 Self::new(0.3)
314 }
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320
321 fn create_test_candidates() -> Vec<ScoredCandidate> {
322 vec![
323 ScoredCandidate::new("doc1", 0.9, 0)
324 .with_content("machine learning deep neural networks")
325 .with_reranking_score(0.85),
326 ScoredCandidate::new("doc2", 0.85, 1)
327 .with_content("machine learning algorithms classification")
328 .with_reranking_score(0.8),
329 ScoredCandidate::new("doc3", 0.7, 2)
330 .with_content("database management systems SQL queries")
331 .with_reranking_score(0.75),
332 ScoredCandidate::new("doc4", 0.65, 3)
333 .with_content("web development JavaScript frameworks")
334 .with_reranking_score(0.7),
335 ]
336 }
337
338 #[test]
339 fn test_mmr_rerank() {
340 let reranker = DiversityReranker::new(0.5);
341 let candidates = create_test_candidates();
342
343 let result = reranker.mmr_rerank(&candidates).unwrap();
344
345 assert_eq!(result.len(), candidates.len());
347
348 assert_eq!(result[0].id, "doc1");
350
351 let first_three_ids: Vec<_> = result.iter().take(3).map(|c| c.id.as_str()).collect();
354 let all_ml = first_three_ids
355 .iter()
356 .all(|id| id.starts_with("doc1") || id.starts_with("doc2"));
357 assert!(!all_ml, "MMR should diversify results");
358 }
359
360 #[test]
361 fn test_cluster_based_rerank() {
362 let reranker = DiversityReranker::with_strategy(0.5, DiversityStrategy::ClusterBased);
363 let candidates = create_test_candidates();
364
365 let result = reranker.cluster_based_rerank(&candidates).unwrap();
366
367 assert!(!result.is_empty());
368 assert!(result.len() <= candidates.len());
369 }
370
371 #[test]
372 fn test_topic_based_rerank() {
373 let reranker = DiversityReranker::with_strategy(0.6, DiversityStrategy::TopicBased);
374 let candidates = create_test_candidates();
375
376 let result = reranker.topic_based_rerank(&candidates).unwrap();
377
378 assert_eq!(result.len(), candidates.len());
379
380 let first_two = &result[0..2.min(result.len())];
382 let similarity = reranker.compute_similarity(&first_two[0], &first_two[1]);
383
384 assert!(
386 similarity < 0.8,
387 "Topic-based reranking should increase diversity"
388 );
389 }
390
391 #[test]
392 fn test_no_diversity() {
393 let reranker = DiversityReranker::new(0.0); let candidates = create_test_candidates();
395
396 let result = reranker.apply_diversity(&candidates).unwrap();
397
398 assert_eq!(result.len(), candidates.len());
400 for (orig, res) in candidates.iter().zip(result.iter()) {
401 assert_eq!(orig.id, res.id);
402 }
403 }
404
405 #[test]
406 fn test_similarity_computation() {
407 let reranker = DiversityReranker::new(0.3);
408
409 let a = ScoredCandidate::new("a", 0.8, 0).with_content("machine learning neural networks");
410
411 let b = ScoredCandidate::new("b", 0.7, 1).with_content("machine learning algorithms");
412
413 let c = ScoredCandidate::new("c", 0.6, 2).with_content("database systems SQL");
414
415 let sim_ab = reranker.compute_similarity(&a, &b);
416 let sim_ac = reranker.compute_similarity(&a, &c);
417
418 assert!(sim_ab > sim_ac);
420 }
421
422 #[test]
423 fn test_topic_extraction() {
424 let reranker = DiversityReranker::new(0.3);
425 let doc = "machine learning and deep neural networks for classification";
426
427 let topics = reranker.extract_topics(doc);
428
429 assert!(topics.contains("machine"));
430 assert!(topics.contains("learning"));
431 assert!(topics.contains("neural"));
432 assert!(topics.contains("networks"));
433 assert!(topics.contains("classification"));
434
435 assert!(!topics.contains("and"));
437 assert!(!topics.contains("for"));
438 }
439
440 #[test]
441 fn test_empty_candidates() {
442 let reranker = DiversityReranker::new(0.5);
443 let candidates = vec![];
444
445 let result = reranker.apply_diversity(&candidates).unwrap();
446 assert!(result.is_empty());
447 }
448
449 #[test]
450 fn test_single_candidate() {
451 let reranker = DiversityReranker::new(0.5);
452 let candidates = vec![ScoredCandidate::new("doc1", 0.8, 0)
453 .with_content("test")
454 .with_reranking_score(0.85)];
455
456 let result = reranker.apply_diversity(&candidates).unwrap();
457 assert_eq!(result.len(), 1);
458 assert_eq!(result[0].id, "doc1");
459 }
460}