#[inline]
#[must_use]
pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len());
let mut sum: f32 = 0.0;
for i in 0..a.len() {
let diff = a[i] - b[i];
sum += diff * diff;
}
sum.sqrt()
}
#[inline]
#[must_use]
pub fn l2_squared(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len());
let mut sum: f32 = 0.0;
for i in 0..a.len() {
let diff = a[i] - b[i];
sum += diff * diff;
}
sum
}
#[inline]
#[must_use]
pub fn l2_squared_u8(a: &[u8], b: &[u8]) -> u32 {
assert_eq!(a.len(), b.len());
let mut sum: u32 = 0;
for i in 0..a.len() {
let diff = i32::from(a[i]) - i32::from(b[i]);
#[allow(clippy::cast_sign_loss)]
let sq = (diff * diff) as u32;
sum += sq;
}
sum
}
#[inline]
#[must_use]
pub fn dot_product_u8(a: &[u8], b: &[u8]) -> u32 {
assert_eq!(a.len(), b.len());
let mut sum: u32 = 0;
for i in 0..a.len() {
sum += u32::from(a[i]) * u32::from(b[i]);
}
sum
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_l2_squared_u8_scalar() {
let a = vec![1, 2, 3];
let b = vec![4, 2, 1];
assert_eq!(l2_squared_u8(&a, &b), 13);
}
#[test]
fn test_dot_product_u8_scalar() {
let a = vec![1, 2, 3];
let b = vec![4, 2, 1];
assert_eq!(dot_product_u8(&a, &b), 11);
}
#[test]
fn test_overflow_protection() {
let n = 1000;
let a = vec![255; n];
let b = vec![0; n];
assert_eq!(l2_squared_u8(&a, &b), 65_025_000);
}
#[test]
fn test_euclidean_distance_scalar() {
let a = vec![0.0f32, 0.0, 0.0];
let b = vec![3.0f32, 4.0, 0.0];
let dist = euclidean_distance(&a, &b);
assert!((dist - 5.0).abs() < 1e-6);
}
#[test]
fn test_euclidean_distance_single_element() {
let a = vec![5.0f32];
let b = vec![3.0f32];
let dist = euclidean_distance(&a, &b);
assert!((dist - 2.0).abs() < 1e-6);
}
#[test]
fn test_euclidean_distance_identical() {
let a = vec![1.0f32, 2.0, 3.0, 4.0];
let b = vec![1.0f32, 2.0, 3.0, 4.0];
let dist = euclidean_distance(&a, &b);
assert!(dist.abs() < 1e-6);
}
#[test]
fn test_l2_squared_f32_scalar() {
let a = vec![1.0f32, 2.0, 3.0];
let b = vec![4.0f32, 2.0, 1.0];
let dist = l2_squared(&a, &b);
assert!((dist - 13.0).abs() < 1e-6);
}
}