Skip to main content

ferray_stats/
parallel.rs

1// ferray-stats: Rayon threshold dispatch for reductions and sorting (REQ-19, REQ-20)
2
3use pulp::Arch;
4use rayon::prelude::*;
5
6/// Threshold above which reductions use parallel tree-reduce.
7///
8/// Set high to avoid rayon thread-pool startup overhead dominating small workloads.
9/// Rayon's thread pool initialization costs ~1-3ms on first use, which is catastrophic
10/// for arrays under ~1M elements where sequential pairwise sum takes <1ms.
11pub const PARALLEL_REDUCTION_THRESHOLD: usize = 1_000_000;
12
13/// Threshold above which sorting uses parallel merge sort.
14pub const PARALLEL_SORT_THRESHOLD: usize = 100_000;
15
16/// Pairwise summation base case threshold.
17///
18/// Below this size, we use an unrolled sequential sum. 256 elements gives a
19/// good balance: halves carry-merge overhead vs 128 while keeping the same
20/// O(ε log N) accuracy bound (just one fewer merge level).
21const PAIRWISE_BASE: usize = 256;
22
23/// 8-wide unrolled base-case sum for auto-vectorization.
24#[inline(always)]
25fn base_sum<T: Copy + std::ops::Add<Output = T>>(data: &[T], identity: T) -> T {
26    let n = data.len();
27    let mut acc0 = identity;
28    let mut acc1 = identity;
29    let mut acc2 = identity;
30    let mut acc3 = identity;
31    let mut acc4 = identity;
32    let mut acc5 = identity;
33    let mut acc6 = identity;
34    let mut acc7 = identity;
35    let chunks = n / 8;
36    let rem = n % 8;
37    for i in 0..chunks {
38        let base = i * 8;
39        acc0 = acc0 + data[base];
40        acc1 = acc1 + data[base + 1];
41        acc2 = acc2 + data[base + 2];
42        acc3 = acc3 + data[base + 3];
43        acc4 = acc4 + data[base + 4];
44        acc5 = acc5 + data[base + 5];
45        acc6 = acc6 + data[base + 6];
46        acc7 = acc7 + data[base + 7];
47    }
48    for i in 0..rem {
49        acc0 = acc0 + data[chunks * 8 + i];
50    }
51    (acc0 + acc1) + (acc2 + acc3) + ((acc4 + acc5) + (acc6 + acc7))
52}
53
54/// Pairwise summation of a slice.
55///
56/// Uses an iterative carry-merge algorithm with O(ε log N) error bound,
57/// matching NumPy's summation accuracy. Data is processed in 128-element
58/// base chunks, then merged pairwise using a small stack (like binary
59/// carry addition). Zero recursion overhead — stack depth is at most
60/// ceil(log2(N/128)) + 1 ≈ 20 entries.
61pub fn pairwise_sum<T>(data: &[T], identity: T) -> T
62where
63    T: Copy + std::ops::Add<Output = T>,
64{
65    let n = data.len();
66    if n == 0 {
67        return identity;
68    }
69    if n <= PAIRWISE_BASE {
70        return base_sum(data, identity);
71    }
72
73    // Iterative pairwise summation using a carry-merge stack.
74    // Each entry holds (partial_sum, level) where level is how many base
75    // chunks it represents. Same-level adjacent entries merge on push,
76    // exactly like binary carry addition. Max depth: ~24 for any realistic size.
77    let mut stack_val: [T; 24] = [identity; 24];
78    let mut stack_lvl: [usize; 24] = [0; 24];
79    let mut depth = 0usize;
80
81    let mut offset = 0;
82    while offset < n {
83        let end = (offset + PAIRWISE_BASE).min(n);
84        let mut current = base_sum(&data[offset..end], identity);
85        offset = end;
86
87        // Push and carry-merge: merge with stack top while levels match
88        let mut level = 1usize;
89        while depth > 0 && stack_lvl[depth - 1] == level {
90            depth -= 1;
91            current = stack_val[depth] + current;
92            level += 1;
93        }
94        stack_val[depth] = current;
95        stack_lvl[depth] = level;
96        depth += 1;
97    }
98
99    // Merge remaining stack entries
100    let mut result = stack_val[depth - 1];
101    for i in (0..depth - 1).rev() {
102        result = stack_val[i] + result;
103    }
104    result
105}
106
107// ---------------------------------------------------------------------------
108// SIMD-accelerated pairwise sum for f64
109// ---------------------------------------------------------------------------
110
111/// SIMD-accelerated pairwise summation for f64 slices.
112///
113/// Uses pulp for hardware SIMD dispatch (AVX2/SSE2/NEON) in the 128-element
114/// base case, with the same iterative carry-merge structure for O(ε log N)
115/// accuracy.
116pub fn pairwise_sum_f64(data: &[f64]) -> f64 {
117    Arch::new().dispatch(PairwiseSumF64Op { data })
118}
119
120struct PairwiseSumF64Op<'a> {
121    data: &'a [f64],
122}
123
124impl pulp::WithSimd for PairwiseSumF64Op<'_> {
125    type Output = f64;
126
127    #[inline(always)]
128    fn with_simd<S: pulp::Simd>(self, simd: S) -> f64 {
129        simd_pairwise_f64(simd, self.data)
130    }
131}
132
133#[inline(always)]
134fn simd_base_sum_f64<S: pulp::Simd>(simd: S, data: &[f64]) -> f64 {
135    let n = data.len();
136    let lane_count = size_of::<S::f64s>() / size_of::<f64>();
137    let simd_end = n - (n % lane_count);
138
139    let zero = simd.splat_f64s(0.0);
140    let mut acc0 = zero;
141    let mut acc1 = zero;
142    let mut acc2 = zero;
143    let mut acc3 = zero;
144
145    // 4-accumulator unroll to saturate FPU throughput
146    let stride = lane_count * 4;
147    let unrolled_end = n - (n % stride);
148    let mut i = 0;
149    while i < unrolled_end {
150        let v0 = simd.partial_load_f64s(&data[i..i + lane_count]);
151        let v1 = simd.partial_load_f64s(&data[i + lane_count..i + lane_count * 2]);
152        let v2 = simd.partial_load_f64s(&data[i + lane_count * 2..i + lane_count * 3]);
153        let v3 = simd.partial_load_f64s(&data[i + lane_count * 3..i + stride]);
154        acc0 = simd.add_f64s(acc0, v0);
155        acc1 = simd.add_f64s(acc1, v1);
156        acc2 = simd.add_f64s(acc2, v2);
157        acc3 = simd.add_f64s(acc3, v3);
158        i += stride;
159    }
160    while i + lane_count <= simd_end {
161        let v = simd.partial_load_f64s(&data[i..i + lane_count]);
162        acc0 = simd.add_f64s(acc0, v);
163        i += lane_count;
164    }
165    acc0 = simd.add_f64s(acc0, acc1);
166    acc2 = simd.add_f64s(acc2, acc3);
167    acc0 = simd.add_f64s(acc0, acc2);
168
169    // Horizontal sum: store SIMD register to temp array
170    let mut temp = [0.0f64; 8]; // max 8 lanes (AVX-512)
171    simd.partial_store_f64s(&mut temp[..lane_count], acc0);
172    let mut sum = 0.0f64;
173    for t in temp.iter().take(lane_count) {
174        sum += t;
175    }
176    // Scalar remainder
177    for &val in &data[simd_end..n] {
178        sum += val;
179    }
180    sum
181}
182
183#[inline(always)]
184fn simd_pairwise_f64<S: pulp::Simd>(simd: S, data: &[f64]) -> f64 {
185    let n = data.len();
186    if n == 0 {
187        return 0.0;
188    }
189    if n <= PAIRWISE_BASE {
190        return simd_base_sum_f64(simd, data);
191    }
192
193    let mut stack_val = [0.0f64; 24];
194    let mut stack_lvl = [0usize; 24];
195    let mut depth = 0usize;
196
197    let mut offset = 0;
198    while offset < n {
199        let end = (offset + PAIRWISE_BASE).min(n);
200        let mut current = simd_base_sum_f64(simd, &data[offset..end]);
201        offset = end;
202
203        let mut level = 1usize;
204        while depth > 0 && stack_lvl[depth - 1] == level {
205            depth -= 1;
206            current += stack_val[depth];
207            level += 1;
208        }
209        stack_val[depth] = current;
210        stack_lvl[depth] = level;
211        depth += 1;
212    }
213
214    let mut result = stack_val[depth - 1];
215    for i in (0..depth - 1).rev() {
216        result += stack_val[i];
217    }
218    result
219}
220
221// ---------------------------------------------------------------------------
222// SIMD-accelerated fused sum of squared differences for variance
223// ---------------------------------------------------------------------------
224
225/// SIMD-accelerated computation of sum((x - mean)²) for f64 slices.
226///
227/// Computes the sum of squared differences from the mean in a single pass
228/// without allocating an intermediate Vec. Uses 4 SIMD accumulators for ILP.
229pub fn simd_sum_sq_diff_f64(data: &[f64], mean: f64) -> f64 {
230    Arch::new().dispatch(SumSqDiffF64Op { data, mean })
231}
232
233struct SumSqDiffF64Op<'a> {
234    data: &'a [f64],
235    mean: f64,
236}
237
238impl pulp::WithSimd for SumSqDiffF64Op<'_> {
239    type Output = f64;
240
241    #[inline(always)]
242    fn with_simd<S: pulp::Simd>(self, simd: S) -> f64 {
243        let data = self.data;
244        let n = data.len();
245        let lane_count = size_of::<S::f64s>() / size_of::<f64>();
246        let simd_end = n - (n % lane_count);
247
248        let zero = simd.splat_f64s(0.0);
249        let mean_v = simd.splat_f64s(self.mean);
250        let mut acc0 = zero;
251        let mut acc1 = zero;
252        let mut acc2 = zero;
253        let mut acc3 = zero;
254
255        let stride = lane_count * 4;
256        let unrolled_end = n - (n % stride);
257        let mut i = 0;
258        while i < unrolled_end {
259            let v0 = simd.partial_load_f64s(&data[i..i + lane_count]);
260            let v1 = simd.partial_load_f64s(&data[i + lane_count..i + lane_count * 2]);
261            let v2 = simd.partial_load_f64s(&data[i + lane_count * 2..i + lane_count * 3]);
262            let v3 = simd.partial_load_f64s(&data[i + lane_count * 3..i + stride]);
263            let d0 = simd.sub_f64s(v0, mean_v);
264            let d1 = simd.sub_f64s(v1, mean_v);
265            let d2 = simd.sub_f64s(v2, mean_v);
266            let d3 = simd.sub_f64s(v3, mean_v);
267            acc0 = simd.mul_add_f64s(d0, d0, acc0);
268            acc1 = simd.mul_add_f64s(d1, d1, acc1);
269            acc2 = simd.mul_add_f64s(d2, d2, acc2);
270            acc3 = simd.mul_add_f64s(d3, d3, acc3);
271            i += stride;
272        }
273        while i + lane_count <= simd_end {
274            let v = simd.partial_load_f64s(&data[i..i + lane_count]);
275            let d = simd.sub_f64s(v, mean_v);
276            acc0 = simd.mul_add_f64s(d, d, acc0);
277            i += lane_count;
278        }
279        acc0 = simd.add_f64s(acc0, acc1);
280        acc2 = simd.add_f64s(acc2, acc3);
281        acc0 = simd.add_f64s(acc0, acc2);
282
283        let mut temp = [0.0f64; 8];
284        simd.partial_store_f64s(&mut temp[..lane_count], acc0);
285        let mut sum = 0.0f64;
286        for t in temp.iter().take(lane_count) {
287            sum += t;
288        }
289        for &val in &data[simd_end..n] {
290            let d = val - self.mean;
291            sum += d * d;
292        }
293        sum
294    }
295}
296
297// ---------------------------------------------------------------------------
298// Parallel dispatch
299// ---------------------------------------------------------------------------
300
301/// Perform a parallel pairwise sum on a slice.
302///
303/// Uses rayon tree-reduce for large slices (which is inherently pairwise),
304/// and iterative pairwise summation for smaller slices.
305pub fn parallel_sum<T>(data: &[T], identity: T) -> T
306where
307    T: Copy + Send + Sync + std::ops::Add<Output = T>,
308{
309    if data.len() >= PARALLEL_REDUCTION_THRESHOLD {
310        // Rayon's reduce is a tree-reduce, which gives pairwise-like accuracy
311        data.par_iter().copied().reduce(|| identity, |a, b| a + b)
312    } else {
313        pairwise_sum(data, identity)
314    }
315}
316
317/// Perform a parallel product on a slice of `Copy + Send + Sync` values.
318pub fn parallel_prod<T>(data: &[T], identity: T) -> T
319where
320    T: Copy + Send + Sync + std::ops::Mul<Output = T>,
321{
322    if data.len() >= PARALLEL_REDUCTION_THRESHOLD {
323        data.par_iter().copied().reduce(|| identity, |a, b| a * b)
324    } else {
325        data.iter().copied().fold(identity, |a, b| a * b)
326    }
327}
328
329/// Parallel sort (unstable) for large slices. Returns a sorted copy.
330pub fn parallel_sort<T>(data: &mut [T])
331where
332    T: Copy + Send + Sync + PartialOrd,
333{
334    if data.len() >= PARALLEL_SORT_THRESHOLD {
335        data.par_sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
336    } else {
337        data.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
338    }
339}
340
341/// Parallel stable sort for large slices.
342pub fn parallel_sort_stable<T>(data: &mut [T])
343where
344    T: Copy + Send + Sync + PartialOrd,
345{
346    if data.len() >= PARALLEL_SORT_THRESHOLD {
347        data.par_sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
348    } else {
349        data.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
350    }
351}