use datasynth_config::schema::ExternalExpectationsConfig;
use datasynth_core::models::{AccountType, ExpectationDriver, ExternalExpectation};
use datasynth_core::utils::seeded_rng;
use datasynth_core::uuid_factory::{DeterministicUuidFactory, GeneratorType};
use rand_chacha::ChaCha8Rng;
use rand_distr::{Distribution, Normal};
use rust_decimal::prelude::{FromPrimitive, ToPrimitive};
use rust_decimal::Decimal;
#[derive(Debug, Clone)]
pub struct AccountActuals {
pub account_code: String,
pub account_description: String,
pub account_type: AccountType,
pub actual_total: Decimal,
pub legit_total: Decimal,
}
pub struct ExternalExpectationsGenerator {
rng: ChaCha8Rng,
uuid_factory: DeterministicUuidFactory,
}
impl ExternalExpectationsGenerator {
pub fn new(seed: u64) -> Self {
Self {
rng: seeded_rng(seed, 0),
uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::ExternalExpectation),
}
}
pub fn generate(
&mut self,
company_code: &str,
fiscal_year: i32,
accounts: &[AccountActuals],
config: &ExternalExpectationsConfig,
) -> Vec<ExternalExpectation> {
let grand_legit: Decimal = accounts.iter().map(|a| a.legit_total).sum();
if grand_legit <= Decimal::ZERO {
return Vec::new();
}
let min_legit = grand_legit
* Decimal::from_f64(config.min_materiality_share.max(0.0)).unwrap_or(Decimal::ZERO);
let noise = Normal::new(0.0, config.forecast_noise.max(1e-9)).expect("valid normal params");
let tol = config.tolerance_pct.max(1e-9);
let mut out = Vec::new();
for a in accounts {
if a.legit_total < min_legit {
continue;
}
let legit = a.legit_total.to_f64().unwrap_or(0.0);
let forecast_err: f64 = noise.sample(&mut self.rng);
let expected_f = (legit * (1.0 + forecast_err)).max(0.0);
let expected = Decimal::from_f64(expected_f)
.unwrap_or(Decimal::ZERO)
.round_dp(2);
let band = (expected_f * tol).max(1.0);
let actual = a.actual_total;
let actual_f = actual.to_f64().unwrap_or(0.0);
let deviation_f = actual_f - expected_f;
let fraud_inflation = a.actual_total - a.legit_total;
let (driver_value, basis) =
driver_view(config.driver, legit, config.growth_rate, expected);
out.push(ExternalExpectation {
expectation_id: self.uuid_factory.next().to_string(),
company_code: company_code.to_string(),
account_code: a.account_code.clone(),
account_description: a.account_description.clone(),
account_type: a.account_type,
fiscal_year,
driver: config.driver,
basis,
driver_value,
expected_value: expected,
tolerance_pct: config.tolerance_pct,
lower_bound: Decimal::from_f64(expected_f - band)
.unwrap_or(Decimal::ZERO)
.round_dp(2),
upper_bound: Decimal::from_f64(expected_f + band)
.unwrap_or(Decimal::ZERO)
.round_dp(2),
actual_value: actual,
deviation: Decimal::from_f64(deviation_f)
.unwrap_or(Decimal::ZERO)
.round_dp(2),
deviation_ratio: deviation_f / band,
exceeds_band: deviation_f.abs() > band,
fraud_inflation,
is_fraud_inflated: fraud_inflation > Decimal::ZERO,
});
}
out
}
}
fn driver_view(
driver: ExpectationDriver,
legit: f64,
growth: f64,
expected: Decimal,
) -> (Decimal, String) {
match driver {
ExpectationDriver::PriorYear => {
let prior = if (1.0 + growth).abs() > 1e-9 {
legit / (1.0 + growth)
} else {
legit
};
let pv = Decimal::from_f64(prior)
.unwrap_or(Decimal::ZERO)
.round_dp(2);
(
pv,
format!("prior-year actual {pv} grown at {:.1}%", growth * 100.0),
)
}
ExpectationDriver::MarketIndex => (
expected,
"market/industry index, sensitivity calibrated to the legitimate level".to_string(),
),
ExpectationDriver::MacroSeries => (
expected,
"macroeconomic series, sensitivity calibrated to the legitimate level".to_string(),
),
ExpectationDriver::Budget => (expected, "budgeted amount for the account".to_string()),
}
}
#[cfg(test)]
mod tests {
use super::*;
fn cfg() -> ExternalExpectationsConfig {
ExternalExpectationsConfig {
enabled: true,
driver: ExpectationDriver::PriorYear,
tolerance_pct: 0.10,
forecast_noise: 0.04,
growth_rate: 0.05,
min_materiality_share: 0.005,
}
}
fn acct(code: &str, actual: i64, legit: i64) -> AccountActuals {
AccountActuals {
account_code: code.to_string(),
account_description: format!("Account {code}"),
account_type: AccountType::Revenue,
actual_total: Decimal::from(actual),
legit_total: Decimal::from(legit),
}
}
#[test]
fn fraud_inflation_breaches_band_clean_does_not() {
let mut g = ExternalExpectationsGenerator::new(7);
let accounts = vec![
acct("4000", 1_000_000, 1_000_000), acct("4010", 3_000_000, 1_000_000), ];
let exps = g.generate("1000", 2024, &accounts, &cfg());
assert_eq!(exps.len(), 2);
let clean = exps.iter().find(|e| e.account_code == "4000").unwrap();
let fraud = exps.iter().find(|e| e.account_code == "4010").unwrap();
assert!(!clean.is_fraud_inflated);
assert!(
!clean.exceeds_band,
"clean account should sit in band, dev_ratio={}",
clean.deviation_ratio
);
assert!(fraud.is_fraud_inflated);
assert_eq!(fraud.fraud_inflation, Decimal::from(2_000_000));
assert!(
fraud.exceeds_band,
"fraud-inflated account must breach the band"
);
assert!(fraud.deviation_ratio > 1.0);
assert!(fraud.expected_value > Decimal::ZERO);
assert!(fraud.lower_bound < fraud.upper_bound);
}
#[test]
fn immaterial_accounts_are_skipped() {
let mut g = ExternalExpectationsGenerator::new(1);
let accounts = vec![
acct("4000", 10_000_000, 10_000_000),
acct("9999", 100, 100), ];
let exps = g.generate("1000", 2024, &accounts, &cfg());
assert!(exps.iter().all(|e| e.account_code != "9999"));
assert!(exps.iter().any(|e| e.account_code == "4000"));
}
#[test]
fn deterministic() {
let accounts = vec![
acct("4000", 2_000_000, 1_000_000),
acct("4010", 900_000, 900_000),
];
let a = ExternalExpectationsGenerator::new(42).generate("1000", 2024, &accounts, &cfg());
let b = ExternalExpectationsGenerator::new(42).generate("1000", 2024, &accounts, &cfg());
assert_eq!(a.len(), b.len());
for (x, y) in a.iter().zip(b.iter()) {
assert_eq!(x.expected_value, y.expected_value);
assert_eq!(x.deviation, y.deviation);
assert_eq!(x.exceeds_band, y.exceeds_band);
}
}
#[test]
fn empty_is_safe() {
let mut g = ExternalExpectationsGenerator::new(3);
assert!(g.generate("1000", 2024, &[], &cfg()).is_empty());
let zero = vec![acct("4000", 0, 0)];
assert!(g.generate("1000", 2024, &zero, &cfg()).is_empty());
}
}