use crate::quantization::simd::portable::hamming_distance_slice as hamming_distance_portable_generic;
use std::arch::aarch64::{
vaddlvq_u8, vaddvq_f32, vcntq_u8, vdupq_n_f32, veorq_u8, vfmaq_f32, vld1q_f32, vld1q_u8,
vsubq_f32,
};
#[inline]
#[must_use]
pub fn hamming_distance_slice(a: &[u8], b: &[u8]) -> u32 {
assert_eq!(a.len(), b.len(), "Slice lengths must match");
unsafe { hamming_distance_neon_unchecked(a, b) }
}
#[inline]
#[target_feature(enable = "neon")]
unsafe fn hamming_distance_neon_unchecked(a: &[u8], b: &[u8]) -> u32 {
debug_assert_eq!(a.len(), b.len(), "Slices must have equal length");
let len = a.len();
let chunks = len / 16;
let mut count: u64 = 0;
for i in 0..chunks {
let offset = i * 16;
let va = vld1q_u8(a.as_ptr().add(offset));
let vb = vld1q_u8(b.as_ptr().add(offset));
let xor = veorq_u8(va, vb);
let bit_counts = vcntq_u8(xor);
count += u64::from(vaddlvq_u8(bit_counts));
}
let tail_start = chunks * 16;
for i in tail_start..len {
count += u64::from((a[i] ^ b[i]).count_ones());
}
#[allow(clippy::cast_possible_truncation)]
let result = count as u32;
result
}
#[inline]
#[must_use]
pub fn hamming_distance(a: &[u8; 96], b: &[u8; 96]) -> u32 {
hamming_distance_slice(a.as_slice(), b.as_slice())
}
#[inline]
#[must_use]
pub fn hamming_distance_portable_ref(a: &[u8], b: &[u8]) -> u32 {
hamming_distance_portable_generic(a, b)
}
#[inline]
#[must_use]
pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len(), "Slice lengths must match");
unsafe { dot_product_neon_unchecked(a, b) }
}
#[inline]
#[target_feature(enable = "neon")]
unsafe fn dot_product_neon_unchecked(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len(), "Slices must have equal length");
let len = a.len();
let chunks = len / 4;
let mut sum = vdupq_n_f32(0.0);
for i in 0..chunks {
let offset = i * 4;
let va = vld1q_f32(a.as_ptr().add(offset));
let vb = vld1q_f32(b.as_ptr().add(offset));
sum = vfmaq_f32(sum, va, vb);
}
let mut result = vaddvq_f32(sum);
let tail_start = chunks * 4;
for i in tail_start..len {
result += a[i] * b[i];
}
result
}
#[inline]
#[must_use]
pub fn dot_product_portable(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len(), "Slice lengths must match");
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
#[inline]
#[must_use]
pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len(), "Slice lengths must match");
unsafe { euclidean_distance_neon_unchecked(a, b) }
}
#[inline]
#[target_feature(enable = "neon")]
unsafe fn euclidean_distance_neon_unchecked(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len(), "Slices must have equal length");
let len = a.len();
let chunks = len / 4;
let mut sum_sq = vdupq_n_f32(0.0);
for i in 0..chunks {
let offset = i * 4;
let va = vld1q_f32(a.as_ptr().add(offset));
let vb = vld1q_f32(b.as_ptr().add(offset));
let diff = vsubq_f32(va, vb);
sum_sq = vfmaq_f32(sum_sq, diff, diff);
}
let mut result = vaddvq_f32(sum_sq);
let tail_start = chunks * 4;
for i in tail_start..len {
let diff = a[i] - b[i];
result += diff * diff;
}
result.sqrt()
}
#[inline]
#[must_use]
pub fn euclidean_distance_portable(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len(), "Slice lengths must match");
a.iter()
.zip(b.iter())
.map(|(x, y)| {
let diff = x - y;
diff * diff
})
.sum::<f32>()
.sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hamming_identical() {
let a = [0xAA; 96];
let b = [0xAA; 96];
assert_eq!(hamming_distance(&a, &b), 0);
}
#[test]
fn test_hamming_opposite() {
let a = [0x00; 96];
let b = [0xFF; 96];
assert_eq!(hamming_distance(&a, &b), 768);
}
#[test]
fn test_hamming_alternating() {
let a = [0xAA; 96]; let b = [0x55; 96]; assert_eq!(hamming_distance(&a, &b), 768);
}
#[test]
fn test_hamming_single_bit() {
let mut a = [0x00; 96];
let b = [0x00; 96];
a[0] = 0x01;
assert_eq!(hamming_distance(&a, &b), 1);
}
#[test]
fn test_slice_empty() {
let a: Vec<u8> = vec![];
let b: Vec<u8> = vec![];
assert_eq!(hamming_distance_slice(&a, &b), 0);
}
#[test]
fn test_slice_single_byte() {
let a = vec![0xFF];
let b = vec![0x00];
assert_eq!(hamming_distance_slice(&a, &b), 8);
}
#[test]
fn test_slice_15_bytes_tail_only() {
let a = vec![0xFF; 15];
let b = vec![0x00; 15];
assert_eq!(hamming_distance_slice(&a, &b), 120); }
#[test]
fn test_slice_16_bytes_exact_chunk() {
let a = vec![0xFF; 16];
let b = vec![0x00; 16];
assert_eq!(hamming_distance_slice(&a, &b), 128); }
#[test]
fn test_slice_17_bytes_with_tail() {
let a = vec![0xFF; 17];
let b = vec![0x00; 17];
assert_eq!(hamming_distance_slice(&a, &b), 136); }
#[test]
fn test_slice_32_bytes_two_chunks() {
let a = vec![0xFF; 32];
let b = vec![0x00; 32];
assert_eq!(hamming_distance_slice(&a, &b), 256); }
#[test]
fn test_slice_100_bytes() {
let a = vec![0xAA; 100];
let b = vec![0x55; 100];
assert_eq!(hamming_distance_slice(&a, &b), 800); }
#[test]
fn test_slice_identical() {
let a = vec![42u8; 1000];
let b = a.clone();
assert_eq!(hamming_distance_slice(&a, &b), 0);
}
#[test]
fn test_slice_matches_portable() {
for size in [0, 1, 15, 16, 17, 31, 32, 33, 64, 96, 100, 128, 1000] {
let a: Vec<u8> = (0..size).map(|i| i as u8).collect();
let b: Vec<u8> = (0..size).map(|i| (i + 1) as u8).collect();
let neon_result = hamming_distance_slice(&a, &b);
let portable_result = hamming_distance_portable_ref(&a, &b);
assert_eq!(
neon_result, portable_result,
"NEON != Portable for size={}: {} != {}",
size, neon_result, portable_result
);
}
}
#[test]
fn test_slice_matches_fixed_96() {
let a = [0xAA; 96];
let b = [0x55; 96];
assert_eq!(hamming_distance(&a, &b), hamming_distance_slice(&a, &b));
}
#[test]
fn test_dot_product_basic() {
let a = [1.0, 2.0, 3.0, 4.0];
let b = [1.0, 1.0, 1.0, 1.0];
let result = dot_product(&a, &b);
assert!((result - 10.0).abs() < 1e-6);
}
#[test]
fn test_dot_product_zero() {
let a = [1.0, 0.0, 1.0, 0.0];
let b = [0.0, 1.0, 0.0, 1.0];
let result = dot_product(&a, &b);
assert!((result - 0.0).abs() < 1e-6);
}
#[test]
fn test_dot_product_empty() {
let a: [f32; 0] = [];
let b: [f32; 0] = [];
let result = dot_product(&a, &b);
assert!((result - 0.0).abs() < 1e-6);
}
#[test]
fn test_euclidean_identical() {
let a = [1.0, 2.0, 3.0, 4.0];
let b = [1.0, 2.0, 3.0, 4.0];
let result = euclidean_distance(&a, &b);
assert!((result - 0.0).abs() < 1e-6);
}
#[test]
fn test_euclidean_unit() {
let a = [0.0, 0.0, 0.0];
let b = [1.0, 0.0, 0.0];
let result = euclidean_distance(&a, &b);
assert!((result - 1.0).abs() < 1e-6);
}
#[test]
fn test_euclidean_pythagoras() {
let a = [0.0, 0.0];
let b = [3.0, 4.0];
let result = euclidean_distance(&a, &b);
assert!((result - 5.0).abs() < 1e-6);
}
#[test]
fn test_euclidean_empty() {
let a: [f32; 0] = [];
let b: [f32; 0] = [];
let result = euclidean_distance(&a, &b);
assert!((result - 0.0).abs() < 1e-6);
}
#[test]
fn test_dot_product_single_element() {
let a = [5.0f32];
let b = [3.0f32];
let result = dot_product(&a, &b);
assert!((result - 15.0).abs() < 1e-6);
}
#[test]
fn test_dot_product_three_elements_tail() {
let a = [1.0f32, 2.0, 3.0];
let b = [4.0f32, 5.0, 6.0];
let result = dot_product(&a, &b);
assert!((result - 32.0).abs() < 1e-6);
}
#[test]
fn test_dot_product_four_elements_exact_chunk() {
let a = [1.0f32, 2.0, 3.0, 4.0];
let b = [4.0f32, 3.0, 2.0, 1.0];
let result = dot_product(&a, &b);
assert!((result - 20.0).abs() < 1e-6);
}
#[test]
fn test_dot_product_five_elements_with_tail() {
let a = [1.0f32, 2.0, 3.0, 4.0, 5.0];
let b = [1.0f32, 1.0, 1.0, 1.0, 1.0];
let result = dot_product(&a, &b);
assert!((result - 15.0).abs() < 1e-6);
}
#[test]
fn test_dot_product_large_768() {
let a: Vec<f32> = (0..768).map(|i| (i as f32) * 0.001).collect();
let b: Vec<f32> = vec![1.0; 768];
let result = dot_product(&a, &b);
let expected: f32 = (0..768).map(|i| (i as f32) * 0.001).sum();
assert!(
(result - expected).abs() < 0.01,
"result={}, expected={}",
result,
expected
);
}
#[test]
fn test_dot_product_matches_portable() {
for size in [0, 1, 3, 4, 5, 7, 8, 9, 100, 768, 1024] {
let a: Vec<f32> = (0..size).map(|i| (i as f32) * 0.1).collect();
let b: Vec<f32> = (0..size).map(|i| ((size - i) as f32) * 0.1).collect();
let neon_result = dot_product(&a, &b);
let portable_result = dot_product_portable(&a, &b);
let abs_diff = (neon_result - portable_result).abs();
let max_val = neon_result.abs().max(portable_result.abs());
let tolerance = if max_val > 1.0 {
max_val * 1e-4 } else {
1e-4 };
assert!(
abs_diff < tolerance,
"NEON != Portable for size={}: {} != {} (diff={}, tol={})",
size,
neon_result,
portable_result,
abs_diff,
tolerance
);
}
}
#[test]
fn test_euclidean_single_element() {
let a = [5.0f32];
let b = [3.0f32];
let result = euclidean_distance(&a, &b);
assert!((result - 2.0).abs() < 1e-6);
}
#[test]
fn test_euclidean_three_elements_tail() {
let a = [0.0f32, 0.0, 0.0];
let b = [1.0f32, 2.0, 2.0];
let result = euclidean_distance(&a, &b);
assert!((result - 3.0).abs() < 1e-6);
}
#[test]
fn test_euclidean_four_elements_exact_chunk() {
let a = [0.0f32, 0.0, 0.0, 0.0];
let b = [1.0f32, 1.0, 1.0, 1.0];
let result = euclidean_distance(&a, &b);
assert!((result - 2.0).abs() < 1e-6);
}
#[test]
fn test_euclidean_five_elements_with_tail() {
let a = [0.0f32, 0.0, 0.0, 0.0, 0.0];
let b = [1.0f32, 1.0, 1.0, 1.0, 1.0];
let result = euclidean_distance(&a, &b);
assert!((result - 5.0f32.sqrt()).abs() < 1e-6);
}
#[test]
fn test_euclidean_large_768() {
let a: Vec<f32> = vec![0.5; 768];
let b = a.clone();
let result = euclidean_distance(&a, &b);
assert!(result < 1e-6, "Distance to self should be ~0");
}
#[test]
fn test_euclidean_matches_portable() {
for size in [0, 1, 3, 4, 5, 7, 8, 9, 100, 768, 1024] {
let a: Vec<f32> = (0..size).map(|i| (i as f32) * 0.1).collect();
let b: Vec<f32> = (0..size).map(|i| ((size - i) as f32) * 0.1).collect();
let neon_result = euclidean_distance(&a, &b);
let portable_result = euclidean_distance_portable(&a, &b);
let abs_diff = (neon_result - portable_result).abs();
let max_val = neon_result.abs().max(portable_result.abs());
let tolerance = if max_val > 1.0 {
max_val * 1e-4 } else {
1e-4 };
assert!(
abs_diff < tolerance,
"NEON != Portable for size={}: {} != {} (diff={}, tol={})",
size,
neon_result,
portable_result,
abs_diff,
tolerance
);
}
}
}