Skip to main content

ferray_stats/
parallel.rs

1// ferray-stats: Rayon threshold dispatch for reductions and sorting (REQ-19, REQ-20)
2
3// `with_simd`, `simd_pairwise_f64`, `simd_base_sum_f64`, and `base_sum`
4// are pulp dispatch entry points where `#[inline(always)]` is required —
5// without it the SIMD instruction set selected at the call site fails to
6// inline into the kernel and dispatch becomes a no-op.
7#![allow(clippy::inline_always)]
8
9use pulp::Arch;
10use rayon::prelude::*;
11
12/// Threshold above which reductions use parallel tree-reduce.
13///
14/// Set high to avoid rayon thread-pool startup overhead dominating small workloads.
15/// Rayon's thread pool initialization costs ~1-3ms on first use, which is catastrophic
16/// for arrays under ~1M elements where sequential pairwise sum takes <1ms.
17pub const PARALLEL_REDUCTION_THRESHOLD: usize = 1_000_000;
18
19/// Threshold above which sorting uses parallel merge sort.
20pub const PARALLEL_SORT_THRESHOLD: usize = 100_000;
21
22/// Pairwise summation base case threshold.
23///
24/// Below this size, we use an unrolled sequential sum. 256 elements gives a
25/// good balance: halves carry-merge overhead vs 128 while keeping the same
26/// O(ε log N) accuracy bound (just one fewer merge level).
27const PAIRWISE_BASE: usize = 256;
28
29/// 8-wide unrolled base-case sum for auto-vectorization.
30#[inline(always)]
31fn base_sum<T: Copy + std::ops::Add<Output = T>>(data: &[T], identity: T) -> T {
32    let n = data.len();
33    let mut acc0 = identity;
34    let mut acc1 = identity;
35    let mut acc2 = identity;
36    let mut acc3 = identity;
37    let mut acc4 = identity;
38    let mut acc5 = identity;
39    let mut acc6 = identity;
40    let mut acc7 = identity;
41    let chunks = n / 8;
42    let rem = n % 8;
43    for i in 0..chunks {
44        let base = i * 8;
45        acc0 = acc0 + data[base];
46        acc1 = acc1 + data[base + 1];
47        acc2 = acc2 + data[base + 2];
48        acc3 = acc3 + data[base + 3];
49        acc4 = acc4 + data[base + 4];
50        acc5 = acc5 + data[base + 5];
51        acc6 = acc6 + data[base + 6];
52        acc7 = acc7 + data[base + 7];
53    }
54    for i in 0..rem {
55        acc0 = acc0 + data[chunks * 8 + i];
56    }
57    (acc0 + acc1) + (acc2 + acc3) + ((acc4 + acc5) + (acc6 + acc7))
58}
59
60/// Pairwise summation of a slice.
61///
62/// Uses an iterative carry-merge algorithm with O(ε log N) error bound,
63/// matching `NumPy`'s summation accuracy. Data is processed in 128-element
64/// base chunks, then merged pairwise using a small stack (like binary
65/// carry addition). Zero recursion overhead — stack depth is at most
66/// ceil(log2(N/128)) + 1 ≈ 20 entries.
67pub fn pairwise_sum<T>(data: &[T], identity: T) -> T
68where
69    T: Copy + std::ops::Add<Output = T>,
70{
71    let n = data.len();
72    if n == 0 {
73        return identity;
74    }
75    if n <= PAIRWISE_BASE {
76        return base_sum(data, identity);
77    }
78
79    // Iterative pairwise summation using a carry-merge stack.
80    // Each entry holds (partial_sum, level) where level is how many base
81    // chunks it represents. Same-level adjacent entries merge on push,
82    // exactly like binary carry addition. Max depth: ~24 for any realistic size.
83    let mut stack_val: [T; 24] = [identity; 24];
84    let mut stack_lvl: [usize; 24] = [0; 24];
85    let mut depth = 0usize;
86
87    let mut offset = 0;
88    while offset < n {
89        let end = (offset + PAIRWISE_BASE).min(n);
90        let mut current = base_sum(&data[offset..end], identity);
91        offset = end;
92
93        // Push and carry-merge: merge with stack top while levels match
94        let mut level = 1usize;
95        while depth > 0 && stack_lvl[depth - 1] == level {
96            depth -= 1;
97            current = stack_val[depth] + current;
98            level += 1;
99        }
100        stack_val[depth] = current;
101        stack_lvl[depth] = level;
102        depth += 1;
103    }
104
105    // Merge remaining stack entries
106    let mut result = stack_val[depth - 1];
107    for i in (0..depth - 1).rev() {
108        result = stack_val[i] + result;
109    }
110    result
111}
112
113// ---------------------------------------------------------------------------
114// SIMD-accelerated pairwise sum for f64
115// ---------------------------------------------------------------------------
116
117/// SIMD-accelerated pairwise summation for f64 slices.
118///
119/// Uses pulp for hardware SIMD dispatch (AVX2/SSE2/NEON) in the 128-element
120/// base case, with the same iterative carry-merge structure for O(ε log N)
121/// accuracy.
122#[must_use]
123pub fn pairwise_sum_f64(data: &[f64]) -> f64 {
124    Arch::new().dispatch(PairwiseSumF64Op { data })
125}
126
127struct PairwiseSumF64Op<'a> {
128    data: &'a [f64],
129}
130
131impl pulp::WithSimd for PairwiseSumF64Op<'_> {
132    type Output = f64;
133
134    #[inline(always)]
135    fn with_simd<S: pulp::Simd>(self, simd: S) -> f64 {
136        simd_pairwise_f64(simd, self.data)
137    }
138}
139
140#[inline(always)]
141fn simd_base_sum_f64<S: pulp::Simd>(simd: S, data: &[f64]) -> f64 {
142    let n = data.len();
143    let lane_count = size_of::<S::f64s>() / size_of::<f64>();
144    let simd_end = n - (n % lane_count);
145
146    let zero = simd.splat_f64s(0.0);
147    let mut acc0 = zero;
148    let mut acc1 = zero;
149    let mut acc2 = zero;
150    let mut acc3 = zero;
151
152    // 4-accumulator unroll to saturate FPU throughput
153    let stride = lane_count * 4;
154    let unrolled_end = n - (n % stride);
155    let mut i = 0;
156    while i < unrolled_end {
157        let v0 = simd.partial_load_f64s(&data[i..i + lane_count]);
158        let v1 = simd.partial_load_f64s(&data[i + lane_count..i + lane_count * 2]);
159        let v2 = simd.partial_load_f64s(&data[i + lane_count * 2..i + lane_count * 3]);
160        let v3 = simd.partial_load_f64s(&data[i + lane_count * 3..i + stride]);
161        acc0 = simd.add_f64s(acc0, v0);
162        acc1 = simd.add_f64s(acc1, v1);
163        acc2 = simd.add_f64s(acc2, v2);
164        acc3 = simd.add_f64s(acc3, v3);
165        i += stride;
166    }
167    while i + lane_count <= simd_end {
168        let v = simd.partial_load_f64s(&data[i..i + lane_count]);
169        acc0 = simd.add_f64s(acc0, v);
170        i += lane_count;
171    }
172    acc0 = simd.add_f64s(acc0, acc1);
173    acc2 = simd.add_f64s(acc2, acc3);
174    acc0 = simd.add_f64s(acc0, acc2);
175
176    // Horizontal sum: store SIMD register to temp array
177    let mut temp = [0.0f64; 8]; // max 8 lanes (AVX-512)
178    simd.partial_store_f64s(&mut temp[..lane_count], acc0);
179    let mut sum = 0.0f64;
180    for t in temp.iter().take(lane_count) {
181        sum += t;
182    }
183    // Scalar remainder
184    for &val in &data[simd_end..n] {
185        sum += val;
186    }
187    sum
188}
189
190#[inline(always)]
191fn simd_pairwise_f64<S: pulp::Simd>(simd: S, data: &[f64]) -> f64 {
192    let n = data.len();
193    if n == 0 {
194        return 0.0;
195    }
196    if n <= PAIRWISE_BASE {
197        return simd_base_sum_f64(simd, data);
198    }
199
200    let mut stack_val = [0.0f64; 24];
201    let mut stack_lvl = [0usize; 24];
202    let mut depth = 0usize;
203
204    let mut offset = 0;
205    while offset < n {
206        let end = (offset + PAIRWISE_BASE).min(n);
207        let mut current = simd_base_sum_f64(simd, &data[offset..end]);
208        offset = end;
209
210        let mut level = 1usize;
211        while depth > 0 && stack_lvl[depth - 1] == level {
212            depth -= 1;
213            current += stack_val[depth];
214            level += 1;
215        }
216        stack_val[depth] = current;
217        stack_lvl[depth] = level;
218        depth += 1;
219    }
220
221    let mut result = stack_val[depth - 1];
222    for i in (0..depth - 1).rev() {
223        result += stack_val[i];
224    }
225    result
226}
227
228// ---------------------------------------------------------------------------
229// SIMD-accelerated fused sum of squared differences for variance
230// ---------------------------------------------------------------------------
231
232/// SIMD-accelerated computation of sum((x - mean)²) for f64 slices.
233///
234/// Computes the sum of squared differences from the mean in a single pass
235/// without allocating an intermediate Vec. Uses 4 SIMD accumulators for ILP.
236#[must_use]
237pub fn simd_sum_sq_diff_f64(data: &[f64], mean: f64) -> f64 {
238    Arch::new().dispatch(SumSqDiffF64Op { data, mean })
239}
240
241struct SumSqDiffF64Op<'a> {
242    data: &'a [f64],
243    mean: f64,
244}
245
246impl pulp::WithSimd for SumSqDiffF64Op<'_> {
247    type Output = f64;
248
249    #[inline(always)]
250    fn with_simd<S: pulp::Simd>(self, simd: S) -> f64 {
251        let data = self.data;
252        let n = data.len();
253        let lane_count = size_of::<S::f64s>() / size_of::<f64>();
254        let simd_end = n - (n % lane_count);
255
256        let zero = simd.splat_f64s(0.0);
257        let mean_v = simd.splat_f64s(self.mean);
258        let mut acc0 = zero;
259        let mut acc1 = zero;
260        let mut acc2 = zero;
261        let mut acc3 = zero;
262
263        let stride = lane_count * 4;
264        let unrolled_end = n - (n % stride);
265        let mut i = 0;
266        while i < unrolled_end {
267            let v0 = simd.partial_load_f64s(&data[i..i + lane_count]);
268            let v1 = simd.partial_load_f64s(&data[i + lane_count..i + lane_count * 2]);
269            let v2 = simd.partial_load_f64s(&data[i + lane_count * 2..i + lane_count * 3]);
270            let v3 = simd.partial_load_f64s(&data[i + lane_count * 3..i + stride]);
271            let d0 = simd.sub_f64s(v0, mean_v);
272            let d1 = simd.sub_f64s(v1, mean_v);
273            let d2 = simd.sub_f64s(v2, mean_v);
274            let d3 = simd.sub_f64s(v3, mean_v);
275            acc0 = simd.mul_add_f64s(d0, d0, acc0);
276            acc1 = simd.mul_add_f64s(d1, d1, acc1);
277            acc2 = simd.mul_add_f64s(d2, d2, acc2);
278            acc3 = simd.mul_add_f64s(d3, d3, acc3);
279            i += stride;
280        }
281        while i + lane_count <= simd_end {
282            let v = simd.partial_load_f64s(&data[i..i + lane_count]);
283            let d = simd.sub_f64s(v, mean_v);
284            acc0 = simd.mul_add_f64s(d, d, acc0);
285            i += lane_count;
286        }
287        acc0 = simd.add_f64s(acc0, acc1);
288        acc2 = simd.add_f64s(acc2, acc3);
289        acc0 = simd.add_f64s(acc0, acc2);
290
291        let mut temp = [0.0f64; 8];
292        simd.partial_store_f64s(&mut temp[..lane_count], acc0);
293        let mut sum = 0.0f64;
294        for t in temp.iter().take(lane_count) {
295            sum += t;
296        }
297        for &val in &data[simd_end..n] {
298            let d = val - self.mean;
299            sum += d * d;
300        }
301        sum
302    }
303}
304
305// ---------------------------------------------------------------------------
306// Parallel dispatch
307// ---------------------------------------------------------------------------
308
309/// Perform a parallel pairwise sum on a slice.
310///
311/// Uses rayon tree-reduce for large slices (which is inherently pairwise),
312/// and iterative pairwise summation for smaller slices.
313pub fn parallel_sum<T>(data: &[T], identity: T) -> T
314where
315    T: Copy + Send + Sync + std::ops::Add<Output = T>,
316{
317    if data.len() >= PARALLEL_REDUCTION_THRESHOLD {
318        // Rayon's reduce is a tree-reduce, which gives pairwise-like accuracy
319        data.par_iter().copied().reduce(|| identity, |a, b| a + b)
320    } else {
321        pairwise_sum(data, identity)
322    }
323}
324
325/// Perform a parallel product on a slice of `Copy + Send + Sync` values.
326pub fn parallel_prod<T>(data: &[T], identity: T) -> T
327where
328    T: Copy + Send + Sync + std::ops::Mul<Output = T>,
329{
330    if data.len() >= PARALLEL_REDUCTION_THRESHOLD {
331        data.par_iter().copied().reduce(|| identity, |a, b| a * b)
332    } else {
333        data.iter().copied().fold(identity, |a, b| a * b)
334    }
335}
336
337/// Parallel sort (unstable) for large slices. Returns a sorted copy.
338pub fn parallel_sort<T>(data: &mut [T])
339where
340    T: Copy + Send + Sync + PartialOrd,
341{
342    if data.len() >= PARALLEL_SORT_THRESHOLD {
343        data.par_sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
344    } else {
345        data.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
346    }
347}
348
349/// Parallel stable sort for large slices.
350pub fn parallel_sort_stable<T>(data: &mut [T])
351where
352    T: Copy + Send + Sync + PartialOrd,
353{
354    if data.len() >= PARALLEL_SORT_THRESHOLD {
355        data.par_sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
356    } else {
357        data.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
358    }
359}