use iqdb::{DistanceMetric, Iqdb, Record, RecordId, Vector};
use proptest::collection::vec as prop_vec;
use proptest::prelude::*;
fn finite_f32() -> impl Strategy<Value = f32> {
(-100.0_f32..100.0_f32).prop_filter("must be finite", |v| v.is_finite())
}
fn vector_strategy(dim: usize) -> impl Strategy<Value = Vec<f32>> {
prop_vec(finite_f32(), dim..=dim)
}
proptest! {
#[test]
fn identity_yields_minimal_distance(components in vector_strategy(4)) {
prop_assume!(components.iter().any(|x| x.abs() > 1e-3));
let a = Vector::new(components.clone()).expect("finite");
let b = Vector::new(components).expect("finite");
let l2 = DistanceMetric::L2.distance(&a, &b).expect("ok");
prop_assert!(l2.abs() < 1e-3, "l2={}", l2);
let cos = DistanceMetric::Cosine.distance(&a, &b).expect("ok");
prop_assert!(cos.abs() < 1e-3, "cos={}", cos);
let dot = DistanceMetric::Dot.distance(&a, &b).expect("ok");
prop_assert!(dot <= 1e-3, "dot={}", dot);
}
#[test]
fn distance_is_symmetric(
a_components in vector_strategy(4),
b_components in vector_strategy(4),
) {
let a = Vector::new(a_components).expect("finite");
let b = Vector::new(b_components).expect("finite");
for metric in [DistanceMetric::L2, DistanceMetric::Cosine, DistanceMetric::Dot] {
let forward = metric.distance(&a, &b).expect("ok");
let backward = metric.distance(&b, &a).expect("ok");
if forward.is_nan() || backward.is_nan() {
prop_assert_eq!(forward.is_nan(), backward.is_nan(), "{:?}", metric);
} else {
prop_assert!(
(forward - backward).abs() < 1e-3,
"{:?}: forward={}, backward={}",
metric, forward, backward,
);
}
}
}
#[test]
fn l2_distance_is_non_negative(
a_components in vector_strategy(4),
b_components in vector_strategy(4),
) {
let a = Vector::new(a_components).expect("finite");
let b = Vector::new(b_components).expect("finite");
let d = DistanceMetric::L2.distance(&a, &b).expect("ok");
prop_assert!(d.is_finite(), "l2 must be finite for finite inputs: {}", d);
prop_assert!(d >= 0.0, "l2 must be non-negative: {}", d);
}
#[test]
fn cosine_distance_is_in_zero_two_range(
a_components in vector_strategy(4),
b_components in vector_strategy(4),
) {
let a = Vector::new(a_components).expect("finite");
let b = Vector::new(b_components).expect("finite");
prop_assume!(a.norm() > 1e-3 && b.norm() > 1e-3);
let d = DistanceMetric::Cosine.distance(&a, &b).expect("ok");
prop_assert!((-1e-3..=2.0 + 1e-3).contains(&d), "cosine out of range: {}", d);
}
#[test]
fn search_length_is_bounded(
records in prop_vec(vector_strategy(4), 0..=16),
k in 0_usize..20,
) {
let db = Iqdb::open_in_memory();
for (id, components) in records.iter().enumerate() {
let v = Vector::new(components.clone()).expect("finite");
db.upsert(Record::new(RecordId::new(id as u64), v)).expect("ok");
}
let probe = Vector::new(vec![0.5, 0.5, 0.5, 0.5]).expect("finite");
let hits = db.search(&probe, k, DistanceMetric::L2).expect("ok");
prop_assert!(hits.len() <= k);
prop_assert!(hits.len() <= records.len());
}
#[test]
fn search_results_are_sorted_ascending(
records in prop_vec(vector_strategy(4), 1..=16),
k in 1_usize..20,
) {
let db = Iqdb::open_in_memory();
for (id, components) in records.iter().enumerate() {
let v = Vector::new(components.clone()).expect("finite");
db.upsert(Record::new(RecordId::new(id as u64), v)).expect("ok");
}
let probe = Vector::new(vec![0.5, 0.5, 0.5, 0.5]).expect("finite");
let hits = db.search(&probe, k, DistanceMetric::L2).expect("ok");
for pair in hits.windows(2) {
let (lo, hi) = (pair[0].score, pair[1].score);
if lo.is_nan() {
prop_assert!(hi.is_nan(), "non-NaN after NaN: lo={}, hi={}", lo, hi);
} else {
prop_assert!(
lo <= hi,
"results out of order: lo={}, hi={}",
lo, hi,
);
}
}
}
#[test]
fn perfect_match_is_always_in_top_k(
components in vector_strategy(4),
decoys in prop_vec(vector_strategy(4), 0..=10),
) {
let db = Iqdb::open_in_memory();
for (i, decoy) in decoys.iter().enumerate() {
let v = Vector::new(decoy.clone()).expect("finite");
db.upsert(Record::new(RecordId::new(i as u64), v)).expect("ok");
}
let perfect_id = (decoys.len() as u64) + 1;
db.upsert(Record::new(
RecordId::new(perfect_id),
Vector::new(components.clone()).expect("finite"),
)).expect("ok");
let probe = Vector::new(components).expect("finite");
let k = decoys.len() + 1;
let hits = db.search(&probe, k, DistanceMetric::L2).expect("ok");
prop_assert!(
hits.iter().any(|h| h.id == RecordId::new(perfect_id)),
"perfect match id={} missing from top-{}",
perfect_id, k,
);
}
#[test]
fn unfiltered_matches_always_true_filter(
records in prop_vec(vector_strategy(4), 1..=16),
k in 1_usize..10,
) {
let db = Iqdb::open_in_memory();
for (id, components) in records.iter().enumerate() {
let v = Vector::new(components.clone()).expect("finite");
db.upsert(Record::new(RecordId::new(id as u64), v)).expect("ok");
}
let probe = Vector::new(vec![0.0, 0.0, 0.0, 0.0]).expect("finite");
let unfiltered = db.search(&probe, k, DistanceMetric::L2).expect("ok");
let always_true = db
.search_with(&probe, k, DistanceMetric::L2, |_| true)
.expect("ok");
let ids_a: Vec<u64> = unfiltered.iter().map(|h| h.id.get()).collect();
let ids_b: Vec<u64> = always_true.iter().map(|h| h.id.get()).collect();
prop_assert_eq!(ids_a, ids_b);
}
}