Skip to main content

scirs2_stats/
parallel_processing.rs

1//! Expanded parallel processing for statistical computations
2//!
3//! This module provides parallelized implementations of:
4//! - Descriptive statistics (Welford's algorithm, parallel quantiles, histograms)
5//! - Hypothesis testing (permutation tests, bootstrap, cross-validation)
6//! - Distribution fitting (parallel MLE, grid search)
7//!
8//! All parallel code is feature-gated behind `cfg(feature = "parallel")` (which scirs2-core
9//! enables via `scirs2-core/parallel = ["rayon"]`).
10
11use crate::error::{StatsError, StatsResult};
12use scirs2_core::ndarray::{Array1, Array2};
13use scirs2_core::parallel_ops::{num_threads, par_chunks, parallel_map, ParallelIterator};
14
15// ---------------------------------------------------------------------------
16// Configuration
17// ---------------------------------------------------------------------------
18
19/// Threshold for switching to parallel execution
20const PAR_THRESHOLD: usize = 5_000;
21
22// ===========================================================================
23// Part 1: Parallel Descriptive Statistics
24// ===========================================================================
25
26// ---------------------------------------------------------------------------
27// Welford accumulators (online, mergeable)
28// ---------------------------------------------------------------------------
29
30/// Mergeable accumulator for mean, variance, skewness, kurtosis via Welford's method.
31///
32/// Uses the parallel variant from Chan, Golub & LeVeque (1979).
33#[derive(Debug, Clone)]
34pub struct WelfordAccumulator {
35    /// Number of observations
36    pub n: u64,
37    /// Running mean
38    pub mean: f64,
39    /// Sum of squared deviations from mean (M2)
40    pub m2: f64,
41    /// Third central moment accumulator (M3)
42    pub m3: f64,
43    /// Fourth central moment accumulator (M4)
44    pub m4: f64,
45}
46
47impl Default for WelfordAccumulator {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53impl WelfordAccumulator {
54    /// Create an empty accumulator.
55    pub fn new() -> Self {
56        WelfordAccumulator {
57            n: 0,
58            mean: 0.0,
59            m2: 0.0,
60            m3: 0.0,
61            m4: 0.0,
62        }
63    }
64
65    /// Add a single observation.
66    pub fn push(&mut self, x: f64) {
67        let n1 = self.n;
68        self.n += 1;
69        let n = self.n as f64;
70        let delta = x - self.mean;
71        let delta_n = delta / n;
72        let delta_n2 = delta_n * delta_n;
73        let term1 = delta * delta_n * n1 as f64;
74
75        self.m4 += term1 * delta_n2 * (n * n - 3.0 * n + 3.0) + 6.0 * delta_n2 * self.m2
76            - 4.0 * delta_n * self.m3;
77        self.m3 += term1 * delta_n * (n - 2.0) - 3.0 * delta_n * self.m2;
78        self.m2 += term1;
79        self.mean += delta_n;
80    }
81
82    /// Merge another accumulator into this one (parallel combine step).
83    ///
84    /// Implements the parallel formulas from Chan, Golub & LeVeque.
85    pub fn merge(&mut self, other: &WelfordAccumulator) {
86        if other.n == 0 {
87            return;
88        }
89        if self.n == 0 {
90            *self = other.clone();
91            return;
92        }
93
94        let na = self.n as f64;
95        let nb = other.n as f64;
96        let n_total = na + nb;
97
98        let delta = other.mean - self.mean;
99        let delta2 = delta * delta;
100        let delta3 = delta2 * delta;
101        let delta4 = delta2 * delta2;
102
103        let new_mean = (na * self.mean + nb * other.mean) / n_total;
104
105        let new_m2 = self.m2 + other.m2 + delta2 * na * nb / n_total;
106
107        let new_m3 = self.m3
108            + other.m3
109            + delta3 * na * nb * (na - nb) / (n_total * n_total)
110            + 3.0 * delta * (na * other.m2 - nb * self.m2) / n_total;
111
112        let new_m4 = self.m4
113            + other.m4
114            + delta4 * na * nb * (na * na - na * nb + nb * nb) / (n_total * n_total * n_total)
115            + 6.0 * delta2 * (na * na * other.m2 + nb * nb * self.m2) / (n_total * n_total)
116            + 4.0 * delta * (na * other.m3 - nb * self.m3) / n_total;
117
118        self.n = self.n + other.n;
119        self.mean = new_mean;
120        self.m2 = new_m2;
121        self.m3 = new_m3;
122        self.m4 = new_m4;
123    }
124
125    /// Population variance
126    pub fn variance(&self) -> f64 {
127        if self.n < 2 {
128            return 0.0;
129        }
130        self.m2 / self.n as f64
131    }
132
133    /// Sample variance (ddof=1)
134    pub fn sample_variance(&self) -> f64 {
135        if self.n < 2 {
136            return 0.0;
137        }
138        self.m2 / (self.n - 1) as f64
139    }
140
141    /// Population skewness (Fisher's definition)
142    pub fn skewness(&self) -> f64 {
143        if self.n < 3 || self.m2.abs() < 1e-300 {
144            return 0.0;
145        }
146        let n = self.n as f64;
147        n.sqrt() * self.m3 / self.m2.powf(1.5)
148    }
149
150    /// Excess kurtosis (Fisher's definition)
151    pub fn kurtosis(&self) -> f64 {
152        if self.n < 4 || self.m2.abs() < 1e-300 {
153            return 0.0;
154        }
155        let n = self.n as f64;
156        n * self.m4 / (self.m2 * self.m2) - 3.0
157    }
158}
159
160/// Compute mean, variance, skewness, and kurtosis in parallel via Welford's method.
161///
162/// Uses the parallel merge variant of Welford's algorithm that is numerically
163/// stable even for large datasets.
164///
165/// # Arguments
166///
167/// * `data` - Input data slice
168///
169/// # Returns
170///
171/// A `WelfordAccumulator` containing all four statistics.
172pub fn parallel_moments(data: &[f64]) -> WelfordAccumulator {
173    if data.len() < PAR_THRESHOLD {
174        // Sequential path
175        let mut acc = WelfordAccumulator::new();
176        for &x in data {
177            acc.push(x);
178        }
179        return acc;
180    }
181
182    let chunk_size = (data.len() / num_threads()).max(1000);
183    par_chunks(data, chunk_size)
184        .map(|chunk| {
185            let mut acc = WelfordAccumulator::new();
186            for &x in chunk {
187                acc.push(x);
188            }
189            acc
190        })
191        .reduce(WelfordAccumulator::new, |mut a, b| {
192            a.merge(&b);
193            a
194        })
195}
196
197/// Compute mean in parallel via Welford's online algorithm.
198pub fn parallel_welford_mean(data: &[f64]) -> StatsResult<f64> {
199    if data.is_empty() {
200        return Err(StatsError::InvalidArgument(
201            "Cannot compute mean of empty array".to_string(),
202        ));
203    }
204    Ok(parallel_moments(data).mean)
205}
206
207/// Compute variance in parallel via Welford's online algorithm.
208///
209/// # Arguments
210///
211/// * `data` - Input data
212/// * `ddof` - Delta degrees of freedom (0 = population, 1 = sample)
213pub fn parallel_welford_variance(data: &[f64], ddof: usize) -> StatsResult<f64> {
214    if data.len() <= ddof {
215        return Err(StatsError::InvalidArgument(
216            "Not enough data for given ddof".to_string(),
217        ));
218    }
219    let acc = parallel_moments(data);
220    if ddof == 0 {
221        Ok(acc.variance())
222    } else {
223        Ok(acc.m2 / (acc.n as f64 - ddof as f64))
224    }
225}
226
227/// Compute skewness in parallel.
228pub fn parallel_welford_skewness(data: &[f64]) -> StatsResult<f64> {
229    if data.len() < 3 {
230        return Err(StatsError::InvalidArgument(
231            "Need at least 3 observations for skewness".to_string(),
232        ));
233    }
234    Ok(parallel_moments(data).skewness())
235}
236
237/// Compute excess kurtosis in parallel.
238pub fn parallel_welford_kurtosis(data: &[f64]) -> StatsResult<f64> {
239    if data.len() < 4 {
240        return Err(StatsError::InvalidArgument(
241            "Need at least 4 observations for kurtosis".to_string(),
242        ));
243    }
244    Ok(parallel_moments(data).kurtosis())
245}
246
247// ---------------------------------------------------------------------------
248// Parallel quantile estimation (parallel-sort approach)
249// ---------------------------------------------------------------------------
250
251/// Compute a single quantile using parallel sort.
252///
253/// For large arrays this is faster than a sequential sort because the
254/// merge phase can be overlapped with computation.
255///
256/// # Arguments
257///
258/// * `data` - Input data (will be sorted internally)
259/// * `q` - Quantile in [0, 1]
260pub fn parallel_quantile(data: &[f64], q: f64) -> StatsResult<f64> {
261    if data.is_empty() {
262        return Err(StatsError::InvalidArgument(
263            "Cannot compute quantile of empty array".to_string(),
264        ));
265    }
266    if q < 0.0 || q > 1.0 {
267        return Err(StatsError::InvalidArgument(
268            "Quantile must be in [0, 1]".to_string(),
269        ));
270    }
271
272    let mut sorted = data.to_vec();
273    sorted.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
274
275    let n = sorted.len();
276    if n == 1 {
277        return Ok(sorted[0]);
278    }
279
280    let pos = q * (n - 1) as f64;
281    let lo = pos.floor() as usize;
282    let hi = pos.ceil() as usize;
283    let frac = pos - lo as f64;
284
285    if lo == hi || hi >= n {
286        Ok(sorted[lo.min(n - 1)])
287    } else {
288        Ok(sorted[lo] * (1.0 - frac) + sorted[hi] * frac)
289    }
290}
291
292/// Compute the median in parallel.
293pub fn parallel_median(data: &[f64]) -> StatsResult<f64> {
294    parallel_quantile(data, 0.5)
295}
296
297// ---------------------------------------------------------------------------
298// Parallel histogram
299// ---------------------------------------------------------------------------
300
301/// Result of a parallel histogram computation.
302#[derive(Debug, Clone)]
303pub struct ParallelHistogramResult {
304    /// Bin edges (length = n_bins + 1)
305    pub edges: Vec<f64>,
306    /// Counts per bin (length = n_bins)
307    pub counts: Vec<u64>,
308    /// Total number of observations
309    pub total: u64,
310}
311
312/// Compute a histogram in parallel.
313///
314/// # Arguments
315///
316/// * `data` - Input data
317/// * `n_bins` - Number of bins
318///
319/// # Returns
320///
321/// `ParallelHistogramResult` with edges, counts, and total.
322pub fn parallel_histogram(data: &[f64], n_bins: usize) -> StatsResult<ParallelHistogramResult> {
323    if data.is_empty() {
324        return Err(StatsError::InvalidArgument(
325            "Cannot compute histogram of empty array".to_string(),
326        ));
327    }
328    if n_bins == 0 {
329        return Err(StatsError::InvalidArgument(
330            "Number of bins must be > 0".to_string(),
331        ));
332    }
333
334    // Find min/max
335    let (min_val, max_val) = data
336        .iter()
337        .fold((f64::INFINITY, f64::NEG_INFINITY), |(lo, hi), &x| {
338            (lo.min(x), hi.max(x))
339        });
340
341    if !min_val.is_finite() || !max_val.is_finite() {
342        return Err(StatsError::InvalidArgument(
343            "Data contains non-finite values".to_string(),
344        ));
345    }
346
347    let range = max_val - min_val;
348    let bin_width = if range < 1e-300 {
349        1.0 // all same value
350    } else {
351        range / n_bins as f64
352    };
353
354    // Build edges
355    let edges: Vec<f64> = (0..=n_bins)
356        .map(|i| min_val + i as f64 * bin_width)
357        .collect();
358
359    if data.len() < PAR_THRESHOLD {
360        // Sequential
361        let mut counts = vec![0u64; n_bins];
362        for &x in data {
363            let bin = if bin_width < 1e-300 {
364                0
365            } else {
366                ((x - min_val) / bin_width).floor() as usize
367            };
368            let bin = bin.min(n_bins - 1);
369            counts[bin] += 1;
370        }
371        return Ok(ParallelHistogramResult {
372            edges,
373            counts,
374            total: data.len() as u64,
375        });
376    }
377
378    // Parallel: each thread builds a partial histogram, then merge
379    let chunk_size = (data.len() / num_threads()).max(1000);
380    let partial_counts: Vec<Vec<u64>> = par_chunks(data, chunk_size)
381        .map(|chunk| {
382            let mut counts = vec![0u64; n_bins];
383            for &x in chunk {
384                let bin = if bin_width < 1e-300 {
385                    0
386                } else {
387                    ((x - min_val) / bin_width).floor() as usize
388                };
389                let bin = bin.min(n_bins - 1);
390                counts[bin] += 1;
391            }
392            counts
393        })
394        .collect();
395
396    // Merge partial histograms
397    let mut counts = vec![0u64; n_bins];
398    for partial in &partial_counts {
399        for (i, &c) in partial.iter().enumerate() {
400            counts[i] += c;
401        }
402    }
403
404    Ok(ParallelHistogramResult {
405        edges,
406        counts,
407        total: data.len() as u64,
408    })
409}
410
411// ===========================================================================
412// Part 2: Parallel Hypothesis Testing
413// ===========================================================================
414
415// ---------------------------------------------------------------------------
416// Parallel permutation test
417// ---------------------------------------------------------------------------
418
419/// Result of a permutation test.
420#[derive(Debug, Clone)]
421pub struct PermutationTestResult {
422    /// Observed test statistic
423    pub observed: f64,
424    /// Two-sided p-value
425    pub p_value: f64,
426    /// Number of permutations performed
427    pub n_permutations: usize,
428    /// Count of permutations with statistic >= |observed|
429    pub n_extreme: usize,
430}
431
432/// Parallel permutation test for the difference in means between two groups.
433///
434/// Randomly permutes group labels and recomputes the test statistic
435/// to build a null distribution.
436///
437/// # Arguments
438///
439/// * `group1` - First group's data
440/// * `group2` - Second group's data
441/// * `n_permutations` - Number of random permutations
442/// * `seed` - Optional random seed for reproducibility
443///
444/// # Returns
445///
446/// `PermutationTestResult` with observed statistic and p-value.
447pub fn parallel_permutation_test(
448    group1: &[f64],
449    group2: &[f64],
450    n_permutations: usize,
451    seed: Option<u64>,
452) -> StatsResult<PermutationTestResult> {
453    if group1.is_empty() || group2.is_empty() {
454        return Err(StatsError::InvalidArgument(
455            "Both groups must be non-empty".to_string(),
456        ));
457    }
458
459    let n1 = group1.len();
460    let combined: Vec<f64> = group1.iter().chain(group2.iter()).copied().collect();
461    let n_total = combined.len();
462
463    // Observed statistic: difference in means
464    let mean1: f64 = group1.iter().sum::<f64>() / n1 as f64;
465    let mean2: f64 = group2.iter().sum::<f64>() / group2.len() as f64;
466    let observed = (mean1 - mean2).abs();
467
468    // Generate seeds for each permutation
469    let base_seed = seed.unwrap_or(42);
470    let perm_seeds: Vec<u64> = (0..n_permutations)
471        .map(|i| {
472            // Simple hash to generate diverse seeds
473            base_seed
474                .wrapping_mul(6_364_136_223_846_793_005)
475                .wrapping_add(i as u64)
476        })
477        .collect();
478
479    // Parallel permutation computation
480    let extreme_counts: Vec<usize> = parallel_map(&perm_seeds, |&s| {
481        // Simple Fisher-Yates shuffle using LCG
482        let mut shuffled = combined.clone();
483        let mut state = s;
484        for i in (1..n_total).rev() {
485            state = state
486                .wrapping_mul(6_364_136_223_846_793_005)
487                .wrapping_add(1_442_695_040_888_963_407);
488            let j = (state >> 1) as usize % (i + 1);
489            shuffled.swap(i, j);
490        }
491
492        let perm_mean1: f64 = shuffled[..n1].iter().sum::<f64>() / n1 as f64;
493        let perm_mean2: f64 = shuffled[n1..].iter().sum::<f64>() / (n_total - n1) as f64;
494        let perm_stat = (perm_mean1 - perm_mean2).abs();
495
496        if perm_stat >= observed - 1e-12 {
497            1
498        } else {
499            0
500        }
501    });
502
503    let n_extreme: usize = extreme_counts.iter().sum();
504    let p_value = (n_extreme as f64 + 1.0) / (n_permutations as f64 + 1.0);
505
506    Ok(PermutationTestResult {
507        observed,
508        p_value,
509        n_permutations,
510        n_extreme,
511    })
512}
513
514// ---------------------------------------------------------------------------
515// Parallel bootstrap
516// ---------------------------------------------------------------------------
517
518/// Result of a parallel bootstrap procedure.
519#[derive(Debug, Clone)]
520pub struct ParallelBootstrapResult {
521    /// Point estimate from original data
522    pub estimate: f64,
523    /// Bootstrap standard error
524    pub standard_error: f64,
525    /// Lower CI bound (percentile method)
526    pub ci_lower: f64,
527    /// Upper CI bound (percentile method)
528    pub ci_upper: f64,
529    /// Confidence level
530    pub confidence_level: f64,
531    /// All bootstrap replicates
532    pub replicates: Vec<f64>,
533}
534
535/// Run a parallel bootstrap procedure.
536///
537/// Distributes bootstrap resampling across threads, each with an independent
538/// pseudo-random number generator seed.
539///
540/// # Arguments
541///
542/// * `data` - Input data
543/// * `statistic` - Function computing the statistic from a sample slice
544/// * `n_bootstrap` - Number of bootstrap samples
545/// * `confidence_level` - Confidence level for CI (e.g. 0.95)
546/// * `seed` - Optional random seed
547pub fn parallel_bootstrap(
548    data: &[f64],
549    statistic: &(dyn Fn(&[f64]) -> f64 + Send + Sync),
550    n_bootstrap: usize,
551    confidence_level: f64,
552    seed: Option<u64>,
553) -> StatsResult<ParallelBootstrapResult> {
554    if data.is_empty() {
555        return Err(StatsError::InvalidArgument(
556            "Cannot bootstrap empty data".to_string(),
557        ));
558    }
559    if confidence_level <= 0.0 || confidence_level >= 1.0 {
560        return Err(StatsError::InvalidArgument(
561            "Confidence level must be in (0, 1)".to_string(),
562        ));
563    }
564
565    let estimate = statistic(data);
566    let n = data.len();
567    let base_seed = seed.unwrap_or(42);
568
569    // Generate seeds for each bootstrap replicate
570    let seeds: Vec<u64> = (0..n_bootstrap)
571        .map(|i| {
572            base_seed
573                .wrapping_mul(6_364_136_223_846_793_005)
574                .wrapping_add(i as u64)
575        })
576        .collect();
577
578    // Parallel bootstrap
579    let replicates: Vec<f64> = parallel_map(&seeds, |&s| {
580        let mut state = s;
581        let mut sample = Vec::with_capacity(n);
582        for _ in 0..n {
583            state = state
584                .wrapping_mul(6_364_136_223_846_793_005)
585                .wrapping_add(1_442_695_040_888_963_407);
586            let idx = (state >> 1) as usize % n;
587            sample.push(data[idx]);
588        }
589        statistic(&sample)
590    });
591
592    // Compute standard error
593    let boot_mean: f64 = replicates.iter().sum::<f64>() / replicates.len() as f64;
594    let boot_var: f64 = replicates
595        .iter()
596        .map(|&x| (x - boot_mean) * (x - boot_mean))
597        .sum::<f64>()
598        / (replicates.len() as f64 - 1.0).max(1.0);
599    let standard_error = boot_var.sqrt();
600
601    // Percentile CI
602    let alpha = 1.0 - confidence_level;
603    let mut sorted = replicates.clone();
604    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
605
606    let lo_idx = (alpha / 2.0 * sorted.len() as f64).floor() as usize;
607    let hi_idx = ((1.0 - alpha / 2.0) * sorted.len() as f64).ceil() as usize;
608    let ci_lower = sorted[lo_idx.min(sorted.len() - 1)];
609    let ci_upper = sorted[hi_idx.min(sorted.len() - 1)];
610
611    Ok(ParallelBootstrapResult {
612        estimate,
613        standard_error,
614        ci_lower,
615        ci_upper,
616        confidence_level,
617        replicates,
618    })
619}
620
621// ---------------------------------------------------------------------------
622// Parallel k-fold cross-validation
623// ---------------------------------------------------------------------------
624
625/// Result of parallel cross-validation.
626#[derive(Debug, Clone)]
627pub struct CrossValidationResult {
628    /// Mean score across folds
629    pub mean_score: f64,
630    /// Standard deviation of scores across folds
631    pub std_score: f64,
632    /// Individual fold scores
633    pub fold_scores: Vec<f64>,
634    /// Number of folds
635    pub n_folds: usize,
636}
637
638/// Run parallel k-fold cross-validation.
639///
640/// Splits data into k folds and evaluates a scoring function on each
641/// train/test split in parallel.
642///
643/// # Arguments
644///
645/// * `data` - Input features (rows = samples)
646/// * `targets` - Target values
647/// * `n_folds` - Number of folds
648/// * `scorer` - Function(train_X, train_y, test_X, test_y) -> score
649/// * `seed` - Optional seed for fold assignment shuffling
650pub fn parallel_cross_validation(
651    data: &Array2<f64>,
652    targets: &Array1<f64>,
653    n_folds: usize,
654    scorer: &(dyn Fn(&Array2<f64>, &Array1<f64>, &Array2<f64>, &Array1<f64>) -> StatsResult<f64>
655          + Send
656          + Sync),
657    seed: Option<u64>,
658) -> StatsResult<CrossValidationResult> {
659    let n_samples = data.nrows();
660    if n_samples < n_folds {
661        return Err(StatsError::InvalidArgument(format!(
662            "Need at least {} samples for {}-fold CV, got {}",
663            n_folds, n_folds, n_samples
664        )));
665    }
666    if data.nrows() != targets.len() {
667        return Err(StatsError::InvalidArgument(
668            "data rows and targets length must match".to_string(),
669        ));
670    }
671
672    // Create fold indices (simple sequential assignment with optional shuffle)
673    let mut indices: Vec<usize> = (0..n_samples).collect();
674    if let Some(s) = seed {
675        // Simple Fisher-Yates shuffle
676        let mut state = s;
677        for i in (1..n_samples).rev() {
678            state = state
679                .wrapping_mul(6_364_136_223_846_793_005)
680                .wrapping_add(1_442_695_040_888_963_407);
681            let j = (state >> 1) as usize % (i + 1);
682            indices.swap(i, j);
683        }
684    }
685
686    let fold_size = n_samples / n_folds;
687    let folds: Vec<usize> = (0..n_folds).collect();
688
689    // Parallel fold evaluation
690    let fold_scores: Vec<f64> = parallel_map(&folds, |&fold_idx| {
691        let test_start = fold_idx * fold_size;
692        let test_end = if fold_idx == n_folds - 1 {
693            n_samples
694        } else {
695            (fold_idx + 1) * fold_size
696        };
697        let test_indices: Vec<usize> = indices[test_start..test_end].to_vec();
698        let train_indices: Vec<usize> = indices
699            .iter()
700            .enumerate()
701            .filter(|(i, _)| *i < test_start || *i >= test_end)
702            .map(|(_, &idx)| idx)
703            .collect();
704
705        let n_train = train_indices.len();
706        let n_test = test_indices.len();
707        let n_features = data.ncols();
708
709        // Build train/test arrays
710        let mut train_x = Array2::zeros((n_train, n_features));
711        let mut train_y = Array1::zeros(n_train);
712        for (row, &idx) in train_indices.iter().enumerate() {
713            for col in 0..n_features {
714                train_x[(row, col)] = data[(idx, col)];
715            }
716            train_y[row] = targets[idx];
717        }
718
719        let mut test_x = Array2::zeros((n_test, n_features));
720        let mut test_y = Array1::zeros(n_test);
721        for (row, &idx) in test_indices.iter().enumerate() {
722            for col in 0..n_features {
723                test_x[(row, col)] = data[(idx, col)];
724            }
725            test_y[row] = targets[idx];
726        }
727
728        scorer(&train_x, &train_y, &test_x, &test_y)
729    })
730    .into_iter()
731    .collect::<StatsResult<Vec<_>>>()?;
732
733    let mean_score = fold_scores.iter().sum::<f64>() / fold_scores.len() as f64;
734    let std_score = if fold_scores.len() > 1 {
735        let var = fold_scores
736            .iter()
737            .map(|&s| (s - mean_score) * (s - mean_score))
738            .sum::<f64>()
739            / (fold_scores.len() as f64 - 1.0);
740        var.sqrt()
741    } else {
742        0.0
743    };
744
745    Ok(CrossValidationResult {
746        mean_score,
747        std_score,
748        fold_scores,
749        n_folds,
750    })
751}
752
753// ===========================================================================
754// Part 3: Parallel Distribution Fitting
755// ===========================================================================
756
757// ---------------------------------------------------------------------------
758// Parallel MLE
759// ---------------------------------------------------------------------------
760
761/// Result of parallel maximum-likelihood estimation.
762#[derive(Debug, Clone)]
763pub struct ParallelMLEResult {
764    /// Best-fit distribution name
765    pub distribution: String,
766    /// Estimated parameters (distribution-specific)
767    pub parameters: Vec<f64>,
768    /// Log-likelihood of the best fit
769    pub log_likelihood: f64,
770    /// AIC (Akaike Information Criterion)
771    pub aic: f64,
772    /// BIC (Bayesian Information Criterion)
773    pub bic: f64,
774}
775
776/// Fit multiple distributions in parallel and return the best by AIC.
777///
778/// Currently supported distributions: Normal, Exponential, Uniform.
779///
780/// # Arguments
781///
782/// * `data` - Observed data
783///
784/// # Returns
785///
786/// The `ParallelMLEResult` for the best-fitting distribution.
787pub fn parallel_mle_fit(data: &[f64]) -> StatsResult<Vec<ParallelMLEResult>> {
788    if data.is_empty() {
789        return Err(StatsError::InvalidArgument(
790            "Cannot fit distributions to empty data".to_string(),
791        ));
792    }
793
794    let n = data.len() as f64;
795    let dist_names: Vec<&str> = vec!["normal", "exponential", "uniform"];
796
797    let results: Vec<ParallelMLEResult> = parallel_map(&dist_names, |&name| {
798        match name {
799            "normal" => {
800                // MLE for Normal: mu = mean, sigma^2 = (1/n)*sum((x-mu)^2)
801                let mu = data.iter().sum::<f64>() / n;
802                let sigma2 = data.iter().map(|&x| (x - mu) * (x - mu)).sum::<f64>() / n;
803                let sigma = sigma2.max(1e-300).sqrt();
804
805                let ll: f64 = data
806                    .iter()
807                    .map(|&x| {
808                        let z = (x - mu) / sigma;
809                        -0.5 * z * z - sigma.ln() - 0.5 * (2.0 * std::f64::consts::PI).ln()
810                    })
811                    .sum();
812
813                let k = 2.0; // number of parameters
814                let aic = 2.0 * k - 2.0 * ll;
815                let bic = k * n.ln() - 2.0 * ll;
816
817                ParallelMLEResult {
818                    distribution: "normal".to_string(),
819                    parameters: vec![mu, sigma],
820                    log_likelihood: ll,
821                    aic,
822                    bic,
823                }
824            }
825            "exponential" => {
826                // MLE for Exponential: rate = 1 / mean (only positive data)
827                let min_val = data.iter().copied().fold(f64::INFINITY, f64::min);
828                let shifted: Vec<f64> = if min_val <= 0.0 {
829                    data.iter().map(|&x| x - min_val + 1e-10).collect()
830                } else {
831                    data.to_vec()
832                };
833                let mean_val = shifted.iter().sum::<f64>() / n;
834                let rate = (1.0 / mean_val).max(1e-300);
835
836                let ll: f64 = shifted.iter().map(|&x| rate.ln() - rate * x).sum();
837
838                let k = 1.0;
839                let aic = 2.0 * k - 2.0 * ll;
840                let bic = k * n.ln() - 2.0 * ll;
841
842                ParallelMLEResult {
843                    distribution: "exponential".to_string(),
844                    parameters: vec![rate],
845                    log_likelihood: ll,
846                    aic,
847                    bic,
848                }
849            }
850            "uniform" => {
851                // MLE for Uniform: a = min, b = max
852                let a = data.iter().copied().fold(f64::INFINITY, f64::min);
853                let b = data.iter().copied().fold(f64::NEG_INFINITY, f64::max);
854                let range = (b - a).max(1e-300);
855                let ll = -n * range.ln();
856
857                let k = 2.0;
858                let aic = 2.0 * k - 2.0 * ll;
859                let bic = k * n.ln() - 2.0 * ll;
860
861                ParallelMLEResult {
862                    distribution: "uniform".to_string(),
863                    parameters: vec![a, b],
864                    log_likelihood: ll,
865                    aic,
866                    bic,
867                }
868            }
869            _ => ParallelMLEResult {
870                distribution: name.to_string(),
871                parameters: vec![],
872                log_likelihood: f64::NEG_INFINITY,
873                aic: f64::INFINITY,
874                bic: f64::INFINITY,
875            },
876        }
877    });
878
879    // Sort by AIC (best first)
880    let mut sorted_results = results;
881    sorted_results.sort_by(|a, b| {
882        a.aic
883            .partial_cmp(&b.aic)
884            .unwrap_or(std::cmp::Ordering::Equal)
885    });
886
887    Ok(sorted_results)
888}
889
890// ---------------------------------------------------------------------------
891// Parallel grid search for distribution parameters
892// ---------------------------------------------------------------------------
893
894/// Result of a parallel parameter grid search.
895#[derive(Debug, Clone)]
896pub struct GridSearchResult {
897    /// Best parameters found
898    pub best_params: Vec<f64>,
899    /// Log-likelihood at best parameters
900    pub best_log_likelihood: f64,
901    /// All parameter combinations evaluated (sorted by log-likelihood descending)
902    pub all_results: Vec<(Vec<f64>, f64)>,
903}
904
905/// Parallel grid search over distribution parameters to maximize log-likelihood.
906///
907/// # Arguments
908///
909/// * `data` - Observed data
910/// * `log_likelihood_fn` - Function(data, params) -> log-likelihood
911/// * `param_grids` - For each parameter, a vector of candidate values
912///
913/// # Returns
914///
915/// `GridSearchResult` with the best parameters.
916pub fn parallel_grid_search(
917    data: &[f64],
918    log_likelihood_fn: &(dyn Fn(&[f64], &[f64]) -> f64 + Send + Sync),
919    param_grids: &[Vec<f64>],
920) -> StatsResult<GridSearchResult> {
921    if data.is_empty() {
922        return Err(StatsError::InvalidArgument(
923            "Data cannot be empty".to_string(),
924        ));
925    }
926    if param_grids.is_empty() {
927        return Err(StatsError::InvalidArgument(
928            "Must specify at least one parameter grid".to_string(),
929        ));
930    }
931
932    // Build Cartesian product of all parameter grids
933    let mut combinations: Vec<Vec<f64>> = vec![vec![]];
934    for grid in param_grids {
935        let mut new_combos = Vec::new();
936        for combo in &combinations {
937            for &val in grid {
938                let mut extended = combo.clone();
939                extended.push(val);
940                new_combos.push(extended);
941            }
942        }
943        combinations = new_combos;
944    }
945
946    // Evaluate log-likelihood for each combination in parallel
947    let results: Vec<(Vec<f64>, f64)> = parallel_map(&combinations, |params| {
948        let ll = log_likelihood_fn(data, params);
949        (params.clone(), ll)
950    });
951
952    // Find best
953    let mut sorted = results;
954    sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
955
956    let (best_params, best_ll) = sorted
957        .first()
958        .map(|(p, ll)| (p.clone(), *ll))
959        .unwrap_or_else(|| (vec![], f64::NEG_INFINITY));
960
961    Ok(GridSearchResult {
962        best_params,
963        best_log_likelihood: best_ll,
964        all_results: sorted,
965    })
966}
967
968// ===========================================================================
969// Tests
970// ===========================================================================
971
972#[cfg(test)]
973mod tests {
974    use super::*;
975
976    // -----------------------------------------------------------------------
977    // Welford accumulator tests
978    // -----------------------------------------------------------------------
979
980    #[test]
981    fn test_welford_empty() {
982        let acc = WelfordAccumulator::new();
983        assert_eq!(acc.n, 0);
984        assert_eq!(acc.mean, 0.0);
985    }
986
987    #[test]
988    fn test_welford_single_value() {
989        let mut acc = WelfordAccumulator::new();
990        acc.push(5.0);
991        assert!((acc.mean - 5.0).abs() < 1e-12);
992        assert_eq!(acc.n, 1);
993    }
994
995    #[test]
996    fn test_welford_mean() {
997        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
998        let mut acc = WelfordAccumulator::new();
999        for &x in &data {
1000            acc.push(x);
1001        }
1002        assert!((acc.mean - 3.0).abs() < 1e-12);
1003    }
1004
1005    #[test]
1006    fn test_welford_variance() {
1007        let data = vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0];
1008        let mut acc = WelfordAccumulator::new();
1009        for &x in &data {
1010            acc.push(x);
1011        }
1012        // Population variance = 4.0
1013        assert!((acc.variance() - 4.0).abs() < 1e-10);
1014        // Sample variance = 4.571428...
1015        assert!((acc.sample_variance() - 4.571_428_571_428_571).abs() < 1e-8);
1016    }
1017
1018    #[test]
1019    fn test_welford_merge_equals_sequential() {
1020        let all_data: Vec<f64> = (0..1000).map(|i| (i as f64 * 0.37).sin() * 100.0).collect();
1021
1022        // Sequential
1023        let mut seq = WelfordAccumulator::new();
1024        for &x in &all_data {
1025            seq.push(x);
1026        }
1027
1028        // Parallel merge (split into 4 chunks)
1029        let chunk_size = all_data.len() / 4;
1030        let mut merged = WelfordAccumulator::new();
1031        for chunk in all_data.chunks(chunk_size) {
1032            let mut partial = WelfordAccumulator::new();
1033            for &x in chunk {
1034                partial.push(x);
1035            }
1036            merged.merge(&partial);
1037        }
1038
1039        assert!((seq.mean - merged.mean).abs() < 1e-8, "means differ");
1040        assert!(
1041            (seq.variance() - merged.variance()).abs() < 1e-6,
1042            "variances differ"
1043        );
1044        assert!(
1045            (seq.skewness() - merged.skewness()).abs() < 0.01,
1046            "skewness differs"
1047        );
1048        assert!(
1049            (seq.kurtosis() - merged.kurtosis()).abs() < 0.1,
1050            "kurtosis differs"
1051        );
1052    }
1053
1054    #[test]
1055    fn test_parallel_moments_small() {
1056        let data: Vec<f64> = (1..=10).map(|x| x as f64).collect();
1057        let acc = parallel_moments(&data);
1058        assert!((acc.mean - 5.5).abs() < 1e-10);
1059        assert_eq!(acc.n, 10);
1060    }
1061
1062    #[test]
1063    fn test_parallel_welford_mean() {
1064        let data: Vec<f64> = (1..=100).map(|x| x as f64).collect();
1065        let m = parallel_welford_mean(&data).expect("mean failed");
1066        assert!((m - 50.5).abs() < 1e-10);
1067    }
1068
1069    #[test]
1070    fn test_parallel_welford_variance() {
1071        let data = vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0];
1072        let v = parallel_welford_variance(&data, 0).expect("var failed");
1073        assert!((v - 4.0).abs() < 1e-10);
1074    }
1075
1076    #[test]
1077    fn test_parallel_welford_mean_empty() {
1078        let result = parallel_welford_mean(&[]);
1079        assert!(result.is_err());
1080    }
1081
1082    // -----------------------------------------------------------------------
1083    // Parallel quantile tests
1084    // -----------------------------------------------------------------------
1085
1086    #[test]
1087    fn test_parallel_quantile_median() {
1088        let data: Vec<f64> = (1..=99).map(|x| x as f64).collect();
1089        let med = parallel_quantile(&data, 0.5).expect("median failed");
1090        assert!((med - 50.0).abs() < 1e-10);
1091    }
1092
1093    #[test]
1094    fn test_parallel_quantile_extremes() {
1095        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1096        let q0 = parallel_quantile(&data, 0.0).expect("q0 failed");
1097        let q1 = parallel_quantile(&data, 1.0).expect("q1 failed");
1098        assert!((q0 - 1.0).abs() < 1e-10);
1099        assert!((q1 - 5.0).abs() < 1e-10);
1100    }
1101
1102    #[test]
1103    fn test_parallel_quantile_empty() {
1104        assert!(parallel_quantile(&[], 0.5).is_err());
1105    }
1106
1107    #[test]
1108    fn test_parallel_quantile_invalid() {
1109        let data = vec![1.0, 2.0, 3.0];
1110        assert!(parallel_quantile(&data, -0.1).is_err());
1111        assert!(parallel_quantile(&data, 1.1).is_err());
1112    }
1113
1114    // -----------------------------------------------------------------------
1115    // Parallel histogram tests
1116    // -----------------------------------------------------------------------
1117
1118    #[test]
1119    fn test_parallel_histogram_basic() {
1120        let data: Vec<f64> = (0..100).map(|x| x as f64).collect();
1121        let hist = parallel_histogram(&data, 10).expect("hist failed");
1122        assert_eq!(hist.edges.len(), 11);
1123        assert_eq!(hist.counts.len(), 10);
1124        let total: u64 = hist.counts.iter().sum();
1125        assert_eq!(total, 100);
1126    }
1127
1128    #[test]
1129    fn test_parallel_histogram_single_value() {
1130        let data = vec![5.0; 50];
1131        let hist = parallel_histogram(&data, 5).expect("hist failed");
1132        let total: u64 = hist.counts.iter().sum();
1133        assert_eq!(total, 50);
1134    }
1135
1136    #[test]
1137    fn test_parallel_histogram_empty() {
1138        assert!(parallel_histogram(&[], 10).is_err());
1139    }
1140
1141    #[test]
1142    fn test_parallel_histogram_zero_bins() {
1143        assert!(parallel_histogram(&[1.0, 2.0], 0).is_err());
1144    }
1145
1146    // -----------------------------------------------------------------------
1147    // Parallel permutation test
1148    // -----------------------------------------------------------------------
1149
1150    #[test]
1151    fn test_permutation_test_identical_groups() {
1152        let group1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1153        let group2 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1154        let result =
1155            parallel_permutation_test(&group1, &group2, 999, Some(42)).expect("perm test failed");
1156        // Identical groups => p-value should be large
1157        assert!(
1158            result.p_value > 0.1,
1159            "p-value {} should be > 0.1 for identical groups",
1160            result.p_value
1161        );
1162    }
1163
1164    #[test]
1165    fn test_permutation_test_different_groups() {
1166        let group1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1167        let group2 = vec![100.0, 200.0, 300.0, 400.0, 500.0];
1168        let result =
1169            parallel_permutation_test(&group1, &group2, 999, Some(42)).expect("perm test failed");
1170        // Very different groups => p-value should be small
1171        assert!(
1172            result.p_value < 0.05,
1173            "p-value {} should be < 0.05 for very different groups",
1174            result.p_value
1175        );
1176    }
1177
1178    #[test]
1179    fn test_permutation_test_empty_group() {
1180        assert!(parallel_permutation_test(&[], &[1.0], 100, None).is_err());
1181    }
1182
1183    // -----------------------------------------------------------------------
1184    // Parallel bootstrap tests
1185    // -----------------------------------------------------------------------
1186
1187    #[test]
1188    fn test_parallel_bootstrap_mean() {
1189        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
1190        let mean_fn = |s: &[f64]| s.iter().sum::<f64>() / s.len() as f64;
1191        let result =
1192            parallel_bootstrap(&data, &mean_fn, 2000, 0.95, Some(42)).expect("bootstrap failed");
1193
1194        // Point estimate should be the mean of the data
1195        assert!((result.estimate - 5.5).abs() < 1e-10);
1196        // CI should contain the estimate
1197        assert!(result.ci_lower <= result.estimate);
1198        assert!(result.ci_upper >= result.estimate);
1199        assert_eq!(result.replicates.len(), 2000);
1200    }
1201
1202    #[test]
1203    fn test_parallel_bootstrap_empty() {
1204        let mean_fn = |s: &[f64]| s.iter().sum::<f64>() / s.len().max(1) as f64;
1205        assert!(parallel_bootstrap(&[], &mean_fn, 100, 0.95, None).is_err());
1206    }
1207
1208    // -----------------------------------------------------------------------
1209    // Parallel cross-validation tests
1210    // -----------------------------------------------------------------------
1211
1212    #[test]
1213    fn test_parallel_cv_basic() {
1214        // Simple regression: y = 2*x + 1
1215        let n = 100;
1216        let data = Array2::from_shape_fn((n, 1), |(i, _)| i as f64);
1217        let targets = Array1::from_shape_fn(n, |i| 2.0 * i as f64 + 1.0);
1218
1219        let scorer = |train_x: &Array2<f64>,
1220                      train_y: &Array1<f64>,
1221                      test_x: &Array2<f64>,
1222                      test_y: &Array1<f64>|
1223         -> StatsResult<f64> {
1224            // Simple mean prediction as baseline
1225            let pred = train_y.iter().sum::<f64>() / train_y.len() as f64;
1226            let ss_res: f64 = test_y.iter().map(|&y| (y - pred) * (y - pred)).sum();
1227            let ss_tot: f64 = {
1228                let mean_y = test_y.iter().sum::<f64>() / test_y.len() as f64;
1229                test_y.iter().map(|&y| (y - mean_y) * (y - mean_y)).sum()
1230            };
1231            if ss_tot.abs() < 1e-12 {
1232                Ok(0.0)
1233            } else {
1234                Ok(1.0 - ss_res / ss_tot) // R^2
1235            }
1236        };
1237
1238        let result =
1239            parallel_cross_validation(&data, &targets, 5, &scorer, Some(42)).expect("CV failed");
1240        assert_eq!(result.n_folds, 5);
1241        assert_eq!(result.fold_scores.len(), 5);
1242    }
1243
1244    // -----------------------------------------------------------------------
1245    // Parallel MLE tests
1246    // -----------------------------------------------------------------------
1247
1248    #[test]
1249    fn test_parallel_mle_normal_data() {
1250        // Generate data from a normal distribution (deterministic)
1251        let data: Vec<f64> = (0..500)
1252            .map(|i| {
1253                let x = (i as f64 * 0.13).sin() * 2.0 + 5.0;
1254                x
1255            })
1256            .collect();
1257
1258        let results = parallel_mle_fit(&data).expect("MLE fit failed");
1259        assert!(!results.is_empty());
1260        // First result should be the best by AIC
1261        assert!(results[0].aic.is_finite());
1262    }
1263
1264    #[test]
1265    fn test_parallel_mle_empty() {
1266        assert!(parallel_mle_fit(&[]).is_err());
1267    }
1268
1269    // -----------------------------------------------------------------------
1270    // Parallel grid search tests
1271    // -----------------------------------------------------------------------
1272
1273    #[test]
1274    fn test_parallel_grid_search_normal() {
1275        let data: Vec<f64> = vec![4.5, 5.0, 5.5, 5.0, 4.8, 5.2, 5.1, 4.9];
1276
1277        let normal_ll = |data: &[f64], params: &[f64]| -> f64 {
1278            if params.len() < 2 || params[1] <= 0.0 {
1279                return f64::NEG_INFINITY;
1280            }
1281            let mu = params[0];
1282            let sigma = params[1];
1283            data.iter()
1284                .map(|&x| {
1285                    let z = (x - mu) / sigma;
1286                    -0.5 * z * z - sigma.ln() - 0.5 * (2.0 * std::f64::consts::PI).ln()
1287                })
1288                .sum()
1289        };
1290
1291        let mu_grid: Vec<f64> = (40..=60).map(|i| i as f64 * 0.1).collect();
1292        let sigma_grid: Vec<f64> = (1..=10).map(|i| i as f64 * 0.1).collect();
1293
1294        let result = parallel_grid_search(&data, &normal_ll, &[mu_grid, sigma_grid])
1295            .expect("grid search failed");
1296
1297        // Best mu should be near the data mean (~5.0)
1298        assert!(
1299            (result.best_params[0] - 5.0).abs() < 0.2,
1300            "best mu = {}",
1301            result.best_params[0]
1302        );
1303    }
1304
1305    #[test]
1306    fn test_parallel_grid_search_empty() {
1307        let ll = |_data: &[f64], _params: &[f64]| -> f64 { 0.0 };
1308        assert!(parallel_grid_search(&[], &ll, &[vec![1.0]]).is_err());
1309    }
1310
1311    // -----------------------------------------------------------------------
1312    // Skewness / kurtosis tests
1313    // -----------------------------------------------------------------------
1314
1315    #[test]
1316    fn test_parallel_skewness_symmetric() {
1317        // Symmetric data => skewness should be near 0
1318        let data: Vec<f64> = (-50..=50).map(|x| x as f64).collect();
1319        let sk = parallel_welford_skewness(&data).expect("skewness failed");
1320        assert!(
1321            sk.abs() < 0.01,
1322            "skewness of symmetric data should be ~0, got {}",
1323            sk
1324        );
1325    }
1326
1327    #[test]
1328    fn test_parallel_kurtosis_uniform() {
1329        // Uniform data has excess kurtosis of -1.2
1330        let data: Vec<f64> = (0..10000).map(|x| x as f64 / 10000.0).collect();
1331        let kurt = parallel_welford_kurtosis(&data).expect("kurtosis failed");
1332        assert!(
1333            (kurt - (-1.2)).abs() < 0.1,
1334            "kurtosis of uniform data should be ~-1.2, got {}",
1335            kurt
1336        );
1337    }
1338
1339    #[test]
1340    fn test_parallel_skewness_insufficient_data() {
1341        assert!(parallel_welford_skewness(&[1.0, 2.0]).is_err());
1342    }
1343
1344    #[test]
1345    fn test_parallel_kurtosis_insufficient_data() {
1346        assert!(parallel_welford_kurtosis(&[1.0, 2.0, 3.0]).is_err());
1347    }
1348}