#![allow(clippy::cast_precision_loss)]
use crate::hnsw::VectorId;
use crate::metric::{L2Squared, Metric};
use crate::storage::VectorStorage;
#[must_use]
pub fn rescore(
candidates: &[(VectorId, f32)],
query: &[f32],
storage: &VectorStorage,
) -> Vec<(VectorId, f32)> {
let mut rescored: Vec<(VectorId, f32)> = candidates
.iter()
.filter_map(|(id, _approx_score)| {
if *id == VectorId::INVALID {
return None;
}
if storage.is_deleted(*id) {
return None;
}
let vector = storage.get_vector(*id);
let distance = L2Squared::distance(query, &vector);
Some((*id, distance))
})
.collect();
rescored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
rescored
}
#[must_use]
pub fn rescore_top_k(
candidates: &[(VectorId, f32)],
query: &[f32],
storage: &VectorStorage,
k: usize,
) -> Vec<(VectorId, f32)> {
let mut rescored = rescore(candidates, query, storage);
rescored.truncate(k);
rescored
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hnsw::HnswConfig;
#[test]
fn test_rescore_empty() {
let config = HnswConfig::new(4);
let storage = VectorStorage::new(&config, None);
let query = vec![1.0, 2.0, 3.0, 4.0];
let rescored = rescore(&[], &query, &storage);
assert!(rescored.is_empty());
}
#[test]
fn test_rescore_single_vector() {
let config = HnswConfig::new(4);
let mut storage = VectorStorage::new(&config, None);
let v = vec![1.0, 2.0, 3.0, 4.0];
let id = storage.insert(&v).unwrap();
let candidates = vec![(id, 0.9)]; let query = vec![1.0, 2.0, 3.0, 4.0];
let rescored = rescore(&candidates, &query, &storage);
assert_eq!(rescored.len(), 1);
assert_eq!(rescored[0].0, id);
assert!((rescored[0].1 - 0.0).abs() < 1e-6); }
#[test]
fn test_rescore_sorts_correctly() {
let config = HnswConfig::new(4);
let mut storage = VectorStorage::new(&config, None);
let v1 = vec![10.0, 10.0, 10.0, 10.0]; let v2 = vec![1.0, 2.0, 3.0, 4.0]; let v3 = vec![2.0, 3.0, 4.0, 5.0];
let id1 = storage.insert(&v1).unwrap();
let id2 = storage.insert(&v2).unwrap();
let id3 = storage.insert(&v3).unwrap();
let candidates = vec![(id1, 0.9), (id3, 0.85), (id2, 0.8)];
let query = vec![1.0, 2.0, 3.0, 4.0];
let rescored = rescore(&candidates, &query, &storage);
assert_eq!(rescored.len(), 3);
assert_eq!(rescored[0].0, id2); assert_eq!(rescored[1].0, id3); assert_eq!(rescored[2].0, id1); }
#[test]
fn test_rescore_top_k() {
let config = HnswConfig::new(4);
let mut storage = VectorStorage::new(&config, None);
for i in 0..5 {
let v = vec![i as f32; 4];
storage.insert(&v).unwrap();
}
let candidates: Vec<_> = (1..=5).map(|i| (VectorId(i), 0.5)).collect();
let query = vec![0.0, 0.0, 0.0, 0.0];
let rescored = rescore_top_k(&candidates, &query, &storage, 3);
assert_eq!(rescored.len(), 3);
assert_eq!(rescored[0].0 .0, 1); assert_eq!(rescored[1].0 .0, 2); assert_eq!(rescored[2].0 .0, 3); }
#[test]
fn test_rescore_skips_invalid_ids() {
let config = HnswConfig::new(4);
let mut storage = VectorStorage::new(&config, None);
let v = vec![1.0, 2.0, 3.0, 4.0];
let id = storage.insert(&v).unwrap();
let candidates = vec![(VectorId::INVALID, 0.9), (id, 0.8)];
let query = vec![1.0, 2.0, 3.0, 4.0];
let rescored = rescore(&candidates, &query, &storage);
assert_eq!(rescored.len(), 1);
assert_eq!(rescored[0].0, id);
}
#[test]
fn test_rescore_skips_deleted_vectors() {
let config = HnswConfig::new(4);
let mut storage = VectorStorage::new(&config, None);
let v1 = vec![1.0, 2.0, 3.0, 4.0];
let v2 = vec![2.0, 3.0, 4.0, 5.0];
let id1 = storage.insert(&v1).unwrap();
let id2 = storage.insert(&v2).unwrap();
storage.mark_deleted(id1);
let candidates = vec![(id1, 0.9), (id2, 0.8)];
let query = vec![1.0, 2.0, 3.0, 4.0];
let rescored = rescore(&candidates, &query, &storage);
assert_eq!(rescored.len(), 1);
assert_eq!(rescored[0].0, id2);
}
}