Skip to main content

datasynth_eval/coherence/
sampling_validation.rs

1//! Materiality-stratified sampling validation.
2//!
3//! Validates that generated journal entry populations are appropriately distributed
4//! across materiality strata and that anomaly coverage meets audit expectations.
5//!
6//! Stratification follows ISA 530 (Audit Sampling) conventions:
7//! - AboveMateriality: full population coverage expected
8//! - BetweenPerformanceAndOverall: judgmental sampling
9//! - BelowPerformanceMateriality: statistical sampling
10//! - ClearlyTrivial: excluded from scope
11
12use datasynth_core::models::JournalEntry;
13use rust_decimal::Decimal;
14use serde::{Deserialize, Serialize};
15use std::collections::HashSet;
16
17// ─── Result types ─────────────────────────────────────────────────────────────
18
19/// Audit materiality strata per ISA 530.
20#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
21pub enum Stratum {
22    /// Amount > overall materiality: full coverage expected.
23    AboveMateriality,
24    /// Performance materiality < amount ≤ overall materiality.
25    BetweenPerformanceAndOverall,
26    /// Clearly trivial threshold < amount ≤ performance materiality.
27    BelowPerformanceMateriality,
28    /// Amount ≤ materiality × 5%: excluded from scope.
29    ClearlyTrivial,
30}
31
32/// Aggregated results for a single materiality stratum.
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct StratumResult {
35    /// Which stratum this represents.
36    pub stratum: Stratum,
37    /// Number of journal entries in this stratum.
38    pub item_count: usize,
39    /// Sum of debit amounts across all entries in this stratum.
40    #[serde(with = "datasynth_core::serde_decimal")]
41    pub total_amount: Decimal,
42    /// Number of entries flagged as anomaly or fraud.
43    pub anomaly_count: usize,
44    /// Fraction of entries in this stratum that are anomaly-flagged (0.0–1.0).
45    pub anomaly_rate: f64,
46}
47
48/// Overall result of materiality-stratified sampling validation.
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct SamplingValidationResult {
51    /// Total number of journal entries supplied.
52    pub total_population: usize,
53    /// Per-stratum breakdown.
54    pub strata: Vec<StratumResult>,
55    /// Fraction of above-materiality items that are anomaly-flagged.
56    ///
57    /// Pass threshold: ≥ 0.95 (auditors expect near-complete coverage of large items).
58    pub above_materiality_coverage: f64,
59    /// Fraction of strata (excluding ClearlyTrivial) that contain at least one anomaly.
60    pub anomaly_stratum_coverage: f64,
61    /// Fraction of unique entity codes (company codes) with at least one anomaly.
62    pub entity_coverage: f64,
63    /// Fraction of unique fiscal periods with at least one anomaly.
64    pub temporal_coverage: f64,
65    /// True when above_materiality_coverage ≥ 0.95 (relaxed threshold for synthetic data).
66    pub passes: bool,
67}
68
69// ─── Helper ───────────────────────────────────────────────────────────────────
70
71/// Compute the sum of debit amounts for a journal entry (representative amount).
72fn entry_amount(entry: &JournalEntry) -> Decimal {
73    entry.lines.iter().map(|l| l.debit_amount).sum()
74}
75
76/// Return `true` if the entry is flagged as an anomaly or fraud.
77fn is_anomalous(entry: &JournalEntry) -> bool {
78    entry.header.is_anomaly || entry.header.is_fraud
79}
80
81/// Assign a stratum based on the entry's representative amount.
82fn classify(amount: Decimal, materiality: Decimal, performance_materiality: Decimal) -> Stratum {
83    let clearly_trivial_threshold = materiality * Decimal::new(5, 2); // 5% of materiality
84    if amount > materiality {
85        Stratum::AboveMateriality
86    } else if amount > performance_materiality {
87        Stratum::BetweenPerformanceAndOverall
88    } else if amount > clearly_trivial_threshold {
89        Stratum::BelowPerformanceMateriality
90    } else {
91        Stratum::ClearlyTrivial
92    }
93}
94
95// ─── Public API ───────────────────────────────────────────────────────────────
96
97/// Validate materiality-stratified sampling of journal entries.
98///
99/// # Arguments
100/// - `entries`: All journal entries in the population.
101/// - `materiality`: Overall materiality threshold.
102/// - `performance_materiality`: Performance (tolerable error) materiality threshold.
103///   Typically 50–75% of overall materiality.
104///
105/// # Returns
106/// A `SamplingValidationResult` describing stratum coverage and pass/fail status.
107pub fn validate_sampling(
108    entries: &[JournalEntry],
109    materiality: Decimal,
110    performance_materiality: Decimal,
111) -> SamplingValidationResult {
112    let total_population = entries.len();
113
114    // ─── Per-stratum accumulators ─────────────────────────────────────────────
115    let strata_order = [
116        Stratum::AboveMateriality,
117        Stratum::BetweenPerformanceAndOverall,
118        Stratum::BelowPerformanceMateriality,
119        Stratum::ClearlyTrivial,
120    ];
121
122    let mut counts = [0usize; 4];
123    let mut totals = [Decimal::ZERO; 4];
124    let mut anomaly_counts = [0usize; 4];
125
126    // ─── Entity / temporal coverage tracking ──────────────────────────────────
127    let mut all_entities: HashSet<String> = HashSet::new();
128    let mut anomaly_entities: HashSet<String> = HashSet::new();
129    // fiscal period key: (fiscal_year, fiscal_period)
130    let mut all_periods: HashSet<(u16, u8)> = HashSet::new();
131    let mut anomaly_periods: HashSet<(u16, u8)> = HashSet::new();
132
133    for entry in entries {
134        let amount = entry_amount(entry);
135        let stratum = classify(amount, materiality, performance_materiality);
136        let idx = match stratum {
137            Stratum::AboveMateriality => 0,
138            Stratum::BetweenPerformanceAndOverall => 1,
139            Stratum::BelowPerformanceMateriality => 2,
140            Stratum::ClearlyTrivial => 3,
141        };
142
143        counts[idx] += 1;
144        totals[idx] += amount;
145
146        let entity_key = entry.header.company_code.clone();
147        let period_key = (entry.header.fiscal_year, entry.header.fiscal_period);
148
149        all_entities.insert(entity_key.clone());
150        all_periods.insert(period_key);
151
152        if is_anomalous(entry) {
153            anomaly_counts[idx] += 1;
154            anomaly_entities.insert(entity_key);
155            anomaly_periods.insert(period_key);
156        }
157    }
158
159    // ─── Build stratum results ────────────────────────────────────────────────
160    let strata: Vec<StratumResult> = strata_order
161        .iter()
162        .enumerate()
163        .map(|(i, stratum)| {
164            let count = counts[i];
165            let anomaly_count = anomaly_counts[i];
166            let anomaly_rate = if count > 0 {
167                anomaly_count as f64 / count as f64
168            } else {
169                0.0
170            };
171            StratumResult {
172                stratum: stratum.clone(),
173                item_count: count,
174                total_amount: totals[i],
175                anomaly_count,
176                anomaly_rate,
177            }
178        })
179        .collect();
180
181    // ─── above_materiality_coverage ───────────────────────────────────────────
182    let above_mat_count = counts[0];
183    let above_mat_anomaly = anomaly_counts[0];
184    let above_materiality_coverage = if above_mat_count > 0 {
185        above_mat_anomaly as f64 / above_mat_count as f64
186    } else {
187        // No items above materiality → vacuously pass
188        1.0
189    };
190
191    // ─── anomaly_stratum_coverage ─────────────────────────────────────────────
192    // Count strata that contain anomalies (exclude ClearlyTrivial index=3).
193    let non_trivial_strata = 3usize; // AboveMateriality, Between, Below
194    let strata_with_anomalies = anomaly_counts[0..3].iter().filter(|&&c| c > 0).count();
195    let anomaly_stratum_coverage = if non_trivial_strata > 0 {
196        strata_with_anomalies as f64 / non_trivial_strata as f64
197    } else {
198        1.0
199    };
200
201    // ─── entity_coverage ──────────────────────────────────────────────────────
202    let entity_coverage = if all_entities.is_empty() {
203        1.0
204    } else {
205        anomaly_entities.len() as f64 / all_entities.len() as f64
206    };
207
208    // ─── temporal_coverage ────────────────────────────────────────────────────
209    let temporal_coverage = if all_periods.is_empty() {
210        1.0
211    } else {
212        anomaly_periods.len() as f64 / all_periods.len() as f64
213    };
214
215    // ─── Pass/fail ────────────────────────────────────────────────────────────
216    // Relaxed threshold: above-materiality items must have ≥ 95% anomaly coverage.
217    let passes = above_materiality_coverage >= 0.95;
218
219    SamplingValidationResult {
220        total_population,
221        strata,
222        above_materiality_coverage,
223        anomaly_stratum_coverage,
224        entity_coverage,
225        temporal_coverage,
226        passes,
227    }
228}
229
230// ─── Unit tests ───────────────────────────────────────────────────────────────
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235    use datasynth_core::models::{JournalEntry, JournalEntryHeader, JournalEntryLine};
236    use rust_decimal_macros::dec;
237
238    fn date(y: i32, m: u32, d: u32) -> chrono::NaiveDate {
239        chrono::NaiveDate::from_ymd_opt(y, m, d).unwrap()
240    }
241
242    fn make_entry(amount: Decimal, anomaly: bool, company: &str, period: u8) -> JournalEntry {
243        let posting_date = date(2024, period as u32, 1);
244        let mut header = JournalEntryHeader::new(company.to_string(), posting_date);
245        header.fiscal_period = period;
246        header.is_anomaly = anomaly;
247        let doc_id = header.document_id;
248        let mut entry = JournalEntry::new(header);
249        entry.add_line(JournalEntryLine::debit(
250            doc_id,
251            1,
252            "6000".to_string(),
253            amount,
254        ));
255        entry.add_line(JournalEntryLine::credit(
256            doc_id,
257            2,
258            "2000".to_string(),
259            amount,
260        ));
261        entry
262    }
263
264    #[test]
265    fn test_stratum_classification() {
266        // materiality = 100_000, performance_materiality = 60_000
267        // clearly_trivial = 5_000
268        let mat = dec!(100_000);
269        let perf = dec!(60_000);
270
271        assert_eq!(
272            classify(dec!(200_000), mat, perf),
273            Stratum::AboveMateriality
274        );
275        assert_eq!(
276            classify(dec!(100_001), mat, perf),
277            Stratum::AboveMateriality
278        );
279        assert_eq!(
280            classify(dec!(80_000), mat, perf),
281            Stratum::BetweenPerformanceAndOverall
282        );
283        assert_eq!(
284            classify(dec!(60_001), mat, perf),
285            Stratum::BetweenPerformanceAndOverall
286        );
287        assert_eq!(
288            classify(dec!(10_000), mat, perf),
289            Stratum::BelowPerformanceMateriality
290        );
291        assert_eq!(classify(dec!(1_000), mat, perf), Stratum::ClearlyTrivial);
292        assert_eq!(classify(dec!(0), mat, perf), Stratum::ClearlyTrivial);
293    }
294
295    #[test]
296    fn test_empty_entries() {
297        let result = validate_sampling(&[], dec!(100_000), dec!(60_000));
298        assert_eq!(result.total_population, 0);
299        // Vacuously passes
300        assert!(result.passes);
301        assert!((result.above_materiality_coverage - 1.0).abs() < 1e-9);
302    }
303
304    #[test]
305    fn test_above_materiality_coverage_full() {
306        // All above-materiality items are anomalous → coverage = 1.0 → passes
307        let entries: Vec<JournalEntry> = (0..5)
308            .map(|_| make_entry(dec!(200_000), true, "C001", 1))
309            .collect();
310        let result = validate_sampling(&entries, dec!(100_000), dec!(60_000));
311        assert!((result.above_materiality_coverage - 1.0).abs() < 1e-9);
312        assert!(result.passes);
313    }
314
315    #[test]
316    fn test_above_materiality_coverage_zero() {
317        // No above-materiality items are anomalous → coverage = 0.0 → fails
318        let entries: Vec<JournalEntry> = (0..5)
319            .map(|_| make_entry(dec!(200_000), false, "C001", 1))
320            .collect();
321        let result = validate_sampling(&entries, dec!(100_000), dec!(60_000));
322        assert!((result.above_materiality_coverage - 0.0).abs() < 1e-9);
323        assert!(!result.passes);
324    }
325
326    #[test]
327    fn test_entity_coverage() {
328        // Two companies, one has anomaly, other does not
329        let mut entries = vec![
330            make_entry(dec!(50_000), true, "C001", 1),
331            make_entry(dec!(50_000), false, "C002", 1),
332        ];
333        // Add above-materiality anomaly to pass the main threshold
334        entries.push(make_entry(dec!(200_000), true, "C001", 1));
335        let result = validate_sampling(&entries, dec!(100_000), dec!(60_000));
336        // C001 has anomaly, C002 does not → 1/2 = 0.5
337        assert!((result.entity_coverage - 0.5).abs() < 1e-9);
338        assert!(result.passes);
339    }
340
341    #[test]
342    fn test_temporal_coverage() {
343        // 3 periods, anomalies only in 2
344        let entries: Vec<JournalEntry> = vec![
345            // period 1: anomaly (above materiality)
346            make_entry(dec!(200_000), true, "C001", 1),
347            // period 2: anomaly
348            make_entry(dec!(50_000), true, "C001", 2),
349            // period 3: no anomaly
350            make_entry(dec!(50_000), false, "C001", 3),
351        ];
352        let result = validate_sampling(&entries, dec!(100_000), dec!(60_000));
353        // 2 out of 3 periods have anomalies
354        assert!((result.temporal_coverage - 2.0 / 3.0).abs() < 1e-9);
355        assert!(result.passes);
356    }
357
358    #[test]
359    fn test_stratum_counts() {
360        let entries = vec![
361            make_entry(dec!(200_000), true, "C001", 1), // AboveMateriality
362            make_entry(dec!(80_000), false, "C001", 2), // Between
363            make_entry(dec!(10_000), false, "C001", 3), // Below
364            make_entry(dec!(500), false, "C001", 4),    // ClearlyTrivial
365        ];
366        let result = validate_sampling(&entries, dec!(100_000), dec!(60_000));
367        assert_eq!(result.total_population, 4);
368        let above = result
369            .strata
370            .iter()
371            .find(|s| s.stratum == Stratum::AboveMateriality)
372            .unwrap();
373        let between = result
374            .strata
375            .iter()
376            .find(|s| s.stratum == Stratum::BetweenPerformanceAndOverall)
377            .unwrap();
378        let below = result
379            .strata
380            .iter()
381            .find(|s| s.stratum == Stratum::BelowPerformanceMateriality)
382            .unwrap();
383        let trivial = result
384            .strata
385            .iter()
386            .find(|s| s.stratum == Stratum::ClearlyTrivial)
387            .unwrap();
388        assert_eq!(above.item_count, 1);
389        assert_eq!(between.item_count, 1);
390        assert_eq!(below.item_count, 1);
391        assert_eq!(trivial.item_count, 1);
392    }
393}