#![cfg(feature = "simd")]
use nodedb_vector::DistanceMetric;
use nodedb_vector::distance::distance;
fn assert_rejects_mismatch(metric: DistanceMetric) {
let a = vec![1.0f32; 9];
let b = vec![1.0f32; 1];
let result = std::panic::catch_unwind(|| distance(&a, &b, metric));
assert!(
result.is_err(),
"distance({metric:?}) must reject length mismatch (a.len()=9, b.len()=1) \
instead of reading past the shorter buffer"
);
}
#[test]
fn l2_rejects_length_mismatch() {
assert_rejects_mismatch(DistanceMetric::L2);
}
#[test]
fn cosine_rejects_length_mismatch() {
assert_rejects_mismatch(DistanceMetric::Cosine);
}
#[test]
fn inner_product_rejects_length_mismatch() {
assert_rejects_mismatch(DistanceMetric::InnerProduct);
}
#[test]
fn l2_rejects_swapped_mismatch() {
let a = vec![1.0f32; 1];
let b = vec![1.0f32; 9];
let result = std::panic::catch_unwind(|| distance(&a, &b, DistanceMetric::L2));
assert!(
result.is_err(),
"distance() must reject length mismatch in either argument order"
);
}