#![allow(clippy::unwrap_used)]
use std::sync::Arc;
use iqdb_distance::{compute, compute_batch, cosine_normalized, normalize};
use iqdb_types::{DistanceMetric, Hit, IqdbError, VectorId};
struct SimIndex {
dim: usize,
metric: DistanceMetric,
ids: Vec<VectorId>,
vectors: Vec<Arc<[f32]>>,
}
impl SimIndex {
fn new(dim: usize, metric: DistanceMetric) -> Self {
Self {
dim,
metric,
ids: Vec::new(),
vectors: Vec::new(),
}
}
fn insert(&mut self, id: u64, vector: Arc<[f32]>) -> Result<(), IqdbError> {
if vector.len() != self.dim {
return Err(IqdbError::DimensionMismatch {
expected: self.dim,
found: vector.len(),
});
}
self.ids.push(VectorId::from(id));
self.vectors.push(vector);
Ok(())
}
fn search(&self, query: &[f32], k: usize) -> Result<Vec<Hit>, IqdbError> {
let candidates: Vec<&[f32]> = self.vectors.iter().map(|a| &a[..]).collect();
let mut distances = vec![0.0_f32; candidates.len()];
compute_batch(self.metric, query, &candidates, &mut distances)?;
if matches!(self.metric, DistanceMetric::DotProduct) {
for d in distances.iter_mut() {
*d = -*d;
}
}
let mut order: Vec<usize> = (0..candidates.len()).collect();
order.sort_by(|&i, &j| {
distances[i]
.partial_cmp(&distances[j])
.unwrap_or(std::cmp::Ordering::Equal)
.then(i.cmp(&j))
});
Ok(order
.into_iter()
.take(k)
.map(|i| Hit::new(self.ids[i].clone(), distances[i]))
.collect())
}
}
fn ids(hits: &[Hit]) -> Vec<VectorId> {
hits.iter().map(|h| h.id.clone()).collect()
}
#[test]
fn euclidean_returns_nearest_first() {
let mut index = SimIndex::new(2, DistanceMetric::Euclidean);
index
.insert(1, Arc::from([0.0_f32, 0.0].as_slice()))
.unwrap();
index
.insert(2, Arc::from([1.0_f32, 1.0].as_slice()))
.unwrap();
index
.insert(3, Arc::from([5.0_f32, 5.0].as_slice()))
.unwrap();
let hits = index.search(&[0.1, 0.1], 3).unwrap();
assert_eq!(
ids(&hits),
[VectorId::U64(1), VectorId::U64(2), VectorId::U64(3)],
);
assert!(hits[0].distance <= hits[1].distance);
assert!(hits[1].distance <= hits[2].distance);
}
#[test]
fn dot_product_returns_most_similar_first() {
let mut index = SimIndex::new(2, DistanceMetric::DotProduct);
index
.insert(1, Arc::from([1.0_f32, 0.0].as_slice()))
.unwrap();
index
.insert(2, Arc::from([10.0_f32, 0.0].as_slice()))
.unwrap();
index
.insert(3, Arc::from([-5.0_f32, 0.0].as_slice()))
.unwrap();
let hits = index.search(&[1.0, 0.0], 3).unwrap();
assert_eq!(
ids(&hits),
[VectorId::U64(2), VectorId::U64(1), VectorId::U64(3)],
);
assert!(hits[0].distance <= hits[1].distance);
}
#[test]
fn cosine_ranks_by_angle_not_magnitude() {
let mut index = SimIndex::new(2, DistanceMetric::Cosine);
index
.insert(1, Arc::from([1.0_f32, 0.0].as_slice()))
.unwrap(); index
.insert(2, Arc::from([100.0_f32, 0.0].as_slice()))
.unwrap(); index
.insert(3, Arc::from([0.0_f32, 1.0].as_slice()))
.unwrap();
let hits = index.search(&[2.0, 0.0], 3).unwrap();
assert_eq!(hits[2].id, VectorId::U64(3));
assert!(hits[0].distance < hits[2].distance);
}
#[test]
fn every_implemented_metric_runs_through_the_index() {
for metric in [
DistanceMetric::Cosine,
DistanceMetric::DotProduct,
DistanceMetric::Euclidean,
DistanceMetric::Manhattan,
DistanceMetric::Hamming,
] {
let mut index = SimIndex::new(3, metric);
index
.insert(1, Arc::from([1.0_f32, 0.0, 0.0].as_slice()))
.unwrap();
index
.insert(2, Arc::from([0.0_f32, 1.0, 0.0].as_slice()))
.unwrap();
let hits = index.search(&[1.0, 0.0, 0.0], 2).unwrap();
assert_eq!(hits.len(), 2, "metric {metric:?}");
}
}
#[test]
fn batch_scoring_matches_per_pair_compute() {
let metric = DistanceMetric::Euclidean;
let query = [0.2_f32, 0.4, 0.6];
let rows: [&[f32]; 3] = [&[1.0, 0.0, 0.0], &[0.0, 1.0, 0.0], &[0.5, 0.5, 0.5]];
let mut batched = [0.0_f32; 3];
compute_batch(metric, &query, &rows, &mut batched).unwrap();
for (row, got) in rows.iter().zip(batched.iter()) {
let single = compute(metric, &query, row).unwrap();
assert_eq!(single.to_bits(), got.to_bits());
}
}
#[test]
fn dimension_mismatch_propagates_from_search() {
let mut index = SimIndex::new(3, DistanceMetric::Euclidean);
index
.insert(1, Arc::from([1.0_f32, 0.0, 0.0].as_slice()))
.unwrap();
let err = index.search(&[1.0, 0.0], 1).unwrap_err();
assert!(matches!(err, IqdbError::DimensionMismatch { .. }));
}
#[test]
fn insert_rejects_wrong_dimension() {
let mut index = SimIndex::new(3, DistanceMetric::Cosine);
let err = index
.insert(1, Arc::from([1.0_f32, 0.0].as_slice()))
.unwrap_err();
assert_eq!(
err,
IqdbError::DimensionMismatch {
expected: 3,
found: 2
}
);
}
#[test]
fn normalized_index_matches_cosine_ranking() {
let raw: [&[f32]; 3] = [&[1.0, 2.0, 3.0], &[-3.0, 1.0, 0.5], &[2.0, 2.0, 2.0]];
let query = [1.0_f32, 2.0, 2.5];
let mut cosine = SimIndex::new(3, DistanceMetric::Cosine);
for (i, r) in raw.iter().enumerate() {
cosine.insert(i as u64, Arc::from(*r)).unwrap();
}
let cosine_order = ids(&cosine.search(&query, 3).unwrap());
let units: Vec<Vec<f32>> = raw.iter().map(|r| normalize(r).unwrap()).collect();
let uq = normalize(&query).unwrap();
let mut scored: Vec<(usize, f32)> = units
.iter()
.enumerate()
.map(|(i, u)| (i, cosine_normalized(u, &uq).unwrap()))
.collect();
scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap().then(a.0.cmp(&b.0)));
let normalized_order: Vec<VectorId> = scored
.into_iter()
.map(|(i, _)| VectorId::U64(i as u64))
.collect();
assert_eq!(cosine_order, normalized_order);
}