1use rust_decimal::prelude::ToPrimitive;
18use rust_decimal::Decimal;
19use serde::{Deserialize, Serialize};
20
21use super::benford::get_first_digit;
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
25#[serde(rename_all = "snake_case")]
26pub enum TestOutcome {
27 Passed,
29 Warning,
31 Failed,
33 Skipped,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct StatisticalTestResult {
40 pub name: String,
42 pub outcome: TestOutcome,
44 pub statistic: f64,
46 pub threshold: f64,
48 pub message: String,
50}
51
52#[derive(Debug, Clone, Default, Serialize, Deserialize)]
54pub struct StatisticalValidationReport {
55 pub sample_count: usize,
57 pub results: Vec<StatisticalTestResult>,
59}
60
61impl StatisticalValidationReport {
62 pub fn all_passed(&self) -> bool {
64 self.results
65 .iter()
66 .all(|r| !matches!(r.outcome, TestOutcome::Failed))
67 }
68
69 pub fn has_warnings(&self) -> bool {
71 self.results
72 .iter()
73 .any(|r| matches!(r.outcome, TestOutcome::Warning))
74 }
75
76 pub fn failed_names(&self) -> Vec<String> {
78 self.results
79 .iter()
80 .filter(|r| matches!(r.outcome, TestOutcome::Failed))
81 .map(|r| r.name.clone())
82 .collect()
83 }
84}
85
86pub fn run_benford_first_digit(
93 amounts: &[Decimal],
94 threshold_mad: f64,
95 warning_mad: f64,
96) -> StatisticalTestResult {
97 let mut counts = [0u32; 10]; let mut total = 0u32;
99 for amount in amounts {
100 if let Some(d) = get_first_digit(*amount) {
101 counts[d as usize] += 1;
102 total += 1;
103 }
104 }
105
106 if total < 100 {
107 return StatisticalTestResult {
108 name: "benford_first_digit".to_string(),
109 outcome: TestOutcome::Skipped,
110 statistic: 0.0,
111 threshold: threshold_mad,
112 message: format!("only {total} samples with valid first digit; need ≥100"),
113 };
114 }
115
116 const EXPECTED: [f64; 10] = [
119 0.0,
120 std::f64::consts::LOG10_2, 0.17609125905568124, 0.12493873660829995,
123 0.09691001300805642,
124 0.07918124604762482,
125 0.06694678963061322,
126 0.057991946977686726,
127 0.05115252244738129,
128 0.04575749056067514,
129 ];
130
131 let total_f = total as f64;
132 let mad: f64 = (1..=9)
133 .map(|d| (counts[d] as f64 / total_f - EXPECTED[d]).abs())
134 .sum::<f64>()
135 / 9.0;
136
137 let outcome = if mad > threshold_mad {
138 TestOutcome::Failed
139 } else if mad > warning_mad {
140 TestOutcome::Warning
141 } else {
142 TestOutcome::Passed
143 };
144
145 StatisticalTestResult {
146 name: "benford_first_digit".to_string(),
147 outcome,
148 statistic: mad,
149 threshold: threshold_mad,
150 message: format!(
151 "MAD={mad:.4} over {total} first digits (threshold={threshold_mad:.4}, warn={warning_mad:.4})"
152 ),
153 }
154}
155
156pub fn run_chi_squared(
164 amounts: &[Decimal],
165 bins: usize,
166 significance: f64,
167) -> StatisticalTestResult {
168 if amounts.len() < 100 {
169 return StatisticalTestResult {
170 name: "chi_squared".to_string(),
171 outcome: TestOutcome::Skipped,
172 statistic: 0.0,
173 threshold: 0.0,
174 message: format!("only {} samples; need ≥100", amounts.len()),
175 };
176 }
177
178 let bins = bins.max(2);
179 let positives: Vec<f64> = amounts
180 .iter()
181 .filter_map(|a| a.to_f64())
182 .filter(|v| *v > 0.0)
183 .collect();
184 if positives.len() < 100 {
185 return StatisticalTestResult {
186 name: "chi_squared".to_string(),
187 outcome: TestOutcome::Skipped,
188 statistic: 0.0,
189 threshold: 0.0,
190 message: format!("only {} positive samples; need ≥100", positives.len()),
191 };
192 }
193
194 let logs: Vec<f64> = positives.iter().map(|v| v.ln()).collect();
195 let min = logs.iter().cloned().fold(f64::INFINITY, f64::min);
196 let max = logs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
197 if !min.is_finite() || !max.is_finite() || max <= min {
198 return StatisticalTestResult {
199 name: "chi_squared".to_string(),
200 outcome: TestOutcome::Skipped,
201 statistic: 0.0,
202 threshold: 0.0,
203 message: "degenerate log-range".to_string(),
204 };
205 }
206
207 let bin_width = (max - min) / bins as f64;
208 let mut observed = vec![0u32; bins];
209 for v in &logs {
210 let idx = (((v - min) / bin_width) as usize).min(bins - 1);
211 observed[idx] += 1;
212 }
213
214 let n = logs.len() as f64;
215 let expected_per_bin = n / bins as f64;
216 let chi_sq: f64 = observed
217 .iter()
218 .map(|o| {
219 let diff = *o as f64 - expected_per_bin;
220 diff * diff / expected_per_bin
221 })
222 .sum();
223
224 let df = bins - 1;
229 let critical = chi_sq_critical(df, significance);
230
231 let outcome = if chi_sq > critical {
232 TestOutcome::Failed
233 } else {
234 TestOutcome::Passed
235 };
236
237 StatisticalTestResult {
238 name: "chi_squared".to_string(),
239 outcome,
240 statistic: chi_sq,
241 threshold: critical,
242 message: format!(
243 "χ²={chi_sq:.2} over {bins} log-bins ({n} samples), critical={critical:.2} at α={significance}"
244 ),
245 }
246}
247
248pub fn run_ks_uniform_log(amounts: &[Decimal], significance: f64) -> StatisticalTestResult {
256 let positives: Vec<f64> = amounts
257 .iter()
258 .filter_map(|a| a.to_f64())
259 .filter(|v| *v > 0.0)
260 .collect();
261 if positives.len() < 100 {
262 return StatisticalTestResult {
263 name: "ks_uniform_log".to_string(),
264 outcome: TestOutcome::Skipped,
265 statistic: 0.0,
266 threshold: 0.0,
267 message: format!("only {} positive samples; need ≥100", positives.len()),
268 };
269 }
270
271 let mut logs: Vec<f64> = positives.iter().map(|v| v.ln()).collect();
272 logs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
273 let min = logs[0];
274 let max = logs[logs.len() - 1];
275 if max <= min {
276 return StatisticalTestResult {
277 name: "ks_uniform_log".to_string(),
278 outcome: TestOutcome::Skipped,
279 statistic: 0.0,
280 threshold: 0.0,
281 message: "degenerate log-range".to_string(),
282 };
283 }
284
285 let n = logs.len() as f64;
286 let mut max_diff: f64 = 0.0;
287 for (i, v) in logs.iter().enumerate() {
288 let empirical = (i as f64 + 1.0) / n;
289 let uniform = (v - min) / (max - min);
290 let diff = (empirical - uniform).abs();
291 if diff > max_diff {
292 max_diff = diff;
293 }
294 }
295
296 let c = if significance <= 0.011 {
300 1.628
301 } else if significance <= 0.051 {
302 1.358
303 } else {
304 1.224
305 };
306 let critical = c / n.sqrt();
307
308 let outcome = if max_diff > critical {
309 TestOutcome::Failed
310 } else {
311 TestOutcome::Passed
312 };
313
314 StatisticalTestResult {
315 name: "ks_uniform_log".to_string(),
316 outcome,
317 statistic: max_diff,
318 threshold: critical,
319 message: format!(
320 "D={max_diff:.4} over {n} samples, critical={critical:.4} at α={significance}"
321 ),
322 }
323}
324
325fn chi_sq_critical(df: usize, alpha: f64) -> f64 {
329 let table: &[(usize, f64, f64, f64)] = &[
331 (1, 2.706, 3.841, 6.635),
332 (2, 4.605, 5.991, 9.210),
333 (3, 6.251, 7.815, 11.345),
334 (4, 7.779, 9.488, 13.277),
335 (5, 9.236, 11.070, 15.086),
336 (6, 10.645, 12.592, 16.812),
337 (7, 12.017, 14.067, 18.475),
338 (8, 13.362, 15.507, 20.090),
339 (9, 14.684, 16.919, 21.666),
340 (10, 15.987, 18.307, 23.209),
341 (14, 21.064, 23.685, 29.141),
342 (19, 27.204, 30.144, 36.191),
343 (24, 33.196, 36.415, 42.980),
344 (29, 39.087, 42.557, 49.588),
345 ];
346
347 let row = table
348 .iter()
349 .min_by_key(|(d, _, _, _)| (*d as i64 - df as i64).unsigned_abs());
350 if let Some(&(_, c_10, c_05, c_01)) = row {
351 if alpha <= 0.011 {
352 c_01
353 } else if alpha <= 0.051 {
354 c_05
355 } else {
356 c_10
357 }
358 } else {
359 1_000_000.0
361 }
362}
363
364#[cfg(test)]
365#[allow(clippy::unwrap_used)]
366mod tests {
367 use super::*;
368 use rand::SeedableRng;
369 use rand_chacha::ChaCha8Rng;
370 use rand_distr::{Distribution, LogNormal};
371
372 fn lognormal_samples(n: usize, mu: f64, sigma: f64, seed: u64) -> Vec<Decimal> {
373 let mut rng = ChaCha8Rng::seed_from_u64(seed);
374 let ln = LogNormal::new(mu, sigma).unwrap();
375 (0..n)
376 .map(|_| Decimal::from_f64_retain(ln.sample(&mut rng)).unwrap_or(Decimal::ONE))
377 .collect()
378 }
379
380 #[test]
381 fn benford_passes_for_lognormal() {
382 let samples = lognormal_samples(2000, 7.0, 2.0, 42);
383 let r = run_benford_first_digit(&samples, 0.015, 0.010);
384 assert!(
385 !matches!(r.outcome, TestOutcome::Failed),
386 "expected pass/warning, got {:?}: {}",
387 r.outcome,
388 r.message
389 );
390 }
391
392 #[test]
393 fn benford_fails_for_concentrated_single_digit() {
394 let samples: Vec<Decimal> = (0..500).map(|i| Decimal::from(5000 + i)).collect();
396 let r = run_benford_first_digit(&samples, 0.015, 0.010);
397 assert!(matches!(r.outcome, TestOutcome::Failed));
398 }
399
400 #[test]
401 fn benford_skipped_below_100_samples() {
402 let samples: Vec<Decimal> = (0..50).map(Decimal::from).collect();
403 let r = run_benford_first_digit(&samples, 0.015, 0.010);
404 assert!(matches!(r.outcome, TestOutcome::Skipped));
405 }
406
407 #[test]
408 fn chi_squared_passes_for_log_uniform() {
409 let samples: Vec<Decimal> = (0..1000)
413 .map(|i| {
414 let log_val = (i as f64 / 1000.0) * 10.0;
416 let v = log_val.exp();
417 Decimal::from_f64_retain(v).unwrap_or(Decimal::ONE)
418 })
419 .collect();
420 let r = run_chi_squared(&samples, 10, 0.05);
421 assert!(
422 !matches!(r.outcome, TestOutcome::Failed),
423 "expected pass, got {:?}: {}",
424 r.outcome,
425 r.message
426 );
427 }
428
429 #[test]
430 fn chi_squared_fails_for_bimodal_concentration() {
431 let mut samples: Vec<Decimal> = (0..450).map(|_| Decimal::from(1000)).collect();
434 samples.extend((0..50).map(|_| Decimal::from(1_000_000)));
435 let r = run_chi_squared(&samples, 10, 0.05);
436 assert!(
437 matches!(r.outcome, TestOutcome::Failed),
438 "expected Failed for bimodal, got {:?}: {}",
439 r.outcome,
440 r.message
441 );
442 }
443
444 #[test]
445 fn report_all_passed_tracks_failures() {
446 let rep = StatisticalValidationReport {
447 sample_count: 100,
448 results: vec![
449 StatisticalTestResult {
450 name: "a".into(),
451 outcome: TestOutcome::Passed,
452 statistic: 0.0,
453 threshold: 1.0,
454 message: "".into(),
455 },
456 StatisticalTestResult {
457 name: "b".into(),
458 outcome: TestOutcome::Warning,
459 statistic: 0.0,
460 threshold: 1.0,
461 message: "".into(),
462 },
463 ],
464 };
465 assert!(rep.all_passed()); assert!(rep.has_warnings());
467
468 let rep_failed = StatisticalValidationReport {
469 sample_count: 100,
470 results: vec![StatisticalTestResult {
471 name: "c".into(),
472 outcome: TestOutcome::Failed,
473 statistic: 2.0,
474 threshold: 1.0,
475 message: "".into(),
476 }],
477 };
478 assert!(!rep_failed.all_passed());
479 assert_eq!(rep_failed.failed_names(), vec!["c".to_string()]);
480 }
481}