1use crate::error::{EvalError, EvalResult};
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Serialize, Deserialize, Default)]
11pub enum BinningStrategy {
12 EqualWidth { num_bins: usize },
14 EqualFrequency { num_bins: usize },
16 Custom { edges: Vec<f64> },
18 #[default]
20 Auto,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct BinFrequency {
26 pub index: usize,
28 pub lower: f64,
30 pub upper: f64,
32 pub observed: usize,
34 pub expected: f64,
36 pub contribution: f64,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct ChiSquaredAnalysis {
43 pub sample_size: usize,
45 pub num_bins: usize,
47 pub degrees_of_freedom: usize,
49 pub statistic: f64,
51 pub p_value: f64,
53 pub significance_level: f64,
55 pub passes: bool,
57 pub critical_value: f64,
59 pub bin_frequencies: Vec<BinFrequency>,
61 pub cramers_v: f64,
63 pub issues: Vec<String>,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize, Default)]
69pub enum ExpectedDistribution {
70 #[default]
72 Uniform,
73 Custom(Vec<f64>),
75 Proportions(Vec<f64>),
77 Observed(Vec<usize>),
79}
80
81pub struct ChiSquaredAnalyzer {
83 binning: BinningStrategy,
85 expected: ExpectedDistribution,
87 significance_level: f64,
89 min_expected: f64,
91}
92
93impl ChiSquaredAnalyzer {
94 pub fn new() -> Self {
96 Self {
97 binning: BinningStrategy::Auto,
98 expected: ExpectedDistribution::Uniform,
99 significance_level: 0.05,
100 min_expected: 5.0,
101 }
102 }
103
104 pub fn with_binning(mut self, strategy: BinningStrategy) -> Self {
106 self.binning = strategy;
107 self
108 }
109
110 pub fn with_expected(mut self, expected: ExpectedDistribution) -> Self {
112 self.expected = expected;
113 self
114 }
115
116 pub fn with_significance_level(mut self, level: f64) -> Self {
118 self.significance_level = level;
119 self
120 }
121
122 pub fn with_min_expected(mut self, min: f64) -> Self {
124 self.min_expected = min;
125 self
126 }
127
128 pub fn analyze_continuous(&self, values: &[f64]) -> EvalResult<ChiSquaredAnalysis> {
130 let n = values.len();
131 if n < 10 {
132 return Err(EvalError::InsufficientData {
133 required: 10,
134 actual: n,
135 });
136 }
137
138 let valid_values: Vec<f64> = values.iter().filter(|&&v| v.is_finite()).copied().collect();
140
141 if valid_values.len() < 10 {
142 return Err(EvalError::InsufficientData {
143 required: 10,
144 actual: valid_values.len(),
145 });
146 }
147
148 let (edges, observed) = self.bin_data(&valid_values)?;
150 let n_f = valid_values.len() as f64;
151
152 let expected = self.calculate_expected(&observed, n_f)?;
154
155 self.perform_test(&edges, &observed, &expected)
156 }
157
158 pub fn analyze_categorical(&self, observed: &[usize]) -> EvalResult<ChiSquaredAnalysis> {
160 if observed.is_empty() {
161 return Err(EvalError::InvalidParameter(
162 "Observed counts cannot be empty".to_string(),
163 ));
164 }
165
166 let total: usize = observed.iter().sum();
167 if total < 10 {
168 return Err(EvalError::InsufficientData {
169 required: 10,
170 actual: total,
171 });
172 }
173
174 let n_f = total as f64;
175
176 let edges: Vec<f64> = (0..=observed.len()).map(|i| i as f64).collect();
178
179 let expected = self.calculate_expected(observed, n_f)?;
181
182 self.perform_test(&edges, observed, &expected)
183 }
184
185 fn bin_data(&self, values: &[f64]) -> EvalResult<(Vec<f64>, Vec<usize>)> {
187 let mut sorted: Vec<f64> = values.to_vec();
188 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
189
190 let min = sorted[0];
191 let max = sorted[sorted.len() - 1];
192
193 let edges = match &self.binning {
194 BinningStrategy::EqualWidth { num_bins } => {
195 let width = (max - min) / (*num_bins as f64);
196 (0..=*num_bins).map(|i| min + (i as f64) * width).collect()
197 }
198 BinningStrategy::EqualFrequency { num_bins } => {
199 let n = sorted.len();
200 let mut edges = vec![min];
201 for i in 1..*num_bins {
202 let idx = (i * n) / *num_bins;
203 edges.push(sorted[idx.min(n - 1)]);
204 }
205 edges.push(max);
206 edges
207 }
208 BinningStrategy::Custom { edges } => edges.clone(),
209 BinningStrategy::Auto => {
210 let num_bins = (1.0 + (values.len() as f64).log2()).ceil() as usize;
212 let width = (max - min) / (num_bins as f64);
213 (0..=num_bins).map(|i| min + (i as f64) * width).collect()
214 }
215 };
216
217 if edges.len() < 2 {
218 return Err(EvalError::InvalidParameter(
219 "Need at least 2 bin edges".to_string(),
220 ));
221 }
222
223 let num_bins = edges.len() - 1;
225 let mut counts = vec![0usize; num_bins];
226
227 for &v in values {
228 for (i, window) in edges.windows(2).enumerate() {
229 let (lower, upper) = (window[0], window[1]);
230 if v >= lower && (v < upper || (i == num_bins - 1 && v <= upper)) {
231 counts[i] += 1;
232 break;
233 }
234 }
235 }
236
237 Ok((edges, counts))
238 }
239
240 fn calculate_expected(&self, observed: &[usize], total: f64) -> EvalResult<Vec<f64>> {
242 match &self.expected {
243 ExpectedDistribution::Uniform => {
244 let expected_per_bin = total / (observed.len() as f64);
245 Ok(vec![expected_per_bin; observed.len()])
246 }
247 ExpectedDistribution::Custom(expected) => {
248 if expected.len() != observed.len() {
249 return Err(EvalError::InvalidParameter(format!(
250 "Expected {} frequencies, got {}",
251 observed.len(),
252 expected.len()
253 )));
254 }
255 Ok(expected.clone())
256 }
257 ExpectedDistribution::Proportions(props) => {
258 if props.len() != observed.len() {
259 return Err(EvalError::InvalidParameter(format!(
260 "Expected {} proportions, got {}",
261 observed.len(),
262 props.len()
263 )));
264 }
265 let sum: f64 = props.iter().sum();
266 if (sum - 1.0).abs() > 0.01 {
267 return Err(EvalError::InvalidParameter(format!(
268 "Proportions must sum to 1.0, got {}",
269 sum
270 )));
271 }
272 Ok(props.iter().map(|&p| p * total).collect())
273 }
274 ExpectedDistribution::Observed(other) => {
275 if other.len() != observed.len() {
276 return Err(EvalError::InvalidParameter(format!(
277 "Expected {} categories, got {}",
278 observed.len(),
279 other.len()
280 )));
281 }
282 let other_total: f64 = other.iter().sum::<usize>() as f64;
283 Ok(other
284 .iter()
285 .map(|&c| (c as f64) / other_total * total)
286 .collect())
287 }
288 }
289 }
290
291 fn perform_test(
293 &self,
294 edges: &[f64],
295 observed: &[usize],
296 expected: &[f64],
297 ) -> EvalResult<ChiSquaredAnalysis> {
298 let n = observed.len();
299 let total: usize = observed.iter().sum();
300 let n_f = total as f64;
301
302 let mut issues = Vec::new();
303
304 let low_expected: Vec<_> = expected
306 .iter()
307 .enumerate()
308 .filter(|(_, &e)| e < self.min_expected)
309 .collect();
310 if !low_expected.is_empty() {
311 issues.push(format!(
312 "{} bins have expected frequency < {:.1}; results may be unreliable",
313 low_expected.len(),
314 self.min_expected
315 ));
316 }
317
318 let mut chi_squared = 0.0;
320 let mut bin_frequencies = Vec::new();
321
322 for (i, ((&obs, &exp), window)) in observed
323 .iter()
324 .zip(expected.iter())
325 .zip(edges.windows(2))
326 .enumerate()
327 {
328 let contribution = if exp > 0.0 {
329 let diff = obs as f64 - exp;
330 (diff * diff) / exp
331 } else {
332 0.0
333 };
334 chi_squared += contribution;
335
336 bin_frequencies.push(BinFrequency {
337 index: i,
338 lower: window[0],
339 upper: window[1],
340 observed: obs,
341 expected: exp,
342 contribution,
343 });
344 }
345
346 let df = n.saturating_sub(1);
350 if df == 0 {
351 return Err(EvalError::InvalidParameter(
352 "Need at least 2 bins for chi-squared test".to_string(),
353 ));
354 }
355
356 let p_value = chi_squared_p_value(chi_squared, df);
358
359 let critical_value = chi_squared_critical(df, self.significance_level);
361
362 let cramers_v = (chi_squared / n_f).sqrt();
364
365 let passes = chi_squared <= critical_value;
366
367 if !passes {
368 issues.push(format!(
369 "χ² = {:.4} exceeds critical value {:.4} at α = {:.2}",
370 chi_squared, critical_value, self.significance_level
371 ));
372 }
373
374 Ok(ChiSquaredAnalysis {
375 sample_size: total,
376 num_bins: n,
377 degrees_of_freedom: df,
378 statistic: chi_squared,
379 p_value,
380 significance_level: self.significance_level,
381 passes,
382 critical_value,
383 bin_frequencies,
384 cramers_v,
385 issues,
386 })
387 }
388}
389
390impl Default for ChiSquaredAnalyzer {
391 fn default() -> Self {
392 Self::new()
393 }
394}
395
396fn chi_squared_p_value(chi_sq: f64, df: usize) -> f64 {
398 1.0 - lower_incomplete_gamma(df as f64 / 2.0, chi_sq / 2.0)
401}
402
403fn chi_squared_critical(df: usize, alpha: f64) -> f64 {
405 if df == 0 {
410 return 0.0;
411 }
412
413 let df_f = df as f64;
414
415 let z = normal_quantile(1.0 - alpha);
417
418 let term = 2.0 / (9.0 * df_f);
420 let inner = 1.0 - term + z * term.sqrt();
421
422 df_f * inner.powi(3).max(0.0)
423}
424
425fn lower_incomplete_gamma(a: f64, x: f64) -> f64 {
427 if x <= 0.0 {
428 return 0.0;
429 }
430 if x >= a + 1.0 {
431 1.0 - upper_incomplete_gamma_cf(a, x)
433 } else {
434 lower_incomplete_gamma_series(a, x)
436 }
437}
438
439fn lower_incomplete_gamma_series(a: f64, x: f64) -> f64 {
441 let ln_gamma_a = ln_gamma(a);
442 let mut sum = 1.0 / a;
443 let mut term = 1.0 / a;
444
445 for n in 1..200 {
446 term *= x / (a + n as f64);
447 sum += term;
448 if term.abs() < 1e-10 * sum.abs() {
449 break;
450 }
451 }
452
453 sum * x.powf(a) * (-x).exp() / ln_gamma_a.exp()
454}
455
456fn upper_incomplete_gamma_cf(a: f64, x: f64) -> f64 {
458 let ln_gamma_a = ln_gamma(a);
459
460 let mut f = 1e-30_f64;
462 let mut c = 1e-30_f64;
463 let mut d = 0.0_f64;
464
465 for i in 1..200 {
466 let i_f = i as f64;
467 let an = if i == 1 {
468 1.0
469 } else if i % 2 == 0 {
470 (i_f / 2.0 - 1.0) - a + 1.0
471 } else {
472 (i_f - 1.0) / 2.0
473 };
474 let bn = if i == 1 { x - a + 1.0 } else { x - a + i_f };
475
476 d = bn + an * d;
477 if d.abs() < 1e-30 {
478 d = 1e-30;
479 }
480 c = bn + an / c;
481 if c.abs() < 1e-30 {
482 c = 1e-30;
483 }
484 d = 1.0 / d;
485 let delta = c * d;
486 f *= delta;
487
488 if (delta - 1.0).abs() < 1e-10 {
489 break;
490 }
491 }
492
493 f * x.powf(a) * (-x).exp() / ln_gamma_a.exp()
494}
495
496fn ln_gamma(x: f64) -> f64 {
498 if x <= 0.0 {
499 return f64::INFINITY;
500 }
501 let coeffs = [
503 76.18009172947146,
504 -86.50532032941677,
505 24.01409824083091,
506 -1.231739572450155,
507 0.1208650973866179e-2,
508 -0.5395239384953e-5,
509 ];
510
511 let tmp = x + 5.5;
512 let tmp = tmp - (x + 0.5) * tmp.ln();
513
514 let mut ser = 1.000000000190015;
515 for (i, &c) in coeffs.iter().enumerate() {
516 ser += c / (x + (i + 1) as f64);
517 }
518
519 -tmp + (2.5066282746310005 * ser / x).ln()
520}
521
522fn normal_quantile(p: f64) -> f64 {
524 if p <= 0.0 {
525 return f64::NEG_INFINITY;
526 }
527 if p >= 1.0 {
528 return f64::INFINITY;
529 }
530 if p == 0.5 {
531 return 0.0;
532 }
533
534 let t = if p < 0.5 {
536 (-2.0 * p.ln()).sqrt()
537 } else {
538 (-2.0 * (1.0 - p).ln()).sqrt()
539 };
540
541 let c0 = 2.515517;
542 let c1 = 0.802853;
543 let c2 = 0.010328;
544 let d1 = 1.432788;
545 let d2 = 0.189269;
546 let d3 = 0.001308;
547
548 let z = t - (c0 + c1 * t + c2 * t * t) / (1.0 + d1 * t + d2 * t * t + d3 * t * t * t);
549
550 if p < 0.5 {
551 -z
552 } else {
553 z
554 }
555}
556
557#[cfg(test)]
558mod tests {
559 use super::*;
560 use rand::SeedableRng;
561 use rand_chacha::ChaCha8Rng;
562 use rand_distr::{Distribution, Uniform};
563
564 #[test]
565 fn test_uniform_distribution() {
566 let mut rng = ChaCha8Rng::seed_from_u64(42);
568 let uniform = Uniform::new(0.0, 100.0);
569 let values: Vec<f64> = (0..1000).map(|_| uniform.sample(&mut rng)).collect();
570
571 let analyzer = ChiSquaredAnalyzer::new()
572 .with_binning(BinningStrategy::EqualWidth { num_bins: 10 })
573 .with_expected(ExpectedDistribution::Uniform)
574 .with_significance_level(0.05);
575
576 let result = analyzer.analyze_continuous(&values).unwrap();
577 assert!(
578 result.passes,
579 "Uniform data should pass uniform chi-squared test"
580 );
581 assert!(result.p_value > 0.05);
582 }
583
584 #[test]
585 fn test_categorical_uniform() {
586 let observed = vec![100, 98, 102, 100, 100]; let analyzer = ChiSquaredAnalyzer::new()
590 .with_expected(ExpectedDistribution::Uniform)
591 .with_significance_level(0.05);
592
593 let result = analyzer.analyze_categorical(&observed).unwrap();
594 assert!(result.passes, "Nearly uniform counts should pass");
595 }
596
597 #[test]
598 fn test_categorical_deviation() {
599 let observed = vec![400, 50, 25, 15, 10]; let analyzer = ChiSquaredAnalyzer::new()
603 .with_expected(ExpectedDistribution::Uniform)
604 .with_significance_level(0.05);
605
606 let result = analyzer.analyze_categorical(&observed).unwrap();
607 assert!(
608 !result.passes,
609 "Highly skewed counts should fail uniform test"
610 );
611 }
612
613 #[test]
614 fn test_custom_proportions() {
615 let observed = vec![300, 200, 100]; let expected_props = vec![0.50, 0.33, 0.17];
618
619 let analyzer = ChiSquaredAnalyzer::new()
620 .with_expected(ExpectedDistribution::Proportions(expected_props))
621 .with_significance_level(0.05);
622
623 let result = analyzer.analyze_categorical(&observed).unwrap();
624 assert!(result.sample_size == 600);
626 }
627
628 #[test]
629 fn test_binning_strategies() {
630 let mut rng = ChaCha8Rng::seed_from_u64(42);
631 let uniform = Uniform::new(0.0, 100.0);
632 let values: Vec<f64> = (0..500).map(|_| uniform.sample(&mut rng)).collect();
633
634 let analyzer1 =
636 ChiSquaredAnalyzer::new().with_binning(BinningStrategy::EqualWidth { num_bins: 10 });
637 let result1 = analyzer1.analyze_continuous(&values).unwrap();
638 assert_eq!(result1.num_bins, 10);
639
640 let analyzer2 =
642 ChiSquaredAnalyzer::new().with_binning(BinningStrategy::EqualFrequency { num_bins: 5 });
643 let result2 = analyzer2.analyze_continuous(&values).unwrap();
644 assert_eq!(result2.num_bins, 5);
645
646 let analyzer3 = ChiSquaredAnalyzer::new().with_binning(BinningStrategy::Auto);
648 let result3 = analyzer3.analyze_continuous(&values).unwrap();
649 assert!(result3.num_bins > 0);
650 }
651
652 #[test]
653 fn test_insufficient_data() {
654 let values = vec![1.0, 2.0, 3.0]; let analyzer = ChiSquaredAnalyzer::new();
657 let result = analyzer.analyze_continuous(&values);
658
659 assert!(matches!(
660 result,
661 Err(EvalError::InsufficientData {
662 required: 10,
663 actual: 3
664 })
665 ));
666 }
667
668 #[test]
669 fn test_cramers_v() {
670 let observed = vec![500, 0, 0, 0, 0]; let analyzer = ChiSquaredAnalyzer::new()
674 .with_expected(ExpectedDistribution::Uniform)
675 .with_significance_level(0.05);
676
677 let result = analyzer.analyze_categorical(&observed).unwrap();
678 assert!(
679 result.cramers_v > 0.5,
680 "Strong deviation should have high V"
681 );
682 }
683
684 #[test]
685 fn test_bin_frequencies() {
686 let observed = vec![50, 100, 50];
687
688 let analyzer = ChiSquaredAnalyzer::new().with_expected(ExpectedDistribution::Uniform);
689
690 let result = analyzer.analyze_categorical(&observed).unwrap();
691
692 assert_eq!(result.bin_frequencies.len(), 3);
693
694 let first_bin = &result.bin_frequencies[0];
696 assert_eq!(first_bin.observed, 50);
697 assert!((first_bin.expected - 66.666).abs() < 0.01);
698 }
699
700 #[test]
701 fn test_critical_value_ordering() {
702 let cv_10 = chi_squared_critical(10, 0.10);
704 let cv_05 = chi_squared_critical(10, 0.05);
705 let cv_01 = chi_squared_critical(10, 0.01);
706
707 assert!(cv_10 < cv_05);
708 assert!(cv_05 < cv_01);
709 }
710
711 #[test]
712 fn test_p_value_range() {
713 let p1 = chi_squared_p_value(0.0, 5);
715 let p2 = chi_squared_p_value(5.0, 5);
716 let p3 = chi_squared_p_value(50.0, 5);
717
718 assert!((0.0..=1.0).contains(&p1));
719 assert!((0.0..=1.0).contains(&p2));
720 assert!((0.0..=1.0).contains(&p3));
721
722 assert!(p1 > p2);
724 assert!(p2 > p3);
725 }
726}