Skip to main content

datasynth_eval/statistical/
chi_squared.rs

1//! Chi-squared goodness-of-fit test for distribution validation.
2//!
3//! Tests whether observed frequency distribution matches expected frequencies.
4//! Useful for categorical data and binned continuous data.
5
6use crate::error::{EvalError, EvalResult};
7use serde::{Deserialize, Serialize};
8
9/// Binning strategy for continuous data.
10#[derive(Debug, Clone, Serialize, Deserialize, Default)]
11pub enum BinningStrategy {
12    /// Fixed number of equal-width bins
13    EqualWidth { num_bins: usize },
14    /// Equal-frequency (quantile) bins
15    EqualFrequency { num_bins: usize },
16    /// Custom bin edges
17    Custom { edges: Vec<f64> },
18    /// Automatic binning using Sturges' rule
19    #[default]
20    Auto,
21}
22
23/// Bin frequency information.
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct BinFrequency {
26    /// Bin index
27    pub index: usize,
28    /// Bin lower edge (inclusive)
29    pub lower: f64,
30    /// Bin upper edge (exclusive, except for last bin)
31    pub upper: f64,
32    /// Observed count
33    pub observed: usize,
34    /// Expected count
35    pub expected: f64,
36    /// Contribution to chi-squared statistic
37    pub contribution: f64,
38}
39
40/// Chi-squared test results.
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct ChiSquaredAnalysis {
43    /// Sample size
44    pub sample_size: usize,
45    /// Number of bins
46    pub num_bins: usize,
47    /// Degrees of freedom
48    pub degrees_of_freedom: usize,
49    /// Chi-squared test statistic
50    pub statistic: f64,
51    /// P-value
52    pub p_value: f64,
53    /// Significance level used for pass/fail
54    pub significance_level: f64,
55    /// Whether the test passes
56    pub passes: bool,
57    /// Critical value at significance level
58    pub critical_value: f64,
59    /// Bin frequencies with observed vs expected
60    pub bin_frequencies: Vec<BinFrequency>,
61    /// Cramér's V effect size (0 = no association, 1 = perfect association)
62    pub cramers_v: f64,
63    /// Issues found during analysis
64    pub issues: Vec<String>,
65}
66
67/// Expected distribution type for comparison.
68#[derive(Debug, Clone, Serialize, Deserialize, Default)]
69pub enum ExpectedDistribution {
70    /// Uniform distribution (equal probability per bin)
71    #[default]
72    Uniform,
73    /// Custom expected frequencies (must sum to sample size)
74    Custom(Vec<f64>),
75    /// Expected proportions (must sum to 1.0)
76    Proportions(Vec<f64>),
77    /// Compare against another observed distribution
78    Observed(Vec<usize>),
79}
80
81/// Analyzer for chi-squared goodness-of-fit tests.
82pub struct ChiSquaredAnalyzer {
83    /// Binning strategy for continuous data
84    binning: BinningStrategy,
85    /// Expected distribution
86    expected: ExpectedDistribution,
87    /// Significance level
88    significance_level: f64,
89    /// Minimum expected frequency per bin (for validity)
90    min_expected: f64,
91}
92
93impl ChiSquaredAnalyzer {
94    /// Create a new analyzer with default settings.
95    pub fn new() -> Self {
96        Self {
97            binning: BinningStrategy::Auto,
98            expected: ExpectedDistribution::Uniform,
99            significance_level: 0.05,
100            min_expected: 5.0,
101        }
102    }
103
104    /// Set the binning strategy.
105    pub fn with_binning(mut self, strategy: BinningStrategy) -> Self {
106        self.binning = strategy;
107        self
108    }
109
110    /// Set the expected distribution.
111    pub fn with_expected(mut self, expected: ExpectedDistribution) -> Self {
112        self.expected = expected;
113        self
114    }
115
116    /// Set the significance level.
117    pub fn with_significance_level(mut self, level: f64) -> Self {
118        self.significance_level = level;
119        self
120    }
121
122    /// Set minimum expected frequency per bin.
123    pub fn with_min_expected(mut self, min: f64) -> Self {
124        self.min_expected = min;
125        self
126    }
127
128    /// Analyze continuous data (will be binned).
129    pub fn analyze_continuous(&self, values: &[f64]) -> EvalResult<ChiSquaredAnalysis> {
130        let n = values.len();
131        if n < 10 {
132            return Err(EvalError::InsufficientData {
133                required: 10,
134                actual: n,
135            });
136        }
137
138        // Filter invalid values
139        let valid_values: Vec<f64> = values.iter().filter(|&&v| v.is_finite()).copied().collect();
140
141        if valid_values.len() < 10 {
142            return Err(EvalError::InsufficientData {
143                required: 10,
144                actual: valid_values.len(),
145            });
146        }
147
148        // Create bins
149        let (edges, observed) = self.bin_data(&valid_values)?;
150        let n_f = valid_values.len() as f64;
151
152        // Calculate expected frequencies
153        let expected = self.calculate_expected(&observed, n_f)?;
154
155        self.perform_test(&edges, &observed, &expected)
156    }
157
158    /// Analyze categorical/count data directly.
159    pub fn analyze_categorical(&self, observed: &[usize]) -> EvalResult<ChiSquaredAnalysis> {
160        if observed.is_empty() {
161            return Err(EvalError::InvalidParameter(
162                "Observed counts cannot be empty".to_string(),
163            ));
164        }
165
166        let total: usize = observed.iter().sum();
167        if total < 10 {
168            return Err(EvalError::InsufficientData {
169                required: 10,
170                actual: total,
171            });
172        }
173
174        let n_f = total as f64;
175
176        // Create pseudo-edges for categorical bins
177        let edges: Vec<f64> = (0..=observed.len()).map(|i| i as f64).collect();
178
179        // Calculate expected frequencies
180        let expected = self.calculate_expected(observed, n_f)?;
181
182        self.perform_test(&edges, observed, &expected)
183    }
184
185    /// Bin continuous data according to strategy.
186    fn bin_data(&self, values: &[f64]) -> EvalResult<(Vec<f64>, Vec<usize>)> {
187        let mut sorted: Vec<f64> = values.to_vec();
188        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
189
190        let min = sorted[0];
191        let max = sorted[sorted.len() - 1];
192
193        let edges = match &self.binning {
194            BinningStrategy::EqualWidth { num_bins } => {
195                let width = (max - min) / (*num_bins as f64);
196                (0..=*num_bins).map(|i| min + (i as f64) * width).collect()
197            }
198            BinningStrategy::EqualFrequency { num_bins } => {
199                let n = sorted.len();
200                let mut edges = vec![min];
201                for i in 1..*num_bins {
202                    let idx = (i * n) / *num_bins;
203                    edges.push(sorted[idx.min(n - 1)]);
204                }
205                edges.push(max);
206                edges
207            }
208            BinningStrategy::Custom { edges } => edges.clone(),
209            BinningStrategy::Auto => {
210                // Sturges' rule
211                let num_bins = (1.0 + (values.len() as f64).log2()).ceil() as usize;
212                let width = (max - min) / (num_bins as f64);
213                (0..=num_bins).map(|i| min + (i as f64) * width).collect()
214            }
215        };
216
217        if edges.len() < 2 {
218            return Err(EvalError::InvalidParameter(
219                "Need at least 2 bin edges".to_string(),
220            ));
221        }
222
223        // Count observations in each bin
224        let num_bins = edges.len() - 1;
225        let mut counts = vec![0usize; num_bins];
226
227        for &v in values {
228            for (i, window) in edges.windows(2).enumerate() {
229                let (lower, upper) = (window[0], window[1]);
230                if v >= lower && (v < upper || (i == num_bins - 1 && v <= upper)) {
231                    counts[i] += 1;
232                    break;
233                }
234            }
235        }
236
237        Ok((edges, counts))
238    }
239
240    /// Calculate expected frequencies based on distribution type.
241    fn calculate_expected(&self, observed: &[usize], total: f64) -> EvalResult<Vec<f64>> {
242        match &self.expected {
243            ExpectedDistribution::Uniform => {
244                let expected_per_bin = total / (observed.len() as f64);
245                Ok(vec![expected_per_bin; observed.len()])
246            }
247            ExpectedDistribution::Custom(expected) => {
248                if expected.len() != observed.len() {
249                    return Err(EvalError::InvalidParameter(format!(
250                        "Expected {} frequencies, got {}",
251                        observed.len(),
252                        expected.len()
253                    )));
254                }
255                Ok(expected.clone())
256            }
257            ExpectedDistribution::Proportions(props) => {
258                if props.len() != observed.len() {
259                    return Err(EvalError::InvalidParameter(format!(
260                        "Expected {} proportions, got {}",
261                        observed.len(),
262                        props.len()
263                    )));
264                }
265                let sum: f64 = props.iter().sum();
266                if (sum - 1.0).abs() > 0.01 {
267                    return Err(EvalError::InvalidParameter(format!(
268                        "Proportions must sum to 1.0, got {sum}"
269                    )));
270                }
271                Ok(props.iter().map(|&p| p * total).collect())
272            }
273            ExpectedDistribution::Observed(other) => {
274                if other.len() != observed.len() {
275                    return Err(EvalError::InvalidParameter(format!(
276                        "Expected {} categories, got {}",
277                        observed.len(),
278                        other.len()
279                    )));
280                }
281                let other_total: f64 = other.iter().sum::<usize>() as f64;
282                Ok(other
283                    .iter()
284                    .map(|&c| (c as f64) / other_total * total)
285                    .collect())
286            }
287        }
288    }
289
290    /// Perform the chi-squared test.
291    fn perform_test(
292        &self,
293        edges: &[f64],
294        observed: &[usize],
295        expected: &[f64],
296    ) -> EvalResult<ChiSquaredAnalysis> {
297        let n = observed.len();
298        let total: usize = observed.iter().sum();
299        let n_f = total as f64;
300
301        let mut issues = Vec::new();
302
303        // Check minimum expected frequency
304        let low_expected: Vec<_> = expected
305            .iter()
306            .enumerate()
307            .filter(|(_, &e)| e < self.min_expected)
308            .collect();
309        if !low_expected.is_empty() {
310            issues.push(format!(
311                "{} bins have expected frequency < {:.1}; results may be unreliable",
312                low_expected.len(),
313                self.min_expected
314            ));
315        }
316
317        // Calculate chi-squared statistic and bin details
318        let mut chi_squared = 0.0;
319        let mut bin_frequencies = Vec::new();
320
321        for (i, ((&obs, &exp), window)) in observed
322            .iter()
323            .zip(expected.iter())
324            .zip(edges.windows(2))
325            .enumerate()
326        {
327            let contribution = if exp > 0.0 {
328                let diff = obs as f64 - exp;
329                (diff * diff) / exp
330            } else {
331                0.0
332            };
333            chi_squared += contribution;
334
335            bin_frequencies.push(BinFrequency {
336                index: i,
337                lower: window[0],
338                upper: window[1],
339                observed: obs,
340                expected: exp,
341                contribution,
342            });
343        }
344
345        // Degrees of freedom
346        // For goodness-of-fit: df = num_bins - 1 - estimated_parameters
347        // For uniform: df = num_bins - 1 (no parameters estimated from data)
348        let df = n.saturating_sub(1);
349        if df == 0 {
350            return Err(EvalError::InvalidParameter(
351                "Need at least 2 bins for chi-squared test".to_string(),
352            ));
353        }
354
355        // Calculate p-value
356        let p_value = chi_squared_p_value(chi_squared, df);
357
358        // Critical value
359        let critical_value = chi_squared_critical(df, self.significance_level);
360
361        // Cramér's V effect size
362        let cramers_v = (chi_squared / n_f).sqrt();
363
364        let passes = chi_squared <= critical_value;
365
366        if !passes {
367            issues.push(format!(
368                "χ² = {:.4} exceeds critical value {:.4} at α = {:.2}",
369                chi_squared, critical_value, self.significance_level
370            ));
371        }
372
373        Ok(ChiSquaredAnalysis {
374            sample_size: total,
375            num_bins: n,
376            degrees_of_freedom: df,
377            statistic: chi_squared,
378            p_value,
379            significance_level: self.significance_level,
380            passes,
381            critical_value,
382            bin_frequencies,
383            cramers_v,
384            issues,
385        })
386    }
387}
388
389impl Default for ChiSquaredAnalyzer {
390    fn default() -> Self {
391        Self::new()
392    }
393}
394
395/// Calculate p-value for chi-squared statistic using incomplete gamma function.
396fn chi_squared_p_value(chi_sq: f64, df: usize) -> f64 {
397    // P(X > chi_sq) = 1 - P(X <= chi_sq) = 1 - gamma_cdf(chi_sq, df)
398    // Using upper incomplete gamma function
399    1.0 - lower_incomplete_gamma(df as f64 / 2.0, chi_sq / 2.0)
400}
401
402/// Calculate chi-squared critical value for given df and significance level.
403fn chi_squared_critical(df: usize, alpha: f64) -> f64 {
404    // Use Wilson-Hilferty approximation for chi-squared quantiles
405    // For df >= 2: chi_sq ≈ df * (1 - 2/(9*df) + z * sqrt(2/(9*df)))^3
406    // where z is the standard normal quantile
407
408    if df == 0 {
409        return 0.0;
410    }
411
412    let df_f = df as f64;
413
414    // Get z-score for 1-alpha quantile
415    let z = normal_quantile(1.0 - alpha);
416
417    // Wilson-Hilferty approximation
418    let term = 2.0 / (9.0 * df_f);
419    let inner = 1.0 - term + z * term.sqrt();
420
421    df_f * inner.powi(3).max(0.0)
422}
423
424/// Lower incomplete gamma function regularized.
425fn lower_incomplete_gamma(a: f64, x: f64) -> f64 {
426    if x <= 0.0 {
427        return 0.0;
428    }
429    if x >= a + 1.0 {
430        // Use continued fraction for large x
431        1.0 - upper_incomplete_gamma_cf(a, x)
432    } else {
433        // Use series expansion for small x
434        lower_incomplete_gamma_series(a, x)
435    }
436}
437
438/// Series expansion for lower incomplete gamma.
439fn lower_incomplete_gamma_series(a: f64, x: f64) -> f64 {
440    let ln_gamma_a = ln_gamma(a);
441    let mut sum = 1.0 / a;
442    let mut term = 1.0 / a;
443
444    for n in 1..200 {
445        term *= x / (a + n as f64);
446        sum += term;
447        if term.abs() < 1e-10 * sum.abs() {
448            break;
449        }
450    }
451
452    sum * x.powf(a) * (-x).exp() / ln_gamma_a.exp()
453}
454
455/// Continued fraction for upper incomplete gamma.
456fn upper_incomplete_gamma_cf(a: f64, x: f64) -> f64 {
457    let ln_gamma_a = ln_gamma(a);
458
459    // Lentz's algorithm
460    let mut f = 1e-30_f64;
461    let mut c = 1e-30_f64;
462    let mut d = 0.0_f64;
463
464    for i in 1..200 {
465        let i_f = i as f64;
466        let an = if i == 1 {
467            1.0
468        } else if i % 2 == 0 {
469            (i_f / 2.0 - 1.0) - a + 1.0
470        } else {
471            (i_f - 1.0) / 2.0
472        };
473        let bn = if i == 1 { x - a + 1.0 } else { x - a + i_f };
474
475        d = bn + an * d;
476        if d.abs() < 1e-30 {
477            d = 1e-30;
478        }
479        c = bn + an / c;
480        if c.abs() < 1e-30 {
481            c = 1e-30;
482        }
483        d = 1.0 / d;
484        let delta = c * d;
485        f *= delta;
486
487        if (delta - 1.0).abs() < 1e-10 {
488            break;
489        }
490    }
491
492    f * x.powf(a) * (-x).exp() / ln_gamma_a.exp()
493}
494
495/// Log gamma function.
496fn ln_gamma(x: f64) -> f64 {
497    if x <= 0.0 {
498        return f64::INFINITY;
499    }
500    // Lanczos approximation
501    let coeffs = [
502        76.18009172947146,
503        -86.50532032941677,
504        24.01409824083091,
505        -1.231739572450155,
506        0.1208650973866179e-2,
507        -0.5395239384953e-5,
508    ];
509
510    let tmp = x + 5.5;
511    let tmp = tmp - (x + 0.5) * tmp.ln();
512
513    let mut ser = 1.000000000190015;
514    for (i, &c) in coeffs.iter().enumerate() {
515        ser += c / (x + (i + 1) as f64);
516    }
517
518    -tmp + (2.5066282746310005 * ser / x).ln()
519}
520
521/// Standard normal quantile (inverse CDF) using rational approximation.
522fn normal_quantile(p: f64) -> f64 {
523    if p <= 0.0 {
524        return f64::NEG_INFINITY;
525    }
526    if p >= 1.0 {
527        return f64::INFINITY;
528    }
529    if p == 0.5 {
530        return 0.0;
531    }
532
533    // Rational approximation from Abramowitz & Stegun
534    let t = if p < 0.5 {
535        (-2.0 * p.ln()).sqrt()
536    } else {
537        (-2.0 * (1.0 - p).ln()).sqrt()
538    };
539
540    let c0 = 2.515517;
541    let c1 = 0.802853;
542    let c2 = 0.010328;
543    let d1 = 1.432788;
544    let d2 = 0.189269;
545    let d3 = 0.001308;
546
547    let z = t - (c0 + c1 * t + c2 * t * t) / (1.0 + d1 * t + d2 * t * t + d3 * t * t * t);
548
549    if p < 0.5 {
550        -z
551    } else {
552        z
553    }
554}
555
556#[cfg(test)]
557#[allow(clippy::unwrap_used)]
558mod tests {
559    use super::*;
560    use rand::SeedableRng;
561    use rand_chacha::ChaCha8Rng;
562    use rand_distr::{Distribution, Uniform};
563
564    #[test]
565    fn test_uniform_distribution() {
566        // Generate uniform data and test against uniform expectation
567        let mut rng = ChaCha8Rng::seed_from_u64(42);
568        let uniform = Uniform::new(0.0, 100.0).unwrap();
569        let values: Vec<f64> = (0..1000).map(|_| uniform.sample(&mut rng)).collect();
570
571        let analyzer = ChiSquaredAnalyzer::new()
572            .with_binning(BinningStrategy::EqualWidth { num_bins: 10 })
573            .with_expected(ExpectedDistribution::Uniform)
574            .with_significance_level(0.05);
575
576        let result = analyzer.analyze_continuous(&values).unwrap();
577        assert!(
578            result.passes,
579            "Uniform data should pass uniform chi-squared test"
580        );
581        assert!(result.p_value > 0.05);
582    }
583
584    #[test]
585    fn test_categorical_uniform() {
586        // Equal counts across categories
587        let observed = vec![100, 98, 102, 100, 100]; // Nearly equal
588
589        let analyzer = ChiSquaredAnalyzer::new()
590            .with_expected(ExpectedDistribution::Uniform)
591            .with_significance_level(0.05);
592
593        let result = analyzer.analyze_categorical(&observed).unwrap();
594        assert!(result.passes, "Nearly uniform counts should pass");
595    }
596
597    #[test]
598    fn test_categorical_deviation() {
599        // Clearly non-uniform distribution
600        let observed = vec![400, 50, 25, 15, 10]; // Very skewed
601
602        let analyzer = ChiSquaredAnalyzer::new()
603            .with_expected(ExpectedDistribution::Uniform)
604            .with_significance_level(0.05);
605
606        let result = analyzer.analyze_categorical(&observed).unwrap();
607        assert!(
608            !result.passes,
609            "Highly skewed counts should fail uniform test"
610        );
611    }
612
613    #[test]
614    fn test_custom_proportions() {
615        // Test against known proportions
616        let observed = vec![300, 200, 100]; // 50%, 33%, 17%
617        let expected_props = vec![0.50, 0.33, 0.17];
618
619        let analyzer = ChiSquaredAnalyzer::new()
620            .with_expected(ExpectedDistribution::Proportions(expected_props))
621            .with_significance_level(0.05);
622
623        let result = analyzer.analyze_categorical(&observed).unwrap();
624        // Should pass or be close to passing given the proportions match
625        assert!(result.sample_size == 600);
626    }
627
628    #[test]
629    fn test_binning_strategies() {
630        let mut rng = ChaCha8Rng::seed_from_u64(42);
631        let uniform = Uniform::new(0.0, 100.0).unwrap();
632        let values: Vec<f64> = (0..500).map(|_| uniform.sample(&mut rng)).collect();
633
634        // Test equal-width
635        let analyzer1 =
636            ChiSquaredAnalyzer::new().with_binning(BinningStrategy::EqualWidth { num_bins: 10 });
637        let result1 = analyzer1.analyze_continuous(&values).unwrap();
638        assert_eq!(result1.num_bins, 10);
639
640        // Test equal-frequency
641        let analyzer2 =
642            ChiSquaredAnalyzer::new().with_binning(BinningStrategy::EqualFrequency { num_bins: 5 });
643        let result2 = analyzer2.analyze_continuous(&values).unwrap();
644        assert_eq!(result2.num_bins, 5);
645
646        // Test auto
647        let analyzer3 = ChiSquaredAnalyzer::new().with_binning(BinningStrategy::Auto);
648        let result3 = analyzer3.analyze_continuous(&values).unwrap();
649        assert!(result3.num_bins > 0);
650    }
651
652    #[test]
653    fn test_insufficient_data() {
654        let values = vec![1.0, 2.0, 3.0]; // Too few
655
656        let analyzer = ChiSquaredAnalyzer::new();
657        let result = analyzer.analyze_continuous(&values);
658
659        assert!(matches!(
660            result,
661            Err(EvalError::InsufficientData {
662                required: 10,
663                actual: 3
664            })
665        ));
666    }
667
668    #[test]
669    fn test_cramers_v() {
670        // Perfect deviation should have high Cramér's V
671        let observed = vec![500, 0, 0, 0, 0]; // All in first bin
672
673        let analyzer = ChiSquaredAnalyzer::new()
674            .with_expected(ExpectedDistribution::Uniform)
675            .with_significance_level(0.05);
676
677        let result = analyzer.analyze_categorical(&observed).unwrap();
678        assert!(
679            result.cramers_v > 0.5,
680            "Strong deviation should have high V"
681        );
682    }
683
684    #[test]
685    fn test_bin_frequencies() {
686        let observed = vec![50, 100, 50];
687
688        let analyzer = ChiSquaredAnalyzer::new().with_expected(ExpectedDistribution::Uniform);
689
690        let result = analyzer.analyze_categorical(&observed).unwrap();
691
692        assert_eq!(result.bin_frequencies.len(), 3);
693
694        // First bin: observed=50, expected=66.67, contribution = (50-66.67)^2/66.67
695        let first_bin = &result.bin_frequencies[0];
696        assert_eq!(first_bin.observed, 50);
697        assert!((first_bin.expected - 66.666).abs() < 0.01);
698    }
699
700    #[test]
701    fn test_critical_value_ordering() {
702        // Critical values should increase as alpha decreases
703        let cv_10 = chi_squared_critical(10, 0.10);
704        let cv_05 = chi_squared_critical(10, 0.05);
705        let cv_01 = chi_squared_critical(10, 0.01);
706
707        assert!(cv_10 < cv_05);
708        assert!(cv_05 < cv_01);
709    }
710
711    #[test]
712    fn test_p_value_range() {
713        // P-value should be in [0, 1]
714        let p1 = chi_squared_p_value(0.0, 5);
715        let p2 = chi_squared_p_value(5.0, 5);
716        let p3 = chi_squared_p_value(50.0, 5);
717
718        assert!((0.0..=1.0).contains(&p1));
719        assert!((0.0..=1.0).contains(&p2));
720        assert!((0.0..=1.0).contains(&p3));
721
722        // Higher chi-squared should have lower p-value
723        assert!(p1 > p2);
724        assert!(p2 > p3);
725    }
726}