use std::collections::{HashMap, HashSet};
#[derive(Debug, Default)]
#[allow(dead_code)]
pub struct SimilarityMatrix {
data: HashMap<(u64, u64), f32>,
}
impl SimilarityMatrix {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn set(&mut self, i: u64, j: u64, sim: f32) {
let key = if i <= j { (i, j) } else { (j, i) };
self.data.insert(key, sim.clamp(0.0, 1.0));
}
#[must_use]
pub fn get(&self, i: u64, j: u64) -> f32 {
let key = if i <= j { (i, j) } else { (j, i) };
*self.data.get(&key).unwrap_or(&0.0)
}
#[must_use]
pub fn len(&self) -> usize {
self.data.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct MaximalMarginalRelevance {
pub lambda: f32,
}
impl MaximalMarginalRelevance {
#[must_use]
pub fn new(lambda: f32) -> Self {
Self {
lambda: lambda.clamp(0.0, 1.0),
}
}
#[must_use]
pub fn select(
&self,
candidates: &[(u64, f32)],
similarity: &SimilarityMatrix,
k: usize,
) -> Vec<u64> {
if candidates.is_empty() || k == 0 {
return Vec::new();
}
let mut selected: Vec<u64> = Vec::with_capacity(k);
let mut remaining: Vec<(u64, f32)> = candidates.to_vec();
while selected.len() < k && !remaining.is_empty() {
let best_idx = remaining
.iter()
.enumerate()
.map(|(i, &(id, rel))| {
let max_sim = selected
.iter()
.map(|&sel| similarity.get(id, sel))
.fold(0.0_f32, f32::max);
let mmr = self.lambda * rel - (1.0 - self.lambda) * max_sim;
(i, mmr)
})
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map_or(0, |(i, _)| i);
let (chosen_id, _) = remaining.remove(best_idx);
selected.push(chosen_id);
}
selected
}
}
impl Default for MaximalMarginalRelevance {
fn default() -> Self {
Self::new(0.7)
}
}
#[derive(Debug, Default)]
#[allow(dead_code)]
pub struct DuplicateDetector {
fingerprints: HashMap<u64, HashSet<u32>>,
}
impl DuplicateDetector {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn add(&mut self, id: u64, fingerprint: Vec<u32>) {
self.fingerprints
.insert(id, fingerprint.into_iter().collect());
}
#[must_use]
pub fn find_duplicates(&self, threshold: f32) -> Vec<(u64, u64)> {
let items: Vec<u64> = self.fingerprints.keys().copied().collect();
let mut result = Vec::new();
for (i, &a) in items.iter().enumerate() {
for &b in &items[i + 1..] {
let sim = self.jaccard(a, b);
if sim >= threshold {
let pair = if a < b { (a, b) } else { (b, a) };
result.push(pair);
}
}
}
result.sort_unstable();
result
}
#[must_use]
fn jaccard(&self, a: u64, b: u64) -> f32 {
let fa = match self.fingerprints.get(&a) {
Some(f) => f,
None => return 0.0,
};
let fb = match self.fingerprints.get(&b) {
Some(f) => f,
None => return 0.0,
};
let intersection = fa.intersection(fb).count();
let union = fa.union(fb).count();
if union == 0 {
return 0.0;
}
intersection as f32 / union as f32
}
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct CategoryDiversifier {
pub max_per_category: usize,
}
impl CategoryDiversifier {
#[must_use]
pub fn new(max_per_category: usize) -> Self {
Self { max_per_category }
}
#[must_use]
pub fn diversify(&self, items: &[(u64, &str)], max_per_category: usize) -> Vec<u64> {
let cap = max_per_category;
let mut counts: HashMap<&str, usize> = HashMap::new();
let mut result = Vec::new();
for &(id, cat) in items {
let count = counts.entry(cat).or_insert(0);
if *count < cap {
*count += 1;
result.push(id);
}
}
result
}
}
impl Default for CategoryDiversifier {
fn default() -> Self {
Self::new(3)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_similarity_matrix_set_get() {
let mut m = SimilarityMatrix::new();
m.set(1, 2, 0.75);
assert!((m.get(1, 2) - 0.75).abs() < 1e-6);
assert!((m.get(2, 1) - 0.75).abs() < 1e-6); }
#[test]
fn test_similarity_matrix_missing_is_zero() {
let m = SimilarityMatrix::new();
assert_eq!(m.get(10, 20), 0.0);
}
#[test]
fn test_similarity_matrix_clamps() {
let mut m = SimilarityMatrix::new();
m.set(1, 2, 1.5); assert!((m.get(1, 2) - 1.0).abs() < 1e-6);
}
#[test]
fn test_similarity_matrix_len() {
let mut m = SimilarityMatrix::new();
m.set(1, 2, 0.5);
m.set(3, 4, 0.8);
assert_eq!(m.len(), 2);
}
#[test]
fn test_mmr_empty_candidates() {
let mmr = MaximalMarginalRelevance::new(0.7);
let sim = SimilarityMatrix::new();
let result = mmr.select(&[], &sim, 5);
assert!(result.is_empty());
}
#[test]
fn test_mmr_selects_top_k() {
let mmr = MaximalMarginalRelevance::new(0.7);
let sim = SimilarityMatrix::new();
let candidates = vec![(1u64, 0.9), (2u64, 0.8), (3u64, 0.7)];
let result = mmr.select(&candidates, &sim, 2);
assert_eq!(result.len(), 2);
}
#[test]
fn test_mmr_diversity() {
let mmr = MaximalMarginalRelevance::new(0.5);
let mut sim = SimilarityMatrix::new();
sim.set(1, 2, 0.95);
sim.set(1, 3, 0.1);
sim.set(2, 3, 0.1);
let candidates = vec![(1u64, 1.0), (2u64, 0.9), (3u64, 0.85)];
let result = mmr.select(&candidates, &sim, 2);
assert!(result.contains(&1));
assert!(result.contains(&3));
}
#[test]
fn test_duplicate_detector_no_duplicates() {
let mut dd = DuplicateDetector::new();
dd.add(1, vec![1, 2, 3]);
dd.add(2, vec![4, 5, 6]);
assert!(dd.find_duplicates(0.5).is_empty());
}
#[test]
fn test_duplicate_detector_exact_duplicate() {
let mut dd = DuplicateDetector::new();
dd.add(1, vec![1, 2, 3, 4]);
dd.add(2, vec![1, 2, 3, 4]);
let dups = dd.find_duplicates(0.9);
assert!(!dups.is_empty());
assert!(dups.contains(&(1, 2)));
}
#[test]
fn test_duplicate_detector_partial_overlap() {
let mut dd = DuplicateDetector::new();
dd.add(10, vec![1, 2, 3, 4]);
dd.add(20, vec![3, 4, 5, 6]);
assert!(dd.find_duplicates(0.5).is_empty());
assert_eq!(dd.find_duplicates(0.3).len(), 1);
}
#[test]
fn test_duplicate_detector_threshold() {
let mut dd = DuplicateDetector::new();
dd.add(1, vec![1, 2, 3]);
dd.add(2, vec![1, 2, 3]);
assert_eq!(dd.find_duplicates(1.0).len(), 1);
assert_eq!(dd.find_duplicates(1.1).len(), 0);
}
#[test]
fn test_category_diversifier_basic() {
let div = CategoryDiversifier::new(2);
let items: Vec<(u64, &str)> = vec![
(1, "action"),
(2, "action"),
(3, "action"), (4, "drama"),
];
let result = div.diversify(&items, 2);
assert_eq!(result, vec![1, 2, 4]);
}
#[test]
fn test_category_diversifier_all_different() {
let div = CategoryDiversifier::default();
let items: Vec<(u64, &str)> = vec![(1, "a"), (2, "b"), (3, "c")];
let result = div.diversify(&items, 1);
assert_eq!(result, vec![1, 2, 3]);
}
#[test]
fn test_category_diversifier_empty() {
let div = CategoryDiversifier::default();
let result = div.diversify(&[], 3);
assert!(result.is_empty());
}
}