Skip to main content

datasynth_core/distributions/
validation.rs

1//! Statistical validation runner for generated amount distributions
2//! (v3.5.1+).
3//!
4//! Executes the tests declared in
5//! `distributions.validation.tests` (schema-side [`StatisticalTest
6//! Config`](../../../../datasynth-config/src/schema.rs)) against a slice
7//! of sampled amounts and emits a [`StatisticalValidationReport`]
8//! summarising which tests passed, warned, or failed.
9//!
10//! This module deliberately keeps the surface minimal: the schema already
11//! has richer test types (Anderson-Darling, correlation check) that will
12//! land in follow-up releases. v3.5.1 implements Benford first-digit,
13//! chi-squared goodness-of-fit, and a lightweight Kolmogorov-Smirnov
14//! distribution-fit check — enough to catch the most common realism
15//! regressions without pulling in a heavyweight stats dependency.
16
17use rust_decimal::prelude::ToPrimitive;
18use rust_decimal::Decimal;
19use serde::{Deserialize, Serialize};
20
21use super::benford::get_first_digit;
22
23/// Outcome of a single statistical test.
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
25#[serde(rename_all = "snake_case")]
26pub enum TestOutcome {
27    /// Test passed all thresholds.
28    Passed,
29    /// Test passed the hard threshold but exceeded a warning band.
30    Warning,
31    /// Test failed the hard threshold.
32    Failed,
33    /// Test was skipped (e.g. too few samples, not-yet-implemented variant).
34    Skipped,
35}
36
37/// Result of running a single statistical test.
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct StatisticalTestResult {
40    /// Human-readable test name (e.g. "benford_first_digit").
41    pub name: String,
42    /// Test outcome.
43    pub outcome: TestOutcome,
44    /// Key measured statistic (e.g. MAD for Benford, chi-squared value).
45    pub statistic: f64,
46    /// Threshold compared against (typically the hard-fail threshold).
47    pub threshold: f64,
48    /// One-line human description of what happened.
49    pub message: String,
50}
51
52/// Aggregate report covering every test run.
53#[derive(Debug, Clone, Default, Serialize, Deserialize)]
54pub struct StatisticalValidationReport {
55    /// Number of samples the report was computed over.
56    pub sample_count: usize,
57    /// Per-test results in input order.
58    pub results: Vec<StatisticalTestResult>,
59}
60
61impl StatisticalValidationReport {
62    /// Did every test pass? (Warnings do not count as failures.)
63    pub fn all_passed(&self) -> bool {
64        self.results
65            .iter()
66            .all(|r| !matches!(r.outcome, TestOutcome::Failed))
67    }
68
69    /// Is there at least one warning?
70    pub fn has_warnings(&self) -> bool {
71        self.results
72            .iter()
73            .any(|r| matches!(r.outcome, TestOutcome::Warning))
74    }
75
76    /// Collect all failed test names.
77    pub fn failed_names(&self) -> Vec<String> {
78        self.results
79            .iter()
80            .filter(|r| matches!(r.outcome, TestOutcome::Failed))
81            .map(|r| r.name.clone())
82            .collect()
83    }
84}
85
86/// Benford first-digit mean-absolute-deviation (MAD) test.
87///
88/// Returns a [`StatisticalTestResult`] where `statistic` is the MAD and
89/// `threshold` is the hard-fail threshold. `Warning` when MAD > warning
90/// threshold but <= hard threshold. `Skipped` when fewer than 100
91/// positive amounts are available (sample too small for stable MAD).
92pub fn run_benford_first_digit(
93    amounts: &[Decimal],
94    threshold_mad: f64,
95    warning_mad: f64,
96) -> StatisticalTestResult {
97    let mut counts = [0u32; 10]; // index 0 unused; 1..=9 used
98    let mut total = 0u32;
99    for amount in amounts {
100        if let Some(d) = get_first_digit(*amount) {
101            counts[d as usize] += 1;
102            total += 1;
103        }
104    }
105
106    if total < 100 {
107        return StatisticalTestResult {
108            name: "benford_first_digit".to_string(),
109            outcome: TestOutcome::Skipped,
110            statistic: 0.0,
111            threshold: threshold_mad,
112            message: format!("only {total} samples with valid first digit; need ≥100"),
113        };
114    }
115
116    // Expected Benford probability for digit d: log10(1 + 1/d).
117    // Index 0 is unused; values for d ∈ {1..=9}.
118    const EXPECTED: [f64; 10] = [
119        0.0,
120        std::f64::consts::LOG10_2, // log10(2)
121        0.17609125905568124,       // log10(3/2)
122        0.12493873660829995,
123        0.09691001300805642,
124        0.07918124604762482,
125        0.06694678963061322,
126        0.057991946977686726,
127        0.05115252244738129,
128        0.04575749056067514,
129    ];
130
131    let total_f = total as f64;
132    let mad: f64 = (1..=9)
133        .map(|d| (counts[d] as f64 / total_f - EXPECTED[d]).abs())
134        .sum::<f64>()
135        / 9.0;
136
137    let outcome = if mad > threshold_mad {
138        TestOutcome::Failed
139    } else if mad > warning_mad {
140        TestOutcome::Warning
141    } else {
142        TestOutcome::Passed
143    };
144
145    StatisticalTestResult {
146        name: "benford_first_digit".to_string(),
147        outcome,
148        statistic: mad,
149        threshold: threshold_mad,
150        message: format!(
151            "MAD={mad:.4} over {total} first digits (threshold={threshold_mad:.4}, warn={warning_mad:.4})"
152        ),
153    }
154}
155
156/// Chi-squared goodness-of-fit test against a uniform binning of
157/// log-scale amounts.
158///
159/// This is intentionally lightweight — it checks that amounts aren't
160/// concentrated in one log-bin (which would indicate a broken mixture
161/// or collapsed distribution). Hard fails when the chi-squared statistic
162/// exceeds the critical value at the configured significance.
163pub fn run_chi_squared(
164    amounts: &[Decimal],
165    bins: usize,
166    significance: f64,
167) -> StatisticalTestResult {
168    if amounts.len() < 100 {
169        return StatisticalTestResult {
170            name: "chi_squared".to_string(),
171            outcome: TestOutcome::Skipped,
172            statistic: 0.0,
173            threshold: 0.0,
174            message: format!("only {} samples; need ≥100", amounts.len()),
175        };
176    }
177
178    let bins = bins.max(2);
179    let positives: Vec<f64> = amounts
180        .iter()
181        .filter_map(|a| a.to_f64())
182        .filter(|v| *v > 0.0)
183        .collect();
184    if positives.len() < 100 {
185        return StatisticalTestResult {
186            name: "chi_squared".to_string(),
187            outcome: TestOutcome::Skipped,
188            statistic: 0.0,
189            threshold: 0.0,
190            message: format!("only {} positive samples; need ≥100", positives.len()),
191        };
192    }
193
194    let logs: Vec<f64> = positives.iter().map(|v| v.ln()).collect();
195    let min = logs.iter().cloned().fold(f64::INFINITY, f64::min);
196    let max = logs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
197    if !min.is_finite() || !max.is_finite() || max <= min {
198        return StatisticalTestResult {
199            name: "chi_squared".to_string(),
200            outcome: TestOutcome::Skipped,
201            statistic: 0.0,
202            threshold: 0.0,
203            message: "degenerate log-range".to_string(),
204        };
205    }
206
207    let bin_width = (max - min) / bins as f64;
208    let mut observed = vec![0u32; bins];
209    for v in &logs {
210        let idx = (((v - min) / bin_width) as usize).min(bins - 1);
211        observed[idx] += 1;
212    }
213
214    let n = logs.len() as f64;
215    let expected_per_bin = n / bins as f64;
216    let chi_sq: f64 = observed
217        .iter()
218        .map(|o| {
219            let diff = *o as f64 - expected_per_bin;
220            diff * diff / expected_per_bin
221        })
222        .sum();
223
224    // Approximate chi-squared critical value for df = bins - 1 at the
225    // configured significance. We ship hard-coded tables for the common
226    // cases (α ∈ {0.01, 0.05, 0.10}, df ∈ {4,5,6,7,8,9,10,14,19,24,29})
227    // and fall back to a generous ceiling otherwise.
228    let df = bins - 1;
229    let critical = chi_sq_critical(df, significance);
230
231    let outcome = if chi_sq > critical {
232        TestOutcome::Failed
233    } else {
234        TestOutcome::Passed
235    };
236
237    StatisticalTestResult {
238        name: "chi_squared".to_string(),
239        outcome,
240        statistic: chi_sq,
241        threshold: critical,
242        message: format!(
243            "χ²={chi_sq:.2} over {bins} log-bins ({n} samples), critical={critical:.2} at α={significance}"
244        ),
245    }
246}
247
248/// Kolmogorov-Smirnov goodness-of-fit against a uniform CDF on the
249/// log-scale of amounts.
250///
251/// This is the simplest version — compares the empirical log-amount CDF
252/// against a uniform CDF on `[min_log, max_log]`. Useful for detecting
253/// grossly skewed outputs; more sophisticated target-distribution fits
254/// (Normal/LogNormal/Exponential) ship in v3.5.2.
255pub fn run_ks_uniform_log(amounts: &[Decimal], significance: f64) -> StatisticalTestResult {
256    let positives: Vec<f64> = amounts
257        .iter()
258        .filter_map(|a| a.to_f64())
259        .filter(|v| *v > 0.0)
260        .collect();
261    if positives.len() < 100 {
262        return StatisticalTestResult {
263            name: "ks_uniform_log".to_string(),
264            outcome: TestOutcome::Skipped,
265            statistic: 0.0,
266            threshold: 0.0,
267            message: format!("only {} positive samples; need ≥100", positives.len()),
268        };
269    }
270
271    let mut logs: Vec<f64> = positives.iter().map(|v| v.ln()).collect();
272    logs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
273    let min = logs[0];
274    let max = logs[logs.len() - 1];
275    if max <= min {
276        return StatisticalTestResult {
277            name: "ks_uniform_log".to_string(),
278            outcome: TestOutcome::Skipped,
279            statistic: 0.0,
280            threshold: 0.0,
281            message: "degenerate log-range".to_string(),
282        };
283    }
284
285    let n = logs.len() as f64;
286    let mut max_diff: f64 = 0.0;
287    for (i, v) in logs.iter().enumerate() {
288        let empirical = (i as f64 + 1.0) / n;
289        let uniform = (v - min) / (max - min);
290        let diff = (empirical - uniform).abs();
291        if diff > max_diff {
292            max_diff = diff;
293        }
294    }
295
296    // Approximate KS critical value at large n (Kolmogorov):
297    //   D_α ≈ c(α) / sqrt(n)
298    // where c(0.05) ≈ 1.358, c(0.01) ≈ 1.628, c(0.10) ≈ 1.224.
299    let c = if significance <= 0.011 {
300        1.628
301    } else if significance <= 0.051 {
302        1.358
303    } else {
304        1.224
305    };
306    let critical = c / n.sqrt();
307
308    let outcome = if max_diff > critical {
309        TestOutcome::Failed
310    } else {
311        TestOutcome::Passed
312    };
313
314    StatisticalTestResult {
315        name: "ks_uniform_log".to_string(),
316        outcome,
317        statistic: max_diff,
318        threshold: critical,
319        message: format!(
320            "D={max_diff:.4} over {n} samples, critical={critical:.4} at α={significance}"
321        ),
322    }
323}
324
325/// Chi-squared critical values for common (df, α) combinations.
326/// Returns a generous upper ceiling for rarely-used df values so the
327/// test defaults to passing in ambiguous cases.
328fn chi_sq_critical(df: usize, alpha: f64) -> f64 {
329    // Rows: (df, α=0.10, α=0.05, α=0.01)
330    let table: &[(usize, f64, f64, f64)] = &[
331        (1, 2.706, 3.841, 6.635),
332        (2, 4.605, 5.991, 9.210),
333        (3, 6.251, 7.815, 11.345),
334        (4, 7.779, 9.488, 13.277),
335        (5, 9.236, 11.070, 15.086),
336        (6, 10.645, 12.592, 16.812),
337        (7, 12.017, 14.067, 18.475),
338        (8, 13.362, 15.507, 20.090),
339        (9, 14.684, 16.919, 21.666),
340        (10, 15.987, 18.307, 23.209),
341        (14, 21.064, 23.685, 29.141),
342        (19, 27.204, 30.144, 36.191),
343        (24, 33.196, 36.415, 42.980),
344        (29, 39.087, 42.557, 49.588),
345    ];
346
347    let row = table
348        .iter()
349        .min_by_key(|(d, _, _, _)| (*d as i64 - df as i64).unsigned_abs());
350    if let Some(&(_, c_10, c_05, c_01)) = row {
351        if alpha <= 0.011 {
352            c_01
353        } else if alpha <= 0.051 {
354            c_05
355        } else {
356            c_10
357        }
358    } else {
359        // Very generous fallback — don't fail tests on exotic df values.
360        1_000_000.0
361    }
362}
363
364#[cfg(test)]
365#[allow(clippy::unwrap_used)]
366mod tests {
367    use super::*;
368    use rand::SeedableRng;
369    use rand_chacha::ChaCha8Rng;
370    use rand_distr::{Distribution, LogNormal};
371
372    fn lognormal_samples(n: usize, mu: f64, sigma: f64, seed: u64) -> Vec<Decimal> {
373        let mut rng = ChaCha8Rng::seed_from_u64(seed);
374        let ln = LogNormal::new(mu, sigma).unwrap();
375        (0..n)
376            .map(|_| Decimal::from_f64_retain(ln.sample(&mut rng)).unwrap_or(Decimal::ONE))
377            .collect()
378    }
379
380    #[test]
381    fn benford_passes_for_lognormal() {
382        let samples = lognormal_samples(2000, 7.0, 2.0, 42);
383        let r = run_benford_first_digit(&samples, 0.015, 0.010);
384        assert!(
385            !matches!(r.outcome, TestOutcome::Failed),
386            "expected pass/warning, got {:?}: {}",
387            r.outcome,
388            r.message
389        );
390    }
391
392    #[test]
393    fn benford_fails_for_concentrated_single_digit() {
394        // All values start with 5 — catastrophic Benford violation.
395        let samples: Vec<Decimal> = (0..500).map(|i| Decimal::from(5000 + i)).collect();
396        let r = run_benford_first_digit(&samples, 0.015, 0.010);
397        assert!(matches!(r.outcome, TestOutcome::Failed));
398    }
399
400    #[test]
401    fn benford_skipped_below_100_samples() {
402        let samples: Vec<Decimal> = (0..50).map(Decimal::from).collect();
403        let r = run_benford_first_digit(&samples, 0.015, 0.010);
404        assert!(matches!(r.outcome, TestOutcome::Skipped));
405    }
406
407    #[test]
408    fn chi_squared_passes_for_log_uniform() {
409        // chi_squared tests uniformity on log scale. Feed it data that is
410        // uniform-on-log (i.e. log-uniform) to get the expected pass.
411        // A log-normal would — correctly — fail uniformity.
412        let samples: Vec<Decimal> = (0..1000)
413            .map(|i| {
414                // Evenly-spaced log values → exactly uniform on log scale.
415                let log_val = (i as f64 / 1000.0) * 10.0;
416                let v = log_val.exp();
417                Decimal::from_f64_retain(v).unwrap_or(Decimal::ONE)
418            })
419            .collect();
420        let r = run_chi_squared(&samples, 10, 0.05);
421        assert!(
422            !matches!(r.outcome, TestOutcome::Failed),
423            "expected pass, got {:?}: {}",
424            r.outcome,
425            r.message
426        );
427    }
428
429    #[test]
430    fn chi_squared_fails_for_bimodal_concentration() {
431        // Bimodal: 450 small values, 50 huge values. Every mid bin empty.
432        // Chi-squared against a uniform expectation will fail hard.
433        let mut samples: Vec<Decimal> = (0..450).map(|_| Decimal::from(1000)).collect();
434        samples.extend((0..50).map(|_| Decimal::from(1_000_000)));
435        let r = run_chi_squared(&samples, 10, 0.05);
436        assert!(
437            matches!(r.outcome, TestOutcome::Failed),
438            "expected Failed for bimodal, got {:?}: {}",
439            r.outcome,
440            r.message
441        );
442    }
443
444    #[test]
445    fn report_all_passed_tracks_failures() {
446        let rep = StatisticalValidationReport {
447            sample_count: 100,
448            results: vec![
449                StatisticalTestResult {
450                    name: "a".into(),
451                    outcome: TestOutcome::Passed,
452                    statistic: 0.0,
453                    threshold: 1.0,
454                    message: "".into(),
455                },
456                StatisticalTestResult {
457                    name: "b".into(),
458                    outcome: TestOutcome::Warning,
459                    statistic: 0.0,
460                    threshold: 1.0,
461                    message: "".into(),
462                },
463            ],
464        };
465        assert!(rep.all_passed()); // warnings don't count
466        assert!(rep.has_warnings());
467
468        let rep_failed = StatisticalValidationReport {
469            sample_count: 100,
470            results: vec![StatisticalTestResult {
471                name: "c".into(),
472                outcome: TestOutcome::Failed,
473                statistic: 2.0,
474                threshold: 1.0,
475                message: "".into(),
476            }],
477        };
478        assert!(!rep_failed.all_passed());
479        assert_eq!(rep_failed.failed_names(), vec!["c".to_string()]);
480    }
481}