1use chrono::NaiveDate;
8use datasynth_config::schema::PayrollConfig;
9use datasynth_core::country::schema::TaxBracket;
10use datasynth_core::models::{PayrollLineItem, PayrollRun, PayrollRunStatus};
11use datasynth_core::utils::{sample_decimal_range, seeded_rng};
12use datasynth_core::uuid_factory::{DeterministicUuidFactory, GeneratorType};
13use datasynth_core::CountryPack;
14use rand::prelude::*;
15use rand_chacha::ChaCha8Rng;
16use rust_decimal::Decimal;
17use tracing::debug;
18
19#[derive(Debug, Clone)]
21struct PayrollRates {
22 income_tax_rate: Decimal,
24 income_tax_brackets: Vec<TaxBracket>,
26 fica_rate: Decimal,
28 health_rate: Decimal,
30 retirement_rate: Decimal,
32 employer_fica_rate: Decimal,
34}
35
36#[derive(Debug, Clone, Default)]
38struct DeductionLabels {
39 tax_withholding: Option<String>,
40 social_security: Option<String>,
41 health_insurance: Option<String>,
42 retirement_contribution: Option<String>,
43 employer_contribution: Option<String>,
44}
45
46pub struct PayrollGenerator {
48 rng: ChaCha8Rng,
49 uuid_factory: DeterministicUuidFactory,
50 line_uuid_factory: DeterministicUuidFactory,
51 config: PayrollConfig,
52 country_pack: Option<CountryPack>,
53 employee_ids_pool: Vec<String>,
55 cost_center_ids_pool: Vec<String>,
58}
59
60impl PayrollGenerator {
61 pub fn new(seed: u64) -> Self {
63 Self {
64 rng: seeded_rng(seed, 0),
65 uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::PayrollRun),
66 line_uuid_factory: DeterministicUuidFactory::with_sub_discriminator(
67 seed,
68 GeneratorType::PayrollRun,
69 1,
70 ),
71 config: PayrollConfig::default(),
72 country_pack: None,
73 employee_ids_pool: Vec::new(),
74 cost_center_ids_pool: Vec::new(),
75 }
76 }
77
78 pub fn with_config(seed: u64, config: PayrollConfig) -> Self {
80 Self {
81 rng: seeded_rng(seed, 0),
82 uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::PayrollRun),
83 line_uuid_factory: DeterministicUuidFactory::with_sub_discriminator(
84 seed,
85 GeneratorType::PayrollRun,
86 1,
87 ),
88 config,
89 country_pack: None,
90 employee_ids_pool: Vec::new(),
91 cost_center_ids_pool: Vec::new(),
92 }
93 }
94
95 pub fn with_pools(mut self, employee_ids: Vec<String>, cost_center_ids: Vec<String>) -> Self {
100 self.employee_ids_pool = employee_ids;
101 self.cost_center_ids_pool = cost_center_ids;
102 self
103 }
104
105 pub fn set_country_pack(&mut self, pack: CountryPack) {
113 self.country_pack = Some(pack);
114 }
115
116 pub fn generate(
131 &mut self,
132 company_code: &str,
133 employees: &[(String, Decimal, Option<String>, Option<String>)],
134 period_start: NaiveDate,
135 period_end: NaiveDate,
136 currency: &str,
137 ) -> (PayrollRun, Vec<PayrollLineItem>) {
138 debug!(company_code, employee_count = employees.len(), %period_start, %period_end, currency, "Generating payroll run");
139 if let Some(pack) = self.country_pack.as_ref() {
140 let rates = self.rates_from_country_pack(pack);
141 let labels = Self::labels_from_country_pack(pack);
142 self.generate_with_rates_and_labels(
143 company_code,
144 employees,
145 period_start,
146 period_end,
147 currency,
148 &rates,
149 &labels,
150 )
151 } else {
152 let rates = self.rates_from_config();
153 self.generate_with_rates_and_labels(
154 company_code,
155 employees,
156 period_start,
157 period_end,
158 currency,
159 &rates,
160 &DeductionLabels::default(),
161 )
162 }
163 }
164
165 pub fn generate_with_country_pack(
186 &mut self,
187 company_code: &str,
188 employees: &[(String, Decimal, Option<String>, Option<String>)],
189 period_start: NaiveDate,
190 period_end: NaiveDate,
191 currency: &str,
192 pack: &CountryPack,
193 ) -> (PayrollRun, Vec<PayrollLineItem>) {
194 let rates = self.rates_from_country_pack(pack);
195 let labels = Self::labels_from_country_pack(pack);
196 self.generate_with_rates_and_labels(
197 company_code,
198 employees,
199 period_start,
200 period_end,
201 currency,
202 &rates,
203 &labels,
204 )
205 }
206
207 fn rates_from_config(&self) -> PayrollRates {
213 let federal_rate = Decimal::from_f64_retain(self.config.tax_rates.federal_effective)
214 .unwrap_or(Decimal::ZERO);
215 let state_rate = Decimal::from_f64_retain(self.config.tax_rates.state_effective)
216 .unwrap_or(Decimal::ZERO);
217 let fica_rate =
218 Decimal::from_f64_retain(self.config.tax_rates.fica).unwrap_or(Decimal::ZERO);
219
220 PayrollRates {
221 income_tax_rate: federal_rate + state_rate,
222 income_tax_brackets: Vec::new(),
223 fica_rate,
224 health_rate: Decimal::from_f64_retain(0.03).unwrap_or(Decimal::ZERO),
225 retirement_rate: Decimal::from_f64_retain(0.05).unwrap_or(Decimal::ZERO),
226 employer_fica_rate: fica_rate,
227 }
228 }
229
230 fn compute_progressive_tax(annual_income: Decimal, brackets: &[TaxBracket]) -> Decimal {
236 let mut total_tax = Decimal::ZERO;
237 let mut taxed_up_to = Decimal::ZERO;
238
239 for bracket in brackets {
240 let bracket_floor = bracket
241 .above
242 .and_then(Decimal::from_f64_retain)
243 .unwrap_or(taxed_up_to);
244 let bracket_rate = Decimal::from_f64_retain(bracket.rate).unwrap_or(Decimal::ZERO);
245
246 if annual_income <= bracket_floor {
247 break;
248 }
249
250 let taxable_in_bracket = if let Some(ceiling) = bracket.up_to {
251 let ceiling = Decimal::from_f64_retain(ceiling).unwrap_or(Decimal::ZERO);
252 (annual_income.min(ceiling) - bracket_floor).max(Decimal::ZERO)
253 } else {
254 (annual_income - bracket_floor).max(Decimal::ZERO)
256 };
257
258 total_tax += (taxable_in_bracket * bracket_rate).round_dp(2);
259 taxed_up_to = bracket
260 .up_to
261 .and_then(Decimal::from_f64_retain)
262 .unwrap_or(annual_income);
263 }
264
265 total_tax.round_dp(2)
266 }
267
268 fn rates_from_country_pack(&self, pack: &CountryPack) -> PayrollRates {
271 let fallback = self.rates_from_config();
272
273 let mut federal_tax = Decimal::ZERO;
276 let mut state_tax = Decimal::ZERO;
277 let mut fica = Decimal::ZERO;
278 let mut health = Decimal::ZERO;
279 let mut retirement = Decimal::ZERO;
280
281 let mut found_federal = false;
283 let mut found_state = false;
284 let mut found_fica = false;
285 let mut found_health = false;
286 let mut found_retirement = false;
287
288 for ded in &pack.payroll.statutory_deductions {
289 let code_upper = ded.code.to_uppercase();
290 let name_en_lower = ded.name_en.to_lowercase();
291 let rate = Decimal::from_f64_retain(ded.rate).unwrap_or(Decimal::ZERO);
292
293 if (ded.deduction_type == "progressive" || ded.type_field == "progressive")
298 && ded.rate == 0.0
299 {
300 if code_upper == "FIT"
301 || code_upper == "LOHNST"
302 || (name_en_lower.contains("income tax") && !name_en_lower.contains("state"))
303 {
304 found_federal = true;
305 }
306 continue;
307 }
308
309 if code_upper == "FIT"
310 || code_upper == "LOHNST"
311 || (name_en_lower.contains("income tax") && !name_en_lower.contains("state"))
312 {
313 federal_tax += rate;
314 found_federal = true;
315 } else if code_upper == "SIT" || name_en_lower.contains("state income tax") {
316 state_tax += rate;
317 found_state = true;
318 } else if code_upper == "FICA" || name_en_lower.contains("social security") {
319 fica += rate;
320 found_fica = true;
321 } else if name_en_lower.contains("health insurance") {
322 health += rate;
323 found_health = true;
324 } else if name_en_lower.contains("pension") || name_en_lower.contains("retirement") {
325 retirement += rate;
326 found_retirement = true;
327 } else {
328 fica += rate;
333 found_fica = true;
334 }
335 }
336
337 PayrollRates {
338 income_tax_rate: if found_federal || found_state {
339 let f = if found_federal {
340 federal_tax
341 } else {
342 fallback.income_tax_rate
343 - Decimal::from_f64_retain(self.config.tax_rates.state_effective)
344 .unwrap_or(Decimal::ZERO)
345 };
346 let s = if found_state {
347 state_tax
348 } else {
349 Decimal::from_f64_retain(self.config.tax_rates.state_effective)
350 .unwrap_or(Decimal::ZERO)
351 };
352 f + s
353 } else {
354 fallback.income_tax_rate
355 },
356 income_tax_brackets: pack.tax.payroll_tax.income_tax_brackets.clone(),
357 fica_rate: if found_fica { fica } else { fallback.fica_rate },
358 health_rate: if found_health {
359 health
360 } else {
361 fallback.health_rate
362 },
363 retirement_rate: if found_retirement {
364 retirement
365 } else {
366 fallback.retirement_rate
367 },
368 employer_fica_rate: if found_fica {
369 fica
370 } else {
371 fallback.employer_fica_rate
372 },
373 }
374 }
375
376 fn labels_from_country_pack(pack: &CountryPack) -> DeductionLabels {
383 let mut labels = DeductionLabels::default();
384
385 for ded in &pack.payroll.statutory_deductions {
386 let code_upper = ded.code.to_uppercase();
387 let name_en_lower = ded.name_en.to_lowercase();
388
389 let label = if ded.name.is_empty() {
392 ded.name_en.clone()
393 } else {
394 ded.name.clone()
395 };
396 if label.is_empty() {
397 continue;
398 }
399
400 if (ded.deduction_type == "progressive" || ded.type_field == "progressive")
403 && ded.rate == 0.0
404 {
405 if code_upper == "FIT"
406 || code_upper == "LOHNST"
407 || (name_en_lower.contains("income tax") && !name_en_lower.contains("state"))
408 {
409 if labels.tax_withholding.is_none() {
410 labels.tax_withholding = Some(label);
411 }
412 } else if code_upper == "SIT" || name_en_lower.contains("state income tax") {
413 labels.tax_withholding = Some(match labels.tax_withholding.take() {
414 Some(existing) => format!("{existing}; {label}"),
415 None => label,
416 });
417 }
418 continue;
419 }
420
421 if code_upper == "FIT"
422 || code_upper == "LOHNST"
423 || code_upper == "SIT"
424 || name_en_lower.contains("income tax")
425 || name_en_lower.contains("state income tax")
426 {
427 labels.tax_withholding = Some(match labels.tax_withholding.take() {
430 Some(existing) => format!("{existing}; {label}"),
431 None => label,
432 });
433 } else if code_upper == "FICA" || name_en_lower.contains("social security") {
434 labels.social_security = Some(match labels.social_security.take() {
435 Some(existing) => format!("{existing}; {label}"),
436 None => label,
437 });
438 } else if name_en_lower.contains("health insurance") {
439 if labels.health_insurance.is_none() {
440 labels.health_insurance = Some(label);
441 }
442 } else if name_en_lower.contains("pension") || name_en_lower.contains("retirement") {
443 if labels.retirement_contribution.is_none() {
444 labels.retirement_contribution = Some(label);
445 }
446 } else {
447 labels.social_security = Some(match labels.social_security.take() {
450 Some(existing) => format!("{existing}; {label}"),
451 None => label,
452 });
453 }
454 }
455
456 let emp_labels: Vec<String> = pack
458 .payroll
459 .employer_contributions
460 .iter()
461 .filter_map(|c| {
462 let l = if c.name.is_empty() {
463 c.name_en.clone()
464 } else {
465 c.name.clone()
466 };
467 if l.is_empty() {
468 None
469 } else {
470 Some(l)
471 }
472 })
473 .collect();
474 if !emp_labels.is_empty() {
475 labels.employer_contribution = Some(emp_labels.join("; "));
476 }
477
478 labels
479 }
480
481 fn generate_with_rates_and_labels(
483 &mut self,
484 company_code: &str,
485 employees: &[(String, Decimal, Option<String>, Option<String>)],
486 period_start: NaiveDate,
487 period_end: NaiveDate,
488 currency: &str,
489 rates: &PayrollRates,
490 labels: &DeductionLabels,
491 ) -> (PayrollRun, Vec<PayrollLineItem>) {
492 let payroll_id = self.uuid_factory.next().to_string();
493
494 let mut line_items = Vec::with_capacity(employees.len());
495 let mut total_gross = Decimal::ZERO;
496 let mut total_deductions = Decimal::ZERO;
497 let mut total_net = Decimal::ZERO;
498 let mut total_employer_cost = Decimal::ZERO;
499
500 let benefits_enrolled = self.config.benefits_enrollment_rate;
501 let retirement_participating = self.config.retirement_participation_rate;
502
503 for (employee_id, base_salary, cost_center, department) in employees {
504 let line_id = self.line_uuid_factory.next().to_string();
505
506 let monthly_base = (*base_salary / Decimal::from(12)).round_dp(2);
508
509 let (overtime_pay, overtime_hours) = if self.rng.random_bool(0.10) {
511 let ot_hours = self.rng.random_range(1.0..=20.0);
512 let hourly_rate = *base_salary / Decimal::from(2080);
514 let ot_rate = hourly_rate * Decimal::from_f64_retain(1.5).unwrap_or(Decimal::ONE);
515 let ot_pay = (ot_rate
516 * Decimal::from_f64_retain(ot_hours).unwrap_or(Decimal::ZERO))
517 .round_dp(2);
518 (ot_pay, ot_hours)
519 } else {
520 (Decimal::ZERO, 0.0)
521 };
522
523 let bonus = if self.rng.random_bool(0.05) {
525 let pct = self.rng.random_range(0.01..=0.10);
526 (monthly_base * Decimal::from_f64_retain(pct).unwrap_or(Decimal::ZERO)).round_dp(2)
527 } else {
528 Decimal::ZERO
529 };
530
531 let gross_pay = monthly_base + overtime_pay + bonus;
532
533 let tax_withholding = if !rates.income_tax_brackets.is_empty() {
535 let annual = gross_pay * Decimal::from(12);
536 Self::compute_progressive_tax(annual, &rates.income_tax_brackets)
537 / Decimal::from(12)
538 } else {
539 (gross_pay * rates.income_tax_rate).round_dp(2)
540 };
541 let social_security = (gross_pay * rates.fica_rate).round_dp(2);
542
543 let health_insurance = if self.rng.random_bool(benefits_enrolled) {
544 (gross_pay * rates.health_rate).round_dp(2)
545 } else {
546 Decimal::ZERO
547 };
548
549 let retirement_contribution = if self.rng.random_bool(retirement_participating) {
550 (gross_pay * rates.retirement_rate).round_dp(2)
551 } else {
552 Decimal::ZERO
553 };
554
555 let other_deductions = if self.rng.random_bool(0.03) {
557 sample_decimal_range(&mut self.rng, Decimal::from(50), Decimal::from(500))
558 .round_dp(2)
559 } else {
560 Decimal::ZERO
561 };
562
563 let total_ded = tax_withholding
564 + social_security
565 + health_insurance
566 + retirement_contribution
567 + other_deductions;
568 let net_pay = gross_pay - total_ded;
569
570 let hours_worked = 160.0;
572
573 let employer_contrib = (gross_pay * rates.employer_fica_rate).round_dp(2);
575 let employer_cost = gross_pay + employer_contrib;
576
577 total_gross += gross_pay;
578 total_deductions += total_ded;
579 total_net += net_pay;
580 total_employer_cost += employer_cost;
581
582 line_items.push(PayrollLineItem {
583 payroll_id: payroll_id.clone(),
584 employee_id: employee_id.clone(),
585 line_id,
586 gross_pay,
587 base_salary: monthly_base,
588 overtime_pay,
589 bonus,
590 tax_withholding,
591 social_security,
592 health_insurance,
593 retirement_contribution,
594 other_deductions,
595 net_pay,
596 hours_worked,
597 overtime_hours,
598 pay_date: period_end,
599 cost_center: cost_center.clone(),
600 department: department.clone(),
601 tax_withholding_label: labels.tax_withholding.clone(),
602 social_security_label: labels.social_security.clone(),
603 health_insurance_label: labels.health_insurance.clone(),
604 retirement_contribution_label: labels.retirement_contribution.clone(),
605 employer_contribution_label: labels.employer_contribution.clone(),
606 });
607 }
608
609 let status_roll: f64 = self.rng.random();
611 let status = if status_roll < 0.60 {
612 PayrollRunStatus::Posted
613 } else if status_roll < 0.85 {
614 PayrollRunStatus::Approved
615 } else if status_roll < 0.95 {
616 PayrollRunStatus::Calculated
617 } else {
618 PayrollRunStatus::Draft
619 };
620
621 let approved_by = if matches!(
622 status,
623 PayrollRunStatus::Approved | PayrollRunStatus::Posted
624 ) {
625 if !self.employee_ids_pool.is_empty() {
626 let idx = self.rng.random_range(0..self.employee_ids_pool.len());
627 Some(self.employee_ids_pool[idx].clone())
628 } else {
629 Some(format!("USR-{:04}", self.rng.random_range(201..=400)))
630 }
631 } else {
632 None
633 };
634
635 let posted_by = if status == PayrollRunStatus::Posted {
636 if !self.employee_ids_pool.is_empty() {
637 let idx = self.rng.random_range(0..self.employee_ids_pool.len());
638 Some(self.employee_ids_pool[idx].clone())
639 } else {
640 Some(format!("USR-{:04}", self.rng.random_range(401..=500)))
641 }
642 } else {
643 None
644 };
645
646 let run = PayrollRun {
647 company_code: company_code.to_string(),
648 payroll_id: payroll_id.clone(),
649 pay_period_start: period_start,
650 pay_period_end: period_end,
651 run_date: period_end,
652 status,
653 total_gross,
654 total_deductions,
655 total_net,
656 total_employer_cost,
657 employee_count: employees.len() as u32,
658 currency: currency.to_string(),
659 posted_by,
660 approved_by,
661 };
662
663 (run, line_items)
664 }
665
666 pub fn generate_with_changes(
683 &mut self,
684 company_code: &str,
685 employees: &[(String, Decimal, Option<String>, Option<String>)],
686 period_start: NaiveDate,
687 period_end: NaiveDate,
688 currency: &str,
689 changes: &[datasynth_core::models::EmployeeChangeEvent],
690 ) -> (PayrollRun, Vec<PayrollLineItem>) {
691 let adjusted: Vec<(String, Decimal, Option<String>, Option<String>)> = employees
692 .iter()
693 .map(|(id, salary, cc, dept)| {
694 let adjusted_salary =
695 Self::apply_salary_changes(id, *salary, period_start, period_end, changes);
696 (id.clone(), adjusted_salary, cc.clone(), dept.clone())
697 })
698 .collect();
699 self.generate(company_code, &adjusted, period_start, period_end, currency)
700 }
701
702 fn apply_salary_changes(
709 employee_id: &str,
710 base_annual_salary: Decimal,
711 period_start: NaiveDate,
712 period_end: NaiveDate,
713 changes: &[datasynth_core::models::EmployeeChangeEvent],
714 ) -> Decimal {
715 use datasynth_core::models::EmployeeEventType;
716
717 let relevant: Vec<&datasynth_core::models::EmployeeChangeEvent> = changes
720 .iter()
721 .filter(|c| {
722 c.employee_id == employee_id
723 && c.event_type == EmployeeEventType::SalaryAdjustment
724 && c.effective_date <= period_end
725 })
726 .collect();
727
728 if relevant.is_empty() {
729 return base_annual_salary;
730 }
731
732 let latest = relevant
734 .iter()
735 .max_by_key(|c| c.effective_date)
736 .expect("non-empty slice always has a max");
737
738 let new_salary = match latest
740 .new_value
741 .as_deref()
742 .and_then(|v| v.parse::<Decimal>().ok())
743 {
744 Some(s) => s,
745 None => return base_annual_salary,
746 };
747
748 let effective = latest.effective_date;
749
750 if effective <= period_start {
751 new_salary
753 } else {
754 let total_days = (period_end - period_start).num_days() + 1;
756 let days_at_old = (effective - period_start).num_days();
757 let days_at_new = total_days - days_at_old;
758
759 let total = Decimal::from(total_days);
760 let old_fraction = Decimal::from(days_at_old) / total;
761 let new_fraction = Decimal::from(days_at_new) / total;
762
763 (base_annual_salary * old_fraction + new_salary * new_fraction).round_dp(2)
764 }
765 }
766}
767
768#[cfg(test)]
769mod tests {
770 use super::*;
771
772 fn test_employees() -> Vec<(String, Decimal, Option<String>, Option<String>)> {
773 vec![
774 (
775 "EMP-001".to_string(),
776 Decimal::from(60_000),
777 Some("CC-100".to_string()),
778 Some("Engineering".to_string()),
779 ),
780 (
781 "EMP-002".to_string(),
782 Decimal::from(85_000),
783 Some("CC-200".to_string()),
784 Some("Finance".to_string()),
785 ),
786 (
787 "EMP-003".to_string(),
788 Decimal::from(120_000),
789 None,
790 Some("Sales".to_string()),
791 ),
792 ]
793 }
794
795 #[test]
796 fn test_basic_payroll_generation() {
797 let mut gen = PayrollGenerator::new(42);
798 let employees = test_employees();
799 let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
800 let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
801
802 let (run, items) = gen.generate("C001", &employees, period_start, period_end, "USD");
803
804 assert_eq!(run.company_code, "C001");
805 assert_eq!(run.currency, "USD");
806 assert_eq!(run.employee_count, 3);
807 assert_eq!(items.len(), 3);
808 assert!(run.total_gross > Decimal::ZERO);
809 assert!(run.total_deductions > Decimal::ZERO);
810 assert!(run.total_net > Decimal::ZERO);
811 assert!(run.total_employer_cost > run.total_gross);
812 assert_eq!(run.total_net, run.total_gross - run.total_deductions);
814
815 for item in &items {
816 assert_eq!(item.payroll_id, run.payroll_id);
817 assert!(item.gross_pay > Decimal::ZERO);
818 assert!(item.net_pay > Decimal::ZERO);
819 assert!(item.net_pay < item.gross_pay);
820 assert!(item.base_salary > Decimal::ZERO);
821 assert_eq!(item.pay_date, period_end);
822 assert!(item.tax_withholding_label.is_none());
824 assert!(item.social_security_label.is_none());
825 }
826 }
827
828 #[test]
829 fn test_deterministic_payroll() {
830 let employees = test_employees();
831 let period_start = NaiveDate::from_ymd_opt(2024, 3, 1).unwrap();
832 let period_end = NaiveDate::from_ymd_opt(2024, 3, 31).unwrap();
833
834 let mut gen1 = PayrollGenerator::new(42);
835 let (run1, items1) = gen1.generate("C001", &employees, period_start, period_end, "USD");
836
837 let mut gen2 = PayrollGenerator::new(42);
838 let (run2, items2) = gen2.generate("C001", &employees, period_start, period_end, "USD");
839
840 assert_eq!(run1.payroll_id, run2.payroll_id);
841 assert_eq!(run1.total_gross, run2.total_gross);
842 assert_eq!(run1.total_net, run2.total_net);
843 assert_eq!(run1.status, run2.status);
844 assert_eq!(items1.len(), items2.len());
845 for (a, b) in items1.iter().zip(items2.iter()) {
846 assert_eq!(a.line_id, b.line_id);
847 assert_eq!(a.gross_pay, b.gross_pay);
848 assert_eq!(a.net_pay, b.net_pay);
849 }
850 }
851
852 #[test]
853 fn test_payroll_deduction_components() {
854 let mut gen = PayrollGenerator::new(99);
855 let employees = vec![(
856 "EMP-010".to_string(),
857 Decimal::from(100_000),
858 Some("CC-300".to_string()),
859 Some("HR".to_string()),
860 )];
861 let period_start = NaiveDate::from_ymd_opt(2024, 6, 1).unwrap();
862 let period_end = NaiveDate::from_ymd_opt(2024, 6, 30).unwrap();
863
864 let (_run, items) = gen.generate("C001", &employees, period_start, period_end, "USD");
865 assert_eq!(items.len(), 1);
866
867 let item = &items[0];
868 let expected_monthly = (Decimal::from(100_000) / Decimal::from(12)).round_dp(2);
870 assert_eq!(item.base_salary, expected_monthly);
871
872 let deduction_sum = item.tax_withholding
874 + item.social_security
875 + item.health_insurance
876 + item.retirement_contribution
877 + item.other_deductions;
878 let expected_net = item.gross_pay - deduction_sum;
879 assert_eq!(item.net_pay, expected_net);
880
881 assert!(item.tax_withholding > Decimal::ZERO);
883 assert!(item.social_security > Decimal::ZERO);
884 }
885
886 fn us_country_pack() -> CountryPack {
892 use datasynth_core::country::schema::{PayrollCountryConfig, PayrollDeduction};
893 CountryPack {
894 country_code: "US".to_string(),
895 payroll: PayrollCountryConfig {
896 statutory_deductions: vec![
897 PayrollDeduction {
898 code: "FICA".to_string(),
899 name_en: "Federal Insurance Contributions Act".to_string(),
900 deduction_type: "percentage".to_string(),
901 rate: 0.0765,
902 ..Default::default()
903 },
904 PayrollDeduction {
905 code: "FIT".to_string(),
906 name_en: "Federal Income Tax".to_string(),
907 deduction_type: "progressive".to_string(),
908 rate: 0.0, ..Default::default()
910 },
911 PayrollDeduction {
912 code: "SIT".to_string(),
913 name_en: "State Income Tax".to_string(),
914 deduction_type: "percentage".to_string(),
915 rate: 0.05,
916 ..Default::default()
917 },
918 ],
919 ..Default::default()
920 },
921 ..Default::default()
922 }
923 }
924
925 fn de_country_pack() -> CountryPack {
927 use datasynth_core::country::schema::{PayrollCountryConfig, PayrollDeduction};
928 CountryPack {
929 country_code: "DE".to_string(),
930 payroll: PayrollCountryConfig {
931 pay_frequency: "monthly".to_string(),
932 currency: "EUR".to_string(),
933 statutory_deductions: vec![
934 PayrollDeduction {
935 code: "LOHNST".to_string(),
936 name_en: "Income Tax".to_string(),
937 type_field: "progressive".to_string(),
938 rate: 0.0, ..Default::default()
940 },
941 PayrollDeduction {
942 code: "SOLI".to_string(),
943 name_en: "Solidarity Surcharge".to_string(),
944 type_field: "percentage".to_string(),
945 rate: 0.055,
946 ..Default::default()
947 },
948 PayrollDeduction {
949 code: "KiSt".to_string(),
950 name_en: "Church Tax".to_string(),
951 type_field: "percentage".to_string(),
952 rate: 0.08,
953 optional: true,
954 ..Default::default()
955 },
956 PayrollDeduction {
957 code: "RV".to_string(),
958 name_en: "Pension Insurance".to_string(),
959 type_field: "percentage".to_string(),
960 rate: 0.093,
961 ..Default::default()
962 },
963 PayrollDeduction {
964 code: "KV".to_string(),
965 name_en: "Health Insurance".to_string(),
966 type_field: "percentage".to_string(),
967 rate: 0.073,
968 ..Default::default()
969 },
970 PayrollDeduction {
971 code: "AV".to_string(),
972 name_en: "Unemployment Insurance".to_string(),
973 type_field: "percentage".to_string(),
974 rate: 0.013,
975 ..Default::default()
976 },
977 PayrollDeduction {
978 code: "PV".to_string(),
979 name_en: "Long-Term Care Insurance".to_string(),
980 type_field: "percentage".to_string(),
981 rate: 0.017,
982 ..Default::default()
983 },
984 ],
985 employer_contributions: vec![
986 PayrollDeduction {
987 code: "AG-RV".to_string(),
988 name_en: "Employer Pension Insurance".to_string(),
989 type_field: "percentage".to_string(),
990 rate: 0.093,
991 ..Default::default()
992 },
993 PayrollDeduction {
994 code: "AG-KV".to_string(),
995 name_en: "Employer Health Insurance".to_string(),
996 type_field: "percentage".to_string(),
997 rate: 0.073,
998 ..Default::default()
999 },
1000 ],
1001 ..Default::default()
1002 },
1003 ..Default::default()
1004 }
1005 }
1006
1007 #[test]
1008 fn test_generate_with_us_country_pack() {
1009 let mut gen = PayrollGenerator::new(42);
1010 let employees = test_employees();
1011 let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
1012 let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
1013 let pack = us_country_pack();
1014
1015 let (run, items) = gen.generate_with_country_pack(
1016 "C001",
1017 &employees,
1018 period_start,
1019 period_end,
1020 "USD",
1021 &pack,
1022 );
1023
1024 assert_eq!(run.company_code, "C001");
1025 assert_eq!(run.employee_count, 3);
1026 assert_eq!(items.len(), 3);
1027 assert_eq!(run.total_net, run.total_gross - run.total_deductions);
1028
1029 for item in &items {
1030 assert!(item.gross_pay > Decimal::ZERO);
1031 assert!(item.net_pay > Decimal::ZERO);
1032 assert!(item.net_pay < item.gross_pay);
1033 assert!(item.social_security > Decimal::ZERO);
1035 assert!(item.tax_withholding_label.is_some());
1037 assert!(item.social_security_label.is_some());
1038 }
1039 }
1040
1041 #[test]
1042 fn test_generate_with_de_country_pack() {
1043 let mut gen = PayrollGenerator::new(42);
1044 let employees = test_employees();
1045 let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
1046 let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
1047 let pack = de_country_pack();
1048
1049 let (run, items) = gen.generate_with_country_pack(
1050 "DE01",
1051 &employees,
1052 period_start,
1053 period_end,
1054 "EUR",
1055 &pack,
1056 );
1057
1058 assert_eq!(run.company_code, "DE01");
1059 assert_eq!(items.len(), 3);
1060 assert_eq!(run.total_net, run.total_gross - run.total_deductions);
1061
1062 let rates = gen.rates_from_country_pack(&pack);
1065 assert_eq!(
1066 rates.retirement_rate,
1067 Decimal::from_f64_retain(0.093).unwrap()
1068 );
1069 assert_eq!(rates.health_rate, Decimal::from_f64_retain(0.073).unwrap());
1070
1071 let item = &items[0];
1073 assert_eq!(
1074 item.health_insurance_label.as_deref(),
1075 Some("Health Insurance")
1076 );
1077 assert_eq!(
1078 item.retirement_contribution_label.as_deref(),
1079 Some("Pension Insurance")
1080 );
1081 assert!(item.employer_contribution_label.is_some());
1083 let ec = item.employer_contribution_label.as_ref().unwrap();
1084 assert!(ec.contains("Employer Pension Insurance"));
1085 assert!(ec.contains("Employer Health Insurance"));
1086 }
1087
1088 #[test]
1089 fn test_country_pack_falls_back_to_config_for_missing_categories() {
1090 let pack = CountryPack::default();
1092 let gen = PayrollGenerator::new(42);
1093 let rates_pack = gen.rates_from_country_pack(&pack);
1094 let rates_cfg = gen.rates_from_config();
1095
1096 assert_eq!(rates_pack.income_tax_rate, rates_cfg.income_tax_rate);
1097 assert_eq!(rates_pack.fica_rate, rates_cfg.fica_rate);
1098 assert_eq!(rates_pack.health_rate, rates_cfg.health_rate);
1099 assert_eq!(rates_pack.retirement_rate, rates_cfg.retirement_rate);
1100 }
1101
1102 #[test]
1103 fn test_country_pack_deterministic() {
1104 let employees = test_employees();
1105 let period_start = NaiveDate::from_ymd_opt(2024, 3, 1).unwrap();
1106 let period_end = NaiveDate::from_ymd_opt(2024, 3, 31).unwrap();
1107 let pack = de_country_pack();
1108
1109 let mut gen1 = PayrollGenerator::new(42);
1110 let (run1, items1) = gen1.generate_with_country_pack(
1111 "DE01",
1112 &employees,
1113 period_start,
1114 period_end,
1115 "EUR",
1116 &pack,
1117 );
1118
1119 let mut gen2 = PayrollGenerator::new(42);
1120 let (run2, items2) = gen2.generate_with_country_pack(
1121 "DE01",
1122 &employees,
1123 period_start,
1124 period_end,
1125 "EUR",
1126 &pack,
1127 );
1128
1129 assert_eq!(run1.payroll_id, run2.payroll_id);
1130 assert_eq!(run1.total_gross, run2.total_gross);
1131 assert_eq!(run1.total_net, run2.total_net);
1132 for (a, b) in items1.iter().zip(items2.iter()) {
1133 assert_eq!(a.net_pay, b.net_pay);
1134 }
1135 }
1136
1137 #[test]
1138 fn test_de_rates_differ_from_default() {
1139 let gen = PayrollGenerator::new(42);
1141 let pack = de_country_pack();
1142 let rates_cfg = gen.rates_from_config();
1143 let rates_de = gen.rates_from_country_pack(&pack);
1144
1145 assert_ne!(rates_de.health_rate, rates_cfg.health_rate);
1149 assert_ne!(rates_de.retirement_rate, rates_cfg.retirement_rate);
1150 }
1151
1152 #[test]
1153 fn test_set_country_pack_uses_labels() {
1154 let mut gen = PayrollGenerator::new(42);
1155 let pack = de_country_pack();
1156 gen.set_country_pack(pack);
1157
1158 let employees = test_employees();
1159 let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
1160 let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
1161
1162 let (_run, items) = gen.generate("DE01", &employees, period_start, period_end, "EUR");
1164
1165 let item = &items[0];
1166 assert!(item.tax_withholding_label.is_some());
1168 assert!(item.health_insurance_label.is_some());
1169 assert!(item.retirement_contribution_label.is_some());
1170 assert!(item.employer_contribution_label.is_some());
1171 }
1172
1173 #[test]
1174 fn test_compute_progressive_tax_us_brackets() {
1175 let brackets = vec![
1177 TaxBracket {
1178 above: Some(0.0),
1179 up_to: Some(11_000.0),
1180 rate: 0.10,
1181 },
1182 TaxBracket {
1183 above: Some(11_000.0),
1184 up_to: Some(44_725.0),
1185 rate: 0.12,
1186 },
1187 TaxBracket {
1188 above: Some(44_725.0),
1189 up_to: Some(95_375.0),
1190 rate: 0.22,
1191 },
1192 TaxBracket {
1193 above: Some(95_375.0),
1194 up_to: None,
1195 rate: 0.24,
1196 },
1197 ];
1198
1199 let tax = PayrollGenerator::compute_progressive_tax(Decimal::from(60_000), &brackets);
1201 assert_eq!(tax, Decimal::from_f64_retain(8507.50).unwrap());
1206
1207 let tax = PayrollGenerator::compute_progressive_tax(Decimal::from(11_000), &brackets);
1209 assert_eq!(tax, Decimal::from_f64_retain(1100.0).unwrap());
1210 }
1211
1212 #[test]
1213 fn test_progressive_tax_zero_income() {
1214 let brackets = vec![TaxBracket {
1215 above: Some(0.0),
1216 up_to: Some(10_000.0),
1217 rate: 0.10,
1218 }];
1219 let tax = PayrollGenerator::compute_progressive_tax(Decimal::ZERO, &brackets);
1220 assert_eq!(tax, Decimal::ZERO);
1221 }
1222
1223 #[test]
1224 fn test_us_pack_employees_have_varying_rates() {
1225 use datasynth_core::country::schema::{
1226 CountryTaxConfig, PayrollCountryConfig, PayrollDeduction, PayrollTaxBracketsConfig,
1227 };
1228
1229 let brackets = vec![
1230 TaxBracket {
1231 above: Some(0.0),
1232 up_to: Some(11_000.0),
1233 rate: 0.10,
1234 },
1235 TaxBracket {
1236 above: Some(11_000.0),
1237 up_to: Some(44_725.0),
1238 rate: 0.12,
1239 },
1240 TaxBracket {
1241 above: Some(44_725.0),
1242 up_to: None,
1243 rate: 0.22,
1244 },
1245 ];
1246 let pack = CountryPack {
1247 country_code: "US".to_string(),
1248 payroll: PayrollCountryConfig {
1249 statutory_deductions: vec![
1250 PayrollDeduction {
1251 code: "FIT".to_string(),
1252 name_en: "Federal Income Tax".to_string(),
1253 deduction_type: "progressive".to_string(),
1254 rate: 0.0,
1255 ..Default::default()
1256 },
1257 PayrollDeduction {
1258 code: "FICA".to_string(),
1259 name_en: "Social Security".to_string(),
1260 deduction_type: "percentage".to_string(),
1261 rate: 0.0765,
1262 ..Default::default()
1263 },
1264 ],
1265 ..Default::default()
1266 },
1267 tax: CountryTaxConfig {
1268 payroll_tax: PayrollTaxBracketsConfig {
1269 income_tax_brackets: brackets,
1270 ..Default::default()
1271 },
1272 ..Default::default()
1273 },
1274 ..Default::default()
1275 };
1276
1277 let mut gen = PayrollGenerator::new(42);
1278 gen.set_country_pack(pack);
1279
1280 let low_earner = vec![("LOW".to_string(), Decimal::from(30_000), None, None)];
1282 let high_earner = vec![("HIGH".to_string(), Decimal::from(200_000), None, None)];
1283
1284 let start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
1285 let end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
1286
1287 let (_, low_items) = gen.generate("C001", &low_earner, start, end, "USD");
1288 let mut gen2 = PayrollGenerator::new(42);
1289 gen2.set_country_pack(CountryPack {
1290 country_code: "US".to_string(),
1291 payroll: PayrollCountryConfig {
1292 statutory_deductions: vec![
1293 PayrollDeduction {
1294 code: "FIT".to_string(),
1295 name_en: "Federal Income Tax".to_string(),
1296 deduction_type: "progressive".to_string(),
1297 rate: 0.0,
1298 ..Default::default()
1299 },
1300 PayrollDeduction {
1301 code: "FICA".to_string(),
1302 name_en: "Social Security".to_string(),
1303 deduction_type: "percentage".to_string(),
1304 rate: 0.0765,
1305 ..Default::default()
1306 },
1307 ],
1308 ..Default::default()
1309 },
1310 tax: CountryTaxConfig {
1311 payroll_tax: PayrollTaxBracketsConfig {
1312 income_tax_brackets: vec![
1313 TaxBracket {
1314 above: Some(0.0),
1315 up_to: Some(11_000.0),
1316 rate: 0.10,
1317 },
1318 TaxBracket {
1319 above: Some(11_000.0),
1320 up_to: Some(44_725.0),
1321 rate: 0.12,
1322 },
1323 TaxBracket {
1324 above: Some(44_725.0),
1325 up_to: None,
1326 rate: 0.22,
1327 },
1328 ],
1329 ..Default::default()
1330 },
1331 ..Default::default()
1332 },
1333 ..Default::default()
1334 });
1335 let (_, high_items) = gen2.generate("C001", &high_earner, start, end, "USD");
1336
1337 let low_eff = low_items[0].tax_withholding / low_items[0].gross_pay;
1338 let high_eff = high_items[0].tax_withholding / high_items[0].gross_pay;
1339
1340 assert!(
1342 high_eff > low_eff,
1343 "High earner effective rate ({high_eff}) should exceed low earner ({low_eff})"
1344 );
1345 }
1346
1347 #[test]
1348 fn test_empty_pack_labels_are_none() {
1349 let pack = CountryPack::default();
1350 let labels = PayrollGenerator::labels_from_country_pack(&pack);
1351 assert!(labels.tax_withholding.is_none());
1352 assert!(labels.social_security.is_none());
1353 assert!(labels.health_insurance.is_none());
1354 assert!(labels.retirement_contribution.is_none());
1355 assert!(labels.employer_contribution.is_none());
1356 }
1357
1358 #[test]
1359 fn test_us_pack_labels() {
1360 let pack = us_country_pack();
1361 let labels = PayrollGenerator::labels_from_country_pack(&pack);
1362 assert!(labels.tax_withholding.is_some());
1364 let tw = labels.tax_withholding.unwrap();
1365 assert!(tw.contains("Federal Income Tax"));
1366 assert!(tw.contains("State Income Tax"));
1367 assert!(labels.social_security.is_some());
1369 assert!(labels
1370 .social_security
1371 .unwrap()
1372 .contains("Federal Insurance Contributions Act"));
1373 }
1374}