Skip to main content

ferray_stats/
parallel.rs

1// ferray-stats: Rayon threshold dispatch for reductions and sorting (REQ-19, REQ-20)
2
3// `with_simd`, `simd_pairwise_f64`, `simd_base_sum_f64`, and `base_sum`
4// are pulp dispatch entry points where `#[inline(always)]` is required —
5// without it the SIMD instruction set selected at the call site fails to
6// inline into the kernel and dispatch becomes a no-op.
7#![allow(clippy::inline_always)]
8
9use pulp::Arch;
10use rayon::prelude::*;
11
12/// Threshold above which reductions use parallel tree-reduce.
13///
14/// Set high to avoid rayon thread-pool startup overhead dominating small workloads.
15/// Rayon's thread pool initialization costs ~1-3ms on first use, which is catastrophic
16/// for arrays under ~1M elements where sequential pairwise sum takes <1ms.
17pub const PARALLEL_REDUCTION_THRESHOLD: usize = 1_000_000;
18
19/// Threshold above which sorting uses parallel merge sort.
20pub const PARALLEL_SORT_THRESHOLD: usize = 100_000;
21
22/// Pairwise summation base case threshold.
23///
24/// Below this size, we use an unrolled sequential sum. 256 elements gives a
25/// good balance: halves carry-merge overhead vs 128 while keeping the same
26/// O(ε log N) accuracy bound (just one fewer merge level).
27/// Base-chunk size for pairwise summation. Each base chunk runs the
28/// flat SIMD inner-loop (4-way unrolled f64 adds with one horizontal
29/// reduction at the end); the pairwise tree only kicks in across base
30/// chunks. So this knob is the trade-off between (a) tree-induced
31/// horizontal-sum overhead — paid once per base chunk — and (b)
32/// numerical precision, which scales as O(ε * base_size) within a
33/// chunk and O(ε log N/base_size) across the tree.
34///
35/// 4096 puts the per-chunk error at ~12 ULP for f64 (well below f64's
36/// 52-bit mantissa) while cutting the horizontal-sum overhead 16× vs
37/// the previous 256. For N=1M that's 250 chunks instead of ~4000 — the
38/// difference between memory-bandwidth-bound (NumPy's 154 µs floor at
39/// 1M) and add-throughput-bound (242 µs at 1M).
40const PAIRWISE_BASE: usize = 4096;
41
42/// 8-wide unrolled base-case sum for auto-vectorization.
43#[inline(always)]
44fn base_sum<T: Copy + std::ops::Add<Output = T>>(data: &[T], identity: T) -> T {
45    let n = data.len();
46    let mut acc0 = identity;
47    let mut acc1 = identity;
48    let mut acc2 = identity;
49    let mut acc3 = identity;
50    let mut acc4 = identity;
51    let mut acc5 = identity;
52    let mut acc6 = identity;
53    let mut acc7 = identity;
54    let chunks = n / 8;
55    let rem = n % 8;
56    for i in 0..chunks {
57        let base = i * 8;
58        acc0 = acc0 + data[base];
59        acc1 = acc1 + data[base + 1];
60        acc2 = acc2 + data[base + 2];
61        acc3 = acc3 + data[base + 3];
62        acc4 = acc4 + data[base + 4];
63        acc5 = acc5 + data[base + 5];
64        acc6 = acc6 + data[base + 6];
65        acc7 = acc7 + data[base + 7];
66    }
67    for i in 0..rem {
68        acc0 = acc0 + data[chunks * 8 + i];
69    }
70    (acc0 + acc1) + (acc2 + acc3) + ((acc4 + acc5) + (acc6 + acc7))
71}
72
73/// Pairwise summation of a slice.
74///
75/// Uses an iterative carry-merge algorithm with O(ε log N) error bound,
76/// matching `NumPy`'s summation accuracy. Data is processed in 128-element
77/// base chunks, then merged pairwise using a small stack (like binary
78/// carry addition). Zero recursion overhead — stack depth is at most
79/// ceil(log2(N/128)) + 1 ≈ 20 entries.
80pub fn pairwise_sum<T>(data: &[T], identity: T) -> T
81where
82    T: Copy + std::ops::Add<Output = T>,
83{
84    let n = data.len();
85    if n == 0 {
86        return identity;
87    }
88    if n <= PAIRWISE_BASE {
89        return base_sum(data, identity);
90    }
91
92    // Iterative pairwise summation using a carry-merge stack.
93    // Each entry holds (partial_sum, level) where level is how many base
94    // chunks it represents. Same-level adjacent entries merge on push,
95    // exactly like binary carry addition. Max depth: ~24 for any realistic size.
96    let mut stack_val: [T; 24] = [identity; 24];
97    let mut stack_lvl: [usize; 24] = [0; 24];
98    let mut depth = 0usize;
99
100    let mut offset = 0;
101    while offset < n {
102        let end = (offset + PAIRWISE_BASE).min(n);
103        let mut current = base_sum(&data[offset..end], identity);
104        offset = end;
105
106        // Push and carry-merge: merge with stack top while levels match
107        let mut level = 1usize;
108        while depth > 0 && stack_lvl[depth - 1] == level {
109            depth -= 1;
110            current = stack_val[depth] + current;
111            level += 1;
112        }
113        stack_val[depth] = current;
114        stack_lvl[depth] = level;
115        depth += 1;
116    }
117
118    // Merge remaining stack entries
119    let mut result = stack_val[depth - 1];
120    for i in (0..depth - 1).rev() {
121        result = stack_val[i] + result;
122    }
123    result
124}
125
126// ---------------------------------------------------------------------------
127// SIMD-accelerated pairwise sum for f64
128// ---------------------------------------------------------------------------
129
130/// SIMD-accelerated pairwise summation for f64 slices.
131///
132/// Uses pulp for hardware SIMD dispatch (AVX2/SSE2/NEON) in the 128-element
133/// base case, with the same iterative carry-merge structure for O(ε log N)
134/// accuracy.
135#[must_use]
136pub fn pairwise_sum_f64(data: &[f64]) -> f64 {
137    Arch::new().dispatch(PairwiseSumF64Op { data })
138}
139
140struct PairwiseSumF64Op<'a> {
141    data: &'a [f64],
142}
143
144impl pulp::WithSimd for PairwiseSumF64Op<'_> {
145    type Output = f64;
146
147    #[inline(always)]
148    fn with_simd<S: pulp::Simd>(self, simd: S) -> f64 {
149        simd_pairwise_f64(simd, self.data)
150    }
151}
152
153#[inline(always)]
154fn simd_base_sum_f64<S: pulp::Simd>(simd: S, data: &[f64]) -> f64 {
155    let n = data.len();
156    let lane_count = size_of::<S::f64s>() / size_of::<f64>();
157    let simd_end = n - (n % lane_count);
158
159    let zero = simd.splat_f64s(0.0);
160    let mut acc0 = zero;
161    let mut acc1 = zero;
162    let mut acc2 = zero;
163    let mut acc3 = zero;
164
165    // 4-accumulator unroll to saturate FPU throughput
166    let stride = lane_count * 4;
167    let unrolled_end = n - (n % stride);
168    let mut i = 0;
169    while i < unrolled_end {
170        let v0 = simd.partial_load_f64s(&data[i..i + lane_count]);
171        let v1 = simd.partial_load_f64s(&data[i + lane_count..i + lane_count * 2]);
172        let v2 = simd.partial_load_f64s(&data[i + lane_count * 2..i + lane_count * 3]);
173        let v3 = simd.partial_load_f64s(&data[i + lane_count * 3..i + stride]);
174        acc0 = simd.add_f64s(acc0, v0);
175        acc1 = simd.add_f64s(acc1, v1);
176        acc2 = simd.add_f64s(acc2, v2);
177        acc3 = simd.add_f64s(acc3, v3);
178        i += stride;
179    }
180    while i + lane_count <= simd_end {
181        let v = simd.partial_load_f64s(&data[i..i + lane_count]);
182        acc0 = simd.add_f64s(acc0, v);
183        i += lane_count;
184    }
185    acc0 = simd.add_f64s(acc0, acc1);
186    acc2 = simd.add_f64s(acc2, acc3);
187    acc0 = simd.add_f64s(acc0, acc2);
188
189    // Horizontal sum: store SIMD register to temp array
190    let mut temp = [0.0f64; 8]; // max 8 lanes (AVX-512)
191    simd.partial_store_f64s(&mut temp[..lane_count], acc0);
192    let mut sum = 0.0f64;
193    for t in temp.iter().take(lane_count) {
194        sum += t;
195    }
196    // Scalar remainder
197    for &val in &data[simd_end..n] {
198        sum += val;
199    }
200    sum
201}
202
203#[inline(always)]
204fn simd_pairwise_f64<S: pulp::Simd>(simd: S, data: &[f64]) -> f64 {
205    let n = data.len();
206    if n == 0 {
207        return 0.0;
208    }
209    if n <= PAIRWISE_BASE {
210        return simd_base_sum_f64(simd, data);
211    }
212
213    let mut stack_val = [0.0f64; 24];
214    let mut stack_lvl = [0usize; 24];
215    let mut depth = 0usize;
216
217    let mut offset = 0;
218    while offset < n {
219        let end = (offset + PAIRWISE_BASE).min(n);
220        let mut current = simd_base_sum_f64(simd, &data[offset..end]);
221        offset = end;
222
223        let mut level = 1usize;
224        while depth > 0 && stack_lvl[depth - 1] == level {
225            depth -= 1;
226            current += stack_val[depth];
227            level += 1;
228        }
229        stack_val[depth] = current;
230        stack_lvl[depth] = level;
231        depth += 1;
232    }
233
234    let mut result = stack_val[depth - 1];
235    for i in (0..depth - 1).rev() {
236        result += stack_val[i];
237    }
238    result
239}
240
241// ---------------------------------------------------------------------------
242// SIMD-accelerated fused sum of squared differences for variance
243// ---------------------------------------------------------------------------
244
245/// SIMD-accelerated computation of sum((x - mean)²) for f64 slices.
246///
247/// Computes the sum of squared differences from the mean in a single pass
248/// without allocating an intermediate Vec. Uses 4 SIMD accumulators for ILP.
249#[must_use]
250pub fn simd_sum_sq_diff_f64(data: &[f64], mean: f64) -> f64 {
251    Arch::new().dispatch(SumSqDiffF64Op { data, mean })
252}
253
254struct SumSqDiffF64Op<'a> {
255    data: &'a [f64],
256    mean: f64,
257}
258
259impl pulp::WithSimd for SumSqDiffF64Op<'_> {
260    type Output = f64;
261
262    #[inline(always)]
263    fn with_simd<S: pulp::Simd>(self, simd: S) -> f64 {
264        let data = self.data;
265        let n = data.len();
266        let lane_count = size_of::<S::f64s>() / size_of::<f64>();
267        let simd_end = n - (n % lane_count);
268
269        let zero = simd.splat_f64s(0.0);
270        let mean_v = simd.splat_f64s(self.mean);
271        let mut acc0 = zero;
272        let mut acc1 = zero;
273        let mut acc2 = zero;
274        let mut acc3 = zero;
275
276        let stride = lane_count * 4;
277        let unrolled_end = n - (n % stride);
278        let mut i = 0;
279        while i < unrolled_end {
280            let v0 = simd.partial_load_f64s(&data[i..i + lane_count]);
281            let v1 = simd.partial_load_f64s(&data[i + lane_count..i + lane_count * 2]);
282            let v2 = simd.partial_load_f64s(&data[i + lane_count * 2..i + lane_count * 3]);
283            let v3 = simd.partial_load_f64s(&data[i + lane_count * 3..i + stride]);
284            let d0 = simd.sub_f64s(v0, mean_v);
285            let d1 = simd.sub_f64s(v1, mean_v);
286            let d2 = simd.sub_f64s(v2, mean_v);
287            let d3 = simd.sub_f64s(v3, mean_v);
288            acc0 = simd.mul_add_f64s(d0, d0, acc0);
289            acc1 = simd.mul_add_f64s(d1, d1, acc1);
290            acc2 = simd.mul_add_f64s(d2, d2, acc2);
291            acc3 = simd.mul_add_f64s(d3, d3, acc3);
292            i += stride;
293        }
294        while i + lane_count <= simd_end {
295            let v = simd.partial_load_f64s(&data[i..i + lane_count]);
296            let d = simd.sub_f64s(v, mean_v);
297            acc0 = simd.mul_add_f64s(d, d, acc0);
298            i += lane_count;
299        }
300        acc0 = simd.add_f64s(acc0, acc1);
301        acc2 = simd.add_f64s(acc2, acc3);
302        acc0 = simd.add_f64s(acc0, acc2);
303
304        let mut temp = [0.0f64; 8];
305        simd.partial_store_f64s(&mut temp[..lane_count], acc0);
306        let mut sum = 0.0f64;
307        for t in temp.iter().take(lane_count) {
308            sum += t;
309        }
310        for &val in &data[simd_end..n] {
311            let d = val - self.mean;
312            sum += d * d;
313        }
314        sum
315    }
316}
317
318// ---------------------------------------------------------------------------
319// SIMD-accelerated pairwise sum / sum-of-squared-differences for f32 (#173)
320//
321// Mirrors the f64 versions above with pulp's f32s SIMD type. Twice as
322// many lanes per register (e.g. AVX-512 f32 = 16 lanes vs f64 = 8), so
323// the constants for max-temp-buffer below need to match the widest f32
324// lane count. The base case and stack-based pairwise tree are
325// structurally identical.
326// ---------------------------------------------------------------------------
327
328/// SIMD-accelerated pairwise summation for f32 slices (#173).
329#[must_use]
330pub fn pairwise_sum_f32(data: &[f32]) -> f32 {
331    Arch::new().dispatch(PairwiseSumF32Op { data })
332}
333
334struct PairwiseSumF32Op<'a> {
335    data: &'a [f32],
336}
337
338impl pulp::WithSimd for PairwiseSumF32Op<'_> {
339    type Output = f32;
340
341    #[inline(always)]
342    fn with_simd<S: pulp::Simd>(self, simd: S) -> f32 {
343        simd_pairwise_f32(simd, self.data)
344    }
345}
346
347#[inline(always)]
348fn simd_base_sum_f32<S: pulp::Simd>(simd: S, data: &[f32]) -> f32 {
349    let n = data.len();
350    let lane_count = size_of::<S::f32s>() / size_of::<f32>();
351    let simd_end = n - (n % lane_count);
352
353    let zero = simd.splat_f32s(0.0);
354    let mut acc0 = zero;
355    let mut acc1 = zero;
356    let mut acc2 = zero;
357    let mut acc3 = zero;
358
359    let stride = lane_count * 4;
360    let unrolled_end = n - (n % stride);
361    let mut i = 0;
362    while i < unrolled_end {
363        let v0 = simd.partial_load_f32s(&data[i..i + lane_count]);
364        let v1 = simd.partial_load_f32s(&data[i + lane_count..i + lane_count * 2]);
365        let v2 = simd.partial_load_f32s(&data[i + lane_count * 2..i + lane_count * 3]);
366        let v3 = simd.partial_load_f32s(&data[i + lane_count * 3..i + stride]);
367        acc0 = simd.add_f32s(acc0, v0);
368        acc1 = simd.add_f32s(acc1, v1);
369        acc2 = simd.add_f32s(acc2, v2);
370        acc3 = simd.add_f32s(acc3, v3);
371        i += stride;
372    }
373    while i + lane_count <= simd_end {
374        let v = simd.partial_load_f32s(&data[i..i + lane_count]);
375        acc0 = simd.add_f32s(acc0, v);
376        i += lane_count;
377    }
378    acc0 = simd.add_f32s(acc0, acc1);
379    acc2 = simd.add_f32s(acc2, acc3);
380    acc0 = simd.add_f32s(acc0, acc2);
381
382    // Horizontal sum: store SIMD register to temp array. AVX-512 f32
383    // is 16 lanes wide, so size the temp accordingly.
384    let mut temp = [0.0f32; 16];
385    simd.partial_store_f32s(&mut temp[..lane_count], acc0);
386    let mut sum = 0.0f32;
387    for t in temp.iter().take(lane_count) {
388        sum += t;
389    }
390    for &val in &data[simd_end..n] {
391        sum += val;
392    }
393    sum
394}
395
396#[inline(always)]
397fn simd_pairwise_f32<S: pulp::Simd>(simd: S, data: &[f32]) -> f32 {
398    let n = data.len();
399    if n == 0 {
400        return 0.0;
401    }
402    if n <= PAIRWISE_BASE {
403        return simd_base_sum_f32(simd, data);
404    }
405
406    let mut stack_val = [0.0f32; 24];
407    let mut stack_lvl = [0usize; 24];
408    let mut depth = 0usize;
409
410    let mut offset = 0;
411    while offset < n {
412        let end = (offset + PAIRWISE_BASE).min(n);
413        let mut current = simd_base_sum_f32(simd, &data[offset..end]);
414        offset = end;
415
416        let mut level = 1usize;
417        while depth > 0 && stack_lvl[depth - 1] == level {
418            depth -= 1;
419            current += stack_val[depth];
420            level += 1;
421        }
422        stack_val[depth] = current;
423        stack_lvl[depth] = level;
424        depth += 1;
425    }
426
427    let mut result = stack_val[depth - 1];
428    for i in (0..depth - 1).rev() {
429        result += stack_val[i];
430    }
431    result
432}
433
434/// SIMD-accelerated computation of sum((x - mean)²) for f32 slices (#173).
435#[must_use]
436pub fn simd_sum_sq_diff_f32(data: &[f32], mean: f32) -> f32 {
437    Arch::new().dispatch(SumSqDiffF32Op { data, mean })
438}
439
440struct SumSqDiffF32Op<'a> {
441    data: &'a [f32],
442    mean: f32,
443}
444
445impl pulp::WithSimd for SumSqDiffF32Op<'_> {
446    type Output = f32;
447
448    #[inline(always)]
449    fn with_simd<S: pulp::Simd>(self, simd: S) -> f32 {
450        let data = self.data;
451        let n = data.len();
452        let lane_count = size_of::<S::f32s>() / size_of::<f32>();
453        let simd_end = n - (n % lane_count);
454
455        let zero = simd.splat_f32s(0.0);
456        let mean_v = simd.splat_f32s(self.mean);
457        let mut acc0 = zero;
458        let mut acc1 = zero;
459        let mut acc2 = zero;
460        let mut acc3 = zero;
461
462        let stride = lane_count * 4;
463        let unrolled_end = n - (n % stride);
464        let mut i = 0;
465        while i < unrolled_end {
466            let v0 = simd.partial_load_f32s(&data[i..i + lane_count]);
467            let v1 = simd.partial_load_f32s(&data[i + lane_count..i + lane_count * 2]);
468            let v2 = simd.partial_load_f32s(&data[i + lane_count * 2..i + lane_count * 3]);
469            let v3 = simd.partial_load_f32s(&data[i + lane_count * 3..i + stride]);
470            let d0 = simd.sub_f32s(v0, mean_v);
471            let d1 = simd.sub_f32s(v1, mean_v);
472            let d2 = simd.sub_f32s(v2, mean_v);
473            let d3 = simd.sub_f32s(v3, mean_v);
474            acc0 = simd.mul_add_f32s(d0, d0, acc0);
475            acc1 = simd.mul_add_f32s(d1, d1, acc1);
476            acc2 = simd.mul_add_f32s(d2, d2, acc2);
477            acc3 = simd.mul_add_f32s(d3, d3, acc3);
478            i += stride;
479        }
480        while i + lane_count <= simd_end {
481            let v = simd.partial_load_f32s(&data[i..i + lane_count]);
482            let d = simd.sub_f32s(v, mean_v);
483            acc0 = simd.mul_add_f32s(d, d, acc0);
484            i += lane_count;
485        }
486        acc0 = simd.add_f32s(acc0, acc1);
487        acc2 = simd.add_f32s(acc2, acc3);
488        acc0 = simd.add_f32s(acc0, acc2);
489
490        let mut temp = [0.0f32; 16];
491        simd.partial_store_f32s(&mut temp[..lane_count], acc0);
492        let mut sum = 0.0f32;
493        for t in temp.iter().take(lane_count) {
494            sum += t;
495        }
496        for &val in &data[simd_end..n] {
497            let d = val - self.mean;
498            sum += d * d;
499        }
500        sum
501    }
502}
503
504// ---------------------------------------------------------------------------
505// Parallel dispatch
506// ---------------------------------------------------------------------------
507
508/// Perform a parallel pairwise sum on a slice.
509///
510/// Uses rayon tree-reduce for large slices (which is inherently pairwise),
511/// and iterative pairwise summation for smaller slices.
512pub fn parallel_sum<T>(data: &[T], identity: T) -> T
513where
514    T: Copy + Send + Sync + std::ops::Add<Output = T>,
515{
516    if data.len() >= PARALLEL_REDUCTION_THRESHOLD {
517        // Rayon's reduce is a tree-reduce, which gives pairwise-like accuracy
518        data.par_iter().copied().reduce(|| identity, |a, b| a + b)
519    } else {
520        pairwise_sum(data, identity)
521    }
522}
523
524/// Perform a parallel product on a slice of `Copy + Send + Sync` values.
525pub fn parallel_prod<T>(data: &[T], identity: T) -> T
526where
527    T: Copy + Send + Sync + std::ops::Mul<Output = T>,
528{
529    if data.len() >= PARALLEL_REDUCTION_THRESHOLD {
530        data.par_iter().copied().reduce(|| identity, |a, b| a * b)
531    } else {
532        data.iter().copied().fold(identity, |a, b| a * b)
533    }
534}
535
536/// Parallel sort (unstable) for large slices. Returns a sorted copy.
537pub fn parallel_sort<T>(data: &mut [T])
538where
539    T: Copy + Send + Sync + PartialOrd,
540{
541    if data.len() >= PARALLEL_SORT_THRESHOLD {
542        data.par_sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
543    } else {
544        data.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
545    }
546}
547
548/// Parallel stable sort for large slices.
549pub fn parallel_sort_stable<T>(data: &mut [T])
550where
551    T: Copy + Send + Sync + PartialOrd,
552{
553    if data.len() >= PARALLEL_SORT_THRESHOLD {
554        data.par_sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
555    } else {
556        data.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
557    }
558}