1use chrono::NaiveDate;
8use datasynth_config::schema::PayrollConfig;
9use datasynth_core::country::schema::TaxBracket;
10use datasynth_core::models::{PayrollLineItem, PayrollRun, PayrollRunStatus};
11use datasynth_core::utils::{sample_decimal_range, seeded_rng};
12use datasynth_core::uuid_factory::{DeterministicUuidFactory, GeneratorType};
13use datasynth_core::CountryPack;
14use rand::prelude::*;
15use rand_chacha::ChaCha8Rng;
16use rust_decimal::Decimal;
17use tracing::debug;
18
19#[derive(Debug, Clone)]
21struct PayrollRates {
22 income_tax_rate: Decimal,
24 income_tax_brackets: Vec<TaxBracket>,
26 fica_rate: Decimal,
28 health_rate: Decimal,
30 retirement_rate: Decimal,
32 employer_fica_rate: Decimal,
34}
35
36#[derive(Debug, Clone, Default)]
38struct DeductionLabels {
39 tax_withholding: Option<String>,
40 social_security: Option<String>,
41 health_insurance: Option<String>,
42 retirement_contribution: Option<String>,
43 employer_contribution: Option<String>,
44}
45
46pub struct PayrollGenerator {
48 rng: ChaCha8Rng,
49 uuid_factory: DeterministicUuidFactory,
50 line_uuid_factory: DeterministicUuidFactory,
51 config: PayrollConfig,
52 country_pack: Option<CountryPack>,
53 employee_ids_pool: Vec<String>,
55 cost_center_ids_pool: Vec<String>,
58}
59
60impl PayrollGenerator {
61 pub fn new(seed: u64) -> Self {
63 Self {
64 rng: seeded_rng(seed, 0),
65 uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::PayrollRun),
66 line_uuid_factory: DeterministicUuidFactory::with_sub_discriminator(
67 seed,
68 GeneratorType::PayrollRun,
69 1,
70 ),
71 config: PayrollConfig::default(),
72 country_pack: None,
73 employee_ids_pool: Vec::new(),
74 cost_center_ids_pool: Vec::new(),
75 }
76 }
77
78 pub fn with_config(seed: u64, config: PayrollConfig) -> Self {
80 Self {
81 rng: seeded_rng(seed, 0),
82 uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::PayrollRun),
83 line_uuid_factory: DeterministicUuidFactory::with_sub_discriminator(
84 seed,
85 GeneratorType::PayrollRun,
86 1,
87 ),
88 config,
89 country_pack: None,
90 employee_ids_pool: Vec::new(),
91 cost_center_ids_pool: Vec::new(),
92 }
93 }
94
95 pub fn with_pools(mut self, employee_ids: Vec<String>, cost_center_ids: Vec<String>) -> Self {
100 self.employee_ids_pool = employee_ids;
101 self.cost_center_ids_pool = cost_center_ids;
102 self
103 }
104
105 pub fn set_country_pack(&mut self, pack: CountryPack) {
113 self.country_pack = Some(pack);
114 }
115
116 pub fn generate(
131 &mut self,
132 company_code: &str,
133 employees: &[(String, Decimal, Option<String>, Option<String>)],
134 period_start: NaiveDate,
135 period_end: NaiveDate,
136 currency: &str,
137 ) -> (PayrollRun, Vec<PayrollLineItem>) {
138 debug!(company_code, employee_count = employees.len(), %period_start, %period_end, currency, "Generating payroll run");
139 if let Some(pack) = self.country_pack.as_ref() {
140 let rates = self.rates_from_country_pack(pack);
141 let labels = Self::labels_from_country_pack(pack);
142 self.generate_with_rates_and_labels(
143 company_code,
144 employees,
145 period_start,
146 period_end,
147 currency,
148 &rates,
149 &labels,
150 )
151 } else {
152 let rates = self.rates_from_config();
153 self.generate_with_rates_and_labels(
154 company_code,
155 employees,
156 period_start,
157 period_end,
158 currency,
159 &rates,
160 &DeductionLabels::default(),
161 )
162 }
163 }
164
165 pub fn generate_with_country_pack(
186 &mut self,
187 company_code: &str,
188 employees: &[(String, Decimal, Option<String>, Option<String>)],
189 period_start: NaiveDate,
190 period_end: NaiveDate,
191 currency: &str,
192 pack: &CountryPack,
193 ) -> (PayrollRun, Vec<PayrollLineItem>) {
194 let rates = self.rates_from_country_pack(pack);
195 let labels = Self::labels_from_country_pack(pack);
196 self.generate_with_rates_and_labels(
197 company_code,
198 employees,
199 period_start,
200 period_end,
201 currency,
202 &rates,
203 &labels,
204 )
205 }
206
207 fn rates_from_config(&self) -> PayrollRates {
213 let federal_rate = Decimal::from_f64_retain(self.config.tax_rates.federal_effective)
214 .unwrap_or(Decimal::ZERO);
215 let state_rate = Decimal::from_f64_retain(self.config.tax_rates.state_effective)
216 .unwrap_or(Decimal::ZERO);
217 let fica_rate =
218 Decimal::from_f64_retain(self.config.tax_rates.fica).unwrap_or(Decimal::ZERO);
219
220 PayrollRates {
221 income_tax_rate: federal_rate + state_rate,
222 income_tax_brackets: Vec::new(),
223 fica_rate,
224 health_rate: Decimal::from_f64_retain(0.03).unwrap_or(Decimal::ZERO),
225 retirement_rate: Decimal::from_f64_retain(0.05).unwrap_or(Decimal::ZERO),
226 employer_fica_rate: fica_rate,
227 }
228 }
229
230 fn compute_progressive_tax(annual_income: Decimal, brackets: &[TaxBracket]) -> Decimal {
236 let mut total_tax = Decimal::ZERO;
237 let mut taxed_up_to = Decimal::ZERO;
238
239 for bracket in brackets {
240 let bracket_floor = bracket
241 .above
242 .and_then(Decimal::from_f64_retain)
243 .unwrap_or(taxed_up_to);
244 let bracket_rate = Decimal::from_f64_retain(bracket.rate).unwrap_or(Decimal::ZERO);
245
246 if annual_income <= bracket_floor {
247 break;
248 }
249
250 let taxable_in_bracket = if let Some(ceiling) = bracket.up_to {
251 let ceiling = Decimal::from_f64_retain(ceiling).unwrap_or(Decimal::ZERO);
252 (annual_income.min(ceiling) - bracket_floor).max(Decimal::ZERO)
253 } else {
254 (annual_income - bracket_floor).max(Decimal::ZERO)
256 };
257
258 total_tax += (taxable_in_bracket * bracket_rate).round_dp(2);
259 taxed_up_to = bracket
260 .up_to
261 .and_then(Decimal::from_f64_retain)
262 .unwrap_or(annual_income);
263 }
264
265 total_tax.round_dp(2)
266 }
267
268 fn rates_from_country_pack(&self, pack: &CountryPack) -> PayrollRates {
271 let fallback = self.rates_from_config();
272
273 let mut federal_tax = Decimal::ZERO;
276 let mut state_tax = Decimal::ZERO;
277 let mut fica = Decimal::ZERO;
278 let mut health = Decimal::ZERO;
279 let mut retirement = Decimal::ZERO;
280
281 let mut found_federal = false;
283 let mut found_state = false;
284 let mut found_fica = false;
285 let mut found_health = false;
286 let mut found_retirement = false;
287
288 for ded in &pack.payroll.statutory_deductions {
289 let code_upper = ded.code.to_uppercase();
290 let name_en_lower = ded.name_en.to_lowercase();
291 let rate = Decimal::from_f64_retain(ded.rate).unwrap_or(Decimal::ZERO);
292
293 if (ded.deduction_type == "progressive" || ded.type_field == "progressive")
298 && ded.rate == 0.0
299 {
300 if code_upper == "FIT"
301 || code_upper == "LOHNST"
302 || (name_en_lower.contains("income tax") && !name_en_lower.contains("state"))
303 {
304 found_federal = true;
305 }
306 continue;
307 }
308
309 if code_upper == "FIT"
310 || code_upper == "LOHNST"
311 || (name_en_lower.contains("income tax") && !name_en_lower.contains("state"))
312 {
313 federal_tax += rate;
314 found_federal = true;
315 } else if code_upper == "SIT" || name_en_lower.contains("state income tax") {
316 state_tax += rate;
317 found_state = true;
318 } else if code_upper == "FICA" || name_en_lower.contains("social security") {
319 fica += rate;
320 found_fica = true;
321 } else if name_en_lower.contains("health insurance") {
322 health += rate;
323 found_health = true;
324 } else if name_en_lower.contains("pension") || name_en_lower.contains("retirement") {
325 retirement += rate;
326 found_retirement = true;
327 } else {
328 fica += rate;
333 found_fica = true;
334 }
335 }
336
337 PayrollRates {
338 income_tax_rate: if found_federal || found_state {
339 let f = if found_federal {
340 federal_tax
341 } else {
342 fallback.income_tax_rate
343 - Decimal::from_f64_retain(self.config.tax_rates.state_effective)
344 .unwrap_or(Decimal::ZERO)
345 };
346 let s = if found_state {
347 state_tax
348 } else {
349 Decimal::from_f64_retain(self.config.tax_rates.state_effective)
350 .unwrap_or(Decimal::ZERO)
351 };
352 f + s
353 } else {
354 fallback.income_tax_rate
355 },
356 income_tax_brackets: pack.tax.payroll_tax.income_tax_brackets.clone(),
357 fica_rate: if found_fica { fica } else { fallback.fica_rate },
358 health_rate: if found_health {
359 health
360 } else {
361 fallback.health_rate
362 },
363 retirement_rate: if found_retirement {
364 retirement
365 } else {
366 fallback.retirement_rate
367 },
368 employer_fica_rate: if found_fica {
369 fica
370 } else {
371 fallback.employer_fica_rate
372 },
373 }
374 }
375
376 fn labels_from_country_pack(pack: &CountryPack) -> DeductionLabels {
383 let mut labels = DeductionLabels::default();
384
385 for ded in &pack.payroll.statutory_deductions {
386 let code_upper = ded.code.to_uppercase();
387 let name_en_lower = ded.name_en.to_lowercase();
388
389 let label = if ded.name.is_empty() {
392 ded.name_en.clone()
393 } else {
394 ded.name.clone()
395 };
396 if label.is_empty() {
397 continue;
398 }
399
400 if (ded.deduction_type == "progressive" || ded.type_field == "progressive")
403 && ded.rate == 0.0
404 {
405 if code_upper == "FIT"
406 || code_upper == "LOHNST"
407 || (name_en_lower.contains("income tax") && !name_en_lower.contains("state"))
408 {
409 if labels.tax_withholding.is_none() {
410 labels.tax_withholding = Some(label);
411 }
412 } else if code_upper == "SIT" || name_en_lower.contains("state income tax") {
413 labels.tax_withholding = Some(match labels.tax_withholding.take() {
414 Some(existing) => format!("{existing}; {label}"),
415 None => label,
416 });
417 }
418 continue;
419 }
420
421 if code_upper == "FIT"
422 || code_upper == "LOHNST"
423 || code_upper == "SIT"
424 || name_en_lower.contains("income tax")
425 || name_en_lower.contains("state income tax")
426 {
427 labels.tax_withholding = Some(match labels.tax_withholding.take() {
430 Some(existing) => format!("{existing}; {label}"),
431 None => label,
432 });
433 } else if code_upper == "FICA" || name_en_lower.contains("social security") {
434 labels.social_security = Some(match labels.social_security.take() {
435 Some(existing) => format!("{existing}; {label}"),
436 None => label,
437 });
438 } else if name_en_lower.contains("health insurance") {
439 if labels.health_insurance.is_none() {
440 labels.health_insurance = Some(label);
441 }
442 } else if name_en_lower.contains("pension") || name_en_lower.contains("retirement") {
443 if labels.retirement_contribution.is_none() {
444 labels.retirement_contribution = Some(label);
445 }
446 } else {
447 labels.social_security = Some(match labels.social_security.take() {
450 Some(existing) => format!("{existing}; {label}"),
451 None => label,
452 });
453 }
454 }
455
456 let emp_labels: Vec<String> = pack
458 .payroll
459 .employer_contributions
460 .iter()
461 .filter_map(|c| {
462 let l = if c.name.is_empty() {
463 c.name_en.clone()
464 } else {
465 c.name.clone()
466 };
467 if l.is_empty() {
468 None
469 } else {
470 Some(l)
471 }
472 })
473 .collect();
474 if !emp_labels.is_empty() {
475 labels.employer_contribution = Some(emp_labels.join("; "));
476 }
477
478 labels
479 }
480
481 fn generate_with_rates_and_labels(
483 &mut self,
484 company_code: &str,
485 employees: &[(String, Decimal, Option<String>, Option<String>)],
486 period_start: NaiveDate,
487 period_end: NaiveDate,
488 currency: &str,
489 rates: &PayrollRates,
490 labels: &DeductionLabels,
491 ) -> (PayrollRun, Vec<PayrollLineItem>) {
492 let payroll_id = self.uuid_factory.next().to_string();
493
494 let mut line_items = Vec::with_capacity(employees.len());
495 let mut total_gross = Decimal::ZERO;
496 let mut total_deductions = Decimal::ZERO;
497 let mut total_net = Decimal::ZERO;
498 let mut total_employer_cost = Decimal::ZERO;
499
500 let benefits_enrolled = self.config.benefits_enrollment_rate;
501 let retirement_participating = self.config.retirement_participation_rate;
502
503 for (employee_id, base_salary, cost_center, department) in employees {
504 let line_id = self.line_uuid_factory.next().to_string();
505
506 let monthly_base = (*base_salary / Decimal::from(12)).round_dp(2);
508
509 let (overtime_pay, overtime_hours) = if self.rng.random_bool(0.10) {
511 let ot_hours = self.rng.random_range(1.0..=20.0);
512 let hourly_rate = *base_salary / Decimal::from(2080);
514 let ot_rate = hourly_rate * Decimal::from_f64_retain(1.5).unwrap_or(Decimal::ONE);
515 let ot_pay = (ot_rate
516 * Decimal::from_f64_retain(ot_hours).unwrap_or(Decimal::ZERO))
517 .round_dp(2);
518 (ot_pay, ot_hours)
519 } else {
520 (Decimal::ZERO, 0.0)
521 };
522
523 let bonus = if self.rng.random_bool(0.05) {
525 let pct = self.rng.random_range(0.01..=0.10);
526 (monthly_base * Decimal::from_f64_retain(pct).unwrap_or(Decimal::ZERO)).round_dp(2)
527 } else {
528 Decimal::ZERO
529 };
530
531 let gross_pay = monthly_base + overtime_pay + bonus;
532
533 let tax_withholding = if !rates.income_tax_brackets.is_empty() {
535 let annual = gross_pay * Decimal::from(12);
536 Self::compute_progressive_tax(annual, &rates.income_tax_brackets)
537 / Decimal::from(12)
538 } else {
539 (gross_pay * rates.income_tax_rate).round_dp(2)
540 };
541 let social_security = (gross_pay * rates.fica_rate).round_dp(2);
542
543 let health_insurance = if self.rng.random_bool(benefits_enrolled) {
544 (gross_pay * rates.health_rate).round_dp(2)
545 } else {
546 Decimal::ZERO
547 };
548
549 let retirement_contribution = if self.rng.random_bool(retirement_participating) {
550 (gross_pay * rates.retirement_rate).round_dp(2)
551 } else {
552 Decimal::ZERO
553 };
554
555 let other_deductions = if self.rng.random_bool(0.03) {
557 sample_decimal_range(&mut self.rng, Decimal::from(50), Decimal::from(500))
558 .round_dp(2)
559 } else {
560 Decimal::ZERO
561 };
562
563 let total_ded = tax_withholding
564 + social_security
565 + health_insurance
566 + retirement_contribution
567 + other_deductions;
568 let net_pay = gross_pay - total_ded;
569
570 let hours_worked = 160.0;
572
573 let employer_contrib = (gross_pay * rates.employer_fica_rate).round_dp(2);
575 let employer_cost = gross_pay + employer_contrib;
576
577 total_gross += gross_pay;
578 total_deductions += total_ded;
579 total_net += net_pay;
580 total_employer_cost += employer_cost;
581
582 line_items.push(PayrollLineItem {
583 payroll_id: payroll_id.clone(),
584 employee_id: employee_id.clone(),
585 line_id,
586 gross_pay,
587 base_salary: monthly_base,
588 overtime_pay,
589 bonus,
590 tax_withholding,
591 social_security,
592 health_insurance,
593 retirement_contribution,
594 other_deductions,
595 net_pay,
596 hours_worked,
597 overtime_hours,
598 pay_date: period_end,
599 cost_center: cost_center.clone(),
600 department: department.clone(),
601 tax_withholding_label: labels.tax_withholding.clone(),
602 social_security_label: labels.social_security.clone(),
603 health_insurance_label: labels.health_insurance.clone(),
604 retirement_contribution_label: labels.retirement_contribution.clone(),
605 employer_contribution_label: labels.employer_contribution.clone(),
606 });
607 }
608
609 let status_roll: f64 = self.rng.random();
611 let status = if status_roll < 0.60 {
612 PayrollRunStatus::Posted
613 } else if status_roll < 0.85 {
614 PayrollRunStatus::Approved
615 } else if status_roll < 0.95 {
616 PayrollRunStatus::Calculated
617 } else {
618 PayrollRunStatus::Draft
619 };
620
621 let approved_by = if matches!(
622 status,
623 PayrollRunStatus::Approved | PayrollRunStatus::Posted
624 ) {
625 if !self.employee_ids_pool.is_empty() {
626 let idx = self.rng.random_range(0..self.employee_ids_pool.len());
627 Some(self.employee_ids_pool[idx].clone())
628 } else {
629 Some(format!("USR-{:04}", self.rng.random_range(201..=400)))
630 }
631 } else {
632 None
633 };
634
635 let posted_by = if status == PayrollRunStatus::Posted {
636 if !self.employee_ids_pool.is_empty() {
637 let idx = self.rng.random_range(0..self.employee_ids_pool.len());
638 Some(self.employee_ids_pool[idx].clone())
639 } else {
640 Some(format!("USR-{:04}", self.rng.random_range(401..=500)))
641 }
642 } else {
643 None
644 };
645
646 let run = PayrollRun {
647 company_code: company_code.to_string(),
648 payroll_id: payroll_id.clone(),
649 pay_period_start: period_start,
650 pay_period_end: period_end,
651 run_date: period_end,
652 status,
653 total_gross,
654 total_deductions,
655 total_net,
656 total_employer_cost,
657 employee_count: employees.len() as u32,
658 currency: currency.to_string(),
659 posted_by,
660 approved_by,
661 };
662
663 (run, line_items)
664 }
665
666 pub fn generate_with_changes(
683 &mut self,
684 company_code: &str,
685 employees: &[(String, Decimal, Option<String>, Option<String>)],
686 period_start: NaiveDate,
687 period_end: NaiveDate,
688 currency: &str,
689 changes: &[datasynth_core::models::EmployeeChangeEvent],
690 ) -> (PayrollRun, Vec<PayrollLineItem>) {
691 let adjusted: Vec<(String, Decimal, Option<String>, Option<String>)> = employees
692 .iter()
693 .map(|(id, salary, cc, dept)| {
694 let adjusted_salary =
695 Self::apply_salary_changes(id, *salary, period_start, period_end, changes);
696 (id.clone(), adjusted_salary, cc.clone(), dept.clone())
697 })
698 .collect();
699 self.generate(company_code, &adjusted, period_start, period_end, currency)
700 }
701
702 fn apply_salary_changes(
709 employee_id: &str,
710 base_annual_salary: Decimal,
711 period_start: NaiveDate,
712 period_end: NaiveDate,
713 changes: &[datasynth_core::models::EmployeeChangeEvent],
714 ) -> Decimal {
715 use datasynth_core::models::EmployeeEventType;
716
717 let relevant: Vec<&datasynth_core::models::EmployeeChangeEvent> = changes
720 .iter()
721 .filter(|c| {
722 c.employee_id == employee_id
723 && c.event_type == EmployeeEventType::SalaryAdjustment
724 && c.effective_date <= period_end
725 })
726 .collect();
727
728 if relevant.is_empty() {
729 return base_annual_salary;
730 }
731
732 let latest = relevant
734 .iter()
735 .max_by_key(|c| c.effective_date)
736 .expect("non-empty slice always has a max");
737
738 let new_salary = match latest
740 .new_value
741 .as_deref()
742 .and_then(|v| v.parse::<Decimal>().ok())
743 {
744 Some(s) => s,
745 None => return base_annual_salary,
746 };
747
748 let effective = latest.effective_date;
749
750 if effective <= period_start {
751 new_salary
753 } else {
754 let total_days = (period_end - period_start).num_days() + 1;
756 let days_at_old = (effective - period_start).num_days();
757 let days_at_new = total_days - days_at_old;
758
759 let total = Decimal::from(total_days);
760 let old_fraction = Decimal::from(days_at_old) / total;
761 let new_fraction = Decimal::from(days_at_new) / total;
762
763 (base_annual_salary * old_fraction + new_salary * new_fraction).round_dp(2)
764 }
765 }
766}
767
768#[cfg(test)]
769#[allow(clippy::unwrap_used)]
770mod tests {
771 use super::*;
772
773 fn test_employees() -> Vec<(String, Decimal, Option<String>, Option<String>)> {
774 vec![
775 (
776 "EMP-001".to_string(),
777 Decimal::from(60_000),
778 Some("CC-100".to_string()),
779 Some("Engineering".to_string()),
780 ),
781 (
782 "EMP-002".to_string(),
783 Decimal::from(85_000),
784 Some("CC-200".to_string()),
785 Some("Finance".to_string()),
786 ),
787 (
788 "EMP-003".to_string(),
789 Decimal::from(120_000),
790 None,
791 Some("Sales".to_string()),
792 ),
793 ]
794 }
795
796 #[test]
797 fn test_basic_payroll_generation() {
798 let mut gen = PayrollGenerator::new(42);
799 let employees = test_employees();
800 let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
801 let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
802
803 let (run, items) = gen.generate("C001", &employees, period_start, period_end, "USD");
804
805 assert_eq!(run.company_code, "C001");
806 assert_eq!(run.currency, "USD");
807 assert_eq!(run.employee_count, 3);
808 assert_eq!(items.len(), 3);
809 assert!(run.total_gross > Decimal::ZERO);
810 assert!(run.total_deductions > Decimal::ZERO);
811 assert!(run.total_net > Decimal::ZERO);
812 assert!(run.total_employer_cost > run.total_gross);
813 assert_eq!(run.total_net, run.total_gross - run.total_deductions);
815
816 for item in &items {
817 assert_eq!(item.payroll_id, run.payroll_id);
818 assert!(item.gross_pay > Decimal::ZERO);
819 assert!(item.net_pay > Decimal::ZERO);
820 assert!(item.net_pay < item.gross_pay);
821 assert!(item.base_salary > Decimal::ZERO);
822 assert_eq!(item.pay_date, period_end);
823 assert!(item.tax_withholding_label.is_none());
825 assert!(item.social_security_label.is_none());
826 }
827 }
828
829 #[test]
830 fn test_deterministic_payroll() {
831 let employees = test_employees();
832 let period_start = NaiveDate::from_ymd_opt(2024, 3, 1).unwrap();
833 let period_end = NaiveDate::from_ymd_opt(2024, 3, 31).unwrap();
834
835 let mut gen1 = PayrollGenerator::new(42);
836 let (run1, items1) = gen1.generate("C001", &employees, period_start, period_end, "USD");
837
838 let mut gen2 = PayrollGenerator::new(42);
839 let (run2, items2) = gen2.generate("C001", &employees, period_start, period_end, "USD");
840
841 assert_eq!(run1.payroll_id, run2.payroll_id);
842 assert_eq!(run1.total_gross, run2.total_gross);
843 assert_eq!(run1.total_net, run2.total_net);
844 assert_eq!(run1.status, run2.status);
845 assert_eq!(items1.len(), items2.len());
846 for (a, b) in items1.iter().zip(items2.iter()) {
847 assert_eq!(a.line_id, b.line_id);
848 assert_eq!(a.gross_pay, b.gross_pay);
849 assert_eq!(a.net_pay, b.net_pay);
850 }
851 }
852
853 #[test]
854 fn test_payroll_deduction_components() {
855 let mut gen = PayrollGenerator::new(99);
856 let employees = vec![(
857 "EMP-010".to_string(),
858 Decimal::from(100_000),
859 Some("CC-300".to_string()),
860 Some("HR".to_string()),
861 )];
862 let period_start = NaiveDate::from_ymd_opt(2024, 6, 1).unwrap();
863 let period_end = NaiveDate::from_ymd_opt(2024, 6, 30).unwrap();
864
865 let (_run, items) = gen.generate("C001", &employees, period_start, period_end, "USD");
866 assert_eq!(items.len(), 1);
867
868 let item = &items[0];
869 let expected_monthly = (Decimal::from(100_000) / Decimal::from(12)).round_dp(2);
871 assert_eq!(item.base_salary, expected_monthly);
872
873 let deduction_sum = item.tax_withholding
875 + item.social_security
876 + item.health_insurance
877 + item.retirement_contribution
878 + item.other_deductions;
879 let expected_net = item.gross_pay - deduction_sum;
880 assert_eq!(item.net_pay, expected_net);
881
882 assert!(item.tax_withholding > Decimal::ZERO);
884 assert!(item.social_security > Decimal::ZERO);
885 }
886
887 fn us_country_pack() -> CountryPack {
893 use datasynth_core::country::schema::{PayrollCountryConfig, PayrollDeduction};
894 CountryPack {
895 country_code: "US".to_string(),
896 payroll: PayrollCountryConfig {
897 statutory_deductions: vec![
898 PayrollDeduction {
899 code: "FICA".to_string(),
900 name_en: "Federal Insurance Contributions Act".to_string(),
901 deduction_type: "percentage".to_string(),
902 rate: 0.0765,
903 ..Default::default()
904 },
905 PayrollDeduction {
906 code: "FIT".to_string(),
907 name_en: "Federal Income Tax".to_string(),
908 deduction_type: "progressive".to_string(),
909 rate: 0.0, ..Default::default()
911 },
912 PayrollDeduction {
913 code: "SIT".to_string(),
914 name_en: "State Income Tax".to_string(),
915 deduction_type: "percentage".to_string(),
916 rate: 0.05,
917 ..Default::default()
918 },
919 ],
920 ..Default::default()
921 },
922 ..Default::default()
923 }
924 }
925
926 fn de_country_pack() -> CountryPack {
928 use datasynth_core::country::schema::{PayrollCountryConfig, PayrollDeduction};
929 CountryPack {
930 country_code: "DE".to_string(),
931 payroll: PayrollCountryConfig {
932 pay_frequency: "monthly".to_string(),
933 currency: "EUR".to_string(),
934 statutory_deductions: vec![
935 PayrollDeduction {
936 code: "LOHNST".to_string(),
937 name_en: "Income Tax".to_string(),
938 type_field: "progressive".to_string(),
939 rate: 0.0, ..Default::default()
941 },
942 PayrollDeduction {
943 code: "SOLI".to_string(),
944 name_en: "Solidarity Surcharge".to_string(),
945 type_field: "percentage".to_string(),
946 rate: 0.055,
947 ..Default::default()
948 },
949 PayrollDeduction {
950 code: "KiSt".to_string(),
951 name_en: "Church Tax".to_string(),
952 type_field: "percentage".to_string(),
953 rate: 0.08,
954 optional: true,
955 ..Default::default()
956 },
957 PayrollDeduction {
958 code: "RV".to_string(),
959 name_en: "Pension Insurance".to_string(),
960 type_field: "percentage".to_string(),
961 rate: 0.093,
962 ..Default::default()
963 },
964 PayrollDeduction {
965 code: "KV".to_string(),
966 name_en: "Health Insurance".to_string(),
967 type_field: "percentage".to_string(),
968 rate: 0.073,
969 ..Default::default()
970 },
971 PayrollDeduction {
972 code: "AV".to_string(),
973 name_en: "Unemployment Insurance".to_string(),
974 type_field: "percentage".to_string(),
975 rate: 0.013,
976 ..Default::default()
977 },
978 PayrollDeduction {
979 code: "PV".to_string(),
980 name_en: "Long-Term Care Insurance".to_string(),
981 type_field: "percentage".to_string(),
982 rate: 0.017,
983 ..Default::default()
984 },
985 ],
986 employer_contributions: vec![
987 PayrollDeduction {
988 code: "AG-RV".to_string(),
989 name_en: "Employer Pension Insurance".to_string(),
990 type_field: "percentage".to_string(),
991 rate: 0.093,
992 ..Default::default()
993 },
994 PayrollDeduction {
995 code: "AG-KV".to_string(),
996 name_en: "Employer Health Insurance".to_string(),
997 type_field: "percentage".to_string(),
998 rate: 0.073,
999 ..Default::default()
1000 },
1001 ],
1002 ..Default::default()
1003 },
1004 ..Default::default()
1005 }
1006 }
1007
1008 #[test]
1009 fn test_generate_with_us_country_pack() {
1010 let mut gen = PayrollGenerator::new(42);
1011 let employees = test_employees();
1012 let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
1013 let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
1014 let pack = us_country_pack();
1015
1016 let (run, items) = gen.generate_with_country_pack(
1017 "C001",
1018 &employees,
1019 period_start,
1020 period_end,
1021 "USD",
1022 &pack,
1023 );
1024
1025 assert_eq!(run.company_code, "C001");
1026 assert_eq!(run.employee_count, 3);
1027 assert_eq!(items.len(), 3);
1028 assert_eq!(run.total_net, run.total_gross - run.total_deductions);
1029
1030 for item in &items {
1031 assert!(item.gross_pay > Decimal::ZERO);
1032 assert!(item.net_pay > Decimal::ZERO);
1033 assert!(item.net_pay < item.gross_pay);
1034 assert!(item.social_security > Decimal::ZERO);
1036 assert!(item.tax_withholding_label.is_some());
1038 assert!(item.social_security_label.is_some());
1039 }
1040 }
1041
1042 #[test]
1043 fn test_generate_with_de_country_pack() {
1044 let mut gen = PayrollGenerator::new(42);
1045 let employees = test_employees();
1046 let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
1047 let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
1048 let pack = de_country_pack();
1049
1050 let (run, items) = gen.generate_with_country_pack(
1051 "DE01",
1052 &employees,
1053 period_start,
1054 period_end,
1055 "EUR",
1056 &pack,
1057 );
1058
1059 assert_eq!(run.company_code, "DE01");
1060 assert_eq!(items.len(), 3);
1061 assert_eq!(run.total_net, run.total_gross - run.total_deductions);
1062
1063 let rates = gen.rates_from_country_pack(&pack);
1066 assert_eq!(
1067 rates.retirement_rate,
1068 Decimal::from_f64_retain(0.093).unwrap()
1069 );
1070 assert_eq!(rates.health_rate, Decimal::from_f64_retain(0.073).unwrap());
1071
1072 let item = &items[0];
1074 assert_eq!(
1075 item.health_insurance_label.as_deref(),
1076 Some("Health Insurance")
1077 );
1078 assert_eq!(
1079 item.retirement_contribution_label.as_deref(),
1080 Some("Pension Insurance")
1081 );
1082 assert!(item.employer_contribution_label.is_some());
1084 let ec = item.employer_contribution_label.as_ref().unwrap();
1085 assert!(ec.contains("Employer Pension Insurance"));
1086 assert!(ec.contains("Employer Health Insurance"));
1087 }
1088
1089 #[test]
1090 fn test_country_pack_falls_back_to_config_for_missing_categories() {
1091 let pack = CountryPack::default();
1093 let gen = PayrollGenerator::new(42);
1094 let rates_pack = gen.rates_from_country_pack(&pack);
1095 let rates_cfg = gen.rates_from_config();
1096
1097 assert_eq!(rates_pack.income_tax_rate, rates_cfg.income_tax_rate);
1098 assert_eq!(rates_pack.fica_rate, rates_cfg.fica_rate);
1099 assert_eq!(rates_pack.health_rate, rates_cfg.health_rate);
1100 assert_eq!(rates_pack.retirement_rate, rates_cfg.retirement_rate);
1101 }
1102
1103 #[test]
1104 fn test_country_pack_deterministic() {
1105 let employees = test_employees();
1106 let period_start = NaiveDate::from_ymd_opt(2024, 3, 1).unwrap();
1107 let period_end = NaiveDate::from_ymd_opt(2024, 3, 31).unwrap();
1108 let pack = de_country_pack();
1109
1110 let mut gen1 = PayrollGenerator::new(42);
1111 let (run1, items1) = gen1.generate_with_country_pack(
1112 "DE01",
1113 &employees,
1114 period_start,
1115 period_end,
1116 "EUR",
1117 &pack,
1118 );
1119
1120 let mut gen2 = PayrollGenerator::new(42);
1121 let (run2, items2) = gen2.generate_with_country_pack(
1122 "DE01",
1123 &employees,
1124 period_start,
1125 period_end,
1126 "EUR",
1127 &pack,
1128 );
1129
1130 assert_eq!(run1.payroll_id, run2.payroll_id);
1131 assert_eq!(run1.total_gross, run2.total_gross);
1132 assert_eq!(run1.total_net, run2.total_net);
1133 for (a, b) in items1.iter().zip(items2.iter()) {
1134 assert_eq!(a.net_pay, b.net_pay);
1135 }
1136 }
1137
1138 #[test]
1139 fn test_de_rates_differ_from_default() {
1140 let gen = PayrollGenerator::new(42);
1142 let pack = de_country_pack();
1143 let rates_cfg = gen.rates_from_config();
1144 let rates_de = gen.rates_from_country_pack(&pack);
1145
1146 assert_ne!(rates_de.health_rate, rates_cfg.health_rate);
1150 assert_ne!(rates_de.retirement_rate, rates_cfg.retirement_rate);
1151 }
1152
1153 #[test]
1154 fn test_set_country_pack_uses_labels() {
1155 let mut gen = PayrollGenerator::new(42);
1156 let pack = de_country_pack();
1157 gen.set_country_pack(pack);
1158
1159 let employees = test_employees();
1160 let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
1161 let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
1162
1163 let (_run, items) = gen.generate("DE01", &employees, period_start, period_end, "EUR");
1165
1166 let item = &items[0];
1167 assert!(item.tax_withholding_label.is_some());
1169 assert!(item.health_insurance_label.is_some());
1170 assert!(item.retirement_contribution_label.is_some());
1171 assert!(item.employer_contribution_label.is_some());
1172 }
1173
1174 #[test]
1175 fn test_compute_progressive_tax_us_brackets() {
1176 let brackets = vec![
1178 TaxBracket {
1179 above: Some(0.0),
1180 up_to: Some(11_000.0),
1181 rate: 0.10,
1182 },
1183 TaxBracket {
1184 above: Some(11_000.0),
1185 up_to: Some(44_725.0),
1186 rate: 0.12,
1187 },
1188 TaxBracket {
1189 above: Some(44_725.0),
1190 up_to: Some(95_375.0),
1191 rate: 0.22,
1192 },
1193 TaxBracket {
1194 above: Some(95_375.0),
1195 up_to: None,
1196 rate: 0.24,
1197 },
1198 ];
1199
1200 let tax = PayrollGenerator::compute_progressive_tax(Decimal::from(60_000), &brackets);
1202 assert_eq!(tax, Decimal::from_f64_retain(8507.50).unwrap());
1207
1208 let tax = PayrollGenerator::compute_progressive_tax(Decimal::from(11_000), &brackets);
1210 assert_eq!(tax, Decimal::from_f64_retain(1100.0).unwrap());
1211 }
1212
1213 #[test]
1214 fn test_progressive_tax_zero_income() {
1215 let brackets = vec![TaxBracket {
1216 above: Some(0.0),
1217 up_to: Some(10_000.0),
1218 rate: 0.10,
1219 }];
1220 let tax = PayrollGenerator::compute_progressive_tax(Decimal::ZERO, &brackets);
1221 assert_eq!(tax, Decimal::ZERO);
1222 }
1223
1224 #[test]
1225 fn test_us_pack_employees_have_varying_rates() {
1226 use datasynth_core::country::schema::{
1227 CountryTaxConfig, PayrollCountryConfig, PayrollDeduction, PayrollTaxBracketsConfig,
1228 };
1229
1230 let brackets = vec![
1231 TaxBracket {
1232 above: Some(0.0),
1233 up_to: Some(11_000.0),
1234 rate: 0.10,
1235 },
1236 TaxBracket {
1237 above: Some(11_000.0),
1238 up_to: Some(44_725.0),
1239 rate: 0.12,
1240 },
1241 TaxBracket {
1242 above: Some(44_725.0),
1243 up_to: None,
1244 rate: 0.22,
1245 },
1246 ];
1247 let pack = CountryPack {
1248 country_code: "US".to_string(),
1249 payroll: PayrollCountryConfig {
1250 statutory_deductions: vec![
1251 PayrollDeduction {
1252 code: "FIT".to_string(),
1253 name_en: "Federal Income Tax".to_string(),
1254 deduction_type: "progressive".to_string(),
1255 rate: 0.0,
1256 ..Default::default()
1257 },
1258 PayrollDeduction {
1259 code: "FICA".to_string(),
1260 name_en: "Social Security".to_string(),
1261 deduction_type: "percentage".to_string(),
1262 rate: 0.0765,
1263 ..Default::default()
1264 },
1265 ],
1266 ..Default::default()
1267 },
1268 tax: CountryTaxConfig {
1269 payroll_tax: PayrollTaxBracketsConfig {
1270 income_tax_brackets: brackets,
1271 ..Default::default()
1272 },
1273 ..Default::default()
1274 },
1275 ..Default::default()
1276 };
1277
1278 let mut gen = PayrollGenerator::new(42);
1279 gen.set_country_pack(pack);
1280
1281 let low_earner = vec![("LOW".to_string(), Decimal::from(30_000), None, None)];
1283 let high_earner = vec![("HIGH".to_string(), Decimal::from(200_000), None, None)];
1284
1285 let start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
1286 let end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
1287
1288 let (_, low_items) = gen.generate("C001", &low_earner, start, end, "USD");
1289 let mut gen2 = PayrollGenerator::new(42);
1290 gen2.set_country_pack(CountryPack {
1291 country_code: "US".to_string(),
1292 payroll: PayrollCountryConfig {
1293 statutory_deductions: vec![
1294 PayrollDeduction {
1295 code: "FIT".to_string(),
1296 name_en: "Federal Income Tax".to_string(),
1297 deduction_type: "progressive".to_string(),
1298 rate: 0.0,
1299 ..Default::default()
1300 },
1301 PayrollDeduction {
1302 code: "FICA".to_string(),
1303 name_en: "Social Security".to_string(),
1304 deduction_type: "percentage".to_string(),
1305 rate: 0.0765,
1306 ..Default::default()
1307 },
1308 ],
1309 ..Default::default()
1310 },
1311 tax: CountryTaxConfig {
1312 payroll_tax: PayrollTaxBracketsConfig {
1313 income_tax_brackets: vec![
1314 TaxBracket {
1315 above: Some(0.0),
1316 up_to: Some(11_000.0),
1317 rate: 0.10,
1318 },
1319 TaxBracket {
1320 above: Some(11_000.0),
1321 up_to: Some(44_725.0),
1322 rate: 0.12,
1323 },
1324 TaxBracket {
1325 above: Some(44_725.0),
1326 up_to: None,
1327 rate: 0.22,
1328 },
1329 ],
1330 ..Default::default()
1331 },
1332 ..Default::default()
1333 },
1334 ..Default::default()
1335 });
1336 let (_, high_items) = gen2.generate("C001", &high_earner, start, end, "USD");
1337
1338 let low_eff = low_items[0].tax_withholding / low_items[0].gross_pay;
1339 let high_eff = high_items[0].tax_withholding / high_items[0].gross_pay;
1340
1341 assert!(
1343 high_eff > low_eff,
1344 "High earner effective rate ({high_eff}) should exceed low earner ({low_eff})"
1345 );
1346 }
1347
1348 #[test]
1349 fn test_empty_pack_labels_are_none() {
1350 let pack = CountryPack::default();
1351 let labels = PayrollGenerator::labels_from_country_pack(&pack);
1352 assert!(labels.tax_withholding.is_none());
1353 assert!(labels.social_security.is_none());
1354 assert!(labels.health_insurance.is_none());
1355 assert!(labels.retirement_contribution.is_none());
1356 assert!(labels.employer_contribution.is_none());
1357 }
1358
1359 #[test]
1360 fn test_us_pack_labels() {
1361 let pack = us_country_pack();
1362 let labels = PayrollGenerator::labels_from_country_pack(&pack);
1363 assert!(labels.tax_withholding.is_some());
1365 let tw = labels.tax_withholding.unwrap();
1366 assert!(tw.contains("Federal Income Tax"));
1367 assert!(tw.contains("State Income Tax"));
1368 assert!(labels.social_security.is_some());
1370 assert!(labels
1371 .social_security
1372 .unwrap()
1373 .contains("Federal Insurance Contributions Act"));
1374 }
1375}