use datasynth_core::models::JournalEntry;
use rust_decimal::Decimal;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum Stratum {
AboveMateriality,
BetweenPerformanceAndOverall,
BelowPerformanceMateriality,
ClearlyTrivial,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StratumResult {
pub stratum: Stratum,
pub item_count: usize,
#[serde(with = "rust_decimal::serde::str")]
pub total_amount: Decimal,
pub anomaly_count: usize,
pub anomaly_rate: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SamplingValidationResult {
pub total_population: usize,
pub strata: Vec<StratumResult>,
pub above_materiality_coverage: f64,
pub anomaly_stratum_coverage: f64,
pub entity_coverage: f64,
pub temporal_coverage: f64,
pub passes: bool,
}
fn entry_amount(entry: &JournalEntry) -> Decimal {
entry.lines.iter().map(|l| l.debit_amount).sum()
}
fn is_anomalous(entry: &JournalEntry) -> bool {
entry.header.is_anomaly || entry.header.is_fraud
}
fn classify(amount: Decimal, materiality: Decimal, performance_materiality: Decimal) -> Stratum {
let clearly_trivial_threshold = materiality * Decimal::new(5, 2); if amount > materiality {
Stratum::AboveMateriality
} else if amount > performance_materiality {
Stratum::BetweenPerformanceAndOverall
} else if amount > clearly_trivial_threshold {
Stratum::BelowPerformanceMateriality
} else {
Stratum::ClearlyTrivial
}
}
pub fn validate_sampling(
entries: &[JournalEntry],
materiality: Decimal,
performance_materiality: Decimal,
) -> SamplingValidationResult {
let total_population = entries.len();
let strata_order = [
Stratum::AboveMateriality,
Stratum::BetweenPerformanceAndOverall,
Stratum::BelowPerformanceMateriality,
Stratum::ClearlyTrivial,
];
let mut counts = [0usize; 4];
let mut totals = [Decimal::ZERO; 4];
let mut anomaly_counts = [0usize; 4];
let mut all_entities: HashSet<String> = HashSet::new();
let mut anomaly_entities: HashSet<String> = HashSet::new();
let mut all_periods: HashSet<(u16, u8)> = HashSet::new();
let mut anomaly_periods: HashSet<(u16, u8)> = HashSet::new();
for entry in entries {
let amount = entry_amount(entry);
let stratum = classify(amount, materiality, performance_materiality);
let idx = match stratum {
Stratum::AboveMateriality => 0,
Stratum::BetweenPerformanceAndOverall => 1,
Stratum::BelowPerformanceMateriality => 2,
Stratum::ClearlyTrivial => 3,
};
counts[idx] += 1;
totals[idx] += amount;
let entity_key = entry.header.company_code.clone();
let period_key = (entry.header.fiscal_year, entry.header.fiscal_period);
all_entities.insert(entity_key.clone());
all_periods.insert(period_key);
if is_anomalous(entry) {
anomaly_counts[idx] += 1;
anomaly_entities.insert(entity_key);
anomaly_periods.insert(period_key);
}
}
let strata: Vec<StratumResult> = strata_order
.iter()
.enumerate()
.map(|(i, stratum)| {
let count = counts[i];
let anomaly_count = anomaly_counts[i];
let anomaly_rate = if count > 0 {
anomaly_count as f64 / count as f64
} else {
0.0
};
StratumResult {
stratum: stratum.clone(),
item_count: count,
total_amount: totals[i],
anomaly_count,
anomaly_rate,
}
})
.collect();
let above_mat_count = counts[0];
let above_mat_anomaly = anomaly_counts[0];
let above_materiality_coverage = if above_mat_count > 0 {
above_mat_anomaly as f64 / above_mat_count as f64
} else {
1.0
};
let non_trivial_strata = 3usize; let strata_with_anomalies = anomaly_counts[0..3].iter().filter(|&&c| c > 0).count();
let anomaly_stratum_coverage = if non_trivial_strata > 0 {
strata_with_anomalies as f64 / non_trivial_strata as f64
} else {
1.0
};
let entity_coverage = if all_entities.is_empty() {
1.0
} else {
anomaly_entities.len() as f64 / all_entities.len() as f64
};
let temporal_coverage = if all_periods.is_empty() {
1.0
} else {
anomaly_periods.len() as f64 / all_periods.len() as f64
};
let passes = above_materiality_coverage >= 0.95;
SamplingValidationResult {
total_population,
strata,
above_materiality_coverage,
anomaly_stratum_coverage,
entity_coverage,
temporal_coverage,
passes,
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use datasynth_core::models::{JournalEntry, JournalEntryHeader, JournalEntryLine};
use rust_decimal_macros::dec;
fn date(y: i32, m: u32, d: u32) -> chrono::NaiveDate {
chrono::NaiveDate::from_ymd_opt(y, m, d).unwrap()
}
fn make_entry(amount: Decimal, anomaly: bool, company: &str, period: u8) -> JournalEntry {
let posting_date = date(2024, period as u32, 1);
let mut header = JournalEntryHeader::new(company.to_string(), posting_date);
header.fiscal_period = period;
header.is_anomaly = anomaly;
let doc_id = header.document_id;
let mut entry = JournalEntry::new(header);
entry.add_line(JournalEntryLine::debit(
doc_id,
1,
"6000".to_string(),
amount,
));
entry.add_line(JournalEntryLine::credit(
doc_id,
2,
"2000".to_string(),
amount,
));
entry
}
#[test]
fn test_stratum_classification() {
let mat = dec!(100_000);
let perf = dec!(60_000);
assert_eq!(
classify(dec!(200_000), mat, perf),
Stratum::AboveMateriality
);
assert_eq!(
classify(dec!(100_001), mat, perf),
Stratum::AboveMateriality
);
assert_eq!(
classify(dec!(80_000), mat, perf),
Stratum::BetweenPerformanceAndOverall
);
assert_eq!(
classify(dec!(60_001), mat, perf),
Stratum::BetweenPerformanceAndOverall
);
assert_eq!(
classify(dec!(10_000), mat, perf),
Stratum::BelowPerformanceMateriality
);
assert_eq!(classify(dec!(1_000), mat, perf), Stratum::ClearlyTrivial);
assert_eq!(classify(dec!(0), mat, perf), Stratum::ClearlyTrivial);
}
#[test]
fn test_empty_entries() {
let result = validate_sampling(&[], dec!(100_000), dec!(60_000));
assert_eq!(result.total_population, 0);
assert!(result.passes);
assert!((result.above_materiality_coverage - 1.0).abs() < 1e-9);
}
#[test]
fn test_above_materiality_coverage_full() {
let entries: Vec<JournalEntry> = (0..5)
.map(|_| make_entry(dec!(200_000), true, "C001", 1))
.collect();
let result = validate_sampling(&entries, dec!(100_000), dec!(60_000));
assert!((result.above_materiality_coverage - 1.0).abs() < 1e-9);
assert!(result.passes);
}
#[test]
fn test_above_materiality_coverage_zero() {
let entries: Vec<JournalEntry> = (0..5)
.map(|_| make_entry(dec!(200_000), false, "C001", 1))
.collect();
let result = validate_sampling(&entries, dec!(100_000), dec!(60_000));
assert!((result.above_materiality_coverage - 0.0).abs() < 1e-9);
assert!(!result.passes);
}
#[test]
fn test_entity_coverage() {
let mut entries = vec![
make_entry(dec!(50_000), true, "C001", 1),
make_entry(dec!(50_000), false, "C002", 1),
];
entries.push(make_entry(dec!(200_000), true, "C001", 1));
let result = validate_sampling(&entries, dec!(100_000), dec!(60_000));
assert!((result.entity_coverage - 0.5).abs() < 1e-9);
assert!(result.passes);
}
#[test]
fn test_temporal_coverage() {
let mut entries: Vec<JournalEntry> = Vec::new();
entries.push(make_entry(dec!(200_000), true, "C001", 1));
entries.push(make_entry(dec!(50_000), true, "C001", 2));
entries.push(make_entry(dec!(50_000), false, "C001", 3));
let result = validate_sampling(&entries, dec!(100_000), dec!(60_000));
assert!((result.temporal_coverage - 2.0 / 3.0).abs() < 1e-9);
assert!(result.passes);
}
#[test]
fn test_stratum_counts() {
let entries = vec![
make_entry(dec!(200_000), true, "C001", 1), make_entry(dec!(80_000), false, "C001", 2), make_entry(dec!(10_000), false, "C001", 3), make_entry(dec!(500), false, "C001", 4), ];
let result = validate_sampling(&entries, dec!(100_000), dec!(60_000));
assert_eq!(result.total_population, 4);
let above = result
.strata
.iter()
.find(|s| s.stratum == Stratum::AboveMateriality)
.unwrap();
let between = result
.strata
.iter()
.find(|s| s.stratum == Stratum::BetweenPerformanceAndOverall)
.unwrap();
let below = result
.strata
.iter()
.find(|s| s.stratum == Stratum::BelowPerformanceMateriality)
.unwrap();
let trivial = result
.strata
.iter()
.find(|s| s.stratum == Stratum::ClearlyTrivial)
.unwrap();
assert_eq!(above.item_count, 1);
assert_eq!(between.item_count, 1);
assert_eq!(below.item_count, 1);
assert_eq!(trivial.item_count, 1);
}
}