datasynth_eval/coherence/
sampling_validation.rs1use datasynth_core::models::JournalEntry;
13use rust_decimal::Decimal;
14use serde::{Deserialize, Serialize};
15use std::collections::HashSet;
16
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
21pub enum Stratum {
22 AboveMateriality,
24 BetweenPerformanceAndOverall,
26 BelowPerformanceMateriality,
28 ClearlyTrivial,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct StratumResult {
35 pub stratum: Stratum,
37 pub item_count: usize,
39 #[serde(with = "rust_decimal::serde::str")]
41 pub total_amount: Decimal,
42 pub anomaly_count: usize,
44 pub anomaly_rate: f64,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct SamplingValidationResult {
51 pub total_population: usize,
53 pub strata: Vec<StratumResult>,
55 pub above_materiality_coverage: f64,
59 pub anomaly_stratum_coverage: f64,
61 pub entity_coverage: f64,
63 pub temporal_coverage: f64,
65 pub passes: bool,
67}
68
69fn entry_amount(entry: &JournalEntry) -> Decimal {
73 entry.lines.iter().map(|l| l.debit_amount).sum()
74}
75
76fn is_anomalous(entry: &JournalEntry) -> bool {
78 entry.header.is_anomaly || entry.header.is_fraud
79}
80
81fn classify(amount: Decimal, materiality: Decimal, performance_materiality: Decimal) -> Stratum {
83 let clearly_trivial_threshold = materiality * Decimal::new(5, 2); if amount > materiality {
85 Stratum::AboveMateriality
86 } else if amount > performance_materiality {
87 Stratum::BetweenPerformanceAndOverall
88 } else if amount > clearly_trivial_threshold {
89 Stratum::BelowPerformanceMateriality
90 } else {
91 Stratum::ClearlyTrivial
92 }
93}
94
95pub fn validate_sampling(
108 entries: &[JournalEntry],
109 materiality: Decimal,
110 performance_materiality: Decimal,
111) -> SamplingValidationResult {
112 let total_population = entries.len();
113
114 let strata_order = [
116 Stratum::AboveMateriality,
117 Stratum::BetweenPerformanceAndOverall,
118 Stratum::BelowPerformanceMateriality,
119 Stratum::ClearlyTrivial,
120 ];
121
122 let mut counts = [0usize; 4];
123 let mut totals = [Decimal::ZERO; 4];
124 let mut anomaly_counts = [0usize; 4];
125
126 let mut all_entities: HashSet<String> = HashSet::new();
128 let mut anomaly_entities: HashSet<String> = HashSet::new();
129 let mut all_periods: HashSet<(u16, u8)> = HashSet::new();
131 let mut anomaly_periods: HashSet<(u16, u8)> = HashSet::new();
132
133 for entry in entries {
134 let amount = entry_amount(entry);
135 let stratum = classify(amount, materiality, performance_materiality);
136 let idx = match stratum {
137 Stratum::AboveMateriality => 0,
138 Stratum::BetweenPerformanceAndOverall => 1,
139 Stratum::BelowPerformanceMateriality => 2,
140 Stratum::ClearlyTrivial => 3,
141 };
142
143 counts[idx] += 1;
144 totals[idx] += amount;
145
146 let entity_key = entry.header.company_code.clone();
147 let period_key = (entry.header.fiscal_year, entry.header.fiscal_period);
148
149 all_entities.insert(entity_key.clone());
150 all_periods.insert(period_key);
151
152 if is_anomalous(entry) {
153 anomaly_counts[idx] += 1;
154 anomaly_entities.insert(entity_key);
155 anomaly_periods.insert(period_key);
156 }
157 }
158
159 let strata: Vec<StratumResult> = strata_order
161 .iter()
162 .enumerate()
163 .map(|(i, stratum)| {
164 let count = counts[i];
165 let anomaly_count = anomaly_counts[i];
166 let anomaly_rate = if count > 0 {
167 anomaly_count as f64 / count as f64
168 } else {
169 0.0
170 };
171 StratumResult {
172 stratum: stratum.clone(),
173 item_count: count,
174 total_amount: totals[i],
175 anomaly_count,
176 anomaly_rate,
177 }
178 })
179 .collect();
180
181 let above_mat_count = counts[0];
183 let above_mat_anomaly = anomaly_counts[0];
184 let above_materiality_coverage = if above_mat_count > 0 {
185 above_mat_anomaly as f64 / above_mat_count as f64
186 } else {
187 1.0
189 };
190
191 let non_trivial_strata = 3usize; let strata_with_anomalies = anomaly_counts[0..3].iter().filter(|&&c| c > 0).count();
195 let anomaly_stratum_coverage = if non_trivial_strata > 0 {
196 strata_with_anomalies as f64 / non_trivial_strata as f64
197 } else {
198 1.0
199 };
200
201 let entity_coverage = if all_entities.is_empty() {
203 1.0
204 } else {
205 anomaly_entities.len() as f64 / all_entities.len() as f64
206 };
207
208 let temporal_coverage = if all_periods.is_empty() {
210 1.0
211 } else {
212 anomaly_periods.len() as f64 / all_periods.len() as f64
213 };
214
215 let passes = above_materiality_coverage >= 0.95;
218
219 SamplingValidationResult {
220 total_population,
221 strata,
222 above_materiality_coverage,
223 anomaly_stratum_coverage,
224 entity_coverage,
225 temporal_coverage,
226 passes,
227 }
228}
229
230#[cfg(test)]
233#[allow(clippy::unwrap_used)]
234mod tests {
235 use super::*;
236 use datasynth_core::models::{JournalEntry, JournalEntryHeader, JournalEntryLine};
237 use rust_decimal_macros::dec;
238
239 fn date(y: i32, m: u32, d: u32) -> chrono::NaiveDate {
240 chrono::NaiveDate::from_ymd_opt(y, m, d).unwrap()
241 }
242
243 fn make_entry(amount: Decimal, anomaly: bool, company: &str, period: u8) -> JournalEntry {
244 let posting_date = date(2024, period as u32, 1);
245 let mut header = JournalEntryHeader::new(company.to_string(), posting_date);
246 header.fiscal_period = period;
247 header.is_anomaly = anomaly;
248 let doc_id = header.document_id;
249 let mut entry = JournalEntry::new(header);
250 entry.add_line(JournalEntryLine::debit(
251 doc_id,
252 1,
253 "6000".to_string(),
254 amount,
255 ));
256 entry.add_line(JournalEntryLine::credit(
257 doc_id,
258 2,
259 "2000".to_string(),
260 amount,
261 ));
262 entry
263 }
264
265 #[test]
266 fn test_stratum_classification() {
267 let mat = dec!(100_000);
270 let perf = dec!(60_000);
271
272 assert_eq!(
273 classify(dec!(200_000), mat, perf),
274 Stratum::AboveMateriality
275 );
276 assert_eq!(
277 classify(dec!(100_001), mat, perf),
278 Stratum::AboveMateriality
279 );
280 assert_eq!(
281 classify(dec!(80_000), mat, perf),
282 Stratum::BetweenPerformanceAndOverall
283 );
284 assert_eq!(
285 classify(dec!(60_001), mat, perf),
286 Stratum::BetweenPerformanceAndOverall
287 );
288 assert_eq!(
289 classify(dec!(10_000), mat, perf),
290 Stratum::BelowPerformanceMateriality
291 );
292 assert_eq!(classify(dec!(1_000), mat, perf), Stratum::ClearlyTrivial);
293 assert_eq!(classify(dec!(0), mat, perf), Stratum::ClearlyTrivial);
294 }
295
296 #[test]
297 fn test_empty_entries() {
298 let result = validate_sampling(&[], dec!(100_000), dec!(60_000));
299 assert_eq!(result.total_population, 0);
300 assert!(result.passes);
302 assert!((result.above_materiality_coverage - 1.0).abs() < 1e-9);
303 }
304
305 #[test]
306 fn test_above_materiality_coverage_full() {
307 let entries: Vec<JournalEntry> = (0..5)
309 .map(|_| make_entry(dec!(200_000), true, "C001", 1))
310 .collect();
311 let result = validate_sampling(&entries, dec!(100_000), dec!(60_000));
312 assert!((result.above_materiality_coverage - 1.0).abs() < 1e-9);
313 assert!(result.passes);
314 }
315
316 #[test]
317 fn test_above_materiality_coverage_zero() {
318 let entries: Vec<JournalEntry> = (0..5)
320 .map(|_| make_entry(dec!(200_000), false, "C001", 1))
321 .collect();
322 let result = validate_sampling(&entries, dec!(100_000), dec!(60_000));
323 assert!((result.above_materiality_coverage - 0.0).abs() < 1e-9);
324 assert!(!result.passes);
325 }
326
327 #[test]
328 fn test_entity_coverage() {
329 let mut entries = vec![
331 make_entry(dec!(50_000), true, "C001", 1),
332 make_entry(dec!(50_000), false, "C002", 1),
333 ];
334 entries.push(make_entry(dec!(200_000), true, "C001", 1));
336 let result = validate_sampling(&entries, dec!(100_000), dec!(60_000));
337 assert!((result.entity_coverage - 0.5).abs() < 1e-9);
339 assert!(result.passes);
340 }
341
342 #[test]
343 fn test_temporal_coverage() {
344 let mut entries: Vec<JournalEntry> = Vec::new();
346 entries.push(make_entry(dec!(200_000), true, "C001", 1));
348 entries.push(make_entry(dec!(50_000), true, "C001", 2));
350 entries.push(make_entry(dec!(50_000), false, "C001", 3));
352 let result = validate_sampling(&entries, dec!(100_000), dec!(60_000));
353 assert!((result.temporal_coverage - 2.0 / 3.0).abs() < 1e-9);
355 assert!(result.passes);
356 }
357
358 #[test]
359 fn test_stratum_counts() {
360 let entries = vec![
361 make_entry(dec!(200_000), true, "C001", 1), make_entry(dec!(80_000), false, "C001", 2), make_entry(dec!(10_000), false, "C001", 3), make_entry(dec!(500), false, "C001", 4), ];
366 let result = validate_sampling(&entries, dec!(100_000), dec!(60_000));
367 assert_eq!(result.total_population, 4);
368 let above = result
369 .strata
370 .iter()
371 .find(|s| s.stratum == Stratum::AboveMateriality)
372 .unwrap();
373 let between = result
374 .strata
375 .iter()
376 .find(|s| s.stratum == Stratum::BetweenPerformanceAndOverall)
377 .unwrap();
378 let below = result
379 .strata
380 .iter()
381 .find(|s| s.stratum == Stratum::BelowPerformanceMateriality)
382 .unwrap();
383 let trivial = result
384 .strata
385 .iter()
386 .find(|s| s.stratum == Stratum::ClearlyTrivial)
387 .unwrap();
388 assert_eq!(above.item_count, 1);
389 assert_eq!(between.item_count, 1);
390 assert_eq!(below.item_count, 1);
391 assert_eq!(trivial.item_count, 1);
392 }
393}