use chrono::NaiveDate;
use datasynth_config::schema::PayrollConfig;
use datasynth_core::country::schema::TaxBracket;
use datasynth_core::models::{PayrollLineItem, PayrollRun, PayrollRunStatus};
use datasynth_core::utils::{sample_decimal_range, seeded_rng};
use datasynth_core::uuid_factory::{DeterministicUuidFactory, GeneratorType};
use datasynth_core::CountryPack;
use rand::prelude::*;
use rand_chacha::ChaCha8Rng;
use rust_decimal::Decimal;
use tracing::debug;
#[derive(Debug, Clone)]
struct PayrollRates {
income_tax_rate: Decimal,
income_tax_brackets: Vec<TaxBracket>,
fica_rate: Decimal,
health_rate: Decimal,
retirement_rate: Decimal,
employer_fica_rate: Decimal,
}
#[derive(Debug, Clone, Default)]
struct DeductionLabels {
tax_withholding: Option<String>,
social_security: Option<String>,
health_insurance: Option<String>,
retirement_contribution: Option<String>,
employer_contribution: Option<String>,
}
pub struct PayrollGenerator {
rng: ChaCha8Rng,
uuid_factory: DeterministicUuidFactory,
line_uuid_factory: DeterministicUuidFactory,
config: PayrollConfig,
country_pack: Option<CountryPack>,
employee_ids_pool: Vec<String>,
cost_center_ids_pool: Vec<String>,
}
impl PayrollGenerator {
pub fn new(seed: u64) -> Self {
Self {
rng: seeded_rng(seed, 0),
uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::PayrollRun),
line_uuid_factory: DeterministicUuidFactory::with_sub_discriminator(
seed,
GeneratorType::PayrollRun,
1,
),
config: PayrollConfig::default(),
country_pack: None,
employee_ids_pool: Vec::new(),
cost_center_ids_pool: Vec::new(),
}
}
pub fn with_config(seed: u64, config: PayrollConfig) -> Self {
Self {
rng: seeded_rng(seed, 0),
uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::PayrollRun),
line_uuid_factory: DeterministicUuidFactory::with_sub_discriminator(
seed,
GeneratorType::PayrollRun,
1,
),
config,
country_pack: None,
employee_ids_pool: Vec::new(),
cost_center_ids_pool: Vec::new(),
}
}
pub fn with_pools(mut self, employee_ids: Vec<String>, cost_center_ids: Vec<String>) -> Self {
self.employee_ids_pool = employee_ids;
self.cost_center_ids_pool = cost_center_ids;
self
}
pub fn set_country_pack(&mut self, pack: CountryPack) {
self.country_pack = Some(pack);
}
pub fn generate(
&mut self,
company_code: &str,
employees: &[(String, Decimal, Option<String>, Option<String>)],
period_start: NaiveDate,
period_end: NaiveDate,
currency: &str,
) -> (PayrollRun, Vec<PayrollLineItem>) {
debug!(company_code, employee_count = employees.len(), %period_start, %period_end, currency, "Generating payroll run");
if let Some(pack) = self.country_pack.as_ref() {
let rates = self.rates_from_country_pack(pack);
let labels = Self::labels_from_country_pack(pack);
self.generate_with_rates_and_labels(
company_code,
employees,
period_start,
period_end,
currency,
&rates,
&labels,
)
} else {
let rates = self.rates_from_config();
self.generate_with_rates_and_labels(
company_code,
employees,
period_start,
period_end,
currency,
&rates,
&DeductionLabels::default(),
)
}
}
pub fn generate_with_country_pack(
&mut self,
company_code: &str,
employees: &[(String, Decimal, Option<String>, Option<String>)],
period_start: NaiveDate,
period_end: NaiveDate,
currency: &str,
pack: &CountryPack,
) -> (PayrollRun, Vec<PayrollLineItem>) {
let rates = self.rates_from_country_pack(pack);
let labels = Self::labels_from_country_pack(pack);
self.generate_with_rates_and_labels(
company_code,
employees,
period_start,
period_end,
currency,
&rates,
&labels,
)
}
fn rates_from_config(&self) -> PayrollRates {
let federal_rate = Decimal::from_f64_retain(self.config.tax_rates.federal_effective)
.unwrap_or(Decimal::ZERO);
let state_rate = Decimal::from_f64_retain(self.config.tax_rates.state_effective)
.unwrap_or(Decimal::ZERO);
let fica_rate =
Decimal::from_f64_retain(self.config.tax_rates.fica).unwrap_or(Decimal::ZERO);
PayrollRates {
income_tax_rate: federal_rate + state_rate,
income_tax_brackets: Vec::new(),
fica_rate,
health_rate: Decimal::from_f64_retain(0.03).unwrap_or(Decimal::ZERO),
retirement_rate: Decimal::from_f64_retain(0.05).unwrap_or(Decimal::ZERO),
employer_fica_rate: fica_rate,
}
}
fn compute_progressive_tax(annual_income: Decimal, brackets: &[TaxBracket]) -> Decimal {
let mut total_tax = Decimal::ZERO;
let mut taxed_up_to = Decimal::ZERO;
for bracket in brackets {
let bracket_floor = bracket
.above
.and_then(Decimal::from_f64_retain)
.unwrap_or(taxed_up_to);
let bracket_rate = Decimal::from_f64_retain(bracket.rate).unwrap_or(Decimal::ZERO);
if annual_income <= bracket_floor {
break;
}
let taxable_in_bracket = if let Some(ceiling) = bracket.up_to {
let ceiling = Decimal::from_f64_retain(ceiling).unwrap_or(Decimal::ZERO);
(annual_income.min(ceiling) - bracket_floor).max(Decimal::ZERO)
} else {
(annual_income - bracket_floor).max(Decimal::ZERO)
};
total_tax += (taxable_in_bracket * bracket_rate).round_dp(2);
taxed_up_to = bracket
.up_to
.and_then(Decimal::from_f64_retain)
.unwrap_or(annual_income);
}
total_tax.round_dp(2)
}
fn rates_from_country_pack(&self, pack: &CountryPack) -> PayrollRates {
let fallback = self.rates_from_config();
let mut federal_tax = Decimal::ZERO;
let mut state_tax = Decimal::ZERO;
let mut fica = Decimal::ZERO;
let mut health = Decimal::ZERO;
let mut retirement = Decimal::ZERO;
let mut found_federal = false;
let mut found_state = false;
let mut found_fica = false;
let mut found_health = false;
let mut found_retirement = false;
for ded in &pack.payroll.statutory_deductions {
let code_upper = ded.code.to_uppercase();
let name_en_lower = ded.name_en.to_lowercase();
let rate = Decimal::from_f64_retain(ded.rate).unwrap_or(Decimal::ZERO);
if (ded.deduction_type == "progressive" || ded.type_field == "progressive")
&& ded.rate == 0.0
{
if code_upper == "FIT"
|| code_upper == "LOHNST"
|| (name_en_lower.contains("income tax") && !name_en_lower.contains("state"))
{
found_federal = true;
}
continue;
}
if code_upper == "FIT"
|| code_upper == "LOHNST"
|| (name_en_lower.contains("income tax") && !name_en_lower.contains("state"))
{
federal_tax += rate;
found_federal = true;
} else if code_upper == "SIT" || name_en_lower.contains("state income tax") {
state_tax += rate;
found_state = true;
} else if code_upper == "FICA" || name_en_lower.contains("social security") {
fica += rate;
found_fica = true;
} else if name_en_lower.contains("health insurance") {
health += rate;
found_health = true;
} else if name_en_lower.contains("pension") || name_en_lower.contains("retirement") {
retirement += rate;
found_retirement = true;
} else {
fica += rate;
found_fica = true;
}
}
PayrollRates {
income_tax_rate: if found_federal || found_state {
let f = if found_federal {
federal_tax
} else {
fallback.income_tax_rate
- Decimal::from_f64_retain(self.config.tax_rates.state_effective)
.unwrap_or(Decimal::ZERO)
};
let s = if found_state {
state_tax
} else {
Decimal::from_f64_retain(self.config.tax_rates.state_effective)
.unwrap_or(Decimal::ZERO)
};
f + s
} else {
fallback.income_tax_rate
},
income_tax_brackets: pack.tax.payroll_tax.income_tax_brackets.clone(),
fica_rate: if found_fica { fica } else { fallback.fica_rate },
health_rate: if found_health {
health
} else {
fallback.health_rate
},
retirement_rate: if found_retirement {
retirement
} else {
fallback.retirement_rate
},
employer_fica_rate: if found_fica {
fica
} else {
fallback.employer_fica_rate
},
}
}
fn labels_from_country_pack(pack: &CountryPack) -> DeductionLabels {
let mut labels = DeductionLabels::default();
for ded in &pack.payroll.statutory_deductions {
let code_upper = ded.code.to_uppercase();
let name_en_lower = ded.name_en.to_lowercase();
let label = if ded.name.is_empty() {
ded.name_en.clone()
} else {
ded.name.clone()
};
if label.is_empty() {
continue;
}
if (ded.deduction_type == "progressive" || ded.type_field == "progressive")
&& ded.rate == 0.0
{
if code_upper == "FIT"
|| code_upper == "LOHNST"
|| (name_en_lower.contains("income tax") && !name_en_lower.contains("state"))
{
if labels.tax_withholding.is_none() {
labels.tax_withholding = Some(label);
}
} else if code_upper == "SIT" || name_en_lower.contains("state income tax") {
labels.tax_withholding = Some(match labels.tax_withholding.take() {
Some(existing) => format!("{existing}; {label}"),
None => label,
});
}
continue;
}
if code_upper == "FIT"
|| code_upper == "LOHNST"
|| code_upper == "SIT"
|| name_en_lower.contains("income tax")
|| name_en_lower.contains("state income tax")
{
labels.tax_withholding = Some(match labels.tax_withholding.take() {
Some(existing) => format!("{existing}; {label}"),
None => label,
});
} else if code_upper == "FICA" || name_en_lower.contains("social security") {
labels.social_security = Some(match labels.social_security.take() {
Some(existing) => format!("{existing}; {label}"),
None => label,
});
} else if name_en_lower.contains("health insurance") {
if labels.health_insurance.is_none() {
labels.health_insurance = Some(label);
}
} else if name_en_lower.contains("pension") || name_en_lower.contains("retirement") {
if labels.retirement_contribution.is_none() {
labels.retirement_contribution = Some(label);
}
} else {
labels.social_security = Some(match labels.social_security.take() {
Some(existing) => format!("{existing}; {label}"),
None => label,
});
}
}
let emp_labels: Vec<String> = pack
.payroll
.employer_contributions
.iter()
.filter_map(|c| {
let l = if c.name.is_empty() {
c.name_en.clone()
} else {
c.name.clone()
};
if l.is_empty() {
None
} else {
Some(l)
}
})
.collect();
if !emp_labels.is_empty() {
labels.employer_contribution = Some(emp_labels.join("; "));
}
labels
}
fn generate_with_rates_and_labels(
&mut self,
company_code: &str,
employees: &[(String, Decimal, Option<String>, Option<String>)],
period_start: NaiveDate,
period_end: NaiveDate,
currency: &str,
rates: &PayrollRates,
labels: &DeductionLabels,
) -> (PayrollRun, Vec<PayrollLineItem>) {
let payroll_id = self.uuid_factory.next().to_string();
let mut line_items = Vec::with_capacity(employees.len());
let mut total_gross = Decimal::ZERO;
let mut total_deductions = Decimal::ZERO;
let mut total_net = Decimal::ZERO;
let mut total_employer_cost = Decimal::ZERO;
let benefits_enrolled = self.config.benefits_enrollment_rate;
let retirement_participating = self.config.retirement_participation_rate;
for (employee_id, base_salary, cost_center, department) in employees {
let line_id = self.line_uuid_factory.next().to_string();
let monthly_base = (*base_salary / Decimal::from(12)).round_dp(2);
let (overtime_pay, overtime_hours) = if self.rng.random_bool(0.10) {
let ot_hours = self.rng.random_range(1.0..=20.0);
let hourly_rate = *base_salary / Decimal::from(2080);
let ot_rate = hourly_rate * Decimal::from_f64_retain(1.5).unwrap_or(Decimal::ONE);
let ot_pay = (ot_rate
* Decimal::from_f64_retain(ot_hours).unwrap_or(Decimal::ZERO))
.round_dp(2);
(ot_pay, ot_hours)
} else {
(Decimal::ZERO, 0.0)
};
let bonus = if self.rng.random_bool(0.05) {
let pct = self.rng.random_range(0.01..=0.10);
(monthly_base * Decimal::from_f64_retain(pct).unwrap_or(Decimal::ZERO)).round_dp(2)
} else {
Decimal::ZERO
};
let gross_pay = monthly_base + overtime_pay + bonus;
let tax_withholding = if !rates.income_tax_brackets.is_empty() {
let annual = gross_pay * Decimal::from(12);
Self::compute_progressive_tax(annual, &rates.income_tax_brackets)
/ Decimal::from(12)
} else {
(gross_pay * rates.income_tax_rate).round_dp(2)
};
let social_security = (gross_pay * rates.fica_rate).round_dp(2);
let health_insurance = if self.rng.random_bool(benefits_enrolled) {
(gross_pay * rates.health_rate).round_dp(2)
} else {
Decimal::ZERO
};
let retirement_contribution = if self.rng.random_bool(retirement_participating) {
(gross_pay * rates.retirement_rate).round_dp(2)
} else {
Decimal::ZERO
};
let other_deductions = if self.rng.random_bool(0.03) {
sample_decimal_range(&mut self.rng, Decimal::from(50), Decimal::from(500))
.round_dp(2)
} else {
Decimal::ZERO
};
let total_ded = tax_withholding
+ social_security
+ health_insurance
+ retirement_contribution
+ other_deductions;
let net_pay = gross_pay - total_ded;
let hours_worked = 160.0;
let employer_contrib = (gross_pay * rates.employer_fica_rate).round_dp(2);
let employer_cost = gross_pay + employer_contrib;
total_gross += gross_pay;
total_deductions += total_ded;
total_net += net_pay;
total_employer_cost += employer_cost;
line_items.push(PayrollLineItem {
payroll_id: payroll_id.clone(),
employee_id: employee_id.clone(),
line_id,
gross_pay,
base_salary: monthly_base,
overtime_pay,
bonus,
tax_withholding,
social_security,
health_insurance,
retirement_contribution,
other_deductions,
net_pay,
hours_worked,
overtime_hours,
pay_date: period_end,
cost_center: cost_center.clone(),
department: department.clone(),
tax_withholding_label: labels.tax_withholding.clone(),
social_security_label: labels.social_security.clone(),
health_insurance_label: labels.health_insurance.clone(),
retirement_contribution_label: labels.retirement_contribution.clone(),
employer_contribution_label: labels.employer_contribution.clone(),
});
}
let status_roll: f64 = self.rng.random();
let status = if status_roll < 0.60 {
PayrollRunStatus::Posted
} else if status_roll < 0.85 {
PayrollRunStatus::Approved
} else if status_roll < 0.95 {
PayrollRunStatus::Calculated
} else {
PayrollRunStatus::Draft
};
let approved_by = if matches!(
status,
PayrollRunStatus::Approved | PayrollRunStatus::Posted
) {
if !self.employee_ids_pool.is_empty() {
let idx = self.rng.random_range(0..self.employee_ids_pool.len());
Some(self.employee_ids_pool[idx].clone())
} else {
Some(format!("USR-{:04}", self.rng.random_range(201..=400)))
}
} else {
None
};
let posted_by = if status == PayrollRunStatus::Posted {
if !self.employee_ids_pool.is_empty() {
let idx = self.rng.random_range(0..self.employee_ids_pool.len());
Some(self.employee_ids_pool[idx].clone())
} else {
Some(format!("USR-{:04}", self.rng.random_range(401..=500)))
}
} else {
None
};
let run = PayrollRun {
company_code: company_code.to_string(),
payroll_id: payroll_id.clone(),
pay_period_start: period_start,
pay_period_end: period_end,
run_date: period_end,
status,
total_gross,
total_deductions,
total_net,
total_employer_cost,
employee_count: employees.len() as u32,
currency: currency.to_string(),
posted_by,
approved_by,
};
(run, line_items)
}
pub fn generate_with_changes(
&mut self,
company_code: &str,
employees: &[(String, Decimal, Option<String>, Option<String>)],
period_start: NaiveDate,
period_end: NaiveDate,
currency: &str,
changes: &[datasynth_core::models::EmployeeChangeEvent],
) -> (PayrollRun, Vec<PayrollLineItem>) {
let adjusted: Vec<(String, Decimal, Option<String>, Option<String>)> = employees
.iter()
.map(|(id, salary, cc, dept)| {
let adjusted_salary =
Self::apply_salary_changes(id, *salary, period_start, period_end, changes);
(id.clone(), adjusted_salary, cc.clone(), dept.clone())
})
.collect();
self.generate(company_code, &adjusted, period_start, period_end, currency)
}
fn apply_salary_changes(
employee_id: &str,
base_annual_salary: Decimal,
period_start: NaiveDate,
period_end: NaiveDate,
changes: &[datasynth_core::models::EmployeeChangeEvent],
) -> Decimal {
use datasynth_core::models::EmployeeEventType;
let relevant: Vec<&datasynth_core::models::EmployeeChangeEvent> = changes
.iter()
.filter(|c| {
c.employee_id == employee_id
&& c.event_type == EmployeeEventType::SalaryAdjustment
&& c.effective_date <= period_end
})
.collect();
if relevant.is_empty() {
return base_annual_salary;
}
let latest = relevant
.iter()
.max_by_key(|c| c.effective_date)
.expect("non-empty slice always has a max");
let new_salary = match latest
.new_value
.as_deref()
.and_then(|v| v.parse::<Decimal>().ok())
{
Some(s) => s,
None => return base_annual_salary,
};
let effective = latest.effective_date;
if effective <= period_start {
new_salary
} else {
let total_days = (period_end - period_start).num_days() + 1;
let days_at_old = (effective - period_start).num_days();
let days_at_new = total_days - days_at_old;
let total = Decimal::from(total_days);
let old_fraction = Decimal::from(days_at_old) / total;
let new_fraction = Decimal::from(days_at_new) / total;
(base_annual_salary * old_fraction + new_salary * new_fraction).round_dp(2)
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
fn test_employees() -> Vec<(String, Decimal, Option<String>, Option<String>)> {
vec![
(
"EMP-001".to_string(),
Decimal::from(60_000),
Some("CC-100".to_string()),
Some("Engineering".to_string()),
),
(
"EMP-002".to_string(),
Decimal::from(85_000),
Some("CC-200".to_string()),
Some("Finance".to_string()),
),
(
"EMP-003".to_string(),
Decimal::from(120_000),
None,
Some("Sales".to_string()),
),
]
}
#[test]
fn test_basic_payroll_generation() {
let mut gen = PayrollGenerator::new(42);
let employees = test_employees();
let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
let (run, items) = gen.generate("C001", &employees, period_start, period_end, "USD");
assert_eq!(run.company_code, "C001");
assert_eq!(run.currency, "USD");
assert_eq!(run.employee_count, 3);
assert_eq!(items.len(), 3);
assert!(run.total_gross > Decimal::ZERO);
assert!(run.total_deductions > Decimal::ZERO);
assert!(run.total_net > Decimal::ZERO);
assert!(run.total_employer_cost > run.total_gross);
assert_eq!(run.total_net, run.total_gross - run.total_deductions);
for item in &items {
assert_eq!(item.payroll_id, run.payroll_id);
assert!(item.gross_pay > Decimal::ZERO);
assert!(item.net_pay > Decimal::ZERO);
assert!(item.net_pay < item.gross_pay);
assert!(item.base_salary > Decimal::ZERO);
assert_eq!(item.pay_date, period_end);
assert!(item.tax_withholding_label.is_none());
assert!(item.social_security_label.is_none());
}
}
#[test]
fn test_deterministic_payroll() {
let employees = test_employees();
let period_start = NaiveDate::from_ymd_opt(2024, 3, 1).unwrap();
let period_end = NaiveDate::from_ymd_opt(2024, 3, 31).unwrap();
let mut gen1 = PayrollGenerator::new(42);
let (run1, items1) = gen1.generate("C001", &employees, period_start, period_end, "USD");
let mut gen2 = PayrollGenerator::new(42);
let (run2, items2) = gen2.generate("C001", &employees, period_start, period_end, "USD");
assert_eq!(run1.payroll_id, run2.payroll_id);
assert_eq!(run1.total_gross, run2.total_gross);
assert_eq!(run1.total_net, run2.total_net);
assert_eq!(run1.status, run2.status);
assert_eq!(items1.len(), items2.len());
for (a, b) in items1.iter().zip(items2.iter()) {
assert_eq!(a.line_id, b.line_id);
assert_eq!(a.gross_pay, b.gross_pay);
assert_eq!(a.net_pay, b.net_pay);
}
}
#[test]
fn test_payroll_deduction_components() {
let mut gen = PayrollGenerator::new(99);
let employees = vec![(
"EMP-010".to_string(),
Decimal::from(100_000),
Some("CC-300".to_string()),
Some("HR".to_string()),
)];
let period_start = NaiveDate::from_ymd_opt(2024, 6, 1).unwrap();
let period_end = NaiveDate::from_ymd_opt(2024, 6, 30).unwrap();
let (_run, items) = gen.generate("C001", &employees, period_start, period_end, "USD");
assert_eq!(items.len(), 1);
let item = &items[0];
let expected_monthly = (Decimal::from(100_000) / Decimal::from(12)).round_dp(2);
assert_eq!(item.base_salary, expected_monthly);
let deduction_sum = item.tax_withholding
+ item.social_security
+ item.health_insurance
+ item.retirement_contribution
+ item.other_deductions;
let expected_net = item.gross_pay - deduction_sum;
assert_eq!(item.net_pay, expected_net);
assert!(item.tax_withholding > Decimal::ZERO);
assert!(item.social_security > Decimal::ZERO);
}
fn us_country_pack() -> CountryPack {
use datasynth_core::country::schema::{PayrollCountryConfig, PayrollDeduction};
CountryPack {
country_code: "US".to_string(),
payroll: PayrollCountryConfig {
statutory_deductions: vec![
PayrollDeduction {
code: "FICA".to_string(),
name_en: "Federal Insurance Contributions Act".to_string(),
deduction_type: "percentage".to_string(),
rate: 0.0765,
..Default::default()
},
PayrollDeduction {
code: "FIT".to_string(),
name_en: "Federal Income Tax".to_string(),
deduction_type: "progressive".to_string(),
rate: 0.0, ..Default::default()
},
PayrollDeduction {
code: "SIT".to_string(),
name_en: "State Income Tax".to_string(),
deduction_type: "percentage".to_string(),
rate: 0.05,
..Default::default()
},
],
..Default::default()
},
..Default::default()
}
}
fn de_country_pack() -> CountryPack {
use datasynth_core::country::schema::{PayrollCountryConfig, PayrollDeduction};
CountryPack {
country_code: "DE".to_string(),
payroll: PayrollCountryConfig {
pay_frequency: "monthly".to_string(),
currency: "EUR".to_string(),
statutory_deductions: vec![
PayrollDeduction {
code: "LOHNST".to_string(),
name_en: "Income Tax".to_string(),
type_field: "progressive".to_string(),
rate: 0.0, ..Default::default()
},
PayrollDeduction {
code: "SOLI".to_string(),
name_en: "Solidarity Surcharge".to_string(),
type_field: "percentage".to_string(),
rate: 0.055,
..Default::default()
},
PayrollDeduction {
code: "KiSt".to_string(),
name_en: "Church Tax".to_string(),
type_field: "percentage".to_string(),
rate: 0.08,
optional: true,
..Default::default()
},
PayrollDeduction {
code: "RV".to_string(),
name_en: "Pension Insurance".to_string(),
type_field: "percentage".to_string(),
rate: 0.093,
..Default::default()
},
PayrollDeduction {
code: "KV".to_string(),
name_en: "Health Insurance".to_string(),
type_field: "percentage".to_string(),
rate: 0.073,
..Default::default()
},
PayrollDeduction {
code: "AV".to_string(),
name_en: "Unemployment Insurance".to_string(),
type_field: "percentage".to_string(),
rate: 0.013,
..Default::default()
},
PayrollDeduction {
code: "PV".to_string(),
name_en: "Long-Term Care Insurance".to_string(),
type_field: "percentage".to_string(),
rate: 0.017,
..Default::default()
},
],
employer_contributions: vec![
PayrollDeduction {
code: "AG-RV".to_string(),
name_en: "Employer Pension Insurance".to_string(),
type_field: "percentage".to_string(),
rate: 0.093,
..Default::default()
},
PayrollDeduction {
code: "AG-KV".to_string(),
name_en: "Employer Health Insurance".to_string(),
type_field: "percentage".to_string(),
rate: 0.073,
..Default::default()
},
],
..Default::default()
},
..Default::default()
}
}
#[test]
fn test_generate_with_us_country_pack() {
let mut gen = PayrollGenerator::new(42);
let employees = test_employees();
let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
let pack = us_country_pack();
let (run, items) = gen.generate_with_country_pack(
"C001",
&employees,
period_start,
period_end,
"USD",
&pack,
);
assert_eq!(run.company_code, "C001");
assert_eq!(run.employee_count, 3);
assert_eq!(items.len(), 3);
assert_eq!(run.total_net, run.total_gross - run.total_deductions);
for item in &items {
assert!(item.gross_pay > Decimal::ZERO);
assert!(item.net_pay > Decimal::ZERO);
assert!(item.net_pay < item.gross_pay);
assert!(item.social_security > Decimal::ZERO);
assert!(item.tax_withholding_label.is_some());
assert!(item.social_security_label.is_some());
}
}
#[test]
fn test_generate_with_de_country_pack() {
let mut gen = PayrollGenerator::new(42);
let employees = test_employees();
let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
let pack = de_country_pack();
let (run, items) = gen.generate_with_country_pack(
"DE01",
&employees,
period_start,
period_end,
"EUR",
&pack,
);
assert_eq!(run.company_code, "DE01");
assert_eq!(items.len(), 3);
assert_eq!(run.total_net, run.total_gross - run.total_deductions);
let rates = gen.rates_from_country_pack(&pack);
assert_eq!(
rates.retirement_rate,
Decimal::from_f64_retain(0.093).unwrap()
);
assert_eq!(rates.health_rate, Decimal::from_f64_retain(0.073).unwrap());
let item = &items[0];
assert_eq!(
item.health_insurance_label.as_deref(),
Some("Health Insurance")
);
assert_eq!(
item.retirement_contribution_label.as_deref(),
Some("Pension Insurance")
);
assert!(item.employer_contribution_label.is_some());
let ec = item.employer_contribution_label.as_ref().unwrap();
assert!(ec.contains("Employer Pension Insurance"));
assert!(ec.contains("Employer Health Insurance"));
}
#[test]
fn test_country_pack_falls_back_to_config_for_missing_categories() {
let pack = CountryPack::default();
let gen = PayrollGenerator::new(42);
let rates_pack = gen.rates_from_country_pack(&pack);
let rates_cfg = gen.rates_from_config();
assert_eq!(rates_pack.income_tax_rate, rates_cfg.income_tax_rate);
assert_eq!(rates_pack.fica_rate, rates_cfg.fica_rate);
assert_eq!(rates_pack.health_rate, rates_cfg.health_rate);
assert_eq!(rates_pack.retirement_rate, rates_cfg.retirement_rate);
}
#[test]
fn test_country_pack_deterministic() {
let employees = test_employees();
let period_start = NaiveDate::from_ymd_opt(2024, 3, 1).unwrap();
let period_end = NaiveDate::from_ymd_opt(2024, 3, 31).unwrap();
let pack = de_country_pack();
let mut gen1 = PayrollGenerator::new(42);
let (run1, items1) = gen1.generate_with_country_pack(
"DE01",
&employees,
period_start,
period_end,
"EUR",
&pack,
);
let mut gen2 = PayrollGenerator::new(42);
let (run2, items2) = gen2.generate_with_country_pack(
"DE01",
&employees,
period_start,
period_end,
"EUR",
&pack,
);
assert_eq!(run1.payroll_id, run2.payroll_id);
assert_eq!(run1.total_gross, run2.total_gross);
assert_eq!(run1.total_net, run2.total_net);
for (a, b) in items1.iter().zip(items2.iter()) {
assert_eq!(a.net_pay, b.net_pay);
}
}
#[test]
fn test_de_rates_differ_from_default() {
let gen = PayrollGenerator::new(42);
let pack = de_country_pack();
let rates_cfg = gen.rates_from_config();
let rates_de = gen.rates_from_country_pack(&pack);
assert_ne!(rates_de.health_rate, rates_cfg.health_rate);
assert_ne!(rates_de.retirement_rate, rates_cfg.retirement_rate);
}
#[test]
fn test_set_country_pack_uses_labels() {
let mut gen = PayrollGenerator::new(42);
let pack = de_country_pack();
gen.set_country_pack(pack);
let employees = test_employees();
let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
let (_run, items) = gen.generate("DE01", &employees, period_start, period_end, "EUR");
let item = &items[0];
assert!(item.tax_withholding_label.is_some());
assert!(item.health_insurance_label.is_some());
assert!(item.retirement_contribution_label.is_some());
assert!(item.employer_contribution_label.is_some());
}
#[test]
fn test_compute_progressive_tax_us_brackets() {
let brackets = vec![
TaxBracket {
above: Some(0.0),
up_to: Some(11_000.0),
rate: 0.10,
},
TaxBracket {
above: Some(11_000.0),
up_to: Some(44_725.0),
rate: 0.12,
},
TaxBracket {
above: Some(44_725.0),
up_to: Some(95_375.0),
rate: 0.22,
},
TaxBracket {
above: Some(95_375.0),
up_to: None,
rate: 0.24,
},
];
let tax = PayrollGenerator::compute_progressive_tax(Decimal::from(60_000), &brackets);
assert_eq!(tax, Decimal::from_f64_retain(8507.50).unwrap());
let tax = PayrollGenerator::compute_progressive_tax(Decimal::from(11_000), &brackets);
assert_eq!(tax, Decimal::from_f64_retain(1100.0).unwrap());
}
#[test]
fn test_progressive_tax_zero_income() {
let brackets = vec![TaxBracket {
above: Some(0.0),
up_to: Some(10_000.0),
rate: 0.10,
}];
let tax = PayrollGenerator::compute_progressive_tax(Decimal::ZERO, &brackets);
assert_eq!(tax, Decimal::ZERO);
}
#[test]
fn test_us_pack_employees_have_varying_rates() {
use datasynth_core::country::schema::{
CountryTaxConfig, PayrollCountryConfig, PayrollDeduction, PayrollTaxBracketsConfig,
};
let brackets = vec![
TaxBracket {
above: Some(0.0),
up_to: Some(11_000.0),
rate: 0.10,
},
TaxBracket {
above: Some(11_000.0),
up_to: Some(44_725.0),
rate: 0.12,
},
TaxBracket {
above: Some(44_725.0),
up_to: None,
rate: 0.22,
},
];
let pack = CountryPack {
country_code: "US".to_string(),
payroll: PayrollCountryConfig {
statutory_deductions: vec![
PayrollDeduction {
code: "FIT".to_string(),
name_en: "Federal Income Tax".to_string(),
deduction_type: "progressive".to_string(),
rate: 0.0,
..Default::default()
},
PayrollDeduction {
code: "FICA".to_string(),
name_en: "Social Security".to_string(),
deduction_type: "percentage".to_string(),
rate: 0.0765,
..Default::default()
},
],
..Default::default()
},
tax: CountryTaxConfig {
payroll_tax: PayrollTaxBracketsConfig {
income_tax_brackets: brackets,
..Default::default()
},
..Default::default()
},
..Default::default()
};
let mut gen = PayrollGenerator::new(42);
gen.set_country_pack(pack);
let low_earner = vec![("LOW".to_string(), Decimal::from(30_000), None, None)];
let high_earner = vec![("HIGH".to_string(), Decimal::from(200_000), None, None)];
let start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
let end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
let (_, low_items) = gen.generate("C001", &low_earner, start, end, "USD");
let mut gen2 = PayrollGenerator::new(42);
gen2.set_country_pack(CountryPack {
country_code: "US".to_string(),
payroll: PayrollCountryConfig {
statutory_deductions: vec![
PayrollDeduction {
code: "FIT".to_string(),
name_en: "Federal Income Tax".to_string(),
deduction_type: "progressive".to_string(),
rate: 0.0,
..Default::default()
},
PayrollDeduction {
code: "FICA".to_string(),
name_en: "Social Security".to_string(),
deduction_type: "percentage".to_string(),
rate: 0.0765,
..Default::default()
},
],
..Default::default()
},
tax: CountryTaxConfig {
payroll_tax: PayrollTaxBracketsConfig {
income_tax_brackets: vec![
TaxBracket {
above: Some(0.0),
up_to: Some(11_000.0),
rate: 0.10,
},
TaxBracket {
above: Some(11_000.0),
up_to: Some(44_725.0),
rate: 0.12,
},
TaxBracket {
above: Some(44_725.0),
up_to: None,
rate: 0.22,
},
],
..Default::default()
},
..Default::default()
},
..Default::default()
});
let (_, high_items) = gen2.generate("C001", &high_earner, start, end, "USD");
let low_eff = low_items[0].tax_withholding / low_items[0].gross_pay;
let high_eff = high_items[0].tax_withholding / high_items[0].gross_pay;
assert!(
high_eff > low_eff,
"High earner effective rate ({high_eff}) should exceed low earner ({low_eff})"
);
}
#[test]
fn test_empty_pack_labels_are_none() {
let pack = CountryPack::default();
let labels = PayrollGenerator::labels_from_country_pack(&pack);
assert!(labels.tax_withholding.is_none());
assert!(labels.social_security.is_none());
assert!(labels.health_insurance.is_none());
assert!(labels.retirement_contribution.is_none());
assert!(labels.employer_contribution.is_none());
}
#[test]
fn test_us_pack_labels() {
let pack = us_country_pack();
let labels = PayrollGenerator::labels_from_country_pack(&pack);
assert!(labels.tax_withholding.is_some());
let tw = labels.tax_withholding.unwrap();
assert!(tw.contains("Federal Income Tax"));
assert!(tw.contains("State Income Tax"));
assert!(labels.social_security.is_some());
assert!(labels
.social_security
.unwrap()
.contains("Federal Insurance Contributions Act"));
}
}