use diskann::neighbor::Neighbor;
use diskann_utils::views::MatrixView;
pub fn groundtruth<T, F>(data: MatrixView<T>, query: &[T], f: F) -> Vec<Neighbor<u32>>
where
F: Fn(&[T], &[T]) -> f32,
{
let mut results: Vec<_> = data
.row_iter()
.enumerate()
.map(|(i, row)| Neighbor::new(i as u32, f(row, query)))
.collect();
results.sort_unstable_by(|a, b| a.cmp(b).reverse());
results
}
#[cfg(test)]
pub fn is_match(
groundtruth: &[Neighbor<u32>],
neighbor: Neighbor<u32>,
margin: f32,
) -> Option<usize> {
for i in (0..groundtruth.len()).rev() {
let gt = groundtruth[i];
if (gt.distance - neighbor.distance).abs() > margin {
return None;
}
if gt.id == neighbor.id {
return Some(i);
}
}
panic!(
"could not find neighbor {:?}. Remaining: {:?}",
neighbor, groundtruth
);
}
pub fn assert_top_k_exactly_match(
query_id: usize,
gt: &[Neighbor<u32>],
ids: &[u32],
distances: &[f32],
top_k: usize,
) {
for i in 0..top_k {
let neighbor = gt[gt.len() - 1 - i];
assert_eq!(
neighbor.distance, distances[i],
"failed on query {} for result {}",
query_id, i
);
assert_eq!(
neighbor.id, ids[i],
"failed on query {} for result {}",
query_id, i
);
}
}
#[cfg(test)]
pub fn assert_range_results_exactly_match(
query_id: usize,
gt: &[Neighbor<u32>],
ids: &[u32],
radius: f32,
inner_radius: Option<f32>,
) {
let gt_ids = if let Some(inner_radius) = inner_radius {
gt.iter()
.filter(|nbh| nbh.distance >= inner_radius && nbh.distance <= radius)
.map(|nbh| nbh.id)
.collect::<Vec<_>>()
} else {
gt.iter()
.filter(|nbh| nbh.distance <= radius)
.map(|nbh| nbh.id)
.collect::<Vec<_>>()
};
if ids.iter().any(|id| !gt_ids.contains(id)) {
panic!(
"query {}: found ids {:?} in range search with radius {}, inner radius {}, but expected {:?}",
query_id,
ids,
radius,
inner_radius.unwrap_or(f32::MIN),
gt_ids
);
}
}