#[allow(unsafe_code)]
#[inline]
pub(crate) fn sum_u32_slice(slice: &[u32]) -> u32 {
#[cfg(all(
target_arch = "aarch64",
target_feature = "neon",
not(colorthief_force_scalar)
))]
{
return aarch64_neon::sum_u32_slice(slice);
}
#[cfg(all(
target_arch = "wasm32",
target_feature = "simd128",
not(colorthief_force_scalar)
))]
{
return wasm_simd128::sum_u32_slice(slice);
}
#[cfg(all(target_arch = "x86_64", feature = "std", not(colorthief_force_scalar)))]
{
if !cfg!(colorthief_disable_avx2) && std::is_x86_feature_detected!("avx2") {
return unsafe { x86_avx2::sum_u32_slice(slice) };
}
if std::is_x86_feature_detected!("sse4.1") {
return unsafe { x86_sse41::sum_u32_slice(slice) };
}
}
#[allow(unreachable_code)]
scalar::sum_u32_slice(slice)
}
pub(crate) mod scalar {
#[allow(dead_code)]
pub fn sum_u32_slice(slice: &[u32]) -> u32 {
let mut sum: u32 = 0;
for &x in slice {
sum = sum.saturating_add(x);
}
sum
}
}
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
#[allow(unsafe_code, dead_code)]
pub(crate) mod aarch64_neon {
use core::arch::aarch64::*;
pub fn sum_u32_slice(slice: &[u32]) -> u32 {
unsafe { sum_u32_slice_neon(slice) }
}
#[target_feature(enable = "neon")]
unsafe fn sum_u32_slice_neon(slice: &[u32]) -> u32 {
let mut acc = vdupq_n_u64(0);
let chunks = slice.chunks_exact(4);
let remainder = chunks.remainder();
for chunk in chunks {
let v32 = unsafe { vld1q_u32(chunk.as_ptr()) };
let widened = vpaddlq_u32(v32);
acc = vaddq_u64(acc, widened);
}
let total64: u64 = vaddvq_u64(acc);
let mut total64 = total64;
for &x in remainder {
total64 = total64.saturating_add(x as u64);
}
total64.min(u32::MAX as u64) as u32
}
}
#[cfg(target_arch = "x86_64")]
#[allow(unsafe_code, dead_code)]
pub(crate) mod x86_sse41 {
use core::arch::x86_64::*;
#[target_feature(enable = "sse4.1")]
pub unsafe fn sum_u32_slice(slice: &[u32]) -> u32 {
let mut acc = _mm_setzero_si128();
let zero = _mm_setzero_si128();
let chunks = slice.chunks_exact(4);
let remainder = chunks.remainder();
for chunk in chunks {
let v32 = unsafe { _mm_loadu_si128(chunk.as_ptr() as *const __m128i) };
let lo64 = _mm_unpacklo_epi32(v32, zero);
let hi64 = _mm_unpackhi_epi32(v32, zero);
acc = _mm_add_epi64(acc, lo64);
acc = _mm_add_epi64(acc, hi64);
}
let mut buf = [0u64; 2];
unsafe { _mm_storeu_si128(buf.as_mut_ptr() as *mut __m128i, acc) };
let mut total64 = buf[0].saturating_add(buf[1]);
for &x in remainder {
total64 = total64.saturating_add(x as u64);
}
total64.min(u32::MAX as u64) as u32
}
}
#[cfg(target_arch = "x86_64")]
#[allow(unsafe_code, dead_code)]
pub(crate) mod x86_avx2 {
use core::arch::x86_64::*;
#[target_feature(enable = "avx2")]
pub unsafe fn sum_u32_slice(slice: &[u32]) -> u32 {
let mut acc = _mm256_setzero_si256();
let zero = _mm256_setzero_si256();
let chunks = slice.chunks_exact(8);
let remainder = chunks.remainder();
for chunk in chunks {
let v32 = unsafe { _mm256_loadu_si256(chunk.as_ptr() as *const __m256i) };
let lo64 = _mm256_unpacklo_epi32(v32, zero);
let hi64 = _mm256_unpackhi_epi32(v32, zero);
acc = _mm256_add_epi64(acc, lo64);
acc = _mm256_add_epi64(acc, hi64);
}
let mut buf = [0u64; 4];
unsafe { _mm256_storeu_si256(buf.as_mut_ptr() as *mut __m256i, acc) };
let mut total64 = buf[0]
.saturating_add(buf[1])
.saturating_add(buf[2])
.saturating_add(buf[3]);
for &x in remainder {
total64 = total64.saturating_add(x as u64);
}
total64.min(u32::MAX as u64) as u32
}
}
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
#[allow(unsafe_code, dead_code)]
pub(crate) mod wasm_simd128 {
use core::arch::wasm32::*;
pub fn sum_u32_slice(slice: &[u32]) -> u32 {
unsafe { sum_u32_slice_simd128(slice) }
}
#[target_feature(enable = "simd128")]
unsafe fn sum_u32_slice_simd128(slice: &[u32]) -> u32 {
let mut acc = u64x2_splat(0);
let chunks = slice.chunks_exact(4);
let remainder = chunks.remainder();
for chunk in chunks {
let v32 = unsafe { v128_load(chunk.as_ptr() as *const v128) };
let lo64 = u64x2_extend_low_u32x4(v32);
let hi64 = u64x2_extend_high_u32x4(v32);
acc = u64x2_add(acc, lo64);
acc = u64x2_add(acc, hi64);
}
let lane0 = u64x2_extract_lane::<0>(acc);
let lane1 = u64x2_extract_lane::<1>(acc);
let mut total64 = lane0.saturating_add(lane1);
for &x in remainder {
total64 = total64.saturating_add(x as u64);
}
total64.min(u32::MAX as u64) as u32
}
}
#[cfg(test)]
#[allow(unsafe_code)]
mod tests {
use super::*;
fn parity_inputs() -> Vec<Vec<u32>> {
vec![
vec![],
vec![1, 2, 3],
vec![1, 2, 3, 4],
vec![1, 2, 3, 4, 5],
vec![10, 20, 30, 40, 50, 60, 70, 80, 90],
{
let mut v = vec![0u32; 31];
v[0] = 100;
v[5] = 200;
v[20] = 50;
v[30] = 1;
v
},
{
let mut v = vec![1u32; 5];
v[0] = u32::MAX - 2;
v
},
]
}
#[test]
fn scalar_matches_naive_fold() {
for input in parity_inputs() {
let naive: u32 = input.iter().fold(0u32, |a, b| a.saturating_add(*b));
assert_eq!(scalar::sum_u32_slice(&input), naive);
}
}
#[test]
#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
fn neon_matches_scalar() {
for input in parity_inputs() {
let s = scalar::sum_u32_slice(&input);
let n = aarch64_neon::sum_u32_slice(&input);
assert_eq!(n, s, "NEON divergence on input of len {}", input.len());
}
}
#[test]
#[cfg(target_arch = "x86_64")]
fn sse41_matches_scalar() {
if !std::is_x86_feature_detected!("sse4.1") {
eprintln!("skipping: SSE4.1 not detected");
return;
}
for input in parity_inputs() {
let s = scalar::sum_u32_slice(&input);
let v = unsafe { x86_sse41::sum_u32_slice(&input) };
assert_eq!(v, s, "SSE4.1 divergence on input of len {}", input.len());
}
}
#[test]
#[cfg(target_arch = "x86_64")]
fn avx2_matches_scalar() {
if !std::is_x86_feature_detected!("avx2") {
eprintln!("skipping: AVX2 not detected");
return;
}
for input in parity_inputs() {
let s = scalar::sum_u32_slice(&input);
let v = unsafe { x86_avx2::sum_u32_slice(&input) };
assert_eq!(v, s, "AVX2 divergence on input of len {}", input.len());
}
}
#[test]
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
fn wasm_simd128_matches_scalar() {
for input in parity_inputs() {
let s = scalar::sum_u32_slice(&input);
let v = wasm_simd128::sum_u32_slice(&input);
assert_eq!(v, s);
}
}
}