use super::Metric;
#[derive(Debug, Clone, Copy, Default)]
pub struct L2Squared;
impl Metric<f32> for L2Squared {
#[inline]
fn distance(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(
a.len(),
b.len(),
"dimension mismatch: {} != {}",
a.len(),
b.len()
);
cfg_if::cfg_if! {
if #[cfg(all(target_arch = "wasm32", target_feature = "simd128"))] {
if a.len() < 16 {
let mut sum = 0.0;
for (x, y) in a.iter().zip(b.iter()) {
assert!(!(x.is_nan() || y.is_nan()), "NaN detected in input");
let diff = x - y;
sum += diff * diff;
}
return sum;
}
let result = super::simd::wasm::l2_squared(a, b);
assert!(!result.is_nan(), "NaN detected in input");
result
} else if #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))] {
if a.len() < 256 {
let mut sum = 0.0;
for (x, y) in a.iter().zip(b.iter()) {
assert!(!(x.is_nan() || y.is_nan()), "NaN detected in input");
let diff = x - y;
sum += diff * diff;
}
return sum;
}
let result = super::simd::x86::l2_squared(a, b);
assert!(!result.is_nan(), "NaN detected in input");
result
} else {
let mut sum = 0.0;
for (x, y) in a.iter().zip(b.iter()) {
assert!(!(x.is_nan() || y.is_nan()), "NaN detected in input");
let diff = x - y;
sum += diff * diff;
}
sum
}
}
}
}