use chrono::{Datelike, NaiveDate};
use rand::prelude::*;
use rand_chacha::ChaCha8Rng;
use rust_decimal::Decimal;
use rust_decimal_macros::dec;
use datasynth_core::accounts::{expense_accounts, tax_accounts};
use datasynth_core::models::deferred_tax::{
DeferredTaxRollforward, DeferredTaxType, PermanentDifference, TaxRateReconciliation,
TemporaryDifference,
};
use datasynth_core::models::journal_entry::{
BusinessProcess, JournalEntry, JournalEntryHeader, JournalEntryLine, TransactionSource,
};
use datasynth_core::utils::seeded_rng;
#[derive(Debug, Clone, Default)]
pub struct DeferredTaxSnapshot {
pub temporary_differences: Vec<TemporaryDifference>,
pub etr_reconciliations: Vec<TaxRateReconciliation>,
pub rollforwards: Vec<DeferredTaxRollforward>,
pub journal_entries: Vec<JournalEntry>,
}
pub struct DeferredTaxGenerator {
rng: ChaCha8Rng,
counter: u64,
}
impl DeferredTaxGenerator {
pub fn new(seed: u64) -> Self {
Self {
rng: seeded_rng(seed, 88),
counter: 0,
}
}
pub fn generate(
&mut self,
companies: &[(&str, &str)],
posting_date: NaiveDate,
journal_entries: &[JournalEntry],
) -> DeferredTaxSnapshot {
let mut snapshot = DeferredTaxSnapshot::default();
for &(company_code, country_code) in companies {
let statutory_rate = Self::statutory_rate(country_code);
let pre_tax_income = self.estimate_pre_tax_income(company_code, journal_entries);
let total_assets =
Self::compute_total_assets(company_code, journal_entries).max(dec!(1_000_000));
let total_revenue =
Self::compute_total_revenue(company_code, journal_entries).max(dec!(500_000));
let period_label = format!("FY{}", posting_date.year());
let is_ifrs = !matches!(country_code.to_uppercase().as_str(), "US");
let diffs = self.generate_temp_diffs_with_framework(
company_code,
pre_tax_income,
total_assets,
total_revenue,
is_ifrs,
);
let (dta, dtl) = compute_dta_dtl(&diffs, statutory_rate);
let etr = self.build_etr_reconciliation(
company_code,
&period_label,
pre_tax_income,
statutory_rate,
);
let rollforward = self.build_rollforward(company_code, &period_label, dta, dtl);
let jes = self.build_journal_entries(company_code, posting_date, dta, dtl);
snapshot.temporary_differences.extend(diffs);
snapshot.etr_reconciliations.push(etr);
snapshot.rollforwards.push(rollforward);
snapshot.journal_entries.extend(jes);
}
snapshot
}
fn generate_temp_diffs_with_framework(
&mut self,
entity_code: &str,
_pre_tax_income: Decimal,
total_assets: Decimal,
revenue_proxy: Decimal,
is_ifrs: bool,
) -> Vec<TemporaryDifference> {
let std_depreciation = if is_ifrs { "IAS 16" } else { "ASC 360" };
let std_accruals = if is_ifrs { "IAS 37" } else { "ASC 450" };
let std_receivables = if is_ifrs { "IFRS 9" } else { "ASC 310" };
let std_inventory = if is_ifrs { "IAS 2" } else { "ASC 330" };
let std_leases = if is_ifrs { "IFRS 16" } else { "ASC 842" };
let std_warranty = if is_ifrs { "IAS 37" } else { "ASC 460" };
let std_sbc = if is_ifrs { "IFRS 2" } else { "ASC 718" };
let std_rd = if is_ifrs { "IAS 38" } else { "ASC 730" };
let templates: Vec<(&str, &str, DeferredTaxType, Option<&str>, Decimal, Decimal)> = vec![
(
"Accelerated depreciation (MACRS / capital allowances)",
datasynth_core::accounts::control_accounts::FIXED_ASSETS,
DeferredTaxType::Liability,
Some(std_depreciation),
total_assets * self.rand_decimal(dec!(0.08), dec!(0.12)),
total_assets * self.rand_decimal(dec!(0.05), dec!(0.09)),
),
(
"Accrued expenses (deductible when paid)",
datasynth_core::accounts::liability_accounts::ACCRUED_EXPENSES,
DeferredTaxType::Asset,
Some(std_accruals),
revenue_proxy * self.rand_decimal(dec!(0.01), dec!(0.03)),
Decimal::ZERO,
),
(
"Allowance for doubtful accounts",
datasynth_core::accounts::control_accounts::AR_CONTROL,
DeferredTaxType::Asset,
Some(std_receivables),
revenue_proxy * self.rand_decimal(dec!(0.005), dec!(0.015)),
Decimal::ZERO,
),
(
"Inventory write-down (LCM / NRV)",
datasynth_core::accounts::control_accounts::INVENTORY,
DeferredTaxType::Asset,
Some(std_inventory),
total_assets * self.rand_decimal(dec!(0.005), dec!(0.015)),
Decimal::ZERO,
),
(
"Right-of-use asset – operating lease (tax: rental deduction)",
expense_accounts::RENT,
DeferredTaxType::Liability,
Some(std_leases),
total_assets * self.rand_decimal(dec!(0.02), dec!(0.06)),
Decimal::ZERO,
),
(
"Warranty provision (deductible when paid)",
datasynth_core::accounts::liability_accounts::ACCRUED_EXPENSES,
DeferredTaxType::Asset,
Some(std_warranty),
revenue_proxy * self.rand_decimal(dec!(0.003), dec!(0.010)),
Decimal::ZERO,
),
(
"Share-based compensation (book > tax until exercise)",
datasynth_core::accounts::expense_accounts::BENEFITS,
DeferredTaxType::Asset,
Some(std_sbc),
revenue_proxy * self.rand_decimal(dec!(0.005), dec!(0.012)),
Decimal::ZERO,
),
(
"Capitalised development costs (expensed for tax)",
datasynth_core::accounts::control_accounts::FIXED_ASSETS,
DeferredTaxType::Liability,
Some(std_rd),
total_assets * self.rand_decimal(dec!(0.01), dec!(0.04)),
Decimal::ZERO,
),
];
let n = self.rng.random_range(5usize..=8);
let mut indices: Vec<usize> = (0..templates.len()).collect();
indices.shuffle(&mut self.rng);
indices.truncate(n);
indices.sort();
indices
.iter()
.map(|&i| {
let (desc, account, dtype, standard, book, tax) = &templates[i];
let book = book.round_dp(2);
let tax = tax.round_dp(2);
let difference = (book - tax).round_dp(2);
self.counter += 1;
TemporaryDifference {
id: format!("TDIFF-{entity_code}-{:05}", self.counter),
entity_code: entity_code.to_string(),
account: account.to_string(),
description: desc.to_string(),
book_basis: book,
tax_basis: tax,
difference,
deferred_type: *dtype,
originating_standard: standard.map(|s| s.to_string()),
}
})
.collect()
}
fn build_etr_reconciliation(
&mut self,
entity_code: &str,
period: &str,
pre_tax_income: Decimal,
statutory_rate: Decimal,
) -> TaxRateReconciliation {
let expected_tax = (pre_tax_income * statutory_rate).round_dp(2);
let perm_diff_templates: Vec<(&str, Decimal, Decimal)> = vec![
(
"Meals & entertainment (50% non-deductible)",
pre_tax_income * self.rand_decimal(dec!(0.002), dec!(0.006)),
pre_tax_income * self.rand_decimal(dec!(0.001), dec!(0.003)) * statutory_rate,
),
(
"Tax-exempt municipal bond interest",
-(pre_tax_income * self.rand_decimal(dec!(0.005), dec!(0.015))),
-(pre_tax_income * self.rand_decimal(dec!(0.005), dec!(0.015)) * statutory_rate),
),
(
"Non-deductible fines & penalties",
pre_tax_income * self.rand_decimal(dec!(0.001), dec!(0.003)),
pre_tax_income * self.rand_decimal(dec!(0.001), dec!(0.003)) * statutory_rate,
),
(
"Research & development tax credits",
-(pre_tax_income * self.rand_decimal(dec!(0.005), dec!(0.020))),
-(pre_tax_income * self.rand_decimal(dec!(0.005), dec!(0.020))),
),
(
"Stock-based compensation – excess tax benefit",
-(pre_tax_income * self.rand_decimal(dec!(0.002), dec!(0.008))),
-(pre_tax_income * self.rand_decimal(dec!(0.002), dec!(0.008)) * statutory_rate),
),
(
"Foreign-derived intangible income (FDII) deduction",
-(pre_tax_income * self.rand_decimal(dec!(0.003), dec!(0.010))),
-(pre_tax_income * self.rand_decimal(dec!(0.003), dec!(0.010)) * statutory_rate),
),
(
"Officer compensation in excess of §162(m) limit",
pre_tax_income * self.rand_decimal(dec!(0.001), dec!(0.004)),
pre_tax_income * self.rand_decimal(dec!(0.001), dec!(0.004)) * statutory_rate,
),
];
let n = self.rng.random_range(3usize..=5);
let mut indices: Vec<usize> = (0..perm_diff_templates.len()).collect();
indices.shuffle(&mut self.rng);
indices.truncate(n);
indices.sort();
let permanent_differences: Vec<PermanentDifference> = indices
.iter()
.map(|&i| {
let (desc, amount, tax_effect) = &perm_diff_templates[i];
PermanentDifference {
description: desc.to_string(),
amount: amount.round_dp(2),
tax_effect: tax_effect.round_dp(2),
}
})
.collect();
let total_perm_effect: Decimal = permanent_differences.iter().map(|p| p.tax_effect).sum();
let actual_tax = (expected_tax + total_perm_effect).round_dp(2);
let effective_rate = if pre_tax_income != Decimal::ZERO {
(actual_tax / pre_tax_income).round_dp(6)
} else {
statutory_rate
};
TaxRateReconciliation {
entity_code: entity_code.to_string(),
period: period.to_string(),
pre_tax_income: pre_tax_income.round_dp(2),
statutory_rate,
expected_tax,
permanent_differences,
effective_rate,
actual_tax,
}
}
fn build_rollforward(
&mut self,
entity_code: &str,
period: &str,
closing_dta: Decimal,
closing_dtl: Decimal,
) -> DeferredTaxRollforward {
self.build_rollforward_with_prior(entity_code, period, closing_dta, closing_dtl, None, None)
}
fn build_rollforward_with_prior(
&mut self,
entity_code: &str,
period: &str,
closing_dta: Decimal,
closing_dtl: Decimal,
prior_closing_dta: Option<Decimal>,
prior_closing_dtl: Option<Decimal>,
) -> DeferredTaxRollforward {
let opening_dta = prior_closing_dta.unwrap_or(Decimal::ZERO);
let opening_dtl = prior_closing_dtl.unwrap_or(Decimal::ZERO);
let current_year_movement = (closing_dta - opening_dta) - (closing_dtl - opening_dtl);
DeferredTaxRollforward {
entity_code: entity_code.to_string(),
period: period.to_string(),
opening_dta,
opening_dtl,
current_year_movement: current_year_movement.round_dp(2),
closing_dta: closing_dta.round_dp(2),
closing_dtl: closing_dtl.round_dp(2),
}
}
fn build_journal_entries(
&mut self,
company_code: &str,
posting_date: NaiveDate,
dta: Decimal,
dtl: Decimal,
) -> Vec<JournalEntry> {
let mut jes = Vec::new();
if dtl > Decimal::ZERO {
self.counter += 1;
let mut header = JournalEntryHeader::new(company_code.to_string(), posting_date);
header.document_type = "TAX_DEFERRED".to_string();
header.created_by = "DEFERRED_TAX_ENGINE".to_string();
header.source = TransactionSource::Automated;
header.business_process = Some(BusinessProcess::R2R);
header.header_text = Some(format!(
"Deferred tax liability – period {}",
posting_date.format("%Y-%m")
));
let doc_id = header.document_id;
let mut je = JournalEntry::new(header);
je.add_line(JournalEntryLine::debit(
doc_id,
1,
tax_accounts::TAX_EXPENSE.to_string(),
dtl,
));
je.add_line(JournalEntryLine::credit(
doc_id,
2,
tax_accounts::DEFERRED_TAX_LIABILITY.to_string(),
dtl,
));
jes.push(je);
}
if dta > Decimal::ZERO {
self.counter += 1;
let mut header = JournalEntryHeader::new(company_code.to_string(), posting_date);
header.document_type = "TAX_DEFERRED".to_string();
header.created_by = "DEFERRED_TAX_ENGINE".to_string();
header.source = TransactionSource::Automated;
header.business_process = Some(BusinessProcess::R2R);
header.header_text = Some(format!(
"Deferred tax asset – period {}",
posting_date.format("%Y-%m")
));
let doc_id = header.document_id;
let mut je = JournalEntry::new(header);
je.add_line(JournalEntryLine::debit(
doc_id,
1,
tax_accounts::DEFERRED_TAX_ASSET.to_string(),
dta,
));
je.add_line(JournalEntryLine::credit(
doc_id,
2,
tax_accounts::TAX_EXPENSE.to_string(),
dta,
));
jes.push(je);
}
jes
}
fn compute_total_assets(company_code: &str, journal_entries: &[JournalEntry]) -> Decimal {
use datasynth_core::accounts::AccountCategory;
let mut net = Decimal::ZERO;
for je in journal_entries {
if je.header.company_code != company_code {
continue;
}
for line in &je.lines {
if matches!(
AccountCategory::from_account(&line.gl_account),
AccountCategory::Asset
) {
net += line.debit_amount - line.credit_amount;
}
}
}
net.abs()
}
fn compute_total_revenue(company_code: &str, journal_entries: &[JournalEntry]) -> Decimal {
use datasynth_core::accounts::AccountCategory;
let mut revenue = Decimal::ZERO;
for je in journal_entries {
if je.header.company_code != company_code {
continue;
}
for line in &je.lines {
if matches!(
AccountCategory::from_account(&line.gl_account),
AccountCategory::Revenue
) {
revenue += line.credit_amount - line.debit_amount;
}
}
}
revenue.max(Decimal::ZERO)
}
fn statutory_rate(country_code: &str) -> Decimal {
match country_code.to_uppercase().as_str() {
"US" => dec!(0.21),
"DE" => dec!(0.30),
"GB" | "UK" => dec!(0.25),
"FR" => dec!(0.25),
"NL" => dec!(0.258),
"IE" => dec!(0.125),
"CH" => dec!(0.15),
"CA" => dec!(0.265),
"AU" => dec!(0.30),
"JP" => dec!(0.2928),
"SG" => dec!(0.17),
"CN" => dec!(0.25),
"IN" => dec!(0.2517),
"BR" => dec!(0.34),
_ => dec!(0.21), }
}
fn estimate_pre_tax_income(
&self,
company_code: &str,
journal_entries: &[JournalEntry],
) -> Decimal {
use datasynth_core::accounts::AccountCategory;
let mut revenue = Decimal::ZERO;
let mut expenses = Decimal::ZERO;
for je in journal_entries {
if je.header.company_code != company_code {
continue;
}
for line in &je.lines {
let cat = AccountCategory::from_account(&line.gl_account);
match cat {
AccountCategory::Revenue => {
revenue += line.credit_amount;
revenue -= line.debit_amount;
}
AccountCategory::Cogs
| AccountCategory::OperatingExpense
| AccountCategory::OtherIncomeExpense => {
expenses += line.debit_amount;
expenses -= line.credit_amount;
}
_ => {}
}
}
}
let pti = (revenue - expenses).round_dp(2);
if pti == Decimal::ZERO {
dec!(1_000_000)
} else {
pti
}
}
fn rand_decimal(&mut self, min: Decimal, max: Decimal) -> Decimal {
let range: f64 = (max - min).to_string().parse().unwrap_or(0.0);
let min_f: f64 = min.to_string().parse().unwrap_or(0.0);
let v = min_f + self.rng.random::<f64>() * range;
Decimal::try_from(v).unwrap_or(min).round_dp(6)
}
}
pub fn compute_dta_dtl(
diffs: &[TemporaryDifference],
statutory_rate: Decimal,
) -> (Decimal, Decimal) {
let mut dta = Decimal::ZERO;
let mut dtl = Decimal::ZERO;
for d in diffs {
let effect = (d.difference.abs() * statutory_rate).round_dp(2);
match d.deferred_type {
DeferredTaxType::Asset => dta += effect,
DeferredTaxType::Liability => dtl += effect,
}
}
(dta.round_dp(2), dtl.round_dp(2))
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use rust_decimal_macros::dec;
fn sample_date() -> NaiveDate {
NaiveDate::from_ymd_opt(2024, 12, 31).unwrap()
}
#[test]
fn test_generate_returns_data_for_each_company() {
let mut gen = DeferredTaxGenerator::new(42);
let companies = vec![("C001", "US"), ("C002", "DE")];
let snapshot = gen.generate(&companies, sample_date(), &[]);
assert!(
snapshot.temporary_differences.len() >= 2 * 5,
"Expected ≥10 temp diffs, got {}",
snapshot.temporary_differences.len()
);
assert_eq!(snapshot.etr_reconciliations.len(), 2);
assert_eq!(snapshot.rollforwards.len(), 2);
}
#[test]
fn test_dta_dtl_computation() {
let diffs = vec![
TemporaryDifference {
id: "T1".into(),
entity_code: "C001".into(),
account: "1500".into(),
description: "Depreciation".into(),
book_basis: dec!(100_000),
tax_basis: dec!(80_000),
difference: dec!(20_000),
deferred_type: DeferredTaxType::Liability,
originating_standard: None,
},
TemporaryDifference {
id: "T2".into(),
entity_code: "C001".into(),
account: "2200".into(),
description: "Accruals".into(),
book_basis: dec!(50_000),
tax_basis: Decimal::ZERO,
difference: dec!(50_000),
deferred_type: DeferredTaxType::Asset,
originating_standard: None,
},
];
let (dta, dtl) = compute_dta_dtl(&diffs, dec!(0.21));
assert_eq!(dtl, (dec!(20_000) * dec!(0.21)).round_dp(2));
assert_eq!(dta, (dec!(50_000) * dec!(0.21)).round_dp(2));
}
#[test]
fn test_etr_reconciliation_math() {
let mut gen = DeferredTaxGenerator::new(7);
let companies = vec![("C001", "US")];
let snap = gen.generate(&companies, sample_date(), &[]);
let etr = &snap.etr_reconciliations[0];
let expected_tax = (etr.pre_tax_income * etr.statutory_rate).round_dp(2);
assert_eq!(etr.expected_tax, expected_tax, "expected_tax mismatch");
let total_perm: Decimal = etr.permanent_differences.iter().map(|p| p.tax_effect).sum();
let expected_actual = (expected_tax + total_perm).round_dp(2);
assert_eq!(etr.actual_tax, expected_actual, "actual_tax mismatch");
if etr.pre_tax_income != Decimal::ZERO {
let expected_etr = (etr.actual_tax / etr.pre_tax_income).round_dp(6);
assert_eq!(etr.effective_rate, expected_etr, "effective_rate mismatch");
}
}
#[test]
fn test_rollforward_opening_plus_movement_equals_closing() {
let mut gen = DeferredTaxGenerator::new(13);
let snap = gen.generate(&[("C001", "GB")], sample_date(), &[]);
let rf = &snap.rollforwards[0];
let implied_movement =
(rf.closing_dta - rf.opening_dta) - (rf.closing_dtl - rf.opening_dtl);
assert_eq!(
rf.current_year_movement, implied_movement,
"Rollforward movement check failed"
);
}
#[test]
fn test_journal_entries_are_balanced() {
let mut gen = DeferredTaxGenerator::new(42);
let snap = gen.generate(&[("C001", "US")], sample_date(), &[]);
for je in &snap.journal_entries {
let total_debit: Decimal = je.lines.iter().map(|l| l.debit_amount).sum();
let total_credit: Decimal = je.lines.iter().map(|l| l.credit_amount).sum();
assert_eq!(
total_debit, total_credit,
"JE {} is not balanced: debits={}, credits={}",
je.header.document_id, total_debit, total_credit
);
}
}
#[test]
fn test_journal_entries_have_tax_document_type() {
let mut gen = DeferredTaxGenerator::new(42);
let snap = gen.generate(&[("C001", "US"), ("C002", "DE")], sample_date(), &[]);
for je in &snap.journal_entries {
assert!(
je.header.document_type.contains("TAX"),
"Expected document_type to contain 'TAX', got '{}'",
je.header.document_type
);
}
}
#[test]
fn test_deterministic() {
let companies = vec![("C001", "US")];
let mut gen1 = DeferredTaxGenerator::new(99);
let snap1 = gen1.generate(&companies, sample_date(), &[]);
let mut gen2 = DeferredTaxGenerator::new(99);
let snap2 = gen2.generate(&companies, sample_date(), &[]);
assert_eq!(
snap1.temporary_differences.len(),
snap2.temporary_differences.len()
);
assert_eq!(
snap1.etr_reconciliations[0].actual_tax,
snap2.etr_reconciliations[0].actual_tax
);
assert_eq!(
snap1.rollforwards[0].closing_dta,
snap2.rollforwards[0].closing_dta
);
}
}