#[cfg(target_arch = "aarch64")]
mod neon {
use std::arch::aarch64::*;
#[inline]
pub(super) unsafe fn matvec(matrix: &[f32], vector: &[f32], dim: usize, out: &mut [f32]) {
debug_assert!(matrix.len() >= out.len() * dim, "matrix too short for matvec");
debug_assert!(vector.len() >= dim, "vector too short for matvec");
let rows = out.len();
for i in 0..rows {
let row = &matrix[i * dim..(i + 1) * dim];
out[i] = dot(row, vector, dim);
}
}
#[inline]
pub(super) unsafe fn dot(a: &[f32], b: &[f32], dim: usize) -> f32 {
debug_assert!(a.len() >= dim, "a too short for dot");
debug_assert!(b.len() >= dim, "b too short for dot");
let mut acc = vdupq_n_f32(0.0);
let mut i = 0usize;
while i + 4 <= dim {
let va = vld1q_f32(a.as_ptr().add(i));
let vb = vld1q_f32(b.as_ptr().add(i));
acc = vmlaq_f32(acc, va, vb);
i += 4;
}
let sum4 = vaddvq_f32(acc);
let mut tail = 0.0f32;
while i < dim {
tail += a[i] * b[i];
i += 1;
}
sum4 + tail
}
}
#[cfg(target_arch = "x86_64")]
mod avx2 {
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
#[inline]
#[target_feature(enable = "avx2")]
pub(super) unsafe fn matvec(matrix: &[f32], vector: &[f32], dim: usize, out: &mut [f32]) {
debug_assert!(matrix.len() >= out.len() * dim, "matrix too short for matvec");
debug_assert!(vector.len() >= dim, "vector too short for matvec");
let rows = out.len();
for i in 0..rows {
let row = &matrix[i * dim..(i + 1) * dim];
out[i] = dot(row, vector, dim);
}
}
#[inline]
#[target_feature(enable = "avx2")]
pub(super) unsafe fn dot(a: &[f32], b: &[f32], dim: usize) -> f32 {
debug_assert!(a.len() >= dim, "a too short for dot");
debug_assert!(b.len() >= dim, "b too short for dot");
let mut acc = _mm256_setzero_ps();
let mut i = 0usize;
while i + 8 <= dim {
let va = _mm256_loadu_ps(a.as_ptr().add(i));
let vb = _mm256_loadu_ps(b.as_ptr().add(i));
acc = _mm256_fmadd_ps(va, vb, acc);
i += 8;
}
let lo = _mm256_castps256_ps128(acc);
let hi = _mm256_extractf128_ps(acc, 1);
let sum128 = _mm_add_ps(lo, hi);
let shuf = _mm_movehdup_ps(sum128); let sums = _mm_add_ps(sum128, shuf); let shuf2 = _mm_movehl_ps(shuf, sums); let total = _mm_add_ss(sums, shuf2);
let mut scalar = _mm_cvtss_f32(total);
while i < dim {
scalar += a[i] * b[i];
i += 1;
}
scalar
}
}
mod scalar {
#[inline]
pub(super) fn matvec(matrix: &[f32], vector: &[f32], dim: usize, out: &mut [f32]) {
let rows = out.len();
for i in 0..rows {
let row = &matrix[i * dim..(i + 1) * dim];
out[i] = dot(row, vector);
}
}
#[inline]
pub(super) fn dot(a: &[f32], b: &[f32]) -> f32 {
let n = a.len().min(b.len());
let mut acc0 = 0.0f32;
let mut acc1 = 0.0f32;
let mut acc2 = 0.0f32;
let mut acc3 = 0.0f32;
let mut i = 0usize;
while i + 4 <= n {
acc0 += a[i] * b[i];
acc1 += a[i + 1] * b[i + 1];
acc2 += a[i + 2] * b[i + 2];
acc3 += a[i + 3] * b[i + 3];
i += 4;
}
let mut sum = acc0 + acc1 + acc2 + acc3;
while i < n {
sum += a[i] * b[i];
i += 1;
}
sum
}
}
#[inline]
pub fn matvec_multiply(matrix: &[f32], vector: &[f32], dim: usize, out: &mut [f32]) {
let rows = out.len();
assert_eq!(
matrix.len(),
rows * dim,
"matrix must be rows×dim = {rows}×{dim} = {}; got {}",
rows * dim,
matrix.len()
);
assert_eq!(
vector.len(),
dim,
"vector length must be dim={dim}; got {}",
vector.len()
);
if dim == 0 || rows == 0 {
return;
}
#[cfg(target_arch = "aarch64")]
{
unsafe { neon::matvec(matrix, vector, dim, out) }
return;
}
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
unsafe { avx2::matvec(matrix, vector, dim, out) }
return;
}
}
scalar::matvec(matrix, vector, dim, out);
}
#[inline]
pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
let dim = a.len().min(b.len());
if dim == 0 {
return 0.0;
}
#[cfg(target_arch = "aarch64")]
{
return unsafe { neon::dot(a, b, dim) };
}
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
return unsafe { avx2::dot(a, b, dim) };
}
}
scalar::dot(a, b)
}
#[cfg(test)]
mod tests {
use super::*;
fn scalar_dot(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
fn scalar_matvec(matrix: &[f32], vector: &[f32], dim: usize) -> Vec<f32> {
let rows = matrix.len() / dim;
(0..rows)
.map(|i| {
let row = &matrix[i * dim..(i + 1) * dim];
scalar_dot(row, vector)
})
.collect()
}
#[test]
fn test_dot_product_matches_scalar() {
let a: Vec<f32> = (0..64).map(|i| i as f32 * 0.1).collect();
let b: Vec<f32> = (0..64).map(|i| (i as f32 * 0.7).sin()).collect();
let expected = scalar_dot(&a, &b);
let got = dot_product(&a, &b);
assert!(
(got - expected).abs() < 1e-3,
"dot_product mismatch: expected={expected}, got={got}"
);
}
#[test]
fn test_dot_product_empty() {
assert_eq!(dot_product(&[], &[]), 0.0);
}
#[test]
fn test_dot_product_single() {
assert!((dot_product(&[3.0], &[4.0]) - 12.0).abs() < 1e-6);
}
#[test]
fn test_dot_product_all_ones() {
let a = vec![1.0f32; 128];
let b = vec![1.0f32; 128];
let result = dot_product(&a, &b);
assert!((result - 128.0).abs() < 1e-4);
}
#[test]
fn test_dot_product_orthogonal() {
let a: Vec<f32> = (0..64).map(|i| if i % 2 == 0 { 1.0 } else { -1.0 }).collect();
let b: Vec<f32> = (0..64).map(|i| if i % 2 == 0 { 1.0 } else { 1.0 }).collect();
let expected = scalar_dot(&a, &b);
let got = dot_product(&a, &b);
assert!((got - expected).abs() < 1e-4);
}
#[test]
fn test_matvec_multiply_matches_scalar() {
let dim = 16usize;
let rows = 8usize;
let matrix: Vec<f32> =
(0..(rows * dim)).map(|i| (i as f32 * 0.1).sin()).collect();
let vector: Vec<f32> = (0..dim).map(|i| i as f32 * 0.05).collect();
let expected = scalar_matvec(&matrix, &vector, dim);
let mut got = vec![0.0f32; rows];
matvec_multiply(&matrix, &vector, dim, &mut got);
for (i, (e, g)) in expected.iter().zip(got.iter()).enumerate() {
assert!(
(e - g).abs() < 1e-3,
"matvec mismatch at row {i}: expected={e}, got={g}"
);
}
}
#[test]
fn test_matvec_multiply_identity() {
let dim = 4usize;
#[rustfmt::skip]
let matrix = vec![
1.0_f32, 0.0, 0.0, 0.0,
0.0, 1.0, 0.0, 0.0,
0.0, 0.0, 1.0, 0.0,
0.0, 0.0, 0.0, 1.0,
];
let vector = vec![1.0_f32, 2.0, 3.0, 4.0];
let mut out = vec![0.0f32; 4];
matvec_multiply(&matrix, &vector, dim, &mut out);
for (i, (&o, &v)) in out.iter().zip(vector.iter()).enumerate() {
assert!((o - v).abs() < 1e-6, "identity multiply failed at index {i}");
}
}
#[test]
fn test_matvec_multiply_zero_vector() {
let dim = 8usize;
let matrix: Vec<f32> = (0..16).map(|i| i as f32).collect();
let vector = vec![0.0f32; dim];
let mut out = vec![99.0f32; 2];
matvec_multiply(&matrix, &vector, dim, &mut out);
for &v in &out {
assert!(v.abs() < 1e-6);
}
}
#[test]
fn test_matvec_multiply_larger() {
let dim = 128usize;
let rows = 32usize;
let matrix: Vec<f32> =
(0..(rows * dim)).map(|i| ((i as f32) * 0.01).cos()).collect();
let vector: Vec<f32> = (0..dim).map(|j| (j as f32 * 0.03).sin()).collect();
let expected = scalar_matvec(&matrix, &vector, dim);
let mut got = vec![0.0f32; rows];
matvec_multiply(&matrix, &vector, dim, &mut got);
for (i, (e, g)) in expected.iter().zip(got.iter()).enumerate() {
assert!(
(e - g).abs() < 1e-2,
"large matvec mismatch at row {i}: expected={e}, got={g}"
);
}
}
#[test]
#[cfg(target_arch = "aarch64")]
fn test_neon_dot_matches_scalar() {
let a: Vec<f32> = (1..=32).map(|i| i as f32 * 0.1).collect();
let b: Vec<f32> = (1..=32).map(|i| (i as f32).recip()).collect();
let expected = scalar_dot(&a, &b);
let got = unsafe { super::neon::dot(&a, &b, 32) };
assert!((got - expected).abs() < 1e-4);
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_avx2_dot_matches_scalar_when_available() {
if !is_x86_feature_detected!("avx2") || !is_x86_feature_detected!("fma") {
return; }
let a: Vec<f32> = (1..=32).map(|i| i as f32 * 0.1).collect();
let b: Vec<f32> = (1..=32).map(|i| (i as f32).recip()).collect();
let expected = scalar_dot(&a, &b);
let got = unsafe { super::avx2::dot(&a, &b, 32) };
assert!((got - expected).abs() < 1e-4);
}
#[test]
fn test_scalar_dot_unrolled() {
let a: Vec<f32> = (0..17).map(|i| i as f32).collect(); let b: Vec<f32> = (0..17).map(|i| i as f32 * 2.0).collect();
let expected = scalar_dot(&a, &b);
let got = scalar::dot(&a, &b);
assert!((got - expected).abs() < 1e-4);
}
}