Skip to main content

datasynth_eval/statistical/
chi_squared.rs

1//! Chi-squared goodness-of-fit test for distribution validation.
2//!
3//! Tests whether observed frequency distribution matches expected frequencies.
4//! Useful for categorical data and binned continuous data.
5
6use crate::error::{EvalError, EvalResult};
7use serde::{Deserialize, Serialize};
8
9/// Binning strategy for continuous data.
10#[derive(Debug, Clone, Serialize, Deserialize, Default)]
11pub enum BinningStrategy {
12    /// Fixed number of equal-width bins
13    EqualWidth { num_bins: usize },
14    /// Equal-frequency (quantile) bins
15    EqualFrequency { num_bins: usize },
16    /// Custom bin edges
17    Custom { edges: Vec<f64> },
18    /// Automatic binning using Sturges' rule
19    #[default]
20    Auto,
21}
22
23/// Bin frequency information.
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct BinFrequency {
26    /// Bin index
27    pub index: usize,
28    /// Bin lower edge (inclusive)
29    pub lower: f64,
30    /// Bin upper edge (exclusive, except for last bin)
31    pub upper: f64,
32    /// Observed count
33    pub observed: usize,
34    /// Expected count
35    pub expected: f64,
36    /// Contribution to chi-squared statistic
37    pub contribution: f64,
38}
39
40/// Chi-squared test results.
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct ChiSquaredAnalysis {
43    /// Sample size
44    pub sample_size: usize,
45    /// Number of bins
46    pub num_bins: usize,
47    /// Degrees of freedom
48    pub degrees_of_freedom: usize,
49    /// Chi-squared test statistic
50    pub statistic: f64,
51    /// P-value
52    pub p_value: f64,
53    /// Significance level used for pass/fail
54    pub significance_level: f64,
55    /// Whether the test passes
56    pub passes: bool,
57    /// Critical value at significance level
58    pub critical_value: f64,
59    /// Bin frequencies with observed vs expected
60    pub bin_frequencies: Vec<BinFrequency>,
61    /// Cramér's V effect size (0 = no association, 1 = perfect association)
62    pub cramers_v: f64,
63    /// Issues found during analysis
64    pub issues: Vec<String>,
65}
66
67/// Expected distribution type for comparison.
68#[derive(Debug, Clone, Serialize, Deserialize, Default)]
69pub enum ExpectedDistribution {
70    /// Uniform distribution (equal probability per bin)
71    #[default]
72    Uniform,
73    /// Custom expected frequencies (must sum to sample size)
74    Custom(Vec<f64>),
75    /// Expected proportions (must sum to 1.0)
76    Proportions(Vec<f64>),
77    /// Compare against another observed distribution
78    Observed(Vec<usize>),
79}
80
81/// Analyzer for chi-squared goodness-of-fit tests.
82pub struct ChiSquaredAnalyzer {
83    /// Binning strategy for continuous data
84    binning: BinningStrategy,
85    /// Expected distribution
86    expected: ExpectedDistribution,
87    /// Significance level
88    significance_level: f64,
89    /// Minimum expected frequency per bin (for validity)
90    min_expected: f64,
91}
92
93impl ChiSquaredAnalyzer {
94    /// Create a new analyzer with default settings.
95    pub fn new() -> Self {
96        Self {
97            binning: BinningStrategy::Auto,
98            expected: ExpectedDistribution::Uniform,
99            significance_level: 0.05,
100            min_expected: 5.0,
101        }
102    }
103
104    /// Set the binning strategy.
105    pub fn with_binning(mut self, strategy: BinningStrategy) -> Self {
106        self.binning = strategy;
107        self
108    }
109
110    /// Set the expected distribution.
111    pub fn with_expected(mut self, expected: ExpectedDistribution) -> Self {
112        self.expected = expected;
113        self
114    }
115
116    /// Set the significance level.
117    pub fn with_significance_level(mut self, level: f64) -> Self {
118        self.significance_level = level;
119        self
120    }
121
122    /// Set minimum expected frequency per bin.
123    pub fn with_min_expected(mut self, min: f64) -> Self {
124        self.min_expected = min;
125        self
126    }
127
128    /// Analyze continuous data (will be binned).
129    pub fn analyze_continuous(&self, values: &[f64]) -> EvalResult<ChiSquaredAnalysis> {
130        let n = values.len();
131        if n < 10 {
132            return Err(EvalError::InsufficientData {
133                required: 10,
134                actual: n,
135            });
136        }
137
138        // Filter invalid values
139        let valid_values: Vec<f64> = values.iter().filter(|&&v| v.is_finite()).copied().collect();
140
141        if valid_values.len() < 10 {
142            return Err(EvalError::InsufficientData {
143                required: 10,
144                actual: valid_values.len(),
145            });
146        }
147
148        // Create bins
149        let (edges, observed) = self.bin_data(&valid_values)?;
150        let n_f = valid_values.len() as f64;
151
152        // Calculate expected frequencies
153        let expected = self.calculate_expected(&observed, n_f)?;
154
155        self.perform_test(&edges, &observed, &expected)
156    }
157
158    /// Analyze categorical/count data directly.
159    pub fn analyze_categorical(&self, observed: &[usize]) -> EvalResult<ChiSquaredAnalysis> {
160        if observed.is_empty() {
161            return Err(EvalError::InvalidParameter(
162                "Observed counts cannot be empty".to_string(),
163            ));
164        }
165
166        let total: usize = observed.iter().sum();
167        if total < 10 {
168            return Err(EvalError::InsufficientData {
169                required: 10,
170                actual: total,
171            });
172        }
173
174        let n_f = total as f64;
175
176        // Create pseudo-edges for categorical bins
177        let edges: Vec<f64> = (0..=observed.len()).map(|i| i as f64).collect();
178
179        // Calculate expected frequencies
180        let expected = self.calculate_expected(observed, n_f)?;
181
182        self.perform_test(&edges, observed, &expected)
183    }
184
185    /// Bin continuous data according to strategy.
186    fn bin_data(&self, values: &[f64]) -> EvalResult<(Vec<f64>, Vec<usize>)> {
187        let mut sorted: Vec<f64> = values.to_vec();
188        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
189
190        let min = sorted[0];
191        let max = sorted[sorted.len() - 1];
192
193        let edges = match &self.binning {
194            BinningStrategy::EqualWidth { num_bins } => {
195                let width = (max - min) / (*num_bins as f64);
196                (0..=*num_bins).map(|i| min + (i as f64) * width).collect()
197            }
198            BinningStrategy::EqualFrequency { num_bins } => {
199                let n = sorted.len();
200                let mut edges = vec![min];
201                for i in 1..*num_bins {
202                    let idx = (i * n) / *num_bins;
203                    edges.push(sorted[idx.min(n - 1)]);
204                }
205                edges.push(max);
206                edges
207            }
208            BinningStrategy::Custom { edges } => edges.clone(),
209            BinningStrategy::Auto => {
210                // Sturges' rule
211                let num_bins = (1.0 + (values.len() as f64).log2()).ceil() as usize;
212                let width = (max - min) / (num_bins as f64);
213                (0..=num_bins).map(|i| min + (i as f64) * width).collect()
214            }
215        };
216
217        if edges.len() < 2 {
218            return Err(EvalError::InvalidParameter(
219                "Need at least 2 bin edges".to_string(),
220            ));
221        }
222
223        // Count observations in each bin
224        let num_bins = edges.len() - 1;
225        let mut counts = vec![0usize; num_bins];
226
227        for &v in values {
228            for (i, window) in edges.windows(2).enumerate() {
229                let (lower, upper) = (window[0], window[1]);
230                if v >= lower && (v < upper || (i == num_bins - 1 && v <= upper)) {
231                    counts[i] += 1;
232                    break;
233                }
234            }
235        }
236
237        Ok((edges, counts))
238    }
239
240    /// Calculate expected frequencies based on distribution type.
241    fn calculate_expected(&self, observed: &[usize], total: f64) -> EvalResult<Vec<f64>> {
242        match &self.expected {
243            ExpectedDistribution::Uniform => {
244                let expected_per_bin = total / (observed.len() as f64);
245                Ok(vec![expected_per_bin; observed.len()])
246            }
247            ExpectedDistribution::Custom(expected) => {
248                if expected.len() != observed.len() {
249                    return Err(EvalError::InvalidParameter(format!(
250                        "Expected {} frequencies, got {}",
251                        observed.len(),
252                        expected.len()
253                    )));
254                }
255                Ok(expected.clone())
256            }
257            ExpectedDistribution::Proportions(props) => {
258                if props.len() != observed.len() {
259                    return Err(EvalError::InvalidParameter(format!(
260                        "Expected {} proportions, got {}",
261                        observed.len(),
262                        props.len()
263                    )));
264                }
265                let sum: f64 = props.iter().sum();
266                if (sum - 1.0).abs() > 0.01 {
267                    return Err(EvalError::InvalidParameter(format!(
268                        "Proportions must sum to 1.0, got {}",
269                        sum
270                    )));
271                }
272                Ok(props.iter().map(|&p| p * total).collect())
273            }
274            ExpectedDistribution::Observed(other) => {
275                if other.len() != observed.len() {
276                    return Err(EvalError::InvalidParameter(format!(
277                        "Expected {} categories, got {}",
278                        observed.len(),
279                        other.len()
280                    )));
281                }
282                let other_total: f64 = other.iter().sum::<usize>() as f64;
283                Ok(other
284                    .iter()
285                    .map(|&c| (c as f64) / other_total * total)
286                    .collect())
287            }
288        }
289    }
290
291    /// Perform the chi-squared test.
292    fn perform_test(
293        &self,
294        edges: &[f64],
295        observed: &[usize],
296        expected: &[f64],
297    ) -> EvalResult<ChiSquaredAnalysis> {
298        let n = observed.len();
299        let total: usize = observed.iter().sum();
300        let n_f = total as f64;
301
302        let mut issues = Vec::new();
303
304        // Check minimum expected frequency
305        let low_expected: Vec<_> = expected
306            .iter()
307            .enumerate()
308            .filter(|(_, &e)| e < self.min_expected)
309            .collect();
310        if !low_expected.is_empty() {
311            issues.push(format!(
312                "{} bins have expected frequency < {:.1}; results may be unreliable",
313                low_expected.len(),
314                self.min_expected
315            ));
316        }
317
318        // Calculate chi-squared statistic and bin details
319        let mut chi_squared = 0.0;
320        let mut bin_frequencies = Vec::new();
321
322        for (i, ((&obs, &exp), window)) in observed
323            .iter()
324            .zip(expected.iter())
325            .zip(edges.windows(2))
326            .enumerate()
327        {
328            let contribution = if exp > 0.0 {
329                let diff = obs as f64 - exp;
330                (diff * diff) / exp
331            } else {
332                0.0
333            };
334            chi_squared += contribution;
335
336            bin_frequencies.push(BinFrequency {
337                index: i,
338                lower: window[0],
339                upper: window[1],
340                observed: obs,
341                expected: exp,
342                contribution,
343            });
344        }
345
346        // Degrees of freedom
347        // For goodness-of-fit: df = num_bins - 1 - estimated_parameters
348        // For uniform: df = num_bins - 1 (no parameters estimated from data)
349        let df = n.saturating_sub(1);
350        if df == 0 {
351            return Err(EvalError::InvalidParameter(
352                "Need at least 2 bins for chi-squared test".to_string(),
353            ));
354        }
355
356        // Calculate p-value
357        let p_value = chi_squared_p_value(chi_squared, df);
358
359        // Critical value
360        let critical_value = chi_squared_critical(df, self.significance_level);
361
362        // Cramér's V effect size
363        let cramers_v = (chi_squared / n_f).sqrt();
364
365        let passes = chi_squared <= critical_value;
366
367        if !passes {
368            issues.push(format!(
369                "χ² = {:.4} exceeds critical value {:.4} at α = {:.2}",
370                chi_squared, critical_value, self.significance_level
371            ));
372        }
373
374        Ok(ChiSquaredAnalysis {
375            sample_size: total,
376            num_bins: n,
377            degrees_of_freedom: df,
378            statistic: chi_squared,
379            p_value,
380            significance_level: self.significance_level,
381            passes,
382            critical_value,
383            bin_frequencies,
384            cramers_v,
385            issues,
386        })
387    }
388}
389
390impl Default for ChiSquaredAnalyzer {
391    fn default() -> Self {
392        Self::new()
393    }
394}
395
396/// Calculate p-value for chi-squared statistic using incomplete gamma function.
397fn chi_squared_p_value(chi_sq: f64, df: usize) -> f64 {
398    // P(X > chi_sq) = 1 - P(X <= chi_sq) = 1 - gamma_cdf(chi_sq, df)
399    // Using upper incomplete gamma function
400    1.0 - lower_incomplete_gamma(df as f64 / 2.0, chi_sq / 2.0)
401}
402
403/// Calculate chi-squared critical value for given df and significance level.
404fn chi_squared_critical(df: usize, alpha: f64) -> f64 {
405    // Use Wilson-Hilferty approximation for chi-squared quantiles
406    // For df >= 2: chi_sq ≈ df * (1 - 2/(9*df) + z * sqrt(2/(9*df)))^3
407    // where z is the standard normal quantile
408
409    if df == 0 {
410        return 0.0;
411    }
412
413    let df_f = df as f64;
414
415    // Get z-score for 1-alpha quantile
416    let z = normal_quantile(1.0 - alpha);
417
418    // Wilson-Hilferty approximation
419    let term = 2.0 / (9.0 * df_f);
420    let inner = 1.0 - term + z * term.sqrt();
421
422    df_f * inner.powi(3).max(0.0)
423}
424
425/// Lower incomplete gamma function regularized.
426fn lower_incomplete_gamma(a: f64, x: f64) -> f64 {
427    if x <= 0.0 {
428        return 0.0;
429    }
430    if x >= a + 1.0 {
431        // Use continued fraction for large x
432        1.0 - upper_incomplete_gamma_cf(a, x)
433    } else {
434        // Use series expansion for small x
435        lower_incomplete_gamma_series(a, x)
436    }
437}
438
439/// Series expansion for lower incomplete gamma.
440fn lower_incomplete_gamma_series(a: f64, x: f64) -> f64 {
441    let ln_gamma_a = ln_gamma(a);
442    let mut sum = 1.0 / a;
443    let mut term = 1.0 / a;
444
445    for n in 1..200 {
446        term *= x / (a + n as f64);
447        sum += term;
448        if term.abs() < 1e-10 * sum.abs() {
449            break;
450        }
451    }
452
453    sum * x.powf(a) * (-x).exp() / ln_gamma_a.exp()
454}
455
456/// Continued fraction for upper incomplete gamma.
457fn upper_incomplete_gamma_cf(a: f64, x: f64) -> f64 {
458    let ln_gamma_a = ln_gamma(a);
459
460    // Lentz's algorithm
461    let mut f = 1e-30_f64;
462    let mut c = 1e-30_f64;
463    let mut d = 0.0_f64;
464
465    for i in 1..200 {
466        let i_f = i as f64;
467        let an = if i == 1 {
468            1.0
469        } else if i % 2 == 0 {
470            (i_f / 2.0 - 1.0) - a + 1.0
471        } else {
472            (i_f - 1.0) / 2.0
473        };
474        let bn = if i == 1 { x - a + 1.0 } else { x - a + i_f };
475
476        d = bn + an * d;
477        if d.abs() < 1e-30 {
478            d = 1e-30;
479        }
480        c = bn + an / c;
481        if c.abs() < 1e-30 {
482            c = 1e-30;
483        }
484        d = 1.0 / d;
485        let delta = c * d;
486        f *= delta;
487
488        if (delta - 1.0).abs() < 1e-10 {
489            break;
490        }
491    }
492
493    f * x.powf(a) * (-x).exp() / ln_gamma_a.exp()
494}
495
496/// Log gamma function.
497fn ln_gamma(x: f64) -> f64 {
498    if x <= 0.0 {
499        return f64::INFINITY;
500    }
501    // Lanczos approximation
502    let coeffs = [
503        76.18009172947146,
504        -86.50532032941677,
505        24.01409824083091,
506        -1.231739572450155,
507        0.1208650973866179e-2,
508        -0.5395239384953e-5,
509    ];
510
511    let tmp = x + 5.5;
512    let tmp = tmp - (x + 0.5) * tmp.ln();
513
514    let mut ser = 1.000000000190015;
515    for (i, &c) in coeffs.iter().enumerate() {
516        ser += c / (x + (i + 1) as f64);
517    }
518
519    -tmp + (2.5066282746310005 * ser / x).ln()
520}
521
522/// Standard normal quantile (inverse CDF) using rational approximation.
523fn normal_quantile(p: f64) -> f64 {
524    if p <= 0.0 {
525        return f64::NEG_INFINITY;
526    }
527    if p >= 1.0 {
528        return f64::INFINITY;
529    }
530    if p == 0.5 {
531        return 0.0;
532    }
533
534    // Rational approximation from Abramowitz & Stegun
535    let t = if p < 0.5 {
536        (-2.0 * p.ln()).sqrt()
537    } else {
538        (-2.0 * (1.0 - p).ln()).sqrt()
539    };
540
541    let c0 = 2.515517;
542    let c1 = 0.802853;
543    let c2 = 0.010328;
544    let d1 = 1.432788;
545    let d2 = 0.189269;
546    let d3 = 0.001308;
547
548    let z = t - (c0 + c1 * t + c2 * t * t) / (1.0 + d1 * t + d2 * t * t + d3 * t * t * t);
549
550    if p < 0.5 {
551        -z
552    } else {
553        z
554    }
555}
556
557#[cfg(test)]
558mod tests {
559    use super::*;
560    use rand::SeedableRng;
561    use rand_chacha::ChaCha8Rng;
562    use rand_distr::{Distribution, Uniform};
563
564    #[test]
565    fn test_uniform_distribution() {
566        // Generate uniform data and test against uniform expectation
567        let mut rng = ChaCha8Rng::seed_from_u64(42);
568        let uniform = Uniform::new(0.0, 100.0);
569        let values: Vec<f64> = (0..1000).map(|_| uniform.sample(&mut rng)).collect();
570
571        let analyzer = ChiSquaredAnalyzer::new()
572            .with_binning(BinningStrategy::EqualWidth { num_bins: 10 })
573            .with_expected(ExpectedDistribution::Uniform)
574            .with_significance_level(0.05);
575
576        let result = analyzer.analyze_continuous(&values).unwrap();
577        assert!(
578            result.passes,
579            "Uniform data should pass uniform chi-squared test"
580        );
581        assert!(result.p_value > 0.05);
582    }
583
584    #[test]
585    fn test_categorical_uniform() {
586        // Equal counts across categories
587        let observed = vec![100, 98, 102, 100, 100]; // Nearly equal
588
589        let analyzer = ChiSquaredAnalyzer::new()
590            .with_expected(ExpectedDistribution::Uniform)
591            .with_significance_level(0.05);
592
593        let result = analyzer.analyze_categorical(&observed).unwrap();
594        assert!(result.passes, "Nearly uniform counts should pass");
595    }
596
597    #[test]
598    fn test_categorical_deviation() {
599        // Clearly non-uniform distribution
600        let observed = vec![400, 50, 25, 15, 10]; // Very skewed
601
602        let analyzer = ChiSquaredAnalyzer::new()
603            .with_expected(ExpectedDistribution::Uniform)
604            .with_significance_level(0.05);
605
606        let result = analyzer.analyze_categorical(&observed).unwrap();
607        assert!(
608            !result.passes,
609            "Highly skewed counts should fail uniform test"
610        );
611    }
612
613    #[test]
614    fn test_custom_proportions() {
615        // Test against known proportions
616        let observed = vec![300, 200, 100]; // 50%, 33%, 17%
617        let expected_props = vec![0.50, 0.33, 0.17];
618
619        let analyzer = ChiSquaredAnalyzer::new()
620            .with_expected(ExpectedDistribution::Proportions(expected_props))
621            .with_significance_level(0.05);
622
623        let result = analyzer.analyze_categorical(&observed).unwrap();
624        // Should pass or be close to passing given the proportions match
625        assert!(result.sample_size == 600);
626    }
627
628    #[test]
629    fn test_binning_strategies() {
630        let mut rng = ChaCha8Rng::seed_from_u64(42);
631        let uniform = Uniform::new(0.0, 100.0);
632        let values: Vec<f64> = (0..500).map(|_| uniform.sample(&mut rng)).collect();
633
634        // Test equal-width
635        let analyzer1 =
636            ChiSquaredAnalyzer::new().with_binning(BinningStrategy::EqualWidth { num_bins: 10 });
637        let result1 = analyzer1.analyze_continuous(&values).unwrap();
638        assert_eq!(result1.num_bins, 10);
639
640        // Test equal-frequency
641        let analyzer2 =
642            ChiSquaredAnalyzer::new().with_binning(BinningStrategy::EqualFrequency { num_bins: 5 });
643        let result2 = analyzer2.analyze_continuous(&values).unwrap();
644        assert_eq!(result2.num_bins, 5);
645
646        // Test auto
647        let analyzer3 = ChiSquaredAnalyzer::new().with_binning(BinningStrategy::Auto);
648        let result3 = analyzer3.analyze_continuous(&values).unwrap();
649        assert!(result3.num_bins > 0);
650    }
651
652    #[test]
653    fn test_insufficient_data() {
654        let values = vec![1.0, 2.0, 3.0]; // Too few
655
656        let analyzer = ChiSquaredAnalyzer::new();
657        let result = analyzer.analyze_continuous(&values);
658
659        assert!(matches!(
660            result,
661            Err(EvalError::InsufficientData {
662                required: 10,
663                actual: 3
664            })
665        ));
666    }
667
668    #[test]
669    fn test_cramers_v() {
670        // Perfect deviation should have high Cramér's V
671        let observed = vec![500, 0, 0, 0, 0]; // All in first bin
672
673        let analyzer = ChiSquaredAnalyzer::new()
674            .with_expected(ExpectedDistribution::Uniform)
675            .with_significance_level(0.05);
676
677        let result = analyzer.analyze_categorical(&observed).unwrap();
678        assert!(
679            result.cramers_v > 0.5,
680            "Strong deviation should have high V"
681        );
682    }
683
684    #[test]
685    fn test_bin_frequencies() {
686        let observed = vec![50, 100, 50];
687
688        let analyzer = ChiSquaredAnalyzer::new().with_expected(ExpectedDistribution::Uniform);
689
690        let result = analyzer.analyze_categorical(&observed).unwrap();
691
692        assert_eq!(result.bin_frequencies.len(), 3);
693
694        // First bin: observed=50, expected=66.67, contribution = (50-66.67)^2/66.67
695        let first_bin = &result.bin_frequencies[0];
696        assert_eq!(first_bin.observed, 50);
697        assert!((first_bin.expected - 66.666).abs() < 0.01);
698    }
699
700    #[test]
701    fn test_critical_value_ordering() {
702        // Critical values should increase as alpha decreases
703        let cv_10 = chi_squared_critical(10, 0.10);
704        let cv_05 = chi_squared_critical(10, 0.05);
705        let cv_01 = chi_squared_critical(10, 0.01);
706
707        assert!(cv_10 < cv_05);
708        assert!(cv_05 < cv_01);
709    }
710
711    #[test]
712    fn test_p_value_range() {
713        // P-value should be in [0, 1]
714        let p1 = chi_squared_p_value(0.0, 5);
715        let p2 = chi_squared_p_value(5.0, 5);
716        let p3 = chi_squared_p_value(50.0, 5);
717
718        assert!((0.0..=1.0).contains(&p1));
719        assert!((0.0..=1.0).contains(&p2));
720        assert!((0.0..=1.0).contains(&p3));
721
722        // Higher chi-squared should have lower p-value
723        assert!(p1 > p2);
724        assert!(p2 > p3);
725    }
726}