use crate::reranking::types::{RerankingResult, ScoredCandidate};
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum DiversityStrategy {
MaximalMarginalRelevance,
ClusterBased,
TopicBased,
None,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DiversityReranker {
weight: f32,
strategy: DiversityStrategy,
similarity_threshold: f32,
}
impl DiversityReranker {
pub fn new(weight: f32) -> Self {
Self {
weight: weight.clamp(0.0, 1.0),
strategy: DiversityStrategy::MaximalMarginalRelevance,
similarity_threshold: 0.85,
}
}
pub fn with_strategy(weight: f32, strategy: DiversityStrategy) -> Self {
Self {
weight: weight.clamp(0.0, 1.0),
strategy,
similarity_threshold: 0.85,
}
}
pub fn set_similarity_threshold(mut self, threshold: f32) -> Self {
self.similarity_threshold = threshold.clamp(0.0, 1.0);
self
}
pub fn apply_diversity(
&self,
candidates: &[ScoredCandidate],
) -> RerankingResult<Vec<ScoredCandidate>> {
if candidates.is_empty() || self.weight == 0.0 {
return Ok(candidates.to_vec());
}
match self.strategy {
DiversityStrategy::MaximalMarginalRelevance => self.mmr_rerank(candidates),
DiversityStrategy::ClusterBased => self.cluster_based_rerank(candidates),
DiversityStrategy::TopicBased => self.topic_based_rerank(candidates),
DiversityStrategy::None => Ok(candidates.to_vec()),
}
}
fn mmr_rerank(&self, candidates: &[ScoredCandidate]) -> RerankingResult<Vec<ScoredCandidate>> {
let lambda = 1.0 - self.weight; let mut selected = Vec::new();
let mut remaining: Vec<_> = candidates.to_vec();
if let Some(first) = remaining.first().cloned() {
selected.push(first);
remaining.remove(0);
}
while !remaining.is_empty() && selected.len() < candidates.len() {
let mut best_idx = 0;
let mut best_mmr = f32::NEG_INFINITY;
for (idx, candidate) in remaining.iter().enumerate() {
let relevance = candidate.effective_score();
let max_similarity = selected
.iter()
.map(|sel| self.compute_similarity(candidate, sel))
.fold(0.0f32, f32::max);
let mmr = lambda * relevance - (1.0 - lambda) * max_similarity;
if mmr > best_mmr {
best_mmr = mmr;
best_idx = idx;
}
}
if best_idx < remaining.len() {
selected.push(remaining.remove(best_idx));
} else {
break;
}
}
Ok(selected)
}
fn cluster_based_rerank(
&self,
candidates: &[ScoredCandidate],
) -> RerankingResult<Vec<ScoredCandidate>> {
if candidates.len() <= 2 {
return Ok(candidates.to_vec());
}
let mut clusters: Vec<Vec<ScoredCandidate>> = Vec::new();
let mut assigned = HashSet::new();
for (idx, candidate) in candidates.iter().enumerate() {
if assigned.contains(&idx) {
continue;
}
let mut cluster = vec![candidate.clone()];
assigned.insert(idx);
for (other_idx, other) in candidates.iter().enumerate() {
if assigned.contains(&other_idx) {
continue;
}
let similarity = self.compute_similarity(candidate, other);
if similarity > self.similarity_threshold {
cluster.push(other.clone());
assigned.insert(other_idx);
}
}
clusters.push(cluster);
}
let mut result = Vec::new();
let num_per_cluster = (candidates.len() / clusters.len().max(1)).max(1);
for cluster in clusters {
let mut sorted_cluster = cluster;
sorted_cluster.sort_by(|a, b| {
b.effective_score()
.partial_cmp(&a.effective_score())
.unwrap_or(std::cmp::Ordering::Equal)
});
result.extend(sorted_cluster.into_iter().take(num_per_cluster));
}
result.sort_by(|a, b| {
b.effective_score()
.partial_cmp(&a.effective_score())
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(result)
}
fn topic_based_rerank(
&self,
candidates: &[ScoredCandidate],
) -> RerankingResult<Vec<ScoredCandidate>> {
let mut doc_topics: Vec<HashSet<String>> = Vec::new();
for candidate in candidates {
let content = candidate.content.as_deref().unwrap_or("");
let topics = self.extract_topics(content);
doc_topics.push(topics);
}
let mut selected = Vec::new();
let mut covered_topics = HashSet::new();
let mut remaining_indices: Vec<usize> = (0..candidates.len()).collect();
while !remaining_indices.is_empty() && selected.len() < candidates.len() {
let mut best_idx = 0;
let mut best_score = f32::NEG_INFINITY;
for (list_idx, &doc_idx) in remaining_indices.iter().enumerate() {
let candidate = &candidates[doc_idx];
let topics = &doc_topics[doc_idx];
let relevance = candidate.effective_score();
let new_topics = topics.difference(&covered_topics).count() as f32;
let total_topics = topics.len().max(1) as f32;
let topic_novelty = new_topics / total_topics;
let score = (1.0 - self.weight) * relevance + self.weight * topic_novelty;
if score > best_score {
best_score = score;
best_idx = list_idx;
}
}
if best_idx < remaining_indices.len() {
let doc_idx = remaining_indices.remove(best_idx);
selected.push(candidates[doc_idx].clone());
for topic in &doc_topics[doc_idx] {
covered_topics.insert(topic.clone());
}
} else {
break;
}
}
Ok(selected)
}
fn compute_similarity(&self, a: &ScoredCandidate, b: &ScoredCandidate) -> f32 {
let a_content = a.content.as_deref().unwrap_or("");
let b_content = b.content.as_deref().unwrap_or("");
let a_words: HashSet<String> = a_content
.to_lowercase()
.split_whitespace()
.filter(|w| w.len() > 3) .map(|w| w.to_string())
.collect();
let b_words: HashSet<String> = b_content
.to_lowercase()
.split_whitespace()
.filter(|w| w.len() > 3)
.map(|w| w.to_string())
.collect();
if a_words.is_empty() || b_words.is_empty() {
return 0.0;
}
let intersection = a_words.intersection(&b_words).count() as f32;
let union = a_words.union(&b_words).count() as f32;
if union == 0.0 {
0.0
} else {
intersection / union
}
}
fn extract_topics(&self, document: &str) -> HashSet<String> {
document
.to_lowercase()
.split_whitespace()
.filter(|w| w.len() > 4) .map(|w| w.to_string())
.collect()
}
}
impl Default for DiversityReranker {
fn default() -> Self {
Self::new(0.3)
}
}
#[cfg(test)]
mod tests {
type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
use super::*;
fn create_test_candidates() -> Vec<ScoredCandidate> {
vec![
ScoredCandidate::new("doc1", 0.9, 0)
.with_content("machine learning deep neural networks")
.with_reranking_score(0.85),
ScoredCandidate::new("doc2", 0.85, 1)
.with_content("machine learning algorithms classification")
.with_reranking_score(0.8),
ScoredCandidate::new("doc3", 0.7, 2)
.with_content("database management systems SQL queries")
.with_reranking_score(0.75),
ScoredCandidate::new("doc4", 0.65, 3)
.with_content("web development JavaScript frameworks")
.with_reranking_score(0.7),
]
}
#[test]
fn test_mmr_rerank() -> Result<()> {
let reranker = DiversityReranker::new(0.5);
let candidates = create_test_candidates();
let result = reranker.mmr_rerank(&candidates)?;
assert_eq!(result.len(), candidates.len());
assert_eq!(result[0].id, "doc1");
let first_three_ids: Vec<_> = result.iter().take(3).map(|c| c.id.as_str()).collect();
let all_ml = first_three_ids
.iter()
.all(|id| id.starts_with("doc1") || id.starts_with("doc2"));
assert!(!all_ml, "MMR should diversify results");
Ok(())
}
#[test]
fn test_cluster_based_rerank() -> Result<()> {
let reranker = DiversityReranker::with_strategy(0.5, DiversityStrategy::ClusterBased);
let candidates = create_test_candidates();
let result = reranker.cluster_based_rerank(&candidates)?;
assert!(!result.is_empty());
assert!(result.len() <= candidates.len());
Ok(())
}
#[test]
fn test_topic_based_rerank() -> Result<()> {
let reranker = DiversityReranker::with_strategy(0.6, DiversityStrategy::TopicBased);
let candidates = create_test_candidates();
let result = reranker.topic_based_rerank(&candidates)?;
assert_eq!(result.len(), candidates.len());
let first_two = &result[0..2.min(result.len())];
let similarity = reranker.compute_similarity(&first_two[0], &first_two[1]);
assert!(
similarity < 0.8,
"Topic-based reranking should increase diversity"
);
Ok(())
}
#[test]
fn test_no_diversity() -> Result<()> {
let reranker = DiversityReranker::new(0.0); let candidates = create_test_candidates();
let result = reranker.apply_diversity(&candidates)?;
assert_eq!(result.len(), candidates.len());
for (orig, res) in candidates.iter().zip(result.iter()) {
assert_eq!(orig.id, res.id);
}
Ok(())
}
#[test]
fn test_similarity_computation() {
let reranker = DiversityReranker::new(0.3);
let a = ScoredCandidate::new("a", 0.8, 0).with_content("machine learning neural networks");
let b = ScoredCandidate::new("b", 0.7, 1).with_content("machine learning algorithms");
let c = ScoredCandidate::new("c", 0.6, 2).with_content("database systems SQL");
let sim_ab = reranker.compute_similarity(&a, &b);
let sim_ac = reranker.compute_similarity(&a, &c);
assert!(sim_ab > sim_ac);
}
#[test]
fn test_topic_extraction() {
let reranker = DiversityReranker::new(0.3);
let doc = "machine learning and deep neural networks for classification";
let topics = reranker.extract_topics(doc);
assert!(topics.contains("machine"));
assert!(topics.contains("learning"));
assert!(topics.contains("neural"));
assert!(topics.contains("networks"));
assert!(topics.contains("classification"));
assert!(!topics.contains("and"));
assert!(!topics.contains("for"));
}
#[test]
fn test_empty_candidates() -> Result<()> {
let reranker = DiversityReranker::new(0.5);
let candidates = vec![];
let result = reranker.apply_diversity(&candidates)?;
assert!(result.is_empty());
Ok(())
}
#[test]
fn test_single_candidate() -> Result<()> {
let reranker = DiversityReranker::new(0.5);
let candidates = vec![ScoredCandidate::new("doc1", 0.8, 0)
.with_content("test")
.with_reranking_score(0.85)];
let result = reranker.apply_diversity(&candidates)?;
assert_eq!(result.len(), 1);
assert_eq!(result[0].id, "doc1");
Ok(())
}
}