1use crate::error::{EvalError, EvalResult};
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Serialize, Deserialize, Default)]
11pub enum BinningStrategy {
12 EqualWidth { num_bins: usize },
14 EqualFrequency { num_bins: usize },
16 Custom { edges: Vec<f64> },
18 #[default]
20 Auto,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct BinFrequency {
26 pub index: usize,
28 pub lower: f64,
30 pub upper: f64,
32 pub observed: usize,
34 pub expected: f64,
36 pub contribution: f64,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct ChiSquaredAnalysis {
43 pub sample_size: usize,
45 pub num_bins: usize,
47 pub degrees_of_freedom: usize,
49 pub statistic: f64,
51 pub p_value: f64,
53 pub significance_level: f64,
55 pub passes: bool,
57 pub critical_value: f64,
59 pub bin_frequencies: Vec<BinFrequency>,
61 pub cramers_v: f64,
63 pub issues: Vec<String>,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize, Default)]
69pub enum ExpectedDistribution {
70 #[default]
72 Uniform,
73 Custom(Vec<f64>),
75 Proportions(Vec<f64>),
77 Observed(Vec<usize>),
79}
80
81pub struct ChiSquaredAnalyzer {
83 binning: BinningStrategy,
85 expected: ExpectedDistribution,
87 significance_level: f64,
89 min_expected: f64,
91}
92
93impl ChiSquaredAnalyzer {
94 pub fn new() -> Self {
96 Self {
97 binning: BinningStrategy::Auto,
98 expected: ExpectedDistribution::Uniform,
99 significance_level: 0.05,
100 min_expected: 5.0,
101 }
102 }
103
104 pub fn with_binning(mut self, strategy: BinningStrategy) -> Self {
106 self.binning = strategy;
107 self
108 }
109
110 pub fn with_expected(mut self, expected: ExpectedDistribution) -> Self {
112 self.expected = expected;
113 self
114 }
115
116 pub fn with_significance_level(mut self, level: f64) -> Self {
118 self.significance_level = level;
119 self
120 }
121
122 pub fn with_min_expected(mut self, min: f64) -> Self {
124 self.min_expected = min;
125 self
126 }
127
128 pub fn analyze_continuous(&self, values: &[f64]) -> EvalResult<ChiSquaredAnalysis> {
130 let n = values.len();
131 if n < 10 {
132 return Err(EvalError::InsufficientData {
133 required: 10,
134 actual: n,
135 });
136 }
137
138 let valid_values: Vec<f64> = values.iter().filter(|&&v| v.is_finite()).copied().collect();
140
141 if valid_values.len() < 10 {
142 return Err(EvalError::InsufficientData {
143 required: 10,
144 actual: valid_values.len(),
145 });
146 }
147
148 let (edges, observed) = self.bin_data(&valid_values)?;
150 let n_f = valid_values.len() as f64;
151
152 let expected = self.calculate_expected(&observed, n_f)?;
154
155 self.perform_test(&edges, &observed, &expected)
156 }
157
158 pub fn analyze_categorical(&self, observed: &[usize]) -> EvalResult<ChiSquaredAnalysis> {
160 if observed.is_empty() {
161 return Err(EvalError::InvalidParameter(
162 "Observed counts cannot be empty".to_string(),
163 ));
164 }
165
166 let total: usize = observed.iter().sum();
167 if total < 10 {
168 return Err(EvalError::InsufficientData {
169 required: 10,
170 actual: total,
171 });
172 }
173
174 let n_f = total as f64;
175
176 let edges: Vec<f64> = (0..=observed.len()).map(|i| i as f64).collect();
178
179 let expected = self.calculate_expected(observed, n_f)?;
181
182 self.perform_test(&edges, observed, &expected)
183 }
184
185 fn bin_data(&self, values: &[f64]) -> EvalResult<(Vec<f64>, Vec<usize>)> {
187 let mut sorted: Vec<f64> = values.to_vec();
188 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
189
190 let min = sorted[0];
191 let max = sorted[sorted.len() - 1];
192
193 let edges = match &self.binning {
194 BinningStrategy::EqualWidth { num_bins } => {
195 let width = (max - min) / (*num_bins as f64);
196 (0..=*num_bins).map(|i| min + (i as f64) * width).collect()
197 }
198 BinningStrategy::EqualFrequency { num_bins } => {
199 let n = sorted.len();
200 let mut edges = vec![min];
201 for i in 1..*num_bins {
202 let idx = (i * n) / *num_bins;
203 edges.push(sorted[idx.min(n - 1)]);
204 }
205 edges.push(max);
206 edges
207 }
208 BinningStrategy::Custom { edges } => edges.clone(),
209 BinningStrategy::Auto => {
210 let num_bins = (1.0 + (values.len() as f64).log2()).ceil() as usize;
212 let width = (max - min) / (num_bins as f64);
213 (0..=num_bins).map(|i| min + (i as f64) * width).collect()
214 }
215 };
216
217 if edges.len() < 2 {
218 return Err(EvalError::InvalidParameter(
219 "Need at least 2 bin edges".to_string(),
220 ));
221 }
222
223 let num_bins = edges.len() - 1;
225 let mut counts = vec![0usize; num_bins];
226
227 for &v in values {
228 for (i, window) in edges.windows(2).enumerate() {
229 let (lower, upper) = (window[0], window[1]);
230 if v >= lower && (v < upper || (i == num_bins - 1 && v <= upper)) {
231 counts[i] += 1;
232 break;
233 }
234 }
235 }
236
237 Ok((edges, counts))
238 }
239
240 fn calculate_expected(&self, observed: &[usize], total: f64) -> EvalResult<Vec<f64>> {
242 match &self.expected {
243 ExpectedDistribution::Uniform => {
244 let expected_per_bin = total / (observed.len() as f64);
245 Ok(vec![expected_per_bin; observed.len()])
246 }
247 ExpectedDistribution::Custom(expected) => {
248 if expected.len() != observed.len() {
249 return Err(EvalError::InvalidParameter(format!(
250 "Expected {} frequencies, got {}",
251 observed.len(),
252 expected.len()
253 )));
254 }
255 Ok(expected.clone())
256 }
257 ExpectedDistribution::Proportions(props) => {
258 if props.len() != observed.len() {
259 return Err(EvalError::InvalidParameter(format!(
260 "Expected {} proportions, got {}",
261 observed.len(),
262 props.len()
263 )));
264 }
265 let sum: f64 = props.iter().sum();
266 if (sum - 1.0).abs() > 0.01 {
267 return Err(EvalError::InvalidParameter(format!(
268 "Proportions must sum to 1.0, got {}",
269 sum
270 )));
271 }
272 Ok(props.iter().map(|&p| p * total).collect())
273 }
274 ExpectedDistribution::Observed(other) => {
275 if other.len() != observed.len() {
276 return Err(EvalError::InvalidParameter(format!(
277 "Expected {} categories, got {}",
278 observed.len(),
279 other.len()
280 )));
281 }
282 let other_total: f64 = other.iter().sum::<usize>() as f64;
283 Ok(other
284 .iter()
285 .map(|&c| (c as f64) / other_total * total)
286 .collect())
287 }
288 }
289 }
290
291 fn perform_test(
293 &self,
294 edges: &[f64],
295 observed: &[usize],
296 expected: &[f64],
297 ) -> EvalResult<ChiSquaredAnalysis> {
298 let n = observed.len();
299 let total: usize = observed.iter().sum();
300 let n_f = total as f64;
301
302 let mut issues = Vec::new();
303
304 let low_expected: Vec<_> = expected
306 .iter()
307 .enumerate()
308 .filter(|(_, &e)| e < self.min_expected)
309 .collect();
310 if !low_expected.is_empty() {
311 issues.push(format!(
312 "{} bins have expected frequency < {:.1}; results may be unreliable",
313 low_expected.len(),
314 self.min_expected
315 ));
316 }
317
318 let mut chi_squared = 0.0;
320 let mut bin_frequencies = Vec::new();
321
322 for (i, ((&obs, &exp), window)) in observed
323 .iter()
324 .zip(expected.iter())
325 .zip(edges.windows(2))
326 .enumerate()
327 {
328 let contribution = if exp > 0.0 {
329 let diff = obs as f64 - exp;
330 (diff * diff) / exp
331 } else {
332 0.0
333 };
334 chi_squared += contribution;
335
336 bin_frequencies.push(BinFrequency {
337 index: i,
338 lower: window[0],
339 upper: window[1],
340 observed: obs,
341 expected: exp,
342 contribution,
343 });
344 }
345
346 let df = n.saturating_sub(1);
350 if df == 0 {
351 return Err(EvalError::InvalidParameter(
352 "Need at least 2 bins for chi-squared test".to_string(),
353 ));
354 }
355
356 let p_value = chi_squared_p_value(chi_squared, df);
358
359 let critical_value = chi_squared_critical(df, self.significance_level);
361
362 let cramers_v = (chi_squared / n_f).sqrt();
364
365 let passes = chi_squared <= critical_value;
366
367 if !passes {
368 issues.push(format!(
369 "χ² = {:.4} exceeds critical value {:.4} at α = {:.2}",
370 chi_squared, critical_value, self.significance_level
371 ));
372 }
373
374 Ok(ChiSquaredAnalysis {
375 sample_size: total,
376 num_bins: n,
377 degrees_of_freedom: df,
378 statistic: chi_squared,
379 p_value,
380 significance_level: self.significance_level,
381 passes,
382 critical_value,
383 bin_frequencies,
384 cramers_v,
385 issues,
386 })
387 }
388}
389
390impl Default for ChiSquaredAnalyzer {
391 fn default() -> Self {
392 Self::new()
393 }
394}
395
396fn chi_squared_p_value(chi_sq: f64, df: usize) -> f64 {
398 1.0 - lower_incomplete_gamma(df as f64 / 2.0, chi_sq / 2.0)
401}
402
403fn chi_squared_critical(df: usize, alpha: f64) -> f64 {
405 if df == 0 {
410 return 0.0;
411 }
412
413 let df_f = df as f64;
414
415 let z = normal_quantile(1.0 - alpha);
417
418 let term = 2.0 / (9.0 * df_f);
420 let inner = 1.0 - term + z * term.sqrt();
421
422 df_f * inner.powi(3).max(0.0)
423}
424
425fn lower_incomplete_gamma(a: f64, x: f64) -> f64 {
427 if x <= 0.0 {
428 return 0.0;
429 }
430 if x >= a + 1.0 {
431 1.0 - upper_incomplete_gamma_cf(a, x)
433 } else {
434 lower_incomplete_gamma_series(a, x)
436 }
437}
438
439fn lower_incomplete_gamma_series(a: f64, x: f64) -> f64 {
441 let ln_gamma_a = ln_gamma(a);
442 let mut sum = 1.0 / a;
443 let mut term = 1.0 / a;
444
445 for n in 1..200 {
446 term *= x / (a + n as f64);
447 sum += term;
448 if term.abs() < 1e-10 * sum.abs() {
449 break;
450 }
451 }
452
453 sum * x.powf(a) * (-x).exp() / ln_gamma_a.exp()
454}
455
456fn upper_incomplete_gamma_cf(a: f64, x: f64) -> f64 {
458 let ln_gamma_a = ln_gamma(a);
459
460 let mut f = 1e-30_f64;
462 let mut c = 1e-30_f64;
463 let mut d = 0.0_f64;
464
465 for i in 1..200 {
466 let i_f = i as f64;
467 let an = if i == 1 {
468 1.0
469 } else if i % 2 == 0 {
470 (i_f / 2.0 - 1.0) - a + 1.0
471 } else {
472 (i_f - 1.0) / 2.0
473 };
474 let bn = if i == 1 { x - a + 1.0 } else { x - a + i_f };
475
476 d = bn + an * d;
477 if d.abs() < 1e-30 {
478 d = 1e-30;
479 }
480 c = bn + an / c;
481 if c.abs() < 1e-30 {
482 c = 1e-30;
483 }
484 d = 1.0 / d;
485 let delta = c * d;
486 f *= delta;
487
488 if (delta - 1.0).abs() < 1e-10 {
489 break;
490 }
491 }
492
493 f * x.powf(a) * (-x).exp() / ln_gamma_a.exp()
494}
495
496fn ln_gamma(x: f64) -> f64 {
498 if x <= 0.0 {
499 return f64::INFINITY;
500 }
501 let coeffs = [
503 76.18009172947146,
504 -86.50532032941677,
505 24.01409824083091,
506 -1.231739572450155,
507 0.1208650973866179e-2,
508 -0.5395239384953e-5,
509 ];
510
511 let tmp = x + 5.5;
512 let tmp = tmp - (x + 0.5) * tmp.ln();
513
514 let mut ser = 1.000000000190015;
515 for (i, &c) in coeffs.iter().enumerate() {
516 ser += c / (x + (i + 1) as f64);
517 }
518
519 -tmp + (2.5066282746310005 * ser / x).ln()
520}
521
522fn normal_quantile(p: f64) -> f64 {
524 if p <= 0.0 {
525 return f64::NEG_INFINITY;
526 }
527 if p >= 1.0 {
528 return f64::INFINITY;
529 }
530 if p == 0.5 {
531 return 0.0;
532 }
533
534 let t = if p < 0.5 {
536 (-2.0 * p.ln()).sqrt()
537 } else {
538 (-2.0 * (1.0 - p).ln()).sqrt()
539 };
540
541 let c0 = 2.515517;
542 let c1 = 0.802853;
543 let c2 = 0.010328;
544 let d1 = 1.432788;
545 let d2 = 0.189269;
546 let d3 = 0.001308;
547
548 let z = t - (c0 + c1 * t + c2 * t * t) / (1.0 + d1 * t + d2 * t * t + d3 * t * t * t);
549
550 if p < 0.5 {
551 -z
552 } else {
553 z
554 }
555}
556
557#[cfg(test)]
558#[allow(clippy::unwrap_used)]
559mod tests {
560 use super::*;
561 use rand::SeedableRng;
562 use rand_chacha::ChaCha8Rng;
563 use rand_distr::{Distribution, Uniform};
564
565 #[test]
566 fn test_uniform_distribution() {
567 let mut rng = ChaCha8Rng::seed_from_u64(42);
569 let uniform = Uniform::new(0.0, 100.0);
570 let values: Vec<f64> = (0..1000).map(|_| uniform.sample(&mut rng)).collect();
571
572 let analyzer = ChiSquaredAnalyzer::new()
573 .with_binning(BinningStrategy::EqualWidth { num_bins: 10 })
574 .with_expected(ExpectedDistribution::Uniform)
575 .with_significance_level(0.05);
576
577 let result = analyzer.analyze_continuous(&values).unwrap();
578 assert!(
579 result.passes,
580 "Uniform data should pass uniform chi-squared test"
581 );
582 assert!(result.p_value > 0.05);
583 }
584
585 #[test]
586 fn test_categorical_uniform() {
587 let observed = vec![100, 98, 102, 100, 100]; let analyzer = ChiSquaredAnalyzer::new()
591 .with_expected(ExpectedDistribution::Uniform)
592 .with_significance_level(0.05);
593
594 let result = analyzer.analyze_categorical(&observed).unwrap();
595 assert!(result.passes, "Nearly uniform counts should pass");
596 }
597
598 #[test]
599 fn test_categorical_deviation() {
600 let observed = vec![400, 50, 25, 15, 10]; let analyzer = ChiSquaredAnalyzer::new()
604 .with_expected(ExpectedDistribution::Uniform)
605 .with_significance_level(0.05);
606
607 let result = analyzer.analyze_categorical(&observed).unwrap();
608 assert!(
609 !result.passes,
610 "Highly skewed counts should fail uniform test"
611 );
612 }
613
614 #[test]
615 fn test_custom_proportions() {
616 let observed = vec![300, 200, 100]; let expected_props = vec![0.50, 0.33, 0.17];
619
620 let analyzer = ChiSquaredAnalyzer::new()
621 .with_expected(ExpectedDistribution::Proportions(expected_props))
622 .with_significance_level(0.05);
623
624 let result = analyzer.analyze_categorical(&observed).unwrap();
625 assert!(result.sample_size == 600);
627 }
628
629 #[test]
630 fn test_binning_strategies() {
631 let mut rng = ChaCha8Rng::seed_from_u64(42);
632 let uniform = Uniform::new(0.0, 100.0);
633 let values: Vec<f64> = (0..500).map(|_| uniform.sample(&mut rng)).collect();
634
635 let analyzer1 =
637 ChiSquaredAnalyzer::new().with_binning(BinningStrategy::EqualWidth { num_bins: 10 });
638 let result1 = analyzer1.analyze_continuous(&values).unwrap();
639 assert_eq!(result1.num_bins, 10);
640
641 let analyzer2 =
643 ChiSquaredAnalyzer::new().with_binning(BinningStrategy::EqualFrequency { num_bins: 5 });
644 let result2 = analyzer2.analyze_continuous(&values).unwrap();
645 assert_eq!(result2.num_bins, 5);
646
647 let analyzer3 = ChiSquaredAnalyzer::new().with_binning(BinningStrategy::Auto);
649 let result3 = analyzer3.analyze_continuous(&values).unwrap();
650 assert!(result3.num_bins > 0);
651 }
652
653 #[test]
654 fn test_insufficient_data() {
655 let values = vec![1.0, 2.0, 3.0]; let analyzer = ChiSquaredAnalyzer::new();
658 let result = analyzer.analyze_continuous(&values);
659
660 assert!(matches!(
661 result,
662 Err(EvalError::InsufficientData {
663 required: 10,
664 actual: 3
665 })
666 ));
667 }
668
669 #[test]
670 fn test_cramers_v() {
671 let observed = vec![500, 0, 0, 0, 0]; let analyzer = ChiSquaredAnalyzer::new()
675 .with_expected(ExpectedDistribution::Uniform)
676 .with_significance_level(0.05);
677
678 let result = analyzer.analyze_categorical(&observed).unwrap();
679 assert!(
680 result.cramers_v > 0.5,
681 "Strong deviation should have high V"
682 );
683 }
684
685 #[test]
686 fn test_bin_frequencies() {
687 let observed = vec![50, 100, 50];
688
689 let analyzer = ChiSquaredAnalyzer::new().with_expected(ExpectedDistribution::Uniform);
690
691 let result = analyzer.analyze_categorical(&observed).unwrap();
692
693 assert_eq!(result.bin_frequencies.len(), 3);
694
695 let first_bin = &result.bin_frequencies[0];
697 assert_eq!(first_bin.observed, 50);
698 assert!((first_bin.expected - 66.666).abs() < 0.01);
699 }
700
701 #[test]
702 fn test_critical_value_ordering() {
703 let cv_10 = chi_squared_critical(10, 0.10);
705 let cv_05 = chi_squared_critical(10, 0.05);
706 let cv_01 = chi_squared_critical(10, 0.01);
707
708 assert!(cv_10 < cv_05);
709 assert!(cv_05 < cv_01);
710 }
711
712 #[test]
713 fn test_p_value_range() {
714 let p1 = chi_squared_p_value(0.0, 5);
716 let p2 = chi_squared_p_value(5.0, 5);
717 let p3 = chi_squared_p_value(50.0, 5);
718
719 assert!((0.0..=1.0).contains(&p1));
720 assert!((0.0..=1.0).contains(&p2));
721 assert!((0.0..=1.0).contains(&p3));
722
723 assert!(p1 > p2);
725 assert!(p2 > p3);
726 }
727}