#![cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
use std::arch::wasm32::*;
pub fn l2_squared(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len(), "wasm_simd128 l2: length mismatch");
unsafe { l2_impl(a, b) }
}
unsafe fn l2_impl(a: &[f32], b: &[f32]) -> f32 {
let n = a.len();
let chunks = n / 4;
let mut acc = f32x4_splat(0.0);
for i in 0..chunks {
let off = i * 4;
let va = unsafe { v128_load(a.as_ptr().add(off) as *const v128) };
let vb = unsafe { v128_load(b.as_ptr().add(off) as *const v128) };
let diff = f32x4_sub(va, vb);
acc = f32x4_add(acc, f32x4_mul(diff, diff));
}
let mut result = f32x4_extract_lane::<0>(acc)
+ f32x4_extract_lane::<1>(acc)
+ f32x4_extract_lane::<2>(acc)
+ f32x4_extract_lane::<3>(acc);
for i in (chunks * 4)..n {
let d = a[i] - b[i];
result += d * d;
}
result
}
pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len(), "wasm_simd128 cosine: length mismatch");
unsafe { cosine_impl(a, b) }
}
unsafe fn cosine_impl(a: &[f32], b: &[f32]) -> f32 {
let n = a.len();
let chunks = n / 4;
let mut vdot = f32x4_splat(0.0);
let mut vna = f32x4_splat(0.0);
let mut vnb = f32x4_splat(0.0);
for i in 0..chunks {
let off = i * 4;
let va = unsafe { v128_load(a.as_ptr().add(off) as *const v128) };
let vb = unsafe { v128_load(b.as_ptr().add(off) as *const v128) };
vdot = f32x4_add(vdot, f32x4_mul(va, vb));
vna = f32x4_add(vna, f32x4_mul(va, va));
vnb = f32x4_add(vnb, f32x4_mul(vb, vb));
}
let mut dot = f32x4_extract_lane::<0>(vdot)
+ f32x4_extract_lane::<1>(vdot)
+ f32x4_extract_lane::<2>(vdot)
+ f32x4_extract_lane::<3>(vdot);
let mut na = f32x4_extract_lane::<0>(vna)
+ f32x4_extract_lane::<1>(vna)
+ f32x4_extract_lane::<2>(vna)
+ f32x4_extract_lane::<3>(vna);
let mut nb = f32x4_extract_lane::<0>(vnb)
+ f32x4_extract_lane::<1>(vnb)
+ f32x4_extract_lane::<2>(vnb)
+ f32x4_extract_lane::<3>(vnb);
for i in (chunks * 4)..n {
dot += a[i] * b[i];
na += a[i] * a[i];
nb += b[i] * b[i];
}
let denom = (na * nb).sqrt();
if denom < f32::EPSILON {
1.0
} else {
(1.0 - dot / denom).max(0.0)
}
}
pub fn neg_inner_product(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len(), "wasm_simd128 ip: length mismatch");
unsafe { ip_impl(a, b) }
}
unsafe fn ip_impl(a: &[f32], b: &[f32]) -> f32 {
let n = a.len();
let chunks = n / 4;
let mut vdot = f32x4_splat(0.0);
for i in 0..chunks {
let off = i * 4;
let va = unsafe { v128_load(a.as_ptr().add(off) as *const v128) };
let vb = unsafe { v128_load(b.as_ptr().add(off) as *const v128) };
vdot = f32x4_add(vdot, f32x4_mul(va, vb));
}
let mut dot = f32x4_extract_lane::<0>(vdot)
+ f32x4_extract_lane::<1>(vdot)
+ f32x4_extract_lane::<2>(vdot)
+ f32x4_extract_lane::<3>(vdot);
for i in (chunks * 4)..n {
dot += a[i] * b[i];
}
-dot
}
#[cfg(target_arch = "wasm32")]
#[cfg(test)]
mod tests {
use super::*;
fn ref_l2(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b).map(|(x, y)| (x - y) * (x - y)).sum()
}
fn ref_cosine(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
let na: f32 = a.iter().map(|x| x * x).sum();
let nb: f32 = b.iter().map(|x| x * x).sum();
let denom = (na * nb).sqrt();
if denom < f32::EPSILON {
1.0
} else {
(1.0 - dot / denom).max(0.0)
}
}
fn ref_nip(a: &[f32], b: &[f32]) -> f32 {
-(a.iter().zip(b).map(|(x, y)| x * y).sum::<f32>())
}
const A16: [f32; 16] = [
0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6,
];
const B16: [f32; 16] = [
1.6, 1.5, 1.4, 1.3, 1.2, 1.1, 1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1,
];
#[test]
fn l2_full_chunks() {
let got = l2_squared(&A16, &B16);
let want = ref_l2(&A16, &B16);
assert!((got - want).abs() < 1e-4, "l2 full: got={got}, want={want}");
}
#[test]
fn cosine_full_chunks() {
let got = cosine_distance(&A16, &B16);
let want = ref_cosine(&A16, &B16);
assert!(
(got - want).abs() < 1e-5,
"cosine full: got={got}, want={want}"
);
}
#[test]
fn nip_full_chunks() {
let got = neg_inner_product(&A16, &B16);
let want = ref_nip(&A16, &B16);
assert!(
(got - want).abs() < 1e-4,
"nip full: got={got}, want={want}"
);
}
const A7: [f32; 7] = [0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5];
const B7: [f32; 7] = [3.5, 3.0, 2.5, 2.0, 1.5, 1.0, 0.5];
#[test]
fn l2_tail() {
let got = l2_squared(&A7, &B7);
let want = ref_l2(&A7, &B7);
assert!((got - want).abs() < 1e-4, "l2 tail: got={got}, want={want}");
}
#[test]
fn cosine_tail() {
let got = cosine_distance(&A7, &B7);
let want = ref_cosine(&A7, &B7);
assert!(
(got - want).abs() < 1e-5,
"cosine tail: got={got}, want={want}"
);
}
#[test]
fn nip_tail() {
let got = neg_inner_product(&A7, &B7);
let want = ref_nip(&A7, &B7);
assert!(
(got - want).abs() < 1e-4,
"nip tail: got={got}, want={want}"
);
}
#[test]
fn cosine_zero_norm_returns_one() {
let z = [0.0f32; 8];
let a = [1.0f32, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
assert_eq!(cosine_distance(&z, &a), 1.0);
assert_eq!(cosine_distance(&a, &z), 1.0);
}
}