use crate::types::RowId;
use std::cmp::Ordering;
#[derive(Debug, Clone)]
pub struct Candidate {
pub id: RowId,
pub distance: f32,
}
impl PartialEq for Candidate {
fn eq(&self, other: &Self) -> bool {
self.distance == other.distance
}
}
impl Eq for Candidate {}
impl PartialOrd for Candidate {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
other.distance.partial_cmp(&self.distance)
}
}
impl Ord for Candidate {
fn cmp(&self, other: &Self) -> Ordering {
self.partial_cmp(other).unwrap_or(Ordering::Equal)
}
}
pub fn robust_prune<F>(
candidates: Vec<Candidate>,
max_degree: usize,
alpha: f32,
distance_fn: F,
) -> Vec<RowId>
where
F: Fn(RowId, RowId) -> f32,
{
if candidates.len() <= max_degree {
return candidates.into_iter().map(|c| c.id).collect();
}
let mut sorted_candidates: Vec<_> = candidates.into_iter().collect();
sorted_candidates.sort_by(|a, b| {
a.distance.partial_cmp(&b.distance).unwrap_or(Ordering::Equal)
});
let mut pruned = Vec::with_capacity(max_degree);
for candidate in sorted_candidates {
if pruned.len() >= max_degree {
break;
}
let mut should_add = true;
for &selected_id in &pruned {
let dist_to_selected = distance_fn(candidate.id, selected_id);
if dist_to_selected < alpha * candidate.distance {
should_add = false;
break;
}
}
if should_add {
pruned.push(candidate.id);
}
}
pruned
}
pub fn simple_prune(candidates: Vec<Candidate>, max_degree: usize) -> Vec<RowId> {
let mut sorted: Vec<_> = candidates.into_iter().collect();
sorted.sort_by(|a, b| {
a.distance.partial_cmp(&b.distance).unwrap_or(Ordering::Equal)
});
sorted
.into_iter()
.take(max_degree)
.map(|c| c.id)
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_candidate_ordering() {
let c1 = Candidate { id: 1, distance: 1.0 };
let c2 = Candidate { id: 2, distance: 2.0 };
assert!(c1 > c2);
}
#[test]
fn test_simple_prune_basic() {
let candidates = vec![
Candidate { id: 1, distance: 1.0 },
Candidate { id: 2, distance: 3.0 },
Candidate { id: 3, distance: 2.0 },
Candidate { id: 4, distance: 4.0 },
];
let pruned = simple_prune(candidates, 2);
assert_eq!(pruned.len(), 2);
assert_eq!(pruned[0], 1); assert_eq!(pruned[1], 3); }
#[test]
fn test_simple_prune_no_pruning_needed() {
let candidates = vec![
Candidate { id: 1, distance: 1.0 },
Candidate { id: 2, distance: 2.0 },
];
let pruned = simple_prune(candidates, 5);
assert_eq!(pruned.len(), 2);
}
#[test]
fn test_robust_prune_basic() {
let candidates = vec![
Candidate { id: 1, distance: 1.0 },
Candidate { id: 2, distance: 2.0 },
Candidate { id: 3, distance: 3.0 },
];
let dist_fn = |_a: RowId, _b: RowId| 10.0;
let pruned = robust_prune(candidates, 2, 1.2, dist_fn);
assert_eq!(pruned.len(), 2);
assert_eq!(pruned[0], 1); }
#[test]
fn test_robust_prune_diversity() {
let candidates = vec![
Candidate { id: 1, distance: 1.0 },
Candidate { id: 2, distance: 1.1 }, Candidate { id: 3, distance: 5.0 },
];
let dist_fn = |a: RowId, b: RowId| {
if (a == 1 && b == 2) || (a == 2 && b == 1) {
0.5 } else {
10.0 }
};
let pruned = robust_prune(candidates, 2, 1.2, dist_fn);
assert!(pruned.contains(&1));
assert!(pruned.contains(&3));
}
}