pub(crate) fn dot_f32_scalar(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.fold(0.0, |acc, (&x, &y)| acc + x * y)
}
pub(crate) fn sum_f32_scalar(a: &[f32]) -> f32 {
a.iter().copied().sum()
}
pub(crate) fn add_f32_scalar(a: &[f32], b: &[f32], out: &mut [f32]) {
for i in 0..a.len() {
out[i] = a[i] + b[i];
}
}
pub(crate) fn mul_f32_scalar(a: &[f32], b: &[f32], out: &mut [f32]) {
for i in 0..a.len() {
out[i] = a[i] * b[i];
}
}
pub(crate) fn axpy_f32_scalar(alpha: f32, x: &[f32], y: &mut [f32]) {
for i in 0..x.len() {
y[i] += alpha * x[i];
}
}
pub(crate) fn scal_f32_scalar(alpha: f32, x: &mut [f32]) {
for v in x.iter_mut() {
*v *= alpha;
}
}
pub(crate) fn sum_sq_f32_scalar(a: &[f32]) -> f32 {
a.iter().fold(0.0, |acc, &v| acc + v * v)
}
pub(crate) fn asum_f32_scalar(a: &[f32]) -> f32 {
a.iter().fold(0.0, |acc, &v| acc + v.abs())
}
pub(crate) fn min_f32_scalar(a: &[f32]) -> f32 {
a.iter().copied().fold(f32::INFINITY, f32::min)
}
pub(crate) fn max_f32_scalar(a: &[f32]) -> f32 {
a.iter().copied().fold(f32::NEG_INFINITY, f32::max)
}
pub(crate) fn mean_f32_scalar(a: &[f32]) -> f32 {
if a.is_empty() {
return 0.0;
}
sum_f32_scalar(a) / a.len() as f32
}
#[cfg(target_arch = "x86_64")]
mod avx {
use core::arch::x86_64::*;
#[inline(always)]
unsafe fn hsum_256_f32(v: __m256) -> f32 {
let hi = _mm256_extractf128_ps(v, 1);
let lo = _mm256_castps256_ps128(v);
let sum128 = _mm_add_ps(lo, hi); let shuf = _mm_movehdup_ps(sum128); let sums = _mm_add_ps(sum128, shuf); let hi32 = _mm_movehl_ps(sums, sums); let total = _mm_add_ss(sums, hi32);
_mm_cvtss_f32(total)
}
#[target_feature(enable = "avx")]
pub(crate) unsafe fn dot_f32_avx(a: &[f32], b: &[f32]) -> f32 {
let n = a.len();
let chunks = n / 8;
let remainder = n % 8;
let mut acc = _mm256_setzero_ps();
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
for i in 0..chunks {
let va = _mm256_loadu_ps(a_ptr.add(i * 8));
let vb = _mm256_loadu_ps(b_ptr.add(i * 8));
acc = _mm256_add_ps(acc, _mm256_mul_ps(va, vb));
}
let mut result = hsum_256_f32(acc);
let tail = chunks * 8;
for j in 0..remainder {
result += a[tail + j] * b[tail + j];
}
result
}
#[target_feature(enable = "avx")]
pub(crate) unsafe fn sum_f32_avx(a: &[f32]) -> f32 {
let n = a.len();
let chunks = n / 8;
let remainder = n % 8;
let mut acc = _mm256_setzero_ps();
let ptr = a.as_ptr();
for i in 0..chunks {
let va = _mm256_loadu_ps(ptr.add(i * 8));
acc = _mm256_add_ps(acc, va);
}
let mut result = hsum_256_f32(acc);
let tail = chunks * 8;
for j in 0..remainder {
result += a[tail + j];
}
result
}
#[target_feature(enable = "avx")]
pub(crate) unsafe fn add_f32_avx(a: &[f32], b: &[f32], out: &mut [f32]) {
let n = a.len();
let chunks = n / 8;
let remainder = n % 8;
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let o_ptr = out.as_mut_ptr();
for i in 0..chunks {
let off = i * 8;
let va = _mm256_loadu_ps(a_ptr.add(off));
let vb = _mm256_loadu_ps(b_ptr.add(off));
_mm256_storeu_ps(o_ptr.add(off), _mm256_add_ps(va, vb));
}
let tail = chunks * 8;
for j in 0..remainder {
*out.get_unchecked_mut(tail + j) =
*a.get_unchecked(tail + j) + *b.get_unchecked(tail + j);
}
}
#[target_feature(enable = "avx")]
pub(crate) unsafe fn mul_f32_avx(a: &[f32], b: &[f32], out: &mut [f32]) {
let n = a.len();
let chunks = n / 8;
let remainder = n % 8;
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let o_ptr = out.as_mut_ptr();
for i in 0..chunks {
let off = i * 8;
let va = _mm256_loadu_ps(a_ptr.add(off));
let vb = _mm256_loadu_ps(b_ptr.add(off));
_mm256_storeu_ps(o_ptr.add(off), _mm256_mul_ps(va, vb));
}
let tail = chunks * 8;
for j in 0..remainder {
*out.get_unchecked_mut(tail + j) =
*a.get_unchecked(tail + j) * *b.get_unchecked(tail + j);
}
}
#[target_feature(enable = "avx")]
pub(crate) unsafe fn axpy_f32_avx(alpha: f32, x: &[f32], y: &mut [f32]) {
let n = x.len();
let chunks = n / 8;
let remainder = n % 8;
let valpha = _mm256_set1_ps(alpha);
let x_ptr = x.as_ptr();
let y_ptr = y.as_mut_ptr();
for i in 0..chunks {
let off = i * 8;
let vx = _mm256_loadu_ps(x_ptr.add(off));
let vy = _mm256_loadu_ps(y_ptr.add(off));
let result = _mm256_add_ps(vy, _mm256_mul_ps(valpha, vx));
_mm256_storeu_ps(y_ptr.add(off), result);
}
let tail = chunks * 8;
for j in 0..remainder {
*y.get_unchecked_mut(tail + j) += alpha * *x.get_unchecked(tail + j);
}
}
#[target_feature(enable = "avx")]
pub(crate) unsafe fn scal_f32_avx(alpha: f32, x: &mut [f32]) {
let n = x.len();
let chunks = n / 8;
let remainder = n % 8;
let valpha = _mm256_set1_ps(alpha);
let ptr = x.as_mut_ptr();
for i in 0..chunks {
let off = i * 8;
let vx = _mm256_loadu_ps(ptr.add(off));
_mm256_storeu_ps(ptr.add(off), _mm256_mul_ps(valpha, vx));
}
let tail = chunks * 8;
for j in 0..remainder {
*x.get_unchecked_mut(tail + j) *= alpha;
}
}
#[target_feature(enable = "avx")]
pub(crate) unsafe fn sum_sq_f32_avx(a: &[f32]) -> f32 {
let n = a.len();
let chunks = n / 8;
let remainder = n % 8;
let mut acc = _mm256_setzero_ps();
let ptr = a.as_ptr();
for i in 0..chunks {
let va = _mm256_loadu_ps(ptr.add(i * 8));
acc = _mm256_add_ps(acc, _mm256_mul_ps(va, va));
}
let mut result = hsum_256_f32(acc);
let tail = chunks * 8;
for j in 0..remainder {
let v = a[tail + j];
result += v * v;
}
result
}
#[target_feature(enable = "avx")]
pub(crate) unsafe fn asum_f32_avx(a: &[f32]) -> f32 {
let n = a.len();
let chunks = n / 8;
let remainder = n % 8;
let sign_mask = _mm256_set1_ps(f32::from_bits(0x7FFF_FFFF));
let mut acc = _mm256_setzero_ps();
let ptr = a.as_ptr();
for i in 0..chunks {
let va = _mm256_loadu_ps(ptr.add(i * 8));
let abs_va = _mm256_and_ps(va, sign_mask);
acc = _mm256_add_ps(acc, abs_va);
}
let mut result = hsum_256_f32(acc);
let tail = chunks * 8;
for j in 0..remainder {
result += a[tail + j].abs();
}
result
}
#[target_feature(enable = "avx")]
pub(crate) unsafe fn min_f32_avx(a: &[f32]) -> f32 {
let n = a.len();
let chunks = n / 8;
let remainder = n % 8;
let mut vmin = _mm256_set1_ps(f32::INFINITY);
let ptr = a.as_ptr();
for i in 0..chunks {
let va = _mm256_loadu_ps(ptr.add(i * 8));
vmin = _mm256_min_ps(vmin, va);
}
let hi = _mm256_extractf128_ps(vmin, 1);
let lo = _mm256_castps256_ps128(vmin);
let min128 = _mm_min_ps(lo, hi);
let shuf = _mm_movehdup_ps(min128);
let mins = _mm_min_ps(min128, shuf);
let hi32 = _mm_movehl_ps(mins, mins);
let min_val = _mm_min_ss(mins, hi32);
let mut result = _mm_cvtss_f32(min_val);
let tail = chunks * 8;
for j in 0..remainder {
result = result.min(a[tail + j]);
}
result
}
#[target_feature(enable = "avx")]
pub(crate) unsafe fn mean_f32_avx(a: &[f32]) -> f32 {
sum_f32_avx(a) / a.len() as f32
}
#[target_feature(enable = "avx")]
pub(crate) unsafe fn max_f32_avx(a: &[f32]) -> f32 {
let n = a.len();
let chunks = n / 8;
let remainder = n % 8;
let mut vmax = _mm256_set1_ps(f32::NEG_INFINITY);
let ptr = a.as_ptr();
for i in 0..chunks {
let va = _mm256_loadu_ps(ptr.add(i * 8));
vmax = _mm256_max_ps(vmax, va);
}
let hi = _mm256_extractf128_ps(vmax, 1);
let lo = _mm256_castps256_ps128(vmax);
let max128 = _mm_max_ps(lo, hi);
let shuf = _mm_movehdup_ps(max128);
let maxs = _mm_max_ps(max128, shuf);
let hi32 = _mm_movehl_ps(maxs, maxs);
let max_val = _mm_max_ss(maxs, hi32);
let mut result = _mm_cvtss_f32(max_val);
let tail = chunks * 8;
for j in 0..remainder {
result = result.max(a[tail + j]);
}
result
}
}
pub(crate) fn dot_f32(a: &[f32], b: &[f32]) -> f32 {
super::dispatch_f32!(
avx::dot_f32_avx,
super::neon_f32_ops::dot_f32_neon,
dot_f32_scalar,
a,
b
)
}
pub(crate) fn sum_f32(a: &[f32]) -> f32 {
super::dispatch_f32!(
avx::sum_f32_avx,
super::neon_f32_ops::sum_f32_neon,
sum_f32_scalar,
a
)
}
pub(crate) fn add_f32(a: &[f32], b: &[f32], out: &mut [f32]) {
super::dispatch_f32!(
avx::add_f32_avx,
super::neon_f32_ops::add_f32_neon,
add_f32_scalar,
a,
b,
out
);
}
pub(crate) fn mul_f32(a: &[f32], b: &[f32], out: &mut [f32]) {
super::dispatch_f32!(
avx::mul_f32_avx,
super::neon_f32_ops::mul_f32_neon,
mul_f32_scalar,
a,
b,
out
);
}
pub(crate) fn axpy_f32(alpha: f32, x: &[f32], y: &mut [f32]) {
super::dispatch_f32!(
avx::axpy_f32_avx,
super::neon_f32_ops::axpy_f32_neon,
axpy_f32_scalar,
alpha,
x,
y
);
}
pub(crate) fn scal_f32(alpha: f32, x: &mut [f32]) {
super::dispatch_f32!(
avx::scal_f32_avx,
super::neon_f32_ops::scal_f32_neon,
scal_f32_scalar,
alpha,
x
);
}
pub(crate) fn sum_sq_f32(a: &[f32]) -> f32 {
super::dispatch_f32!(
avx::sum_sq_f32_avx,
super::neon_f32_ops::sum_sq_f32_neon,
sum_sq_f32_scalar,
a
)
}
pub(crate) fn asum_f32(a: &[f32]) -> f32 {
super::dispatch_f32!(
avx::asum_f32_avx,
super::neon_f32_ops::asum_f32_neon,
asum_f32_scalar,
a
)
}
pub(crate) fn min_f32(a: &[f32]) -> f32 {
super::dispatch_f32!(
avx::min_f32_avx,
super::neon_f32_ops::min_f32_neon,
min_f32_scalar,
a
)
}
pub(crate) fn max_f32(a: &[f32]) -> f32 {
super::dispatch_f32!(
avx::max_f32_avx,
super::neon_f32_ops::max_f32_neon,
max_f32_scalar,
a
)
}
pub(crate) fn mean_f32(a: &[f32]) -> f32 {
super::dispatch_f32!(
avx::mean_f32_avx,
super::neon_f32_ops::mean_f32_neon,
mean_f32_scalar,
a
)
}
#[cfg(test)]
#[allow(clippy::float_cmp)]
mod tests {
use super::*;
#[test]
fn test_dot_f32() {
let a = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let b = vec![9.0_f32, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
let expected = dot_f32_scalar(&a, &b);
assert!((dot_f32(&a, &b) - expected).abs() < 1e-4);
}
#[test]
fn test_dot_f32_large() {
let n = 1024;
let a: Vec<f32> = (0..n).map(|i| i as f32).collect();
let b: Vec<f32> = (0..n).map(|i| (n - i) as f32).collect();
let expected = dot_f32_scalar(&a, &b);
assert!((dot_f32(&a, &b) - expected).abs() / expected.abs() < 1e-4);
}
#[test]
fn test_sum_f32() {
let a: Vec<f32> = (1..=100).map(|i| i as f32).collect();
assert!((sum_f32(&a) - 5050.0).abs() < 1.0);
}
#[test]
fn test_add_f32() {
let a: Vec<f32> = (0..10).map(|i| i as f32).collect();
let b: Vec<f32> = (10..20).map(|i| i as f32).collect();
let mut out = vec![0.0_f32; 10];
add_f32(&a, &b, &mut out);
for i in 0..10 {
assert_eq!(out[i], a[i] + b[i]);
}
}
#[test]
fn test_mul_f32() {
let a = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let b = vec![2.0_f32; 9];
let mut out = vec![0.0_f32; 9];
mul_f32(&a, &b, &mut out);
for i in 0..9 {
assert_eq!(out[i], a[i] * 2.0);
}
}
#[test]
fn test_axpy_f32() {
let x = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let mut y = vec![10.0_f32; 9];
axpy_f32(2.0, &x, &mut y);
for i in 0..9 {
assert_eq!(y[i], 10.0 + 2.0 * x[i]);
}
}
#[test]
fn test_scal_f32() {
let mut x = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let orig: Vec<f32> = x.clone();
scal_f32(3.0, &mut x);
for i in 0..9 {
assert_eq!(x[i], orig[i] * 3.0);
}
}
#[test]
fn test_sum_sq_f32() {
let a = vec![3.0_f32, 4.0];
assert!((sum_sq_f32(&a) - 25.0).abs() < 1e-4);
}
#[test]
fn test_asum_f32() {
let a = vec![-1.0_f32, 2.0, -3.0, 4.0, -5.0, 6.0, -7.0, 8.0, -9.0];
assert!((asum_f32(&a) - 45.0).abs() < 1e-4);
}
#[test]
fn test_min_f32() {
let a = vec![3.0_f32, 1.0, 4.0, 1.5, 9.0, 2.6, 5.3, 0.5, 7.0];
assert!((min_f32(&a) - 0.5).abs() < 1e-6);
}
#[test]
fn test_max_f32() {
let a = vec![3.0_f32, 1.0, 4.0, 1.5, 9.0, 2.6, 5.3, 0.5, 7.0];
assert!((max_f32(&a) - 9.0).abs() < 1e-6);
}
#[test]
fn test_mean_f32() {
let a: Vec<f32> = (1..=100).map(|i| i as f32).collect();
assert!((mean_f32(&a) - 50.5).abs() < 1e-2);
}
#[test]
fn test_empty_slices_f32() {
assert_eq!(dot_f32(&[], &[]), 0.0);
assert_eq!(sum_f32(&[]), 0.0);
assert_eq!(mean_f32(&[]), 0.0);
}
}