use pulp::Arch;
use rayon::prelude::*;
pub const PARALLEL_REDUCTION_THRESHOLD: usize = 1_000_000;
pub const PARALLEL_SORT_THRESHOLD: usize = 100_000;
const PAIRWISE_BASE: usize = 256;
#[inline(always)]
fn base_sum<T: Copy + std::ops::Add<Output = T>>(data: &[T], identity: T) -> T {
let n = data.len();
let mut acc0 = identity;
let mut acc1 = identity;
let mut acc2 = identity;
let mut acc3 = identity;
let mut acc4 = identity;
let mut acc5 = identity;
let mut acc6 = identity;
let mut acc7 = identity;
let chunks = n / 8;
let rem = n % 8;
for i in 0..chunks {
let base = i * 8;
acc0 = acc0 + data[base];
acc1 = acc1 + data[base + 1];
acc2 = acc2 + data[base + 2];
acc3 = acc3 + data[base + 3];
acc4 = acc4 + data[base + 4];
acc5 = acc5 + data[base + 5];
acc6 = acc6 + data[base + 6];
acc7 = acc7 + data[base + 7];
}
for i in 0..rem {
acc0 = acc0 + data[chunks * 8 + i];
}
(acc0 + acc1) + (acc2 + acc3) + ((acc4 + acc5) + (acc6 + acc7))
}
pub fn pairwise_sum<T>(data: &[T], identity: T) -> T
where
T: Copy + std::ops::Add<Output = T>,
{
let n = data.len();
if n == 0 {
return identity;
}
if n <= PAIRWISE_BASE {
return base_sum(data, identity);
}
let mut stack_val: [T; 24] = [identity; 24];
let mut stack_lvl: [usize; 24] = [0; 24];
let mut depth = 0usize;
let mut offset = 0;
while offset < n {
let end = (offset + PAIRWISE_BASE).min(n);
let mut current = base_sum(&data[offset..end], identity);
offset = end;
let mut level = 1usize;
while depth > 0 && stack_lvl[depth - 1] == level {
depth -= 1;
current = stack_val[depth] + current;
level += 1;
}
stack_val[depth] = current;
stack_lvl[depth] = level;
depth += 1;
}
let mut result = stack_val[depth - 1];
for i in (0..depth - 1).rev() {
result = stack_val[i] + result;
}
result
}
pub fn pairwise_sum_f64(data: &[f64]) -> f64 {
Arch::new().dispatch(PairwiseSumF64Op { data })
}
struct PairwiseSumF64Op<'a> {
data: &'a [f64],
}
impl pulp::WithSimd for PairwiseSumF64Op<'_> {
type Output = f64;
#[inline(always)]
fn with_simd<S: pulp::Simd>(self, simd: S) -> f64 {
simd_pairwise_f64(simd, self.data)
}
}
#[inline(always)]
fn simd_base_sum_f64<S: pulp::Simd>(simd: S, data: &[f64]) -> f64 {
let n = data.len();
let lane_count = size_of::<S::f64s>() / size_of::<f64>();
let simd_end = n - (n % lane_count);
let zero = simd.splat_f64s(0.0);
let mut acc0 = zero;
let mut acc1 = zero;
let mut acc2 = zero;
let mut acc3 = zero;
let stride = lane_count * 4;
let unrolled_end = n - (n % stride);
let mut i = 0;
while i < unrolled_end {
let v0 = simd.partial_load_f64s(&data[i..i + lane_count]);
let v1 = simd.partial_load_f64s(&data[i + lane_count..i + lane_count * 2]);
let v2 = simd.partial_load_f64s(&data[i + lane_count * 2..i + lane_count * 3]);
let v3 = simd.partial_load_f64s(&data[i + lane_count * 3..i + stride]);
acc0 = simd.add_f64s(acc0, v0);
acc1 = simd.add_f64s(acc1, v1);
acc2 = simd.add_f64s(acc2, v2);
acc3 = simd.add_f64s(acc3, v3);
i += stride;
}
while i + lane_count <= simd_end {
let v = simd.partial_load_f64s(&data[i..i + lane_count]);
acc0 = simd.add_f64s(acc0, v);
i += lane_count;
}
acc0 = simd.add_f64s(acc0, acc1);
acc2 = simd.add_f64s(acc2, acc3);
acc0 = simd.add_f64s(acc0, acc2);
let mut temp = [0.0f64; 8]; simd.partial_store_f64s(&mut temp[..lane_count], acc0);
let mut sum = 0.0f64;
for t in temp.iter().take(lane_count) {
sum += t;
}
for &val in &data[simd_end..n] {
sum += val;
}
sum
}
#[inline(always)]
fn simd_pairwise_f64<S: pulp::Simd>(simd: S, data: &[f64]) -> f64 {
let n = data.len();
if n == 0 {
return 0.0;
}
if n <= PAIRWISE_BASE {
return simd_base_sum_f64(simd, data);
}
let mut stack_val = [0.0f64; 24];
let mut stack_lvl = [0usize; 24];
let mut depth = 0usize;
let mut offset = 0;
while offset < n {
let end = (offset + PAIRWISE_BASE).min(n);
let mut current = simd_base_sum_f64(simd, &data[offset..end]);
offset = end;
let mut level = 1usize;
while depth > 0 && stack_lvl[depth - 1] == level {
depth -= 1;
current += stack_val[depth];
level += 1;
}
stack_val[depth] = current;
stack_lvl[depth] = level;
depth += 1;
}
let mut result = stack_val[depth - 1];
for i in (0..depth - 1).rev() {
result += stack_val[i];
}
result
}
pub fn simd_sum_sq_diff_f64(data: &[f64], mean: f64) -> f64 {
Arch::new().dispatch(SumSqDiffF64Op { data, mean })
}
struct SumSqDiffF64Op<'a> {
data: &'a [f64],
mean: f64,
}
impl pulp::WithSimd for SumSqDiffF64Op<'_> {
type Output = f64;
#[inline(always)]
fn with_simd<S: pulp::Simd>(self, simd: S) -> f64 {
let data = self.data;
let n = data.len();
let lane_count = size_of::<S::f64s>() / size_of::<f64>();
let simd_end = n - (n % lane_count);
let zero = simd.splat_f64s(0.0);
let mean_v = simd.splat_f64s(self.mean);
let mut acc0 = zero;
let mut acc1 = zero;
let mut acc2 = zero;
let mut acc3 = zero;
let stride = lane_count * 4;
let unrolled_end = n - (n % stride);
let mut i = 0;
while i < unrolled_end {
let v0 = simd.partial_load_f64s(&data[i..i + lane_count]);
let v1 = simd.partial_load_f64s(&data[i + lane_count..i + lane_count * 2]);
let v2 = simd.partial_load_f64s(&data[i + lane_count * 2..i + lane_count * 3]);
let v3 = simd.partial_load_f64s(&data[i + lane_count * 3..i + stride]);
let d0 = simd.sub_f64s(v0, mean_v);
let d1 = simd.sub_f64s(v1, mean_v);
let d2 = simd.sub_f64s(v2, mean_v);
let d3 = simd.sub_f64s(v3, mean_v);
acc0 = simd.mul_add_f64s(d0, d0, acc0);
acc1 = simd.mul_add_f64s(d1, d1, acc1);
acc2 = simd.mul_add_f64s(d2, d2, acc2);
acc3 = simd.mul_add_f64s(d3, d3, acc3);
i += stride;
}
while i + lane_count <= simd_end {
let v = simd.partial_load_f64s(&data[i..i + lane_count]);
let d = simd.sub_f64s(v, mean_v);
acc0 = simd.mul_add_f64s(d, d, acc0);
i += lane_count;
}
acc0 = simd.add_f64s(acc0, acc1);
acc2 = simd.add_f64s(acc2, acc3);
acc0 = simd.add_f64s(acc0, acc2);
let mut temp = [0.0f64; 8];
simd.partial_store_f64s(&mut temp[..lane_count], acc0);
let mut sum = 0.0f64;
for t in temp.iter().take(lane_count) {
sum += t;
}
for &val in &data[simd_end..n] {
let d = val - self.mean;
sum += d * d;
}
sum
}
}
pub fn parallel_sum<T>(data: &[T], identity: T) -> T
where
T: Copy + Send + Sync + std::ops::Add<Output = T>,
{
if data.len() >= PARALLEL_REDUCTION_THRESHOLD {
data.par_iter().copied().reduce(|| identity, |a, b| a + b)
} else {
pairwise_sum(data, identity)
}
}
pub fn parallel_prod<T>(data: &[T], identity: T) -> T
where
T: Copy + Send + Sync + std::ops::Mul<Output = T>,
{
if data.len() >= PARALLEL_REDUCTION_THRESHOLD {
data.par_iter().copied().reduce(|| identity, |a, b| a * b)
} else {
data.iter().copied().fold(identity, |a, b| a * b)
}
}
pub fn parallel_sort<T>(data: &mut [T])
where
T: Copy + Send + Sync + PartialOrd,
{
if data.len() >= PARALLEL_SORT_THRESHOLD {
data.par_sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
} else {
data.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
}
}
pub fn parallel_sort_stable<T>(data: &mut [T])
where
T: Copy + Send + Sync + PartialOrd,
{
if data.len() >= PARALLEL_SORT_THRESHOLD {
data.par_sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
} else {
data.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
}
}