Skip to main content

datasynth_eval/coherence/
sampling_validation.rs

1//! Materiality-stratified sampling validation.
2//!
3//! Validates that generated journal entry populations are appropriately distributed
4//! across materiality strata and that anomaly coverage meets audit expectations.
5//!
6//! Stratification follows ISA 530 (Audit Sampling) conventions:
7//! - AboveMateriality: full population coverage expected
8//! - BetweenPerformanceAndOverall: judgmental sampling
9//! - BelowPerformanceMateriality: statistical sampling
10//! - ClearlyTrivial: excluded from scope
11
12use datasynth_core::models::JournalEntry;
13use rust_decimal::Decimal;
14use serde::{Deserialize, Serialize};
15use std::collections::HashSet;
16
17// ─── Result types ─────────────────────────────────────────────────────────────
18
19/// Audit materiality strata per ISA 530.
20#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
21pub enum Stratum {
22    /// Amount > overall materiality: full coverage expected.
23    AboveMateriality,
24    /// Performance materiality < amount ≤ overall materiality.
25    BetweenPerformanceAndOverall,
26    /// Clearly trivial threshold < amount ≤ performance materiality.
27    BelowPerformanceMateriality,
28    /// Amount ≤ materiality × 5%: excluded from scope.
29    ClearlyTrivial,
30}
31
32/// Aggregated results for a single materiality stratum.
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct StratumResult {
35    /// Which stratum this represents.
36    pub stratum: Stratum,
37    /// Number of journal entries in this stratum.
38    pub item_count: usize,
39    /// Sum of debit amounts across all entries in this stratum.
40    #[serde(with = "rust_decimal::serde::str")]
41    pub total_amount: Decimal,
42    /// Number of entries flagged as anomaly or fraud.
43    pub anomaly_count: usize,
44    /// Fraction of entries in this stratum that are anomaly-flagged (0.0–1.0).
45    pub anomaly_rate: f64,
46}
47
48/// Overall result of materiality-stratified sampling validation.
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct SamplingValidationResult {
51    /// Total number of journal entries supplied.
52    pub total_population: usize,
53    /// Per-stratum breakdown.
54    pub strata: Vec<StratumResult>,
55    /// Fraction of above-materiality items that are anomaly-flagged.
56    ///
57    /// Pass threshold: ≥ 0.95 (auditors expect near-complete coverage of large items).
58    pub above_materiality_coverage: f64,
59    /// Fraction of strata (excluding ClearlyTrivial) that contain at least one anomaly.
60    pub anomaly_stratum_coverage: f64,
61    /// Fraction of unique entity codes (company codes) with at least one anomaly.
62    pub entity_coverage: f64,
63    /// Fraction of unique fiscal periods with at least one anomaly.
64    pub temporal_coverage: f64,
65    /// True when above_materiality_coverage ≥ 0.95 (relaxed threshold for synthetic data).
66    pub passes: bool,
67}
68
69// ─── Helper ───────────────────────────────────────────────────────────────────
70
71/// Compute the sum of debit amounts for a journal entry (representative amount).
72fn entry_amount(entry: &JournalEntry) -> Decimal {
73    entry.lines.iter().map(|l| l.debit_amount).sum()
74}
75
76/// Return `true` if the entry is flagged as an anomaly or fraud.
77fn is_anomalous(entry: &JournalEntry) -> bool {
78    entry.header.is_anomaly || entry.header.is_fraud
79}
80
81/// Assign a stratum based on the entry's representative amount.
82fn classify(amount: Decimal, materiality: Decimal, performance_materiality: Decimal) -> Stratum {
83    let clearly_trivial_threshold = materiality * Decimal::new(5, 2); // 5% of materiality
84    if amount > materiality {
85        Stratum::AboveMateriality
86    } else if amount > performance_materiality {
87        Stratum::BetweenPerformanceAndOverall
88    } else if amount > clearly_trivial_threshold {
89        Stratum::BelowPerformanceMateriality
90    } else {
91        Stratum::ClearlyTrivial
92    }
93}
94
95// ─── Public API ───────────────────────────────────────────────────────────────
96
97/// Validate materiality-stratified sampling of journal entries.
98///
99/// # Arguments
100/// - `entries`: All journal entries in the population.
101/// - `materiality`: Overall materiality threshold.
102/// - `performance_materiality`: Performance (tolerable error) materiality threshold.
103///   Typically 50–75% of overall materiality.
104///
105/// # Returns
106/// A `SamplingValidationResult` describing stratum coverage and pass/fail status.
107pub fn validate_sampling(
108    entries: &[JournalEntry],
109    materiality: Decimal,
110    performance_materiality: Decimal,
111) -> SamplingValidationResult {
112    let total_population = entries.len();
113
114    // ─── Per-stratum accumulators ─────────────────────────────────────────────
115    let strata_order = [
116        Stratum::AboveMateriality,
117        Stratum::BetweenPerformanceAndOverall,
118        Stratum::BelowPerformanceMateriality,
119        Stratum::ClearlyTrivial,
120    ];
121
122    let mut counts = [0usize; 4];
123    let mut totals = [Decimal::ZERO; 4];
124    let mut anomaly_counts = [0usize; 4];
125
126    // ─── Entity / temporal coverage tracking ──────────────────────────────────
127    let mut all_entities: HashSet<String> = HashSet::new();
128    let mut anomaly_entities: HashSet<String> = HashSet::new();
129    // fiscal period key: (fiscal_year, fiscal_period)
130    let mut all_periods: HashSet<(u16, u8)> = HashSet::new();
131    let mut anomaly_periods: HashSet<(u16, u8)> = HashSet::new();
132
133    for entry in entries {
134        let amount = entry_amount(entry);
135        let stratum = classify(amount, materiality, performance_materiality);
136        let idx = match stratum {
137            Stratum::AboveMateriality => 0,
138            Stratum::BetweenPerformanceAndOverall => 1,
139            Stratum::BelowPerformanceMateriality => 2,
140            Stratum::ClearlyTrivial => 3,
141        };
142
143        counts[idx] += 1;
144        totals[idx] += amount;
145
146        let entity_key = entry.header.company_code.clone();
147        let period_key = (entry.header.fiscal_year, entry.header.fiscal_period);
148
149        all_entities.insert(entity_key.clone());
150        all_periods.insert(period_key);
151
152        if is_anomalous(entry) {
153            anomaly_counts[idx] += 1;
154            anomaly_entities.insert(entity_key);
155            anomaly_periods.insert(period_key);
156        }
157    }
158
159    // ─── Build stratum results ────────────────────────────────────────────────
160    let strata: Vec<StratumResult> = strata_order
161        .iter()
162        .enumerate()
163        .map(|(i, stratum)| {
164            let count = counts[i];
165            let anomaly_count = anomaly_counts[i];
166            let anomaly_rate = if count > 0 {
167                anomaly_count as f64 / count as f64
168            } else {
169                0.0
170            };
171            StratumResult {
172                stratum: stratum.clone(),
173                item_count: count,
174                total_amount: totals[i],
175                anomaly_count,
176                anomaly_rate,
177            }
178        })
179        .collect();
180
181    // ─── above_materiality_coverage ───────────────────────────────────────────
182    let above_mat_count = counts[0];
183    let above_mat_anomaly = anomaly_counts[0];
184    let above_materiality_coverage = if above_mat_count > 0 {
185        above_mat_anomaly as f64 / above_mat_count as f64
186    } else {
187        // No items above materiality → vacuously pass
188        1.0
189    };
190
191    // ─── anomaly_stratum_coverage ─────────────────────────────────────────────
192    // Count strata that contain anomalies (exclude ClearlyTrivial index=3).
193    let non_trivial_strata = 3usize; // AboveMateriality, Between, Below
194    let strata_with_anomalies = anomaly_counts[0..3].iter().filter(|&&c| c > 0).count();
195    let anomaly_stratum_coverage = if non_trivial_strata > 0 {
196        strata_with_anomalies as f64 / non_trivial_strata as f64
197    } else {
198        1.0
199    };
200
201    // ─── entity_coverage ──────────────────────────────────────────────────────
202    let entity_coverage = if all_entities.is_empty() {
203        1.0
204    } else {
205        anomaly_entities.len() as f64 / all_entities.len() as f64
206    };
207
208    // ─── temporal_coverage ────────────────────────────────────────────────────
209    let temporal_coverage = if all_periods.is_empty() {
210        1.0
211    } else {
212        anomaly_periods.len() as f64 / all_periods.len() as f64
213    };
214
215    // ─── Pass/fail ────────────────────────────────────────────────────────────
216    // Relaxed threshold: above-materiality items must have ≥ 95% anomaly coverage.
217    let passes = above_materiality_coverage >= 0.95;
218
219    SamplingValidationResult {
220        total_population,
221        strata,
222        above_materiality_coverage,
223        anomaly_stratum_coverage,
224        entity_coverage,
225        temporal_coverage,
226        passes,
227    }
228}
229
230// ─── Unit tests ───────────────────────────────────────────────────────────────
231
232#[cfg(test)]
233#[allow(clippy::unwrap_used)]
234mod tests {
235    use super::*;
236    use datasynth_core::models::{JournalEntry, JournalEntryHeader, JournalEntryLine};
237    use rust_decimal_macros::dec;
238
239    fn date(y: i32, m: u32, d: u32) -> chrono::NaiveDate {
240        chrono::NaiveDate::from_ymd_opt(y, m, d).unwrap()
241    }
242
243    fn make_entry(amount: Decimal, anomaly: bool, company: &str, period: u8) -> JournalEntry {
244        let posting_date = date(2024, period as u32, 1);
245        let mut header = JournalEntryHeader::new(company.to_string(), posting_date);
246        header.fiscal_period = period;
247        header.is_anomaly = anomaly;
248        let doc_id = header.document_id;
249        let mut entry = JournalEntry::new(header);
250        entry.add_line(JournalEntryLine::debit(
251            doc_id,
252            1,
253            "6000".to_string(),
254            amount,
255        ));
256        entry.add_line(JournalEntryLine::credit(
257            doc_id,
258            2,
259            "2000".to_string(),
260            amount,
261        ));
262        entry
263    }
264
265    #[test]
266    fn test_stratum_classification() {
267        // materiality = 100_000, performance_materiality = 60_000
268        // clearly_trivial = 5_000
269        let mat = dec!(100_000);
270        let perf = dec!(60_000);
271
272        assert_eq!(
273            classify(dec!(200_000), mat, perf),
274            Stratum::AboveMateriality
275        );
276        assert_eq!(
277            classify(dec!(100_001), mat, perf),
278            Stratum::AboveMateriality
279        );
280        assert_eq!(
281            classify(dec!(80_000), mat, perf),
282            Stratum::BetweenPerformanceAndOverall
283        );
284        assert_eq!(
285            classify(dec!(60_001), mat, perf),
286            Stratum::BetweenPerformanceAndOverall
287        );
288        assert_eq!(
289            classify(dec!(10_000), mat, perf),
290            Stratum::BelowPerformanceMateriality
291        );
292        assert_eq!(classify(dec!(1_000), mat, perf), Stratum::ClearlyTrivial);
293        assert_eq!(classify(dec!(0), mat, perf), Stratum::ClearlyTrivial);
294    }
295
296    #[test]
297    fn test_empty_entries() {
298        let result = validate_sampling(&[], dec!(100_000), dec!(60_000));
299        assert_eq!(result.total_population, 0);
300        // Vacuously passes
301        assert!(result.passes);
302        assert!((result.above_materiality_coverage - 1.0).abs() < 1e-9);
303    }
304
305    #[test]
306    fn test_above_materiality_coverage_full() {
307        // All above-materiality items are anomalous → coverage = 1.0 → passes
308        let entries: Vec<JournalEntry> = (0..5)
309            .map(|_| make_entry(dec!(200_000), true, "C001", 1))
310            .collect();
311        let result = validate_sampling(&entries, dec!(100_000), dec!(60_000));
312        assert!((result.above_materiality_coverage - 1.0).abs() < 1e-9);
313        assert!(result.passes);
314    }
315
316    #[test]
317    fn test_above_materiality_coverage_zero() {
318        // No above-materiality items are anomalous → coverage = 0.0 → fails
319        let entries: Vec<JournalEntry> = (0..5)
320            .map(|_| make_entry(dec!(200_000), false, "C001", 1))
321            .collect();
322        let result = validate_sampling(&entries, dec!(100_000), dec!(60_000));
323        assert!((result.above_materiality_coverage - 0.0).abs() < 1e-9);
324        assert!(!result.passes);
325    }
326
327    #[test]
328    fn test_entity_coverage() {
329        // Two companies, one has anomaly, other does not
330        let mut entries = vec![
331            make_entry(dec!(50_000), true, "C001", 1),
332            make_entry(dec!(50_000), false, "C002", 1),
333        ];
334        // Add above-materiality anomaly to pass the main threshold
335        entries.push(make_entry(dec!(200_000), true, "C001", 1));
336        let result = validate_sampling(&entries, dec!(100_000), dec!(60_000));
337        // C001 has anomaly, C002 does not → 1/2 = 0.5
338        assert!((result.entity_coverage - 0.5).abs() < 1e-9);
339        assert!(result.passes);
340    }
341
342    #[test]
343    fn test_temporal_coverage() {
344        // 3 periods, anomalies only in 2
345        let mut entries: Vec<JournalEntry> = Vec::new();
346        // period 1: anomaly (above materiality)
347        entries.push(make_entry(dec!(200_000), true, "C001", 1));
348        // period 2: anomaly
349        entries.push(make_entry(dec!(50_000), true, "C001", 2));
350        // period 3: no anomaly
351        entries.push(make_entry(dec!(50_000), false, "C001", 3));
352        let result = validate_sampling(&entries, dec!(100_000), dec!(60_000));
353        // 2 out of 3 periods have anomalies
354        assert!((result.temporal_coverage - 2.0 / 3.0).abs() < 1e-9);
355        assert!(result.passes);
356    }
357
358    #[test]
359    fn test_stratum_counts() {
360        let entries = vec![
361            make_entry(dec!(200_000), true, "C001", 1), // AboveMateriality
362            make_entry(dec!(80_000), false, "C001", 2), // Between
363            make_entry(dec!(10_000), false, "C001", 3), // Below
364            make_entry(dec!(500), false, "C001", 4),    // ClearlyTrivial
365        ];
366        let result = validate_sampling(&entries, dec!(100_000), dec!(60_000));
367        assert_eq!(result.total_population, 4);
368        let above = result
369            .strata
370            .iter()
371            .find(|s| s.stratum == Stratum::AboveMateriality)
372            .unwrap();
373        let between = result
374            .strata
375            .iter()
376            .find(|s| s.stratum == Stratum::BetweenPerformanceAndOverall)
377            .unwrap();
378        let below = result
379            .strata
380            .iter()
381            .find(|s| s.stratum == Stratum::BelowPerformanceMateriality)
382            .unwrap();
383        let trivial = result
384            .strata
385            .iter()
386            .find(|s| s.stratum == Stratum::ClearlyTrivial)
387            .unwrap();
388        assert_eq!(above.item_count, 1);
389        assert_eq!(between.item_count, 1);
390        assert_eq!(below.item_count, 1);
391        assert_eq!(trivial.item_count, 1);
392    }
393}