#![allow(clippy::unwrap_used)]
use iqdb_flat::{FlatConfig, FlatIndex};
use iqdb_index::{Index, IndexCore};
use iqdb_types::{DistanceMetric, Hit, Metadata, SearchParams, VectorId};
use std::sync::Arc;
fn arc(v: &[f32]) -> Arc<[f32]> {
Arc::from(v)
}
type Raw = Vec<(VectorId, Vec<f32>, Option<Metadata>)>;
fn ref_euclidean(a: &[f32], b: &[f32]) -> f32 {
let mut acc = 0.0_f32;
for (x, y) in a.iter().zip(b.iter()) {
let d = x - y;
acc += d * d;
}
acc.sqrt()
}
fn ref_manhattan(a: &[f32], b: &[f32]) -> f32 {
let mut acc = 0.0_f32;
for (x, y) in a.iter().zip(b.iter()) {
acc += (x - y).abs();
}
acc
}
fn ref_dot_product(a: &[f32], b: &[f32]) -> f32 {
let mut acc = 0.0_f32;
for (x, y) in a.iter().zip(b.iter()) {
acc += x * y;
}
acc
}
fn ref_cosine(a: &[f32], b: &[f32]) -> f32 {
let mut dot = 0.0_f32;
let mut na = 0.0_f32;
let mut nb = 0.0_f32;
for (x, y) in a.iter().zip(b.iter()) {
dot += x * y;
na += x * x;
nb += y * y;
}
let denom = (na * nb).sqrt();
if denom == 0.0 {
return 1.0;
}
1.0 - dot / denom
}
fn ref_hamming(a: &[f32], b: &[f32]) -> f32 {
let mut count = 0u64;
for (x, y) in a.iter().zip(b.iter()) {
if x.to_bits() != y.to_bits() {
count += 1;
}
}
count as f32
}
fn independent_distance(metric: DistanceMetric, a: &[f32], b: &[f32]) -> f32 {
match metric {
DistanceMetric::Cosine => ref_cosine(a, b),
DistanceMetric::DotProduct => ref_dot_product(a, b),
DistanceMetric::Euclidean => ref_euclidean(a, b),
DistanceMetric::Manhattan => ref_manhattan(a, b),
DistanceMetric::Hamming => ref_hamming(a, b),
other => panic!("no hand-coded reference distance for {other:?}"),
}
}
fn deterministic_dataset(n: usize, dim: usize) -> Raw {
let mut out = Vec::with_capacity(n);
for i in 0..n {
let row: Vec<f32> = (0..dim)
.map(|j| ((i * 17 + j * 31) as f32).sin() + 0.5)
.collect();
out.push((VectorId::from(i as u64), row, None));
}
out
}
fn build_index(metric: DistanceMetric, dim: usize, raw: &Raw) -> FlatIndex {
let mut idx = FlatIndex::new(dim, metric, FlatConfig).unwrap();
for (id, vector, metadata) in raw {
idx.insert(id.clone(), arc(vector), metadata.clone())
.unwrap();
}
idx
}
fn naive_topk(metric: DistanceMetric, query: &[f32], raw: &Raw, k: usize) -> Vec<Hit> {
if k == 0 {
return Vec::new();
}
let mut scored: Vec<(usize, f32)> = raw
.iter()
.enumerate()
.map(|(i, (_, vector, _))| {
let mut distance = independent_distance(metric, query, vector);
if matches!(metric, DistanceMetric::DotProduct) {
distance = -distance;
}
(i, distance)
})
.collect();
scored.sort_by(|a, b| a.1.total_cmp(&b.1).then(a.0.cmp(&b.0)));
scored.truncate(k);
scored
.into_iter()
.map(|(i, distance)| Hit {
id: raw[i].0.clone(),
distance,
metadata: raw[i].2.clone(),
})
.collect()
}
const EPS_ABS: f32 = 1e-3;
const EPS_REL: f32 = 1e-4;
fn close_enough(x: f32, y: f32) -> bool {
if !x.is_finite() || !y.is_finite() {
return x.to_bits() == y.to_bits();
}
let diff = (x - y).abs();
diff <= EPS_ABS || diff <= EPS_REL * x.abs().max(y.abs())
}
fn assert_hits_equal(left: &[Hit], right: &[Hit]) {
assert_eq!(left.len(), right.len(), "different hit counts");
for (a, b) in left.iter().zip(right.iter()) {
assert_eq!(a.id, b.id, "id mismatch");
assert!(
close_enough(a.distance, b.distance),
"distances disagree: sut={} ref={}",
a.distance,
b.distance,
);
assert_eq!(a.metadata, b.metadata, "metadata mismatch");
}
}
fn check_metric(metric: DistanceMetric) {
const N: usize = 50;
const DIM: usize = 16;
const K: usize = 7;
let raw = deterministic_dataset(N, DIM);
let idx = build_index(metric, DIM, &raw);
let query: Vec<f32> = (0..DIM).map(|j| ((j as f32) * 0.37).cos()).collect();
let params = SearchParams::new(K, metric);
let actual = idx.search(&query, ¶ms).unwrap();
let expected = naive_topk(metric, &query, &raw, K);
assert_hits_equal(&actual, &expected);
}
#[test]
fn matches_naive_for_cosine() {
check_metric(DistanceMetric::Cosine);
}
#[test]
fn matches_naive_for_dot_product() {
check_metric(DistanceMetric::DotProduct);
}
#[test]
fn matches_naive_for_euclidean() {
check_metric(DistanceMetric::Euclidean);
}
#[test]
fn matches_naive_for_manhattan() {
check_metric(DistanceMetric::Manhattan);
}
#[test]
fn matches_naive_for_hamming() {
check_metric(DistanceMetric::Hamming);
}
#[test]
fn dot_product_distance_is_negated_inner_product() {
let dim = 3;
let raw: Raw = vec![
(VectorId::from(1u64), vec![1.0, 0.0, 0.0], None),
(VectorId::from(2u64), vec![0.0, 1.0, 0.0], None),
(VectorId::from(3u64), vec![10.0, 0.0, 0.0], None),
];
let idx = build_index(DistanceMetric::DotProduct, dim, &raw);
let query = vec![1.0, 0.0, 0.0];
let hits = idx
.search(&query, &SearchParams::new(3, DistanceMetric::DotProduct))
.unwrap();
assert_eq!(hits[0].id, VectorId::U64(3));
assert_eq!(hits[0].distance.to_bits(), (-10.0_f32).to_bits());
assert_eq!(hits[1].id, VectorId::U64(1));
assert_eq!(hits[1].distance.to_bits(), (-1.0_f32).to_bits());
assert_eq!(hits[2].id, VectorId::U64(2));
assert_eq!(hits[2].distance.to_bits(), (-0.0_f32).to_bits());
}