1use crate::error::{EvalError, EvalResult};
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Serialize, Deserialize, Default)]
11pub enum BinningStrategy {
12 EqualWidth { num_bins: usize },
14 EqualFrequency { num_bins: usize },
16 Custom { edges: Vec<f64> },
18 #[default]
20 Auto,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct BinFrequency {
26 pub index: usize,
28 pub lower: f64,
30 pub upper: f64,
32 pub observed: usize,
34 pub expected: f64,
36 pub contribution: f64,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct ChiSquaredAnalysis {
43 pub sample_size: usize,
45 pub num_bins: usize,
47 pub degrees_of_freedom: usize,
49 pub statistic: f64,
51 pub p_value: f64,
53 pub significance_level: f64,
55 pub passes: bool,
57 pub critical_value: f64,
59 pub bin_frequencies: Vec<BinFrequency>,
61 pub cramers_v: f64,
63 pub issues: Vec<String>,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize, Default)]
69pub enum ExpectedDistribution {
70 #[default]
72 Uniform,
73 Custom(Vec<f64>),
75 Proportions(Vec<f64>),
77 Observed(Vec<usize>),
79}
80
81pub struct ChiSquaredAnalyzer {
83 binning: BinningStrategy,
85 expected: ExpectedDistribution,
87 significance_level: f64,
89 min_expected: f64,
91}
92
93impl ChiSquaredAnalyzer {
94 pub fn new() -> Self {
96 Self {
97 binning: BinningStrategy::Auto,
98 expected: ExpectedDistribution::Uniform,
99 significance_level: 0.05,
100 min_expected: 5.0,
101 }
102 }
103
104 pub fn with_binning(mut self, strategy: BinningStrategy) -> Self {
106 self.binning = strategy;
107 self
108 }
109
110 pub fn with_expected(mut self, expected: ExpectedDistribution) -> Self {
112 self.expected = expected;
113 self
114 }
115
116 pub fn with_significance_level(mut self, level: f64) -> Self {
118 self.significance_level = level;
119 self
120 }
121
122 pub fn with_min_expected(mut self, min: f64) -> Self {
124 self.min_expected = min;
125 self
126 }
127
128 pub fn analyze_continuous(&self, values: &[f64]) -> EvalResult<ChiSquaredAnalysis> {
130 let n = values.len();
131 if n < 10 {
132 return Err(EvalError::InsufficientData {
133 required: 10,
134 actual: n,
135 });
136 }
137
138 let valid_values: Vec<f64> = values.iter().filter(|&&v| v.is_finite()).copied().collect();
140
141 if valid_values.len() < 10 {
142 return Err(EvalError::InsufficientData {
143 required: 10,
144 actual: valid_values.len(),
145 });
146 }
147
148 let (edges, observed) = self.bin_data(&valid_values)?;
150 let n_f = valid_values.len() as f64;
151
152 let expected = self.calculate_expected(&observed, n_f)?;
154
155 self.perform_test(&edges, &observed, &expected)
156 }
157
158 pub fn analyze_categorical(&self, observed: &[usize]) -> EvalResult<ChiSquaredAnalysis> {
160 if observed.is_empty() {
161 return Err(EvalError::InvalidParameter(
162 "Observed counts cannot be empty".to_string(),
163 ));
164 }
165
166 let total: usize = observed.iter().sum();
167 if total < 10 {
168 return Err(EvalError::InsufficientData {
169 required: 10,
170 actual: total,
171 });
172 }
173
174 let n_f = total as f64;
175
176 let edges: Vec<f64> = (0..=observed.len()).map(|i| i as f64).collect();
178
179 let expected = self.calculate_expected(observed, n_f)?;
181
182 self.perform_test(&edges, observed, &expected)
183 }
184
185 fn bin_data(&self, values: &[f64]) -> EvalResult<(Vec<f64>, Vec<usize>)> {
187 let mut sorted: Vec<f64> = values.to_vec();
188 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
189
190 let min = sorted[0];
191 let max = sorted[sorted.len() - 1];
192
193 let edges = match &self.binning {
194 BinningStrategy::EqualWidth { num_bins } => {
195 let width = (max - min) / (*num_bins as f64);
196 (0..=*num_bins).map(|i| min + (i as f64) * width).collect()
197 }
198 BinningStrategy::EqualFrequency { num_bins } => {
199 let n = sorted.len();
200 let mut edges = vec![min];
201 for i in 1..*num_bins {
202 let idx = (i * n) / *num_bins;
203 edges.push(sorted[idx.min(n - 1)]);
204 }
205 edges.push(max);
206 edges
207 }
208 BinningStrategy::Custom { edges } => edges.clone(),
209 BinningStrategy::Auto => {
210 let num_bins = (1.0 + (values.len() as f64).log2()).ceil() as usize;
212 let width = (max - min) / (num_bins as f64);
213 (0..=num_bins).map(|i| min + (i as f64) * width).collect()
214 }
215 };
216
217 if edges.len() < 2 {
218 return Err(EvalError::InvalidParameter(
219 "Need at least 2 bin edges".to_string(),
220 ));
221 }
222
223 let num_bins = edges.len() - 1;
225 let mut counts = vec![0usize; num_bins];
226
227 for &v in values {
228 for (i, window) in edges.windows(2).enumerate() {
229 let (lower, upper) = (window[0], window[1]);
230 if v >= lower && (v < upper || (i == num_bins - 1 && v <= upper)) {
231 counts[i] += 1;
232 break;
233 }
234 }
235 }
236
237 Ok((edges, counts))
238 }
239
240 fn calculate_expected(&self, observed: &[usize], total: f64) -> EvalResult<Vec<f64>> {
242 match &self.expected {
243 ExpectedDistribution::Uniform => {
244 let expected_per_bin = total / (observed.len() as f64);
245 Ok(vec![expected_per_bin; observed.len()])
246 }
247 ExpectedDistribution::Custom(expected) => {
248 if expected.len() != observed.len() {
249 return Err(EvalError::InvalidParameter(format!(
250 "Expected {} frequencies, got {}",
251 observed.len(),
252 expected.len()
253 )));
254 }
255 Ok(expected.clone())
256 }
257 ExpectedDistribution::Proportions(props) => {
258 if props.len() != observed.len() {
259 return Err(EvalError::InvalidParameter(format!(
260 "Expected {} proportions, got {}",
261 observed.len(),
262 props.len()
263 )));
264 }
265 let sum: f64 = props.iter().sum();
266 if (sum - 1.0).abs() > 0.01 {
267 return Err(EvalError::InvalidParameter(format!(
268 "Proportions must sum to 1.0, got {sum}"
269 )));
270 }
271 Ok(props.iter().map(|&p| p * total).collect())
272 }
273 ExpectedDistribution::Observed(other) => {
274 if other.len() != observed.len() {
275 return Err(EvalError::InvalidParameter(format!(
276 "Expected {} categories, got {}",
277 observed.len(),
278 other.len()
279 )));
280 }
281 let other_total: f64 = other.iter().sum::<usize>() as f64;
282 Ok(other
283 .iter()
284 .map(|&c| (c as f64) / other_total * total)
285 .collect())
286 }
287 }
288 }
289
290 fn perform_test(
292 &self,
293 edges: &[f64],
294 observed: &[usize],
295 expected: &[f64],
296 ) -> EvalResult<ChiSquaredAnalysis> {
297 let n = observed.len();
298 let total: usize = observed.iter().sum();
299 let n_f = total as f64;
300
301 let mut issues = Vec::new();
302
303 let low_expected: Vec<_> = expected
305 .iter()
306 .enumerate()
307 .filter(|(_, &e)| e < self.min_expected)
308 .collect();
309 if !low_expected.is_empty() {
310 issues.push(format!(
311 "{} bins have expected frequency < {:.1}; results may be unreliable",
312 low_expected.len(),
313 self.min_expected
314 ));
315 }
316
317 let mut chi_squared = 0.0;
319 let mut bin_frequencies = Vec::new();
320
321 for (i, ((&obs, &exp), window)) in observed
322 .iter()
323 .zip(expected.iter())
324 .zip(edges.windows(2))
325 .enumerate()
326 {
327 let contribution = if exp > 0.0 {
328 let diff = obs as f64 - exp;
329 (diff * diff) / exp
330 } else {
331 0.0
332 };
333 chi_squared += contribution;
334
335 bin_frequencies.push(BinFrequency {
336 index: i,
337 lower: window[0],
338 upper: window[1],
339 observed: obs,
340 expected: exp,
341 contribution,
342 });
343 }
344
345 let df = n.saturating_sub(1);
349 if df == 0 {
350 return Err(EvalError::InvalidParameter(
351 "Need at least 2 bins for chi-squared test".to_string(),
352 ));
353 }
354
355 let p_value = chi_squared_p_value(chi_squared, df);
357
358 let critical_value = chi_squared_critical(df, self.significance_level);
360
361 let cramers_v = (chi_squared / n_f).sqrt();
363
364 let passes = chi_squared <= critical_value;
365
366 if !passes {
367 issues.push(format!(
368 "χ² = {:.4} exceeds critical value {:.4} at α = {:.2}",
369 chi_squared, critical_value, self.significance_level
370 ));
371 }
372
373 Ok(ChiSquaredAnalysis {
374 sample_size: total,
375 num_bins: n,
376 degrees_of_freedom: df,
377 statistic: chi_squared,
378 p_value,
379 significance_level: self.significance_level,
380 passes,
381 critical_value,
382 bin_frequencies,
383 cramers_v,
384 issues,
385 })
386 }
387}
388
389impl Default for ChiSquaredAnalyzer {
390 fn default() -> Self {
391 Self::new()
392 }
393}
394
395fn chi_squared_p_value(chi_sq: f64, df: usize) -> f64 {
397 1.0 - lower_incomplete_gamma(df as f64 / 2.0, chi_sq / 2.0)
400}
401
402fn chi_squared_critical(df: usize, alpha: f64) -> f64 {
404 if df == 0 {
409 return 0.0;
410 }
411
412 let df_f = df as f64;
413
414 let z = normal_quantile(1.0 - alpha);
416
417 let term = 2.0 / (9.0 * df_f);
419 let inner = 1.0 - term + z * term.sqrt();
420
421 df_f * inner.powi(3).max(0.0)
422}
423
424fn lower_incomplete_gamma(a: f64, x: f64) -> f64 {
426 if x <= 0.0 {
427 return 0.0;
428 }
429 if x >= a + 1.0 {
430 1.0 - upper_incomplete_gamma_cf(a, x)
432 } else {
433 lower_incomplete_gamma_series(a, x)
435 }
436}
437
438fn lower_incomplete_gamma_series(a: f64, x: f64) -> f64 {
440 let ln_gamma_a = ln_gamma(a);
441 let mut sum = 1.0 / a;
442 let mut term = 1.0 / a;
443
444 for n in 1..200 {
445 term *= x / (a + n as f64);
446 sum += term;
447 if term.abs() < 1e-10 * sum.abs() {
448 break;
449 }
450 }
451
452 sum * x.powf(a) * (-x).exp() / ln_gamma_a.exp()
453}
454
455fn upper_incomplete_gamma_cf(a: f64, x: f64) -> f64 {
457 let ln_gamma_a = ln_gamma(a);
458
459 let mut f = 1e-30_f64;
461 let mut c = 1e-30_f64;
462 let mut d = 0.0_f64;
463
464 for i in 1..200 {
465 let i_f = i as f64;
466 let an = if i == 1 {
467 1.0
468 } else if i % 2 == 0 {
469 (i_f / 2.0 - 1.0) - a + 1.0
470 } else {
471 (i_f - 1.0) / 2.0
472 };
473 let bn = if i == 1 { x - a + 1.0 } else { x - a + i_f };
474
475 d = bn + an * d;
476 if d.abs() < 1e-30 {
477 d = 1e-30;
478 }
479 c = bn + an / c;
480 if c.abs() < 1e-30 {
481 c = 1e-30;
482 }
483 d = 1.0 / d;
484 let delta = c * d;
485 f *= delta;
486
487 if (delta - 1.0).abs() < 1e-10 {
488 break;
489 }
490 }
491
492 f * x.powf(a) * (-x).exp() / ln_gamma_a.exp()
493}
494
495fn ln_gamma(x: f64) -> f64 {
497 if x <= 0.0 {
498 return f64::INFINITY;
499 }
500 let coeffs = [
502 76.18009172947146,
503 -86.50532032941677,
504 24.01409824083091,
505 -1.231739572450155,
506 0.1208650973866179e-2,
507 -0.5395239384953e-5,
508 ];
509
510 let tmp = x + 5.5;
511 let tmp = tmp - (x + 0.5) * tmp.ln();
512
513 let mut ser = 1.000000000190015;
514 for (i, &c) in coeffs.iter().enumerate() {
515 ser += c / (x + (i + 1) as f64);
516 }
517
518 -tmp + (2.5066282746310005 * ser / x).ln()
519}
520
521fn normal_quantile(p: f64) -> f64 {
523 if p <= 0.0 {
524 return f64::NEG_INFINITY;
525 }
526 if p >= 1.0 {
527 return f64::INFINITY;
528 }
529 if p == 0.5 {
530 return 0.0;
531 }
532
533 let t = if p < 0.5 {
535 (-2.0 * p.ln()).sqrt()
536 } else {
537 (-2.0 * (1.0 - p).ln()).sqrt()
538 };
539
540 let c0 = 2.515517;
541 let c1 = 0.802853;
542 let c2 = 0.010328;
543 let d1 = 1.432788;
544 let d2 = 0.189269;
545 let d3 = 0.001308;
546
547 let z = t - (c0 + c1 * t + c2 * t * t) / (1.0 + d1 * t + d2 * t * t + d3 * t * t * t);
548
549 if p < 0.5 {
550 -z
551 } else {
552 z
553 }
554}
555
556#[cfg(test)]
557#[allow(clippy::unwrap_used)]
558mod tests {
559 use super::*;
560 use rand::SeedableRng;
561 use rand_chacha::ChaCha8Rng;
562 use rand_distr::{Distribution, Uniform};
563
564 #[test]
565 fn test_uniform_distribution() {
566 let mut rng = ChaCha8Rng::seed_from_u64(42);
568 let uniform = Uniform::new(0.0, 100.0).unwrap();
569 let values: Vec<f64> = (0..1000).map(|_| uniform.sample(&mut rng)).collect();
570
571 let analyzer = ChiSquaredAnalyzer::new()
572 .with_binning(BinningStrategy::EqualWidth { num_bins: 10 })
573 .with_expected(ExpectedDistribution::Uniform)
574 .with_significance_level(0.05);
575
576 let result = analyzer.analyze_continuous(&values).unwrap();
577 assert!(
578 result.passes,
579 "Uniform data should pass uniform chi-squared test"
580 );
581 assert!(result.p_value > 0.05);
582 }
583
584 #[test]
585 fn test_categorical_uniform() {
586 let observed = vec![100, 98, 102, 100, 100]; let analyzer = ChiSquaredAnalyzer::new()
590 .with_expected(ExpectedDistribution::Uniform)
591 .with_significance_level(0.05);
592
593 let result = analyzer.analyze_categorical(&observed).unwrap();
594 assert!(result.passes, "Nearly uniform counts should pass");
595 }
596
597 #[test]
598 fn test_categorical_deviation() {
599 let observed = vec![400, 50, 25, 15, 10]; let analyzer = ChiSquaredAnalyzer::new()
603 .with_expected(ExpectedDistribution::Uniform)
604 .with_significance_level(0.05);
605
606 let result = analyzer.analyze_categorical(&observed).unwrap();
607 assert!(
608 !result.passes,
609 "Highly skewed counts should fail uniform test"
610 );
611 }
612
613 #[test]
614 fn test_custom_proportions() {
615 let observed = vec![300, 200, 100]; let expected_props = vec![0.50, 0.33, 0.17];
618
619 let analyzer = ChiSquaredAnalyzer::new()
620 .with_expected(ExpectedDistribution::Proportions(expected_props))
621 .with_significance_level(0.05);
622
623 let result = analyzer.analyze_categorical(&observed).unwrap();
624 assert!(result.sample_size == 600);
626 }
627
628 #[test]
629 fn test_binning_strategies() {
630 let mut rng = ChaCha8Rng::seed_from_u64(42);
631 let uniform = Uniform::new(0.0, 100.0).unwrap();
632 let values: Vec<f64> = (0..500).map(|_| uniform.sample(&mut rng)).collect();
633
634 let analyzer1 =
636 ChiSquaredAnalyzer::new().with_binning(BinningStrategy::EqualWidth { num_bins: 10 });
637 let result1 = analyzer1.analyze_continuous(&values).unwrap();
638 assert_eq!(result1.num_bins, 10);
639
640 let analyzer2 =
642 ChiSquaredAnalyzer::new().with_binning(BinningStrategy::EqualFrequency { num_bins: 5 });
643 let result2 = analyzer2.analyze_continuous(&values).unwrap();
644 assert_eq!(result2.num_bins, 5);
645
646 let analyzer3 = ChiSquaredAnalyzer::new().with_binning(BinningStrategy::Auto);
648 let result3 = analyzer3.analyze_continuous(&values).unwrap();
649 assert!(result3.num_bins > 0);
650 }
651
652 #[test]
653 fn test_insufficient_data() {
654 let values = vec![1.0, 2.0, 3.0]; let analyzer = ChiSquaredAnalyzer::new();
657 let result = analyzer.analyze_continuous(&values);
658
659 assert!(matches!(
660 result,
661 Err(EvalError::InsufficientData {
662 required: 10,
663 actual: 3
664 })
665 ));
666 }
667
668 #[test]
669 fn test_cramers_v() {
670 let observed = vec![500, 0, 0, 0, 0]; let analyzer = ChiSquaredAnalyzer::new()
674 .with_expected(ExpectedDistribution::Uniform)
675 .with_significance_level(0.05);
676
677 let result = analyzer.analyze_categorical(&observed).unwrap();
678 assert!(
679 result.cramers_v > 0.5,
680 "Strong deviation should have high V"
681 );
682 }
683
684 #[test]
685 fn test_bin_frequencies() {
686 let observed = vec![50, 100, 50];
687
688 let analyzer = ChiSquaredAnalyzer::new().with_expected(ExpectedDistribution::Uniform);
689
690 let result = analyzer.analyze_categorical(&observed).unwrap();
691
692 assert_eq!(result.bin_frequencies.len(), 3);
693
694 let first_bin = &result.bin_frequencies[0];
696 assert_eq!(first_bin.observed, 50);
697 assert!((first_bin.expected - 66.666).abs() < 0.01);
698 }
699
700 #[test]
701 fn test_critical_value_ordering() {
702 let cv_10 = chi_squared_critical(10, 0.10);
704 let cv_05 = chi_squared_critical(10, 0.05);
705 let cv_01 = chi_squared_critical(10, 0.01);
706
707 assert!(cv_10 < cv_05);
708 assert!(cv_05 < cv_01);
709 }
710
711 #[test]
712 fn test_p_value_range() {
713 let p1 = chi_squared_p_value(0.0, 5);
715 let p2 = chi_squared_p_value(5.0, 5);
716 let p3 = chi_squared_p_value(50.0, 5);
717
718 assert!((0.0..=1.0).contains(&p1));
719 assert!((0.0..=1.0).contains(&p2));
720 assert!((0.0..=1.0).contains(&p3));
721
722 assert!(p1 > p2);
724 assert!(p2 > p3);
725 }
726}