Skip to main content

ferray_stats/
parallel.rs

1// ferray-stats: Rayon threshold dispatch for reductions and sorting (REQ-19, REQ-20)
2
3// `with_simd`, `simd_pairwise_f64`, `simd_base_sum_f64`, and `base_sum`
4// are pulp dispatch entry points where `#[inline(always)]` is required —
5// without it the SIMD instruction set selected at the call site fails to
6// inline into the kernel and dispatch becomes a no-op.
7#![allow(clippy::inline_always)]
8
9use pulp::Arch;
10use rayon::prelude::*;
11
12/// Threshold above which reductions use parallel tree-reduce.
13///
14/// Set high to avoid rayon thread-pool startup overhead dominating small workloads.
15/// Rayon's thread pool initialization costs ~1-3ms on first use, which is catastrophic
16/// for arrays under ~1M elements where sequential pairwise sum takes <1ms.
17pub const PARALLEL_REDUCTION_THRESHOLD: usize = 1_000_000;
18
19/// Threshold above which sorting uses parallel merge sort.
20pub const PARALLEL_SORT_THRESHOLD: usize = 100_000;
21
22/// Pairwise summation base case threshold.
23///
24/// Below this size, we use an unrolled sequential sum. 256 elements gives a
25/// good balance: halves carry-merge overhead vs 128 while keeping the same
26/// O(ε log N) accuracy bound (just one fewer merge level).
27/// Base-chunk size for pairwise summation. Each base chunk runs the
28/// flat SIMD inner-loop (4-way unrolled f64 adds with one horizontal
29/// reduction at the end); the pairwise tree only kicks in across base
30/// chunks. So this knob is the trade-off between (a) tree-induced
31/// horizontal-sum overhead — paid once per base chunk — and (b)
32/// numerical precision, which scales as O(ε * base_size) within a
33/// chunk and O(ε log N/base_size) across the tree.
34///
35/// 4096 puts the per-chunk error at ~12 ULP for f64 (well below f64's
36/// 52-bit mantissa) while cutting the horizontal-sum overhead 16× vs
37/// the previous 256. For N=1M that's 250 chunks instead of ~4000 — the
38/// difference between memory-bandwidth-bound (NumPy's 154 µs floor at
39/// 1M) and add-throughput-bound (242 µs at 1M).
40const PAIRWISE_BASE: usize = 4096;
41
42/// 8-wide unrolled base-case sum for auto-vectorization.
43#[inline(always)]
44fn base_sum<T: Copy + std::ops::Add<Output = T>>(data: &[T], identity: T) -> T {
45    let n = data.len();
46    let mut acc0 = identity;
47    let mut acc1 = identity;
48    let mut acc2 = identity;
49    let mut acc3 = identity;
50    let mut acc4 = identity;
51    let mut acc5 = identity;
52    let mut acc6 = identity;
53    let mut acc7 = identity;
54    let chunks = n / 8;
55    let rem = n % 8;
56    for i in 0..chunks {
57        let base = i * 8;
58        acc0 = acc0 + data[base];
59        acc1 = acc1 + data[base + 1];
60        acc2 = acc2 + data[base + 2];
61        acc3 = acc3 + data[base + 3];
62        acc4 = acc4 + data[base + 4];
63        acc5 = acc5 + data[base + 5];
64        acc6 = acc6 + data[base + 6];
65        acc7 = acc7 + data[base + 7];
66    }
67    for i in 0..rem {
68        acc0 = acc0 + data[chunks * 8 + i];
69    }
70    (acc0 + acc1) + (acc2 + acc3) + ((acc4 + acc5) + (acc6 + acc7))
71}
72
73/// Pairwise summation of a slice.
74///
75/// Uses an iterative carry-merge algorithm with O(ε log N) error bound,
76/// matching `NumPy`'s summation accuracy. Data is processed in 128-element
77/// base chunks, then merged pairwise using a small stack (like binary
78/// carry addition). Zero recursion overhead — stack depth is at most
79/// ceil(log2(N/128)) + 1 ≈ 20 entries.
80pub fn pairwise_sum<T>(data: &[T], identity: T) -> T
81where
82    T: Copy + std::ops::Add<Output = T>,
83{
84    let n = data.len();
85    if n == 0 {
86        return identity;
87    }
88    if n <= PAIRWISE_BASE {
89        return base_sum(data, identity);
90    }
91
92    // Iterative pairwise summation using a carry-merge stack.
93    // Each entry holds (partial_sum, level) where level is how many base
94    // chunks it represents. Same-level adjacent entries merge on push,
95    // exactly like binary carry addition. Max depth: ~24 for any realistic size.
96    let mut stack_val: [T; 24] = [identity; 24];
97    let mut stack_lvl: [usize; 24] = [0; 24];
98    let mut depth = 0usize;
99
100    let mut offset = 0;
101    while offset < n {
102        let end = (offset + PAIRWISE_BASE).min(n);
103        let mut current = base_sum(&data[offset..end], identity);
104        offset = end;
105
106        // Push and carry-merge: merge with stack top while levels match
107        let mut level = 1usize;
108        while depth > 0 && stack_lvl[depth - 1] == level {
109            depth -= 1;
110            current = stack_val[depth] + current;
111            level += 1;
112        }
113        stack_val[depth] = current;
114        stack_lvl[depth] = level;
115        depth += 1;
116    }
117
118    // Merge remaining stack entries
119    let mut result = stack_val[depth - 1];
120    for i in (0..depth - 1).rev() {
121        result = stack_val[i] + result;
122    }
123    result
124}
125
126// ---------------------------------------------------------------------------
127// SIMD-accelerated pairwise sum for f64
128// ---------------------------------------------------------------------------
129
130/// SIMD-accelerated pairwise summation for f64 slices.
131///
132/// Uses pulp for hardware SIMD dispatch (AVX2/SSE2/NEON) in the 128-element
133/// base case, with the same iterative carry-merge structure for O(ε log N)
134/// accuracy.
135#[must_use]
136pub fn pairwise_sum_f64(data: &[f64]) -> f64 {
137    Arch::new().dispatch(PairwiseSumF64Op { data })
138}
139
140struct PairwiseSumF64Op<'a> {
141    data: &'a [f64],
142}
143
144impl pulp::WithSimd for PairwiseSumF64Op<'_> {
145    type Output = f64;
146
147    #[inline(always)]
148    fn with_simd<S: pulp::Simd>(self, simd: S) -> f64 {
149        simd_pairwise_f64(simd, self.data)
150    }
151}
152
153#[inline(always)]
154fn simd_base_sum_f64<S: pulp::Simd>(simd: S, data: &[f64]) -> f64 {
155    let n = data.len();
156    let lane_count = size_of::<S::f64s>() / size_of::<f64>();
157    let simd_end = n - (n % lane_count);
158
159    let zero = simd.splat_f64s(0.0);
160    let mut acc0 = zero;
161    let mut acc1 = zero;
162    let mut acc2 = zero;
163    let mut acc3 = zero;
164
165    // 4-accumulator unroll to saturate FPU throughput
166    let stride = lane_count * 4;
167    let unrolled_end = n - (n % stride);
168    let mut i = 0;
169    while i < unrolled_end {
170        let v0 = simd.partial_load_f64s(&data[i..i + lane_count]);
171        let v1 = simd.partial_load_f64s(&data[i + lane_count..i + lane_count * 2]);
172        let v2 = simd.partial_load_f64s(&data[i + lane_count * 2..i + lane_count * 3]);
173        let v3 = simd.partial_load_f64s(&data[i + lane_count * 3..i + stride]);
174        acc0 = simd.add_f64s(acc0, v0);
175        acc1 = simd.add_f64s(acc1, v1);
176        acc2 = simd.add_f64s(acc2, v2);
177        acc3 = simd.add_f64s(acc3, v3);
178        i += stride;
179    }
180    while i + lane_count <= simd_end {
181        let v = simd.partial_load_f64s(&data[i..i + lane_count]);
182        acc0 = simd.add_f64s(acc0, v);
183        i += lane_count;
184    }
185    acc0 = simd.add_f64s(acc0, acc1);
186    acc2 = simd.add_f64s(acc2, acc3);
187    acc0 = simd.add_f64s(acc0, acc2);
188
189    // Horizontal sum: store SIMD register to temp array
190    let mut temp = [0.0f64; 8]; // max 8 lanes (AVX-512)
191    simd.partial_store_f64s(&mut temp[..lane_count], acc0);
192    let mut sum = 0.0f64;
193    for t in temp.iter().take(lane_count) {
194        sum += t;
195    }
196    // Scalar remainder
197    for &val in &data[simd_end..n] {
198        sum += val;
199    }
200    sum
201}
202
203#[inline(always)]
204fn simd_pairwise_f64<S: pulp::Simd>(simd: S, data: &[f64]) -> f64 {
205    let n = data.len();
206    if n == 0 {
207        return 0.0;
208    }
209    if n <= PAIRWISE_BASE {
210        return simd_base_sum_f64(simd, data);
211    }
212
213    let mut stack_val = [0.0f64; 24];
214    let mut stack_lvl = [0usize; 24];
215    let mut depth = 0usize;
216
217    let mut offset = 0;
218    while offset < n {
219        let end = (offset + PAIRWISE_BASE).min(n);
220        let mut current = simd_base_sum_f64(simd, &data[offset..end]);
221        offset = end;
222
223        let mut level = 1usize;
224        while depth > 0 && stack_lvl[depth - 1] == level {
225            depth -= 1;
226            current += stack_val[depth];
227            level += 1;
228        }
229        stack_val[depth] = current;
230        stack_lvl[depth] = level;
231        depth += 1;
232    }
233
234    let mut result = stack_val[depth - 1];
235    for i in (0..depth - 1).rev() {
236        result += stack_val[i];
237    }
238    result
239}
240
241// ---------------------------------------------------------------------------
242// SIMD-accelerated fused sum of squared differences for variance
243// ---------------------------------------------------------------------------
244
245/// SIMD-accelerated computation of sum((x - mean)²) for f64 slices.
246///
247/// Computes the sum of squared differences from the mean in a single pass
248/// without allocating an intermediate Vec. Uses 4 SIMD accumulators for ILP.
249#[must_use]
250pub fn simd_sum_sq_diff_f64(data: &[f64], mean: f64) -> f64 {
251    Arch::new().dispatch(SumSqDiffF64Op { data, mean })
252}
253
254struct SumSqDiffF64Op<'a> {
255    data: &'a [f64],
256    mean: f64,
257}
258
259impl pulp::WithSimd for SumSqDiffF64Op<'_> {
260    type Output = f64;
261
262    #[inline(always)]
263    fn with_simd<S: pulp::Simd>(self, simd: S) -> f64 {
264        let data = self.data;
265        let n = data.len();
266        let lane_count = size_of::<S::f64s>() / size_of::<f64>();
267        let simd_end = n - (n % lane_count);
268
269        let zero = simd.splat_f64s(0.0);
270        let mean_v = simd.splat_f64s(self.mean);
271        let mut acc0 = zero;
272        let mut acc1 = zero;
273        let mut acc2 = zero;
274        let mut acc3 = zero;
275
276        let stride = lane_count * 4;
277        let unrolled_end = n - (n % stride);
278        let mut i = 0;
279        while i < unrolled_end {
280            let v0 = simd.partial_load_f64s(&data[i..i + lane_count]);
281            let v1 = simd.partial_load_f64s(&data[i + lane_count..i + lane_count * 2]);
282            let v2 = simd.partial_load_f64s(&data[i + lane_count * 2..i + lane_count * 3]);
283            let v3 = simd.partial_load_f64s(&data[i + lane_count * 3..i + stride]);
284            let d0 = simd.sub_f64s(v0, mean_v);
285            let d1 = simd.sub_f64s(v1, mean_v);
286            let d2 = simd.sub_f64s(v2, mean_v);
287            let d3 = simd.sub_f64s(v3, mean_v);
288            acc0 = simd.mul_add_f64s(d0, d0, acc0);
289            acc1 = simd.mul_add_f64s(d1, d1, acc1);
290            acc2 = simd.mul_add_f64s(d2, d2, acc2);
291            acc3 = simd.mul_add_f64s(d3, d3, acc3);
292            i += stride;
293        }
294        while i + lane_count <= simd_end {
295            let v = simd.partial_load_f64s(&data[i..i + lane_count]);
296            let d = simd.sub_f64s(v, mean_v);
297            acc0 = simd.mul_add_f64s(d, d, acc0);
298            i += lane_count;
299        }
300        acc0 = simd.add_f64s(acc0, acc1);
301        acc2 = simd.add_f64s(acc2, acc3);
302        acc0 = simd.add_f64s(acc0, acc2);
303
304        let mut temp = [0.0f64; 8];
305        simd.partial_store_f64s(&mut temp[..lane_count], acc0);
306        let mut sum = 0.0f64;
307        for t in temp.iter().take(lane_count) {
308            sum += t;
309        }
310        for &val in &data[simd_end..n] {
311            let d = val - self.mean;
312            sum += d * d;
313        }
314        sum
315    }
316}
317
318// ---------------------------------------------------------------------------
319// Parallel dispatch
320// ---------------------------------------------------------------------------
321
322/// Perform a parallel pairwise sum on a slice.
323///
324/// Uses rayon tree-reduce for large slices (which is inherently pairwise),
325/// and iterative pairwise summation for smaller slices.
326pub fn parallel_sum<T>(data: &[T], identity: T) -> T
327where
328    T: Copy + Send + Sync + std::ops::Add<Output = T>,
329{
330    if data.len() >= PARALLEL_REDUCTION_THRESHOLD {
331        // Rayon's reduce is a tree-reduce, which gives pairwise-like accuracy
332        data.par_iter().copied().reduce(|| identity, |a, b| a + b)
333    } else {
334        pairwise_sum(data, identity)
335    }
336}
337
338/// Perform a parallel product on a slice of `Copy + Send + Sync` values.
339pub fn parallel_prod<T>(data: &[T], identity: T) -> T
340where
341    T: Copy + Send + Sync + std::ops::Mul<Output = T>,
342{
343    if data.len() >= PARALLEL_REDUCTION_THRESHOLD {
344        data.par_iter().copied().reduce(|| identity, |a, b| a * b)
345    } else {
346        data.iter().copied().fold(identity, |a, b| a * b)
347    }
348}
349
350/// Parallel sort (unstable) for large slices. Returns a sorted copy.
351pub fn parallel_sort<T>(data: &mut [T])
352where
353    T: Copy + Send + Sync + PartialOrd,
354{
355    if data.len() >= PARALLEL_SORT_THRESHOLD {
356        data.par_sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
357    } else {
358        data.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
359    }
360}
361
362/// Parallel stable sort for large slices.
363pub fn parallel_sort_stable<T>(data: &mut [T])
364where
365    T: Copy + Send + Sync + PartialOrd,
366{
367    if data.len() >= PARALLEL_SORT_THRESHOLD {
368        data.par_sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
369    } else {
370        data.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
371    }
372}