sklears_utils/
statistical.rs

1//! Statistical Utilities
2//!
3//! This module provides comprehensive statistical analysis utilities for machine learning,
4//! including statistical tests, confidence intervals, correlation analysis, hypothesis testing,
5//! and distribution fitting utilities.
6
7use crate::{math_utils::SpecialFunctions, UtilsError};
8use scirs2_core::ndarray::{Array1, Array2};
9use scirs2_core::numeric::Float;
10
11/// Statistical test results
12#[derive(Debug, Clone)]
13pub struct TestResult {
14    pub statistic: f64,
15    pub p_value: f64,
16    pub critical_value: Option<f64>,
17    pub test_name: String,
18    pub significant: bool,
19}
20
21impl TestResult {
22    pub fn new(statistic: f64, p_value: f64, test_name: String, alpha: f64) -> Self {
23        Self {
24            statistic,
25            p_value,
26            critical_value: None,
27            test_name,
28            significant: p_value < alpha,
29        }
30    }
31
32    pub fn with_critical_value(mut self, critical_value: f64) -> Self {
33        self.critical_value = Some(critical_value);
34        self
35    }
36}
37
38/// Confidence interval
39#[derive(Debug, Clone)]
40pub struct ConfidenceInterval {
41    pub lower: f64,
42    pub upper: f64,
43    pub confidence_level: f64,
44    pub parameter: String,
45}
46
47impl ConfidenceInterval {
48    pub fn new(lower: f64, upper: f64, confidence_level: f64, parameter: String) -> Self {
49        Self {
50            lower,
51            upper,
52            confidence_level,
53            parameter,
54        }
55    }
56
57    pub fn width(&self) -> f64 {
58        self.upper - self.lower
59    }
60
61    pub fn contains(&self, value: f64) -> bool {
62        value >= self.lower && value <= self.upper
63    }
64}
65
66/// Statistical tests implementation
67pub struct StatisticalTests;
68
69impl StatisticalTests {
70    /// One-sample t-test
71    pub fn one_sample_ttest(
72        data: &Array1<f64>,
73        population_mean: f64,
74        alpha: f64,
75    ) -> Result<TestResult, UtilsError> {
76        if data.is_empty() {
77            return Err(UtilsError::EmptyInput);
78        }
79
80        let n = data.len() as f64;
81        let sample_mean = data.mean().unwrap();
82        let sample_std = Self::standard_deviation(data);
83
84        if sample_std < f64::EPSILON {
85            return Err(UtilsError::InvalidParameter(
86                "Standard deviation is zero".to_string(),
87            ));
88        }
89
90        let t_statistic = (sample_mean - population_mean) / (sample_std / n.sqrt());
91        let degrees_of_freedom = n - 1.0;
92
93        // Approximate p-value using t-distribution (simplified)
94        let p_value = Self::t_distribution_cdf(-t_statistic.abs(), degrees_of_freedom) * 2.0;
95
96        Ok(TestResult::new(
97            t_statistic,
98            p_value,
99            "One-sample t-test".to_string(),
100            alpha,
101        ))
102    }
103
104    /// Two-sample t-test (assuming equal variances)
105    pub fn two_sample_ttest(
106        data1: &Array1<f64>,
107        data2: &Array1<f64>,
108        alpha: f64,
109    ) -> Result<TestResult, UtilsError> {
110        if data1.is_empty() || data2.is_empty() {
111            return Err(UtilsError::EmptyInput);
112        }
113
114        let n1 = data1.len() as f64;
115        let n2 = data2.len() as f64;
116        let mean1 = data1.mean().unwrap();
117        let mean2 = data2.mean().unwrap();
118        let std1 = Self::standard_deviation(data1);
119        let std2 = Self::standard_deviation(data2);
120
121        // Pooled standard deviation
122        let pooled_std =
123            ((std1.powi(2) * (n1 - 1.0) + std2.powi(2) * (n2 - 1.0)) / (n1 + n2 - 2.0)).sqrt();
124
125        if pooled_std < f64::EPSILON {
126            return Err(UtilsError::InvalidParameter(
127                "Pooled standard deviation is zero".to_string(),
128            ));
129        }
130
131        let t_statistic = (mean1 - mean2) / (pooled_std * (1.0 / n1 + 1.0 / n2).sqrt());
132        let degrees_of_freedom = n1 + n2 - 2.0;
133
134        let p_value = Self::t_distribution_cdf(-t_statistic.abs(), degrees_of_freedom) * 2.0;
135
136        Ok(TestResult::new(
137            t_statistic,
138            p_value,
139            "Two-sample t-test".to_string(),
140            alpha,
141        ))
142    }
143
144    /// Welch's t-test (unequal variances)
145    pub fn welch_ttest(
146        data1: &Array1<f64>,
147        data2: &Array1<f64>,
148        alpha: f64,
149    ) -> Result<TestResult, UtilsError> {
150        if data1.is_empty() || data2.is_empty() {
151            return Err(UtilsError::EmptyInput);
152        }
153
154        let n1 = data1.len() as f64;
155        let n2 = data2.len() as f64;
156        let mean1 = data1.mean().unwrap();
157        let mean2 = data2.mean().unwrap();
158        let var1 = Self::variance(data1);
159        let var2 = Self::variance(data2);
160
161        let se = (var1 / n1 + var2 / n2).sqrt();
162        if se < f64::EPSILON {
163            return Err(UtilsError::InvalidParameter(
164                "Standard error is zero".to_string(),
165            ));
166        }
167
168        let t_statistic = (mean1 - mean2) / se;
169
170        // Welch-Satterthwaite degrees of freedom
171        let degrees_of_freedom = (var1 / n1 + var2 / n2).powi(2)
172            / ((var1 / n1).powi(2) / (n1 - 1.0) + (var2 / n2).powi(2) / (n2 - 1.0));
173
174        let p_value = Self::t_distribution_cdf(-t_statistic.abs(), degrees_of_freedom) * 2.0;
175
176        Ok(TestResult::new(
177            t_statistic,
178            p_value,
179            "Welch's t-test".to_string(),
180            alpha,
181        ))
182    }
183
184    /// Chi-square goodness of fit test
185    pub fn chi_square_goodness_of_fit(
186        observed: &Array1<f64>,
187        expected: &Array1<f64>,
188        alpha: f64,
189    ) -> Result<TestResult, UtilsError> {
190        if observed.len() != expected.len() {
191            return Err(UtilsError::ShapeMismatch {
192                expected: vec![expected.len()],
193                actual: vec![observed.len()],
194            });
195        }
196
197        if observed.is_empty() {
198            return Err(UtilsError::EmptyInput);
199        }
200
201        let mut chi_square = 0.0;
202        for (obs, exp) in observed.iter().zip(expected.iter()) {
203            if *exp <= 0.0 {
204                return Err(UtilsError::InvalidParameter(
205                    "Expected frequencies must be positive".to_string(),
206                ));
207            }
208            chi_square += (obs - exp).powi(2) / exp;
209        }
210
211        let degrees_of_freedom = (observed.len() - 1) as f64;
212        let p_value = 1.0 - Self::chi_square_cdf(chi_square, degrees_of_freedom);
213
214        Ok(TestResult::new(
215            chi_square,
216            p_value,
217            "Chi-square goodness of fit".to_string(),
218            alpha,
219        ))
220    }
221
222    /// Kolmogorov-Smirnov test for normality
223    pub fn ks_test_normality(data: &Array1<f64>, alpha: f64) -> Result<TestResult, UtilsError> {
224        if data.is_empty() {
225            return Err(UtilsError::EmptyInput);
226        }
227
228        let n = data.len() as f64;
229        let mean = data.mean().unwrap();
230        let std = Self::standard_deviation(data);
231
232        let mut sorted_data = data.to_vec();
233        sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap());
234
235        let mut d_plus = 0.0;
236        let mut d_minus = 0.0;
237
238        for (i, &value) in sorted_data.iter().enumerate() {
239            let empirical_cdf = (i + 1) as f64 / n;
240            let theoretical_cdf = Self::normal_cdf((value - mean) / std);
241
242            d_plus = d_plus.max(empirical_cdf - theoretical_cdf);
243            d_minus = d_minus.max(theoretical_cdf - empirical_cdf);
244        }
245
246        let ks_statistic = d_plus.max(d_minus);
247
248        // Approximate p-value using Kolmogorov distribution
249        let p_value = Self::kolmogorov_smirnov_p_value(ks_statistic, n);
250
251        Ok(TestResult::new(
252            ks_statistic,
253            p_value,
254            "Kolmogorov-Smirnov normality test".to_string(),
255            alpha,
256        ))
257    }
258
259    /// Anderson-Darling test for normality
260    pub fn anderson_darling_test(data: &Array1<f64>, alpha: f64) -> Result<TestResult, UtilsError> {
261        if data.is_empty() {
262            return Err(UtilsError::EmptyInput);
263        }
264
265        let n = data.len() as f64;
266        let mean = data.mean().unwrap();
267        let std = Self::standard_deviation(data);
268
269        let mut sorted_data = data.to_vec();
270        sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap());
271
272        let mut ad_statistic = 0.0;
273
274        for (i, &value) in sorted_data.iter().enumerate() {
275            let z = (value - mean) / std;
276            let phi = Self::normal_cdf(z);
277            let phi_complement = 1.0 - phi;
278
279            if phi > 0.0 && phi < 1.0 && phi_complement > 0.0 {
280                let j = i + 1;
281                ad_statistic +=
282                    ((2 * j - 1) as f64) * (phi.ln() + sorted_data[n as usize - j].ln());
283            }
284        }
285
286        ad_statistic = -n - ad_statistic / n;
287
288        // Adjust for finite sample size
289        ad_statistic *= 1.0 + 0.75 / n + 2.25 / n.powi(2);
290
291        // Approximate p-value
292        let p_value = if ad_statistic >= 0.6 {
293            (-1.2337141 / ad_statistic).exp()
294                * (2.00012
295                    + (ad_statistic
296                        * (-3.00021
297                            + ad_statistic
298                                * (12.24425
299                                    + ad_statistic
300                                        * (-17.2385
301                                            + ad_statistic * (12.79 - ad_statistic * 5.27)))))
302                        .exp())
303        } else if ad_statistic >= 0.34 {
304            (-0.9177 - 2.0637 * ad_statistic).exp()
305        } else if ad_statistic >= 0.2 {
306            1.0 - (-8.318 + 42.796 * ad_statistic - 59.938 * ad_statistic.powi(2)).exp()
307        } else {
308            1.0 - (-13.436 + 101.14 * ad_statistic - 223.73 * ad_statistic.powi(2)).exp()
309        };
310
311        Ok(TestResult::new(
312            ad_statistic,
313            p_value,
314            "Anderson-Darling normality test".to_string(),
315            alpha,
316        ))
317    }
318
319    // Helper functions for statistical distributions
320
321    fn standard_deviation(data: &Array1<f64>) -> f64 {
322        Self::variance(data).sqrt()
323    }
324
325    fn variance(data: &Array1<f64>) -> f64 {
326        if data.len() <= 1 {
327            return 0.0;
328        }
329
330        let mean = data.mean().unwrap();
331        let sum_squares = data.iter().map(|x| (x - mean).powi(2)).sum::<f64>();
332        sum_squares / (data.len() - 1) as f64
333    }
334
335    /// Approximate normal CDF using error function
336    fn normal_cdf(x: f64) -> f64 {
337        0.5 * (1.0 + SpecialFunctions::erf(x / 2.0_f64.sqrt()))
338    }
339
340    /// Approximate t-distribution CDF (simplified)
341    fn t_distribution_cdf(t: f64, df: f64) -> f64 {
342        if df >= 30.0 {
343            // For large df, t-distribution approaches normal
344            return Self::normal_cdf(t);
345        }
346
347        // Simplified approximation for t-distribution
348        let x = t / (t.powi(2) + df).sqrt();
349        0.5 + 0.5 * x * SpecialFunctions::gamma((df + 1.0) / 2.0)
350            / ((df * std::f64::consts::PI).sqrt() * SpecialFunctions::gamma(df / 2.0))
351    }
352
353    /// Approximate chi-square CDF
354    fn chi_square_cdf(x: f64, df: f64) -> f64 {
355        if x <= 0.0 {
356            return 0.0;
357        }
358
359        // Use incomplete gamma function
360        SpecialFunctions::gamma_inc(df / 2.0, x / 2.0) / SpecialFunctions::gamma(df / 2.0)
361    }
362
363    /// Approximate Kolmogorov-Smirnov p-value
364    fn kolmogorov_smirnov_p_value(d: f64, n: f64) -> f64 {
365        let lambda = d * n.sqrt();
366        let mut p_value = 0.0;
367
368        for i in 1..=10 {
369            let term = (-2.0 * (i as f64).powi(2) * lambda.powi(2)).exp();
370            if i % 2 == 1 {
371                p_value += term;
372            } else {
373                p_value -= term;
374            }
375        }
376
377        2.0 * p_value
378    }
379}
380
381/// Confidence interval computation utilities
382pub struct ConfidenceIntervals;
383
384impl ConfidenceIntervals {
385    /// Confidence interval for mean (t-distribution)
386    pub fn mean_ci(
387        data: &Array1<f64>,
388        confidence_level: f64,
389    ) -> Result<ConfidenceInterval, UtilsError> {
390        if data.is_empty() {
391            return Err(UtilsError::EmptyInput);
392        }
393
394        if !(0.0..1.0).contains(&confidence_level) {
395            return Err(UtilsError::InvalidParameter(
396                "Confidence level must be between 0 and 1".to_string(),
397            ));
398        }
399
400        let n = data.len() as f64;
401        let mean = data.mean().unwrap();
402        let std = StatisticalTests::standard_deviation(data);
403        let se = std / n.sqrt();
404
405        let alpha = 1.0 - confidence_level;
406        let df = n - 1.0;
407
408        // Approximate t-critical value (simplified)
409        let t_critical = Self::t_critical_value(alpha / 2.0, df);
410        let margin_of_error = t_critical * se;
411
412        Ok(ConfidenceInterval::new(
413            mean - margin_of_error,
414            mean + margin_of_error,
415            confidence_level,
416            "Mean".to_string(),
417        ))
418    }
419
420    /// Confidence interval for proportion
421    pub fn proportion_ci(
422        successes: usize,
423        trials: usize,
424        confidence_level: f64,
425    ) -> Result<ConfidenceInterval, UtilsError> {
426        if trials == 0 {
427            return Err(UtilsError::InvalidParameter(
428                "Number of trials must be positive".to_string(),
429            ));
430        }
431
432        if successes > trials {
433            return Err(UtilsError::InvalidParameter(
434                "Successes cannot exceed trials".to_string(),
435            ));
436        }
437
438        let p = successes as f64 / trials as f64;
439        let n = trials as f64;
440        let alpha = 1.0 - confidence_level;
441
442        // Use normal approximation for large samples
443        if n * p >= 5.0 && n * (1.0 - p) >= 5.0 {
444            let z_critical = Self::normal_critical_value(alpha / 2.0);
445            let se = (p * (1.0 - p) / n).sqrt();
446            let margin_of_error = z_critical * se;
447
448            Ok(ConfidenceInterval::new(
449                (p - margin_of_error).max(0.0),
450                (p + margin_of_error).min(1.0),
451                confidence_level,
452                "Proportion".to_string(),
453            ))
454        } else {
455            // Use Wilson score interval for small samples
456            let z = Self::normal_critical_value(alpha / 2.0);
457            let z2 = z * z;
458            let center = (p + z2 / (2.0 * n)) / (1.0 + z2 / n);
459            let width = z / (1.0 + z2 / n) * (p * (1.0 - p) / n + z2 / (4.0 * n * n)).sqrt();
460
461            Ok(ConfidenceInterval::new(
462                (center - width).max(0.0),
463                (center + width).min(1.0),
464                confidence_level,
465                "Proportion (Wilson)".to_string(),
466            ))
467        }
468    }
469
470    /// Confidence interval for variance
471    pub fn variance_ci(
472        data: &Array1<f64>,
473        confidence_level: f64,
474    ) -> Result<ConfidenceInterval, UtilsError> {
475        if data.len() <= 1 {
476            return Err(UtilsError::InsufficientData {
477                min: 2,
478                actual: data.len(),
479            });
480        }
481
482        let n = data.len() as f64;
483        let variance = StatisticalTests::variance(data);
484        let df = n - 1.0;
485        let alpha = 1.0 - confidence_level;
486
487        // Chi-square critical values (approximated)
488        let chi2_lower = Self::chi_square_critical_value(1.0 - alpha / 2.0, df);
489        let chi2_upper = Self::chi_square_critical_value(alpha / 2.0, df);
490
491        let lower = df * variance / chi2_upper;
492        let upper = df * variance / chi2_lower;
493
494        Ok(ConfidenceInterval::new(
495            lower,
496            upper,
497            confidence_level,
498            "Variance".to_string(),
499        ))
500    }
501
502    // Helper functions for critical values
503
504    fn normal_critical_value(alpha: f64) -> f64 {
505        // Approximation for normal critical values
506        if alpha <= 0.001 {
507            3.291
508        } else if alpha <= 0.005 {
509            2.807
510        } else if alpha <= 0.01 {
511            2.576
512        } else if alpha <= 0.025 {
513            1.960
514        } else if alpha <= 0.05 {
515            1.645
516        } else if alpha <= 0.1 {
517            1.282
518        } else {
519            0.674
520        }
521    }
522
523    fn t_critical_value(alpha: f64, df: f64) -> f64 {
524        if df >= 30.0 {
525            return Self::normal_critical_value(alpha);
526        }
527
528        // Simplified t-critical value approximation
529        let normal_val = Self::normal_critical_value(alpha);
530        let correction = (1.0 + (normal_val.powi(2) + 1.0) / (4.0 * df))
531            * (1.0
532                + (5.0 * normal_val.powi(4) + 16.0 * normal_val.powi(2) + 3.0)
533                    / (96.0 * df.powi(2)));
534        normal_val * correction
535    }
536
537    fn chi_square_critical_value(alpha: f64, df: f64) -> f64 {
538        // Simplified chi-square critical value approximation
539        let h = 2.0 / (9.0 * df);
540        let normal_val = Self::normal_critical_value(alpha);
541        df * (1.0 - h + normal_val * (h * 2.0).sqrt()).powi(3)
542    }
543}
544
545/// Correlation analysis utilities
546pub struct CorrelationAnalysis;
547
548impl CorrelationAnalysis {
549    /// Pearson correlation coefficient
550    pub fn pearson_correlation(x: &Array1<f64>, y: &Array1<f64>) -> Result<f64, UtilsError> {
551        if x.len() != y.len() {
552            return Err(UtilsError::ShapeMismatch {
553                expected: vec![x.len()],
554                actual: vec![y.len()],
555            });
556        }
557
558        if x.is_empty() {
559            return Err(UtilsError::EmptyInput);
560        }
561
562        let _n = x.len() as f64;
563        let mean_x = x.mean().unwrap();
564        let mean_y = y.mean().unwrap();
565
566        let mut numerator = 0.0;
567        let mut sum_sq_x = 0.0;
568        let mut sum_sq_y = 0.0;
569
570        for (xi, yi) in x.iter().zip(y.iter()) {
571            let dx = xi - mean_x;
572            let dy = yi - mean_y;
573            numerator += dx * dy;
574            sum_sq_x += dx * dx;
575            sum_sq_y += dy * dy;
576        }
577
578        let denominator = (sum_sq_x * sum_sq_y).sqrt();
579        if denominator < f64::EPSILON {
580            return Ok(0.0);
581        }
582
583        Ok(numerator / denominator)
584    }
585
586    /// Spearman rank correlation coefficient
587    pub fn spearman_correlation(x: &Array1<f64>, y: &Array1<f64>) -> Result<f64, UtilsError> {
588        if x.len() != y.len() {
589            return Err(UtilsError::ShapeMismatch {
590                expected: vec![x.len()],
591                actual: vec![y.len()],
592            });
593        }
594
595        let ranks_x = Self::compute_ranks(x);
596        let ranks_y = Self::compute_ranks(y);
597
598        Self::pearson_correlation(&ranks_x, &ranks_y)
599    }
600
601    /// Kendall's tau correlation coefficient
602    pub fn kendall_tau(x: &Array1<f64>, y: &Array1<f64>) -> Result<f64, UtilsError> {
603        if x.len() != y.len() {
604            return Err(UtilsError::ShapeMismatch {
605                expected: vec![x.len()],
606                actual: vec![y.len()],
607            });
608        }
609
610        if x.is_empty() {
611            return Err(UtilsError::EmptyInput);
612        }
613
614        let n = x.len();
615        let mut concordant = 0;
616        let mut discordant = 0;
617
618        for i in 0..n {
619            for j in (i + 1)..n {
620                let sign_x = (x[j] - x[i]).signum();
621                let sign_y = (y[j] - y[i]).signum();
622
623                if sign_x * sign_y > 0.0 {
624                    concordant += 1;
625                } else if sign_x * sign_y < 0.0 {
626                    discordant += 1;
627                }
628            }
629        }
630
631        let total_pairs = n * (n - 1) / 2;
632        Ok((concordant - discordant) as f64 / total_pairs as f64)
633    }
634
635    /// Correlation matrix for multiple variables
636    pub fn correlation_matrix(data: &Array2<f64>) -> Result<Array2<f64>, UtilsError> {
637        let (n_samples, n_features) = data.dim();
638        if n_samples == 0 || n_features == 0 {
639            return Err(UtilsError::EmptyInput);
640        }
641
642        let mut corr_matrix = Array2::zeros((n_features, n_features));
643
644        for i in 0..n_features {
645            for j in 0..n_features {
646                if i == j {
647                    corr_matrix[(i, j)] = 1.0;
648                } else {
649                    let col_i = data.column(i).to_owned();
650                    let col_j = data.column(j).to_owned();
651                    corr_matrix[(i, j)] = Self::pearson_correlation(&col_i, &col_j)?;
652                }
653            }
654        }
655
656        Ok(corr_matrix)
657    }
658
659    /// Test correlation significance
660    pub fn correlation_test(
661        correlation: f64,
662        n: usize,
663        alpha: f64,
664    ) -> Result<TestResult, UtilsError> {
665        if n < 3 {
666            return Err(UtilsError::InsufficientData { min: 3, actual: n });
667        }
668
669        let df = (n - 2) as f64;
670        let t_statistic = correlation * (df / (1.0 - correlation.powi(2))).sqrt();
671
672        let p_value = 2.0 * StatisticalTests::t_distribution_cdf(-t_statistic.abs(), df);
673
674        Ok(TestResult::new(
675            t_statistic,
676            p_value,
677            "Correlation significance test".to_string(),
678            alpha,
679        ))
680    }
681
682    fn compute_ranks(data: &Array1<f64>) -> Array1<f64> {
683        let mut indexed_data: Vec<(usize, f64)> =
684            data.iter().enumerate().map(|(i, &x)| (i, x)).collect();
685        indexed_data.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
686
687        let mut ranks = Array1::zeros(data.len());
688
689        for (rank, (original_index, _)) in indexed_data.iter().enumerate() {
690            ranks[*original_index] = (rank + 1) as f64;
691        }
692
693        // Handle ties by averaging ranks
694        let mut i = 0;
695        while i < indexed_data.len() {
696            let current_value = indexed_data[i].1;
697            let mut j = i + 1;
698
699            while j < indexed_data.len() && (indexed_data[j].1 - current_value).abs() < f64::EPSILON
700            {
701                j += 1;
702            }
703
704            if j > i + 1 {
705                // There are ties
706                let average_rank = ((i + 1) + j) as f64 / 2.0;
707                for k in i..j {
708                    ranks[indexed_data[k].0] = average_rank;
709                }
710            }
711
712            i = j;
713        }
714
715        ranks
716    }
717}
718
719/// Distribution fitting utilities
720pub struct DistributionFitting;
721
722impl DistributionFitting {
723    /// Fit normal distribution parameters
724    pub fn fit_normal(data: &Array1<f64>) -> Result<(f64, f64), UtilsError> {
725        if data.is_empty() {
726            return Err(UtilsError::EmptyInput);
727        }
728
729        let mean = data.mean().unwrap();
730        let std = StatisticalTests::standard_deviation(data);
731
732        Ok((mean, std))
733    }
734
735    /// Fit exponential distribution parameter
736    pub fn fit_exponential(data: &Array1<f64>) -> Result<f64, UtilsError> {
737        if data.is_empty() {
738            return Err(UtilsError::EmptyInput);
739        }
740
741        // Check for positive values
742        if data.iter().any(|&x| x <= 0.0) {
743            return Err(UtilsError::InvalidParameter(
744                "Exponential distribution requires positive values".to_string(),
745            ));
746        }
747
748        let mean = data.mean().unwrap();
749        Ok(1.0 / mean) // Lambda parameter
750    }
751
752    /// Fit uniform distribution parameters
753    pub fn fit_uniform(data: &Array1<f64>) -> Result<(f64, f64), UtilsError> {
754        if data.is_empty() {
755            return Err(UtilsError::EmptyInput);
756        }
757
758        let min = data.iter().fold(f64::INFINITY, |a, &b| a.min(b));
759        let max = data.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
760
761        Ok((min, max))
762    }
763
764    /// Goodness of fit test using chi-square
765    pub fn goodness_of_fit_test(
766        data: &Array1<f64>,
767        expected_cdf: fn(f64, &[f64]) -> f64,
768        parameters: &[f64],
769        bins: usize,
770        alpha: f64,
771    ) -> Result<TestResult, UtilsError> {
772        if data.is_empty() {
773            return Err(UtilsError::EmptyInput);
774        }
775
776        if bins < 2 {
777            return Err(UtilsError::InvalidParameter(
778                "Number of bins must be at least 2".to_string(),
779            ));
780        }
781
782        let n = data.len() as f64;
783        let min_val = data.iter().fold(f64::INFINITY, |a, &b| a.min(b));
784        let max_val = data.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
785
786        let bin_width = (max_val - min_val) / bins as f64;
787        let mut observed = Array1::zeros(bins);
788        let mut expected = Array1::zeros(bins);
789
790        // Count observed frequencies
791        for &value in data.iter() {
792            let bin_index = ((value - min_val) / bin_width).floor() as usize;
793            let bin_index = bin_index.min(bins - 1);
794            observed[bin_index] += 1.0;
795        }
796
797        // Calculate expected frequencies
798        for i in 0..bins {
799            let lower = min_val + i as f64 * bin_width;
800            let upper = min_val + (i + 1) as f64 * bin_width;
801            let prob = expected_cdf(upper, parameters) - expected_cdf(lower, parameters);
802            expected[i] = n * prob;
803        }
804
805        StatisticalTests::chi_square_goodness_of_fit(&observed, &expected, alpha)
806    }
807}
808
809#[allow(non_snake_case)]
810#[cfg(test)]
811mod tests {
812    use super::*;
813    use approx::assert_abs_diff_eq;
814    use scirs2_core::ndarray::array;
815
816    #[test]
817    fn test_one_sample_ttest() {
818        let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
819        let result = StatisticalTests::one_sample_ttest(&data, 3.0, 0.05).unwrap();
820
821        assert_eq!(result.test_name, "One-sample t-test");
822        assert!(!result.statistic.is_nan());
823        assert!(result.p_value >= 0.0 && result.p_value <= 1.0);
824    }
825
826    #[test]
827    fn test_two_sample_ttest() {
828        let data1 = array![1.0, 2.0, 3.0, 4.0, 5.0];
829        let data2 = array![2.0, 3.0, 4.0, 5.0, 6.0];
830        let result = StatisticalTests::two_sample_ttest(&data1, &data2, 0.05).unwrap();
831
832        assert_eq!(result.test_name, "Two-sample t-test");
833        assert!(!result.statistic.is_nan());
834        assert!(result.p_value >= 0.0 && result.p_value <= 1.0);
835    }
836
837    #[test]
838    fn test_pearson_correlation() {
839        let x = array![1.0, 2.0, 3.0, 4.0, 5.0];
840        let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
841        let correlation = CorrelationAnalysis::pearson_correlation(&x, &y).unwrap();
842
843        assert_abs_diff_eq!(correlation, 1.0, epsilon = 1e-10);
844    }
845
846    #[test]
847    fn test_spearman_correlation() {
848        let x = array![1.0, 2.0, 3.0, 4.0, 5.0];
849        let y = array![1.0, 4.0, 9.0, 16.0, 25.0]; // y = x^2
850        let correlation = CorrelationAnalysis::spearman_correlation(&x, &y).unwrap();
851
852        assert_abs_diff_eq!(correlation, 1.0, epsilon = 1e-10);
853    }
854
855    #[test]
856    fn test_confidence_interval_mean() {
857        let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
858        let ci = ConfidenceIntervals::mean_ci(&data, 0.95).unwrap();
859
860        assert_eq!(ci.parameter, "Mean");
861        assert_eq!(ci.confidence_level, 0.95);
862        assert!(ci.lower < ci.upper);
863        assert!(ci.contains(3.0)); // Should contain the sample mean
864    }
865
866    #[test]
867    fn test_confidence_interval_proportion() {
868        let ci = ConfidenceIntervals::proportion_ci(30, 100, 0.95).unwrap();
869
870        assert_eq!(ci.parameter, "Proportion");
871        assert!(ci.lower >= 0.0 && ci.upper <= 1.0);
872        assert!(ci.contains(0.3)); // Should contain the sample proportion
873    }
874
875    #[test]
876    fn test_chi_square_goodness_of_fit() {
877        let observed = array![10.0, 15.0, 8.0, 12.0];
878        let expected = array![11.25, 11.25, 11.25, 11.25];
879        let result =
880            StatisticalTests::chi_square_goodness_of_fit(&observed, &expected, 0.05).unwrap();
881
882        assert_eq!(result.test_name, "Chi-square goodness of fit");
883        assert!(result.statistic >= 0.0);
884        assert!(result.p_value >= 0.0 && result.p_value <= 1.0);
885    }
886
887    #[test]
888    fn test_correlation_matrix() {
889        let data = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
890        let corr_matrix = CorrelationAnalysis::correlation_matrix(&data).unwrap();
891
892        assert_eq!(corr_matrix.shape(), &[3, 3]);
893
894        // Diagonal should be 1.0
895        for i in 0..3 {
896            assert_abs_diff_eq!(corr_matrix[(i, i)], 1.0, epsilon = 1e-10);
897        }
898
899        // Matrix should be symmetric
900        for i in 0..3 {
901            for j in 0..3 {
902                assert_abs_diff_eq!(corr_matrix[(i, j)], corr_matrix[(j, i)], epsilon = 1e-10);
903            }
904        }
905    }
906
907    #[test]
908    fn test_distribution_fitting() {
909        let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
910
911        let (mean, std) = DistributionFitting::fit_normal(&data).unwrap();
912        assert_abs_diff_eq!(mean, 3.0, epsilon = 1e-10);
913        assert!(std > 0.0);
914
915        let (min, max) = DistributionFitting::fit_uniform(&data).unwrap();
916        assert_eq!(min, 1.0);
917        assert_eq!(max, 5.0);
918    }
919
920    #[test]
921    fn test_kendall_tau() {
922        let x = array![1.0, 2.0, 3.0, 4.0, 5.0];
923        let y = array![1.0, 2.0, 3.0, 4.0, 5.0];
924        let tau = CorrelationAnalysis::kendall_tau(&x, &y).unwrap();
925
926        assert_abs_diff_eq!(tau, 1.0, epsilon = 1e-10);
927    }
928}