#![allow(clippy::needless_range_loop, clippy::missing_safety_doc)]
#[inline(always)]
pub fn mean(data: &[f64]) -> f64 {
if is_x86_feature_detected!("avx512f") {
unsafe { mean_avx512(data).0 }
} else if is_x86_feature_detected!("avx2") {
unsafe { mean_avx2(data).0 }
} else if is_x86_feature_detected!("sse4.1") {
unsafe { mean_sse4(data).0 }
} else {
mean_naive(data).0
}
}
#[inline(always)]
pub fn mean_naive(data: &[f64]) -> (f64, u64) {
let mut sum = 0.0;
let mut count = 0;
for i in 0..data.len() {
let d = data[i];
if !d.is_nan() {
count += 1;
sum += d;
}
}
(if count == 0 { 0.0 } else { sum / count as f64 }, count)
}
#[inline(always)]
pub unsafe fn mean_sse4(data: &[f64]) -> (f64, u64) {
let mut sum: f64;
let mut count: u64;
core::arch::asm! {
"xorpd xmm1, xmm1",
"pxor xmm2, xmm2",
"movddup xmm3, xmm3",
"xorpd xmm4, xmm4",
"test rax, rax",
"jz 3f",
"2:",
"movupd xmm0, [rsi]",
"movupd xmm5, xmm0",
"cmppd xmm0, xmm0, 3",
"blendvpd xmm5, xmm4",
"pand xmm0, xmm3",
"paddq xmm2, xmm0",
"addpd xmm1, xmm5",
"add rsi, 16",
"dec rax",
"jnz 2b",
"3:",
"haddpd xmm1, xmm1",
"movupd xmm3, xmm2",
"psrldq xmm2, 8",
"paddq xmm2, xmm3",
out("xmm0") _,
out("xmm1") sum,
out("xmm2") count,
inout("xmm3") 1 => _,
out("xmm4") _,
out("xmm5") _,
inout("rsi") data.as_ptr() => _,
inout("rax") data.len() / 2 => _,
options(readonly, nostack),
}
if data.len() % 2 != 0 {
let d = data[data.len() - 1];
if d.is_nan() {
count += 1;
} else {
sum += d;
}
}
count = data.len() as u64 - count;
(if count == 0 { 0.0 } else { sum / count as f64 }, count)
}
#[inline(always)]
pub unsafe fn mean_avx2(data: &[f64]) -> (f64, u64) {
let mut sum: f64;
let mut count: u64;
core::arch::asm! {
"vxorpd ymm0, ymm0, ymm0",
"vpxor ymm1, ymm1, ymm1",
"vbroadcastsd ymm2, xmm2",
"vxorpd ymm3, ymm3, ymm3",
"test rax, rax",
"jz 3f",
"2:",
"vmovupd ymm5, [rsi]",
"vcmppd ymm4, ymm5, ymm5, 3",
"vblendvpd ymm5, ymm5, ymm3, ymm4",
"vpand ymm4, ymm4, ymm2",
"vpaddq ymm1, ymm1, ymm4",
"vaddpd ymm0, ymm0, ymm5",
"add rsi, 32",
"dec rax",
"jnz 2b",
"3:",
"vextractf128 xmm2, ymm0, 1",
"vaddpd xmm0, xmm0, xmm2",
"vhaddpd xmm0, xmm0, xmm0",
"vextracti128 xmm2, ymm1, 1",
"vpaddq xmm1, xmm1, xmm2",
"movupd xmm2, xmm1",
"psrldq xmm1, 8",
"paddq xmm1, xmm2",
"vzeroupper",
out("xmm0") sum,
out("xmm1") count,
inout("xmm2") 1 => _,
out("ymm3") _,
out("ymm4") _,
out("ymm5") _,
inout("rax") data.len() / 4 => _,
inout("rsi") data.as_ptr() => _,
options(readonly, nostack),
}
if data.len() % 4 != 0 {
for i in (data.len() - data.len() % 4)..data.len() {
let d = data[i];
if d.is_nan() {
count += 1;
} else {
sum += d;
}
}
}
count = data.len() as u64 - count;
(if count == 0 { 0.0 } else { sum / count as f64 }, count)
}
#[inline(always)]
pub unsafe fn mean_avx512(data: &[f64]) -> (f64, u64) {
let mut sum: f64;
let mut non_nan: u64;
core::arch::asm! {
"vxorpd zmm0, zmm0, zmm0",
"vpxorq zmm1, zmm1, zmm1",
"vbroadcastsd zmm2, xmm2",
"vxorpd zmm3, zmm3, zmm3",
"test rax, rax",
"jz 3f",
"2:",
"vmovupd zmm4, [rsi]",
"vcmppd k1, zmm4, zmm4, 0",
"vpaddq zmm1{{k1}}, zmm1, zmm2",
"vaddpd zmm0{{k1}}, zmm0, zmm4",
"add rsi, 64",
"dec rax",
"jnz 2b",
"3:",
"vextractf64x4 ymm2, zmm0, 1",
"vaddpd ymm0, ymm0, ymm2",
"vextractf64x2 xmm2, ymm0, 1",
"vaddpd xmm0, xmm0, xmm2",
"vhaddpd xmm0, xmm0, xmm0",
"vextracti64x4 ymm2, zmm1, 1",
"vpaddq ymm1, ymm1, ymm2",
"vextracti64x2 xmm2, ymm1, 1",
"vpaddq xmm1, xmm1, xmm2",
"movupd xmm2, xmm1",
"psrldq xmm1, 8",
"paddq xmm1, xmm2",
"vzeroupper",
out("xmm0") sum,
out("xmm1") non_nan,
inout("xmm2") 1 => _,
out("zmm3") _,
out("zmm4") _,
inout("rax") data.len() / 8 => _,
inout("rsi") data.as_ptr() => _,
options(readonly, nostack),
}
if data.len() % 8 != 0 {
for i in (data.len() - data.len() % 8)..data.len() {
let d = data[i];
if !d.is_nan() {
non_nan += 1;
sum += d;
}
}
}
(
if non_nan == 0 {
0.0
} else {
sum / non_nan as f64
},
non_nan,
)
}
#[cfg(test)]
mod tests {
use super::*;
fn data() -> Vec<f64> {
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]
.iter()
.cycle()
.take(8 * 1000 - 1)
.copied()
.collect::<Vec<f64>>()
}
fn data_nan() -> Vec<f64> {
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, f64::NAN, f64::NAN]
.iter()
.cycle()
.take(8 * 1000 - 1)
.copied()
.collect::<Vec<f64>>()
}
const MEAN: (f64, u64) = ((36000.0 - 8.0) / 7999.0, 7999);
const MEAN_NAN: (f64, u64) = (21000.0 / 6000.0, 6000);
#[test]
fn test_mean_naive() {
assert_eq!(mean_naive(&data()), MEAN);
}
#[test]
fn test_mean_naive_nan() {
assert_eq!(mean_naive(&data_nan()), MEAN_NAN);
}
#[test]
#[cfg_attr(not(target_feature = "sse4.1"), ignore)]
fn test_mean_sse4() {
assert_eq!(unsafe { mean_sse4(&data()) }, MEAN);
}
#[test]
#[cfg_attr(not(target_feature = "sse4.1"), ignore)]
fn test_mean_sse4_nan() {
assert_eq!(unsafe { mean_sse4(&data_nan()) }, MEAN_NAN);
}
#[test]
#[cfg_attr(not(target_feature = "avx2"), ignore)]
fn test_mean_avx2() {
assert_eq!(unsafe { mean_avx2(&data()) }, MEAN);
}
#[test]
#[cfg_attr(not(target_feature = "avx2"), ignore)]
fn test_mean_avx2_nan() {
assert_eq!(unsafe { mean_avx2(&data_nan()) }, MEAN_NAN);
}
#[test]
#[cfg_attr(not(target_feature = "avx512f"), ignore)]
fn test_mean_avx512() {
assert_eq!(unsafe { mean_avx512(&data()) }, MEAN);
}
#[test]
#[cfg_attr(not(target_feature = "avx512f"), ignore)]
fn test_mean_avx512_nan() {
assert_eq!(unsafe { mean_avx512(&data_nan()) }, MEAN_NAN);
}
}