1use chrono::NaiveDate;
8use datasynth_config::schema::PayrollConfig;
9use datasynth_core::country::schema::TaxBracket;
10use datasynth_core::models::{PayrollLineItem, PayrollRun, PayrollRunStatus};
11use datasynth_core::utils::{sample_decimal_range, seeded_rng};
12use datasynth_core::uuid_factory::{DeterministicUuidFactory, GeneratorType};
13use datasynth_core::CountryPack;
14use rand::prelude::*;
15use rand_chacha::ChaCha8Rng;
16use rust_decimal::Decimal;
17use tracing::debug;
18
19#[derive(Debug, Clone)]
21struct PayrollRates {
22 income_tax_rate: Decimal,
24 income_tax_brackets: Vec<TaxBracket>,
26 fica_rate: Decimal,
28 health_rate: Decimal,
30 retirement_rate: Decimal,
32 employer_fica_rate: Decimal,
34}
35
36#[derive(Debug, Clone, Default)]
38struct DeductionLabels {
39 tax_withholding: Option<String>,
40 social_security: Option<String>,
41 health_insurance: Option<String>,
42 retirement_contribution: Option<String>,
43 employer_contribution: Option<String>,
44}
45
46pub struct PayrollGenerator {
48 rng: ChaCha8Rng,
49 uuid_factory: DeterministicUuidFactory,
50 line_uuid_factory: DeterministicUuidFactory,
51 config: PayrollConfig,
52 country_pack: Option<CountryPack>,
53 employee_ids_pool: Vec<String>,
55 cost_center_ids_pool: Vec<String>,
58}
59
60impl PayrollGenerator {
61 pub fn new(seed: u64) -> Self {
63 Self {
64 rng: seeded_rng(seed, 0),
65 uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::PayrollRun),
66 line_uuid_factory: DeterministicUuidFactory::with_sub_discriminator(
67 seed,
68 GeneratorType::PayrollRun,
69 1,
70 ),
71 config: PayrollConfig::default(),
72 country_pack: None,
73 employee_ids_pool: Vec::new(),
74 cost_center_ids_pool: Vec::new(),
75 }
76 }
77
78 pub fn with_config(seed: u64, config: PayrollConfig) -> Self {
80 Self {
81 rng: seeded_rng(seed, 0),
82 uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::PayrollRun),
83 line_uuid_factory: DeterministicUuidFactory::with_sub_discriminator(
84 seed,
85 GeneratorType::PayrollRun,
86 1,
87 ),
88 config,
89 country_pack: None,
90 employee_ids_pool: Vec::new(),
91 cost_center_ids_pool: Vec::new(),
92 }
93 }
94
95 pub fn with_pools(mut self, employee_ids: Vec<String>, cost_center_ids: Vec<String>) -> Self {
100 self.employee_ids_pool = employee_ids;
101 self.cost_center_ids_pool = cost_center_ids;
102 self
103 }
104
105 pub fn set_country_pack(&mut self, pack: CountryPack) {
113 self.country_pack = Some(pack);
114 }
115
116 pub fn generate(
131 &mut self,
132 company_code: &str,
133 employees: &[(String, Decimal, Option<String>, Option<String>)],
134 period_start: NaiveDate,
135 period_end: NaiveDate,
136 currency: &str,
137 ) -> (PayrollRun, Vec<PayrollLineItem>) {
138 debug!(company_code, employee_count = employees.len(), %period_start, %period_end, currency, "Generating payroll run");
139 if let Some(pack) = self.country_pack.as_ref() {
140 let rates = self.rates_from_country_pack(pack);
141 let labels = Self::labels_from_country_pack(pack);
142 self.generate_with_rates_and_labels(
143 company_code,
144 employees,
145 period_start,
146 period_end,
147 currency,
148 &rates,
149 &labels,
150 )
151 } else {
152 let rates = self.rates_from_config();
153 self.generate_with_rates_and_labels(
154 company_code,
155 employees,
156 period_start,
157 period_end,
158 currency,
159 &rates,
160 &DeductionLabels::default(),
161 )
162 }
163 }
164
165 pub fn generate_with_country_pack(
186 &mut self,
187 company_code: &str,
188 employees: &[(String, Decimal, Option<String>, Option<String>)],
189 period_start: NaiveDate,
190 period_end: NaiveDate,
191 currency: &str,
192 pack: &CountryPack,
193 ) -> (PayrollRun, Vec<PayrollLineItem>) {
194 let rates = self.rates_from_country_pack(pack);
195 let labels = Self::labels_from_country_pack(pack);
196 self.generate_with_rates_and_labels(
197 company_code,
198 employees,
199 period_start,
200 period_end,
201 currency,
202 &rates,
203 &labels,
204 )
205 }
206
207 fn rates_from_config(&self) -> PayrollRates {
213 let federal_rate = Decimal::from_f64_retain(self.config.tax_rates.federal_effective)
214 .unwrap_or(Decimal::ZERO);
215 let state_rate = Decimal::from_f64_retain(self.config.tax_rates.state_effective)
216 .unwrap_or(Decimal::ZERO);
217 let fica_rate =
218 Decimal::from_f64_retain(self.config.tax_rates.fica).unwrap_or(Decimal::ZERO);
219
220 PayrollRates {
221 income_tax_rate: federal_rate + state_rate,
222 income_tax_brackets: Vec::new(),
223 fica_rate,
224 health_rate: Decimal::from_f64_retain(0.03).unwrap_or(Decimal::ZERO),
225 retirement_rate: Decimal::from_f64_retain(0.05).unwrap_or(Decimal::ZERO),
226 employer_fica_rate: fica_rate,
227 }
228 }
229
230 fn compute_progressive_tax(annual_income: Decimal, brackets: &[TaxBracket]) -> Decimal {
236 let mut total_tax = Decimal::ZERO;
237 let mut taxed_up_to = Decimal::ZERO;
238
239 for bracket in brackets {
240 let bracket_floor = bracket
241 .above
242 .and_then(Decimal::from_f64_retain)
243 .unwrap_or(taxed_up_to);
244 let bracket_rate = Decimal::from_f64_retain(bracket.rate).unwrap_or(Decimal::ZERO);
245
246 if annual_income <= bracket_floor {
247 break;
248 }
249
250 let taxable_in_bracket = if let Some(ceiling) = bracket.up_to {
251 let ceiling = Decimal::from_f64_retain(ceiling).unwrap_or(Decimal::ZERO);
252 (annual_income.min(ceiling) - bracket_floor).max(Decimal::ZERO)
253 } else {
254 (annual_income - bracket_floor).max(Decimal::ZERO)
256 };
257
258 total_tax += (taxable_in_bracket * bracket_rate).round_dp(2);
259 taxed_up_to = bracket
260 .up_to
261 .and_then(Decimal::from_f64_retain)
262 .unwrap_or(annual_income);
263 }
264
265 total_tax.round_dp(2)
266 }
267
268 fn rates_from_country_pack(&self, pack: &CountryPack) -> PayrollRates {
271 let fallback = self.rates_from_config();
272
273 let mut federal_tax = Decimal::ZERO;
276 let mut state_tax = Decimal::ZERO;
277 let mut fica = Decimal::ZERO;
278 let mut health = Decimal::ZERO;
279 let mut retirement = Decimal::ZERO;
280
281 let mut found_federal = false;
283 let mut found_state = false;
284 let mut found_fica = false;
285 let mut found_health = false;
286 let mut found_retirement = false;
287
288 for ded in &pack.payroll.statutory_deductions {
289 let code_upper = ded.code.to_uppercase();
290 let name_en_lower = ded.name_en.to_lowercase();
291 let rate = Decimal::from_f64_retain(ded.rate).unwrap_or(Decimal::ZERO);
292
293 if (ded.deduction_type == "progressive" || ded.type_field == "progressive")
298 && ded.rate == 0.0
299 {
300 if code_upper == "FIT"
301 || code_upper == "LOHNST"
302 || (name_en_lower.contains("income tax") && !name_en_lower.contains("state"))
303 {
304 found_federal = true;
305 }
306 continue;
307 }
308
309 if code_upper == "FIT"
310 || code_upper == "LOHNST"
311 || (name_en_lower.contains("income tax") && !name_en_lower.contains("state"))
312 {
313 federal_tax += rate;
314 found_federal = true;
315 } else if code_upper == "SIT" || name_en_lower.contains("state income tax") {
316 state_tax += rate;
317 found_state = true;
318 } else if code_upper == "FICA" || name_en_lower.contains("social security") {
319 fica += rate;
320 found_fica = true;
321 } else if name_en_lower.contains("health insurance") {
322 health += rate;
323 found_health = true;
324 } else if name_en_lower.contains("pension") || name_en_lower.contains("retirement") {
325 retirement += rate;
326 found_retirement = true;
327 } else {
328 fica += rate;
333 found_fica = true;
334 }
335 }
336
337 PayrollRates {
338 income_tax_rate: if found_federal || found_state {
339 let f = if found_federal {
340 federal_tax
341 } else {
342 fallback.income_tax_rate
343 - Decimal::from_f64_retain(self.config.tax_rates.state_effective)
344 .unwrap_or(Decimal::ZERO)
345 };
346 let s = if found_state {
347 state_tax
348 } else {
349 Decimal::from_f64_retain(self.config.tax_rates.state_effective)
350 .unwrap_or(Decimal::ZERO)
351 };
352 f + s
353 } else {
354 fallback.income_tax_rate
355 },
356 income_tax_brackets: pack.tax.payroll_tax.income_tax_brackets.clone(),
357 fica_rate: if found_fica { fica } else { fallback.fica_rate },
358 health_rate: if found_health {
359 health
360 } else {
361 fallback.health_rate
362 },
363 retirement_rate: if found_retirement {
364 retirement
365 } else {
366 fallback.retirement_rate
367 },
368 employer_fica_rate: if found_fica {
369 fica
370 } else {
371 fallback.employer_fica_rate
372 },
373 }
374 }
375
376 fn labels_from_country_pack(pack: &CountryPack) -> DeductionLabels {
383 let mut labels = DeductionLabels::default();
384
385 for ded in &pack.payroll.statutory_deductions {
386 let code_upper = ded.code.to_uppercase();
387 let name_en_lower = ded.name_en.to_lowercase();
388
389 let label = if ded.name.is_empty() {
392 ded.name_en.clone()
393 } else {
394 ded.name.clone()
395 };
396 if label.is_empty() {
397 continue;
398 }
399
400 if (ded.deduction_type == "progressive" || ded.type_field == "progressive")
403 && ded.rate == 0.0
404 {
405 if code_upper == "FIT"
406 || code_upper == "LOHNST"
407 || (name_en_lower.contains("income tax") && !name_en_lower.contains("state"))
408 {
409 if labels.tax_withholding.is_none() {
410 labels.tax_withholding = Some(label);
411 }
412 } else if code_upper == "SIT" || name_en_lower.contains("state income tax") {
413 labels.tax_withholding = Some(match labels.tax_withholding.take() {
414 Some(existing) => format!("{existing}; {label}"),
415 None => label,
416 });
417 }
418 continue;
419 }
420
421 if code_upper == "FIT"
422 || code_upper == "LOHNST"
423 || code_upper == "SIT"
424 || name_en_lower.contains("income tax")
425 || name_en_lower.contains("state income tax")
426 {
427 labels.tax_withholding = Some(match labels.tax_withholding.take() {
430 Some(existing) => format!("{existing}; {label}"),
431 None => label,
432 });
433 } else if code_upper == "FICA" || name_en_lower.contains("social security") {
434 labels.social_security = Some(match labels.social_security.take() {
435 Some(existing) => format!("{existing}; {label}"),
436 None => label,
437 });
438 } else if name_en_lower.contains("health insurance") {
439 if labels.health_insurance.is_none() {
440 labels.health_insurance = Some(label);
441 }
442 } else if name_en_lower.contains("pension") || name_en_lower.contains("retirement") {
443 if labels.retirement_contribution.is_none() {
444 labels.retirement_contribution = Some(label);
445 }
446 } else {
447 labels.social_security = Some(match labels.social_security.take() {
450 Some(existing) => format!("{existing}; {label}"),
451 None => label,
452 });
453 }
454 }
455
456 let emp_labels: Vec<String> = pack
458 .payroll
459 .employer_contributions
460 .iter()
461 .filter_map(|c| {
462 let l = if c.name.is_empty() {
463 c.name_en.clone()
464 } else {
465 c.name.clone()
466 };
467 if l.is_empty() {
468 None
469 } else {
470 Some(l)
471 }
472 })
473 .collect();
474 if !emp_labels.is_empty() {
475 labels.employer_contribution = Some(emp_labels.join("; "));
476 }
477
478 labels
479 }
480
481 fn generate_with_rates_and_labels(
483 &mut self,
484 company_code: &str,
485 employees: &[(String, Decimal, Option<String>, Option<String>)],
486 period_start: NaiveDate,
487 period_end: NaiveDate,
488 currency: &str,
489 rates: &PayrollRates,
490 labels: &DeductionLabels,
491 ) -> (PayrollRun, Vec<PayrollLineItem>) {
492 let payroll_id = self.uuid_factory.next().to_string();
493
494 let mut line_items = Vec::with_capacity(employees.len());
495 let mut total_gross = Decimal::ZERO;
496 let mut total_deductions = Decimal::ZERO;
497 let mut total_net = Decimal::ZERO;
498 let mut total_employer_cost = Decimal::ZERO;
499
500 let benefits_enrolled = self.config.benefits_enrollment_rate;
501 let retirement_participating = self.config.retirement_participation_rate;
502
503 for (employee_id, base_salary, cost_center, department) in employees {
504 let line_id = self.line_uuid_factory.next().to_string();
505
506 let monthly_base = (*base_salary / Decimal::from(12)).round_dp(2);
508
509 let (overtime_pay, overtime_hours) = if self.rng.random_bool(0.10) {
511 let ot_hours = self.rng.random_range(1.0..=20.0);
512 let hourly_rate = *base_salary / Decimal::from(2080);
514 let ot_rate = hourly_rate * Decimal::from_f64_retain(1.5).unwrap_or(Decimal::ONE);
515 let ot_pay = (ot_rate
516 * Decimal::from_f64_retain(ot_hours).unwrap_or(Decimal::ZERO))
517 .round_dp(2);
518 (ot_pay, ot_hours)
519 } else {
520 (Decimal::ZERO, 0.0)
521 };
522
523 let bonus = if self.rng.random_bool(0.05) {
525 let pct = self.rng.random_range(0.01..=0.10);
526 (monthly_base * Decimal::from_f64_retain(pct).unwrap_or(Decimal::ZERO)).round_dp(2)
527 } else {
528 Decimal::ZERO
529 };
530
531 let gross_pay = monthly_base + overtime_pay + bonus;
532
533 let tax_withholding = if !rates.income_tax_brackets.is_empty() {
535 let annual = gross_pay * Decimal::from(12);
536 Self::compute_progressive_tax(annual, &rates.income_tax_brackets)
537 / Decimal::from(12)
538 } else {
539 (gross_pay * rates.income_tax_rate).round_dp(2)
540 };
541 let social_security = (gross_pay * rates.fica_rate).round_dp(2);
542
543 let health_insurance = if self.rng.random_bool(benefits_enrolled) {
544 (gross_pay * rates.health_rate).round_dp(2)
545 } else {
546 Decimal::ZERO
547 };
548
549 let retirement_contribution = if self.rng.random_bool(retirement_participating) {
550 (gross_pay * rates.retirement_rate).round_dp(2)
551 } else {
552 Decimal::ZERO
553 };
554
555 let other_deductions = if self.rng.random_bool(0.03) {
557 sample_decimal_range(&mut self.rng, Decimal::from(50), Decimal::from(500))
558 .round_dp(2)
559 } else {
560 Decimal::ZERO
561 };
562
563 let total_ded = tax_withholding
564 + social_security
565 + health_insurance
566 + retirement_contribution
567 + other_deductions;
568 let net_pay = gross_pay - total_ded;
569
570 let hours_worked = 160.0;
572
573 let employer_contrib = (gross_pay * rates.employer_fica_rate).round_dp(2);
575 let employer_cost = gross_pay + employer_contrib;
576
577 total_gross += gross_pay;
578 total_deductions += total_ded;
579 total_net += net_pay;
580 total_employer_cost += employer_cost;
581
582 line_items.push(PayrollLineItem {
583 payroll_id: payroll_id.clone(),
584 employee_id: employee_id.clone(),
585 line_id,
586 gross_pay,
587 base_salary: monthly_base,
588 overtime_pay,
589 bonus,
590 tax_withholding,
591 social_security,
592 health_insurance,
593 retirement_contribution,
594 other_deductions,
595 net_pay,
596 hours_worked,
597 overtime_hours,
598 pay_date: period_end,
599 cost_center: cost_center.clone(),
600 department: department.clone(),
601 tax_withholding_label: labels.tax_withholding.clone(),
602 social_security_label: labels.social_security.clone(),
603 health_insurance_label: labels.health_insurance.clone(),
604 retirement_contribution_label: labels.retirement_contribution.clone(),
605 employer_contribution_label: labels.employer_contribution.clone(),
606 });
607 }
608
609 let status_roll: f64 = self.rng.random();
611 let status = if status_roll < 0.60 {
612 PayrollRunStatus::Posted
613 } else if status_roll < 0.85 {
614 PayrollRunStatus::Approved
615 } else if status_roll < 0.95 {
616 PayrollRunStatus::Calculated
617 } else {
618 PayrollRunStatus::Draft
619 };
620
621 let approved_by = if matches!(
622 status,
623 PayrollRunStatus::Approved | PayrollRunStatus::Posted
624 ) {
625 if !self.employee_ids_pool.is_empty() {
626 let idx = self.rng.random_range(0..self.employee_ids_pool.len());
627 Some(self.employee_ids_pool[idx].clone())
628 } else {
629 Some(format!("USR-{:04}", self.rng.random_range(201..=400)))
630 }
631 } else {
632 None
633 };
634
635 let posted_by = if status == PayrollRunStatus::Posted {
636 if !self.employee_ids_pool.is_empty() {
637 let idx = self.rng.random_range(0..self.employee_ids_pool.len());
638 Some(self.employee_ids_pool[idx].clone())
639 } else {
640 Some(format!("USR-{:04}", self.rng.random_range(401..=500)))
641 }
642 } else {
643 None
644 };
645
646 let run = PayrollRun {
647 company_code: company_code.to_string(),
648 payroll_id: payroll_id.clone(),
649 pay_period_start: period_start,
650 pay_period_end: period_end,
651 run_date: period_end,
652 status,
653 total_gross,
654 total_deductions,
655 total_net,
656 total_employer_cost,
657 employee_count: employees.len() as u32,
658 currency: currency.to_string(),
659 posted_by,
660 approved_by,
661 };
662
663 (run, line_items)
664 }
665}
666
667#[cfg(test)]
668#[allow(clippy::unwrap_used)]
669mod tests {
670 use super::*;
671
672 fn test_employees() -> Vec<(String, Decimal, Option<String>, Option<String>)> {
673 vec![
674 (
675 "EMP-001".to_string(),
676 Decimal::from(60_000),
677 Some("CC-100".to_string()),
678 Some("Engineering".to_string()),
679 ),
680 (
681 "EMP-002".to_string(),
682 Decimal::from(85_000),
683 Some("CC-200".to_string()),
684 Some("Finance".to_string()),
685 ),
686 (
687 "EMP-003".to_string(),
688 Decimal::from(120_000),
689 None,
690 Some("Sales".to_string()),
691 ),
692 ]
693 }
694
695 #[test]
696 fn test_basic_payroll_generation() {
697 let mut gen = PayrollGenerator::new(42);
698 let employees = test_employees();
699 let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
700 let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
701
702 let (run, items) = gen.generate("C001", &employees, period_start, period_end, "USD");
703
704 assert_eq!(run.company_code, "C001");
705 assert_eq!(run.currency, "USD");
706 assert_eq!(run.employee_count, 3);
707 assert_eq!(items.len(), 3);
708 assert!(run.total_gross > Decimal::ZERO);
709 assert!(run.total_deductions > Decimal::ZERO);
710 assert!(run.total_net > Decimal::ZERO);
711 assert!(run.total_employer_cost > run.total_gross);
712 assert_eq!(run.total_net, run.total_gross - run.total_deductions);
714
715 for item in &items {
716 assert_eq!(item.payroll_id, run.payroll_id);
717 assert!(item.gross_pay > Decimal::ZERO);
718 assert!(item.net_pay > Decimal::ZERO);
719 assert!(item.net_pay < item.gross_pay);
720 assert!(item.base_salary > Decimal::ZERO);
721 assert_eq!(item.pay_date, period_end);
722 assert!(item.tax_withholding_label.is_none());
724 assert!(item.social_security_label.is_none());
725 }
726 }
727
728 #[test]
729 fn test_deterministic_payroll() {
730 let employees = test_employees();
731 let period_start = NaiveDate::from_ymd_opt(2024, 3, 1).unwrap();
732 let period_end = NaiveDate::from_ymd_opt(2024, 3, 31).unwrap();
733
734 let mut gen1 = PayrollGenerator::new(42);
735 let (run1, items1) = gen1.generate("C001", &employees, period_start, period_end, "USD");
736
737 let mut gen2 = PayrollGenerator::new(42);
738 let (run2, items2) = gen2.generate("C001", &employees, period_start, period_end, "USD");
739
740 assert_eq!(run1.payroll_id, run2.payroll_id);
741 assert_eq!(run1.total_gross, run2.total_gross);
742 assert_eq!(run1.total_net, run2.total_net);
743 assert_eq!(run1.status, run2.status);
744 assert_eq!(items1.len(), items2.len());
745 for (a, b) in items1.iter().zip(items2.iter()) {
746 assert_eq!(a.line_id, b.line_id);
747 assert_eq!(a.gross_pay, b.gross_pay);
748 assert_eq!(a.net_pay, b.net_pay);
749 }
750 }
751
752 #[test]
753 fn test_payroll_deduction_components() {
754 let mut gen = PayrollGenerator::new(99);
755 let employees = vec![(
756 "EMP-010".to_string(),
757 Decimal::from(100_000),
758 Some("CC-300".to_string()),
759 Some("HR".to_string()),
760 )];
761 let period_start = NaiveDate::from_ymd_opt(2024, 6, 1).unwrap();
762 let period_end = NaiveDate::from_ymd_opt(2024, 6, 30).unwrap();
763
764 let (_run, items) = gen.generate("C001", &employees, period_start, period_end, "USD");
765 assert_eq!(items.len(), 1);
766
767 let item = &items[0];
768 let expected_monthly = (Decimal::from(100_000) / Decimal::from(12)).round_dp(2);
770 assert_eq!(item.base_salary, expected_monthly);
771
772 let deduction_sum = item.tax_withholding
774 + item.social_security
775 + item.health_insurance
776 + item.retirement_contribution
777 + item.other_deductions;
778 let expected_net = item.gross_pay - deduction_sum;
779 assert_eq!(item.net_pay, expected_net);
780
781 assert!(item.tax_withholding > Decimal::ZERO);
783 assert!(item.social_security > Decimal::ZERO);
784 }
785
786 fn us_country_pack() -> CountryPack {
792 use datasynth_core::country::schema::{PayrollCountryConfig, PayrollDeduction};
793 CountryPack {
794 country_code: "US".to_string(),
795 payroll: PayrollCountryConfig {
796 statutory_deductions: vec![
797 PayrollDeduction {
798 code: "FICA".to_string(),
799 name_en: "Federal Insurance Contributions Act".to_string(),
800 deduction_type: "percentage".to_string(),
801 rate: 0.0765,
802 ..Default::default()
803 },
804 PayrollDeduction {
805 code: "FIT".to_string(),
806 name_en: "Federal Income Tax".to_string(),
807 deduction_type: "progressive".to_string(),
808 rate: 0.0, ..Default::default()
810 },
811 PayrollDeduction {
812 code: "SIT".to_string(),
813 name_en: "State Income Tax".to_string(),
814 deduction_type: "percentage".to_string(),
815 rate: 0.05,
816 ..Default::default()
817 },
818 ],
819 ..Default::default()
820 },
821 ..Default::default()
822 }
823 }
824
825 fn de_country_pack() -> CountryPack {
827 use datasynth_core::country::schema::{PayrollCountryConfig, PayrollDeduction};
828 CountryPack {
829 country_code: "DE".to_string(),
830 payroll: PayrollCountryConfig {
831 pay_frequency: "monthly".to_string(),
832 currency: "EUR".to_string(),
833 statutory_deductions: vec![
834 PayrollDeduction {
835 code: "LOHNST".to_string(),
836 name_en: "Income Tax".to_string(),
837 type_field: "progressive".to_string(),
838 rate: 0.0, ..Default::default()
840 },
841 PayrollDeduction {
842 code: "SOLI".to_string(),
843 name_en: "Solidarity Surcharge".to_string(),
844 type_field: "percentage".to_string(),
845 rate: 0.055,
846 ..Default::default()
847 },
848 PayrollDeduction {
849 code: "KiSt".to_string(),
850 name_en: "Church Tax".to_string(),
851 type_field: "percentage".to_string(),
852 rate: 0.08,
853 optional: true,
854 ..Default::default()
855 },
856 PayrollDeduction {
857 code: "RV".to_string(),
858 name_en: "Pension Insurance".to_string(),
859 type_field: "percentage".to_string(),
860 rate: 0.093,
861 ..Default::default()
862 },
863 PayrollDeduction {
864 code: "KV".to_string(),
865 name_en: "Health Insurance".to_string(),
866 type_field: "percentage".to_string(),
867 rate: 0.073,
868 ..Default::default()
869 },
870 PayrollDeduction {
871 code: "AV".to_string(),
872 name_en: "Unemployment Insurance".to_string(),
873 type_field: "percentage".to_string(),
874 rate: 0.013,
875 ..Default::default()
876 },
877 PayrollDeduction {
878 code: "PV".to_string(),
879 name_en: "Long-Term Care Insurance".to_string(),
880 type_field: "percentage".to_string(),
881 rate: 0.017,
882 ..Default::default()
883 },
884 ],
885 employer_contributions: vec![
886 PayrollDeduction {
887 code: "AG-RV".to_string(),
888 name_en: "Employer Pension Insurance".to_string(),
889 type_field: "percentage".to_string(),
890 rate: 0.093,
891 ..Default::default()
892 },
893 PayrollDeduction {
894 code: "AG-KV".to_string(),
895 name_en: "Employer Health Insurance".to_string(),
896 type_field: "percentage".to_string(),
897 rate: 0.073,
898 ..Default::default()
899 },
900 ],
901 ..Default::default()
902 },
903 ..Default::default()
904 }
905 }
906
907 #[test]
908 fn test_generate_with_us_country_pack() {
909 let mut gen = PayrollGenerator::new(42);
910 let employees = test_employees();
911 let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
912 let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
913 let pack = us_country_pack();
914
915 let (run, items) = gen.generate_with_country_pack(
916 "C001",
917 &employees,
918 period_start,
919 period_end,
920 "USD",
921 &pack,
922 );
923
924 assert_eq!(run.company_code, "C001");
925 assert_eq!(run.employee_count, 3);
926 assert_eq!(items.len(), 3);
927 assert_eq!(run.total_net, run.total_gross - run.total_deductions);
928
929 for item in &items {
930 assert!(item.gross_pay > Decimal::ZERO);
931 assert!(item.net_pay > Decimal::ZERO);
932 assert!(item.net_pay < item.gross_pay);
933 assert!(item.social_security > Decimal::ZERO);
935 assert!(item.tax_withholding_label.is_some());
937 assert!(item.social_security_label.is_some());
938 }
939 }
940
941 #[test]
942 fn test_generate_with_de_country_pack() {
943 let mut gen = PayrollGenerator::new(42);
944 let employees = test_employees();
945 let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
946 let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
947 let pack = de_country_pack();
948
949 let (run, items) = gen.generate_with_country_pack(
950 "DE01",
951 &employees,
952 period_start,
953 period_end,
954 "EUR",
955 &pack,
956 );
957
958 assert_eq!(run.company_code, "DE01");
959 assert_eq!(items.len(), 3);
960 assert_eq!(run.total_net, run.total_gross - run.total_deductions);
961
962 let rates = gen.rates_from_country_pack(&pack);
965 assert_eq!(
966 rates.retirement_rate,
967 Decimal::from_f64_retain(0.093).unwrap()
968 );
969 assert_eq!(rates.health_rate, Decimal::from_f64_retain(0.073).unwrap());
970
971 let item = &items[0];
973 assert_eq!(
974 item.health_insurance_label.as_deref(),
975 Some("Health Insurance")
976 );
977 assert_eq!(
978 item.retirement_contribution_label.as_deref(),
979 Some("Pension Insurance")
980 );
981 assert!(item.employer_contribution_label.is_some());
983 let ec = item.employer_contribution_label.as_ref().unwrap();
984 assert!(ec.contains("Employer Pension Insurance"));
985 assert!(ec.contains("Employer Health Insurance"));
986 }
987
988 #[test]
989 fn test_country_pack_falls_back_to_config_for_missing_categories() {
990 let pack = CountryPack::default();
992 let gen = PayrollGenerator::new(42);
993 let rates_pack = gen.rates_from_country_pack(&pack);
994 let rates_cfg = gen.rates_from_config();
995
996 assert_eq!(rates_pack.income_tax_rate, rates_cfg.income_tax_rate);
997 assert_eq!(rates_pack.fica_rate, rates_cfg.fica_rate);
998 assert_eq!(rates_pack.health_rate, rates_cfg.health_rate);
999 assert_eq!(rates_pack.retirement_rate, rates_cfg.retirement_rate);
1000 }
1001
1002 #[test]
1003 fn test_country_pack_deterministic() {
1004 let employees = test_employees();
1005 let period_start = NaiveDate::from_ymd_opt(2024, 3, 1).unwrap();
1006 let period_end = NaiveDate::from_ymd_opt(2024, 3, 31).unwrap();
1007 let pack = de_country_pack();
1008
1009 let mut gen1 = PayrollGenerator::new(42);
1010 let (run1, items1) = gen1.generate_with_country_pack(
1011 "DE01",
1012 &employees,
1013 period_start,
1014 period_end,
1015 "EUR",
1016 &pack,
1017 );
1018
1019 let mut gen2 = PayrollGenerator::new(42);
1020 let (run2, items2) = gen2.generate_with_country_pack(
1021 "DE01",
1022 &employees,
1023 period_start,
1024 period_end,
1025 "EUR",
1026 &pack,
1027 );
1028
1029 assert_eq!(run1.payroll_id, run2.payroll_id);
1030 assert_eq!(run1.total_gross, run2.total_gross);
1031 assert_eq!(run1.total_net, run2.total_net);
1032 for (a, b) in items1.iter().zip(items2.iter()) {
1033 assert_eq!(a.net_pay, b.net_pay);
1034 }
1035 }
1036
1037 #[test]
1038 fn test_de_rates_differ_from_default() {
1039 let gen = PayrollGenerator::new(42);
1041 let pack = de_country_pack();
1042 let rates_cfg = gen.rates_from_config();
1043 let rates_de = gen.rates_from_country_pack(&pack);
1044
1045 assert_ne!(rates_de.health_rate, rates_cfg.health_rate);
1049 assert_ne!(rates_de.retirement_rate, rates_cfg.retirement_rate);
1050 }
1051
1052 #[test]
1053 fn test_set_country_pack_uses_labels() {
1054 let mut gen = PayrollGenerator::new(42);
1055 let pack = de_country_pack();
1056 gen.set_country_pack(pack);
1057
1058 let employees = test_employees();
1059 let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
1060 let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
1061
1062 let (_run, items) = gen.generate("DE01", &employees, period_start, period_end, "EUR");
1064
1065 let item = &items[0];
1066 assert!(item.tax_withholding_label.is_some());
1068 assert!(item.health_insurance_label.is_some());
1069 assert!(item.retirement_contribution_label.is_some());
1070 assert!(item.employer_contribution_label.is_some());
1071 }
1072
1073 #[test]
1074 fn test_compute_progressive_tax_us_brackets() {
1075 let brackets = vec![
1077 TaxBracket {
1078 above: Some(0.0),
1079 up_to: Some(11_000.0),
1080 rate: 0.10,
1081 },
1082 TaxBracket {
1083 above: Some(11_000.0),
1084 up_to: Some(44_725.0),
1085 rate: 0.12,
1086 },
1087 TaxBracket {
1088 above: Some(44_725.0),
1089 up_to: Some(95_375.0),
1090 rate: 0.22,
1091 },
1092 TaxBracket {
1093 above: Some(95_375.0),
1094 up_to: None,
1095 rate: 0.24,
1096 },
1097 ];
1098
1099 let tax = PayrollGenerator::compute_progressive_tax(Decimal::from(60_000), &brackets);
1101 assert_eq!(tax, Decimal::from_f64_retain(8507.50).unwrap());
1106
1107 let tax = PayrollGenerator::compute_progressive_tax(Decimal::from(11_000), &brackets);
1109 assert_eq!(tax, Decimal::from_f64_retain(1100.0).unwrap());
1110 }
1111
1112 #[test]
1113 fn test_progressive_tax_zero_income() {
1114 let brackets = vec![TaxBracket {
1115 above: Some(0.0),
1116 up_to: Some(10_000.0),
1117 rate: 0.10,
1118 }];
1119 let tax = PayrollGenerator::compute_progressive_tax(Decimal::ZERO, &brackets);
1120 assert_eq!(tax, Decimal::ZERO);
1121 }
1122
1123 #[test]
1124 fn test_us_pack_employees_have_varying_rates() {
1125 use datasynth_core::country::schema::{
1126 CountryTaxConfig, PayrollCountryConfig, PayrollDeduction, PayrollTaxBracketsConfig,
1127 };
1128
1129 let brackets = vec![
1130 TaxBracket {
1131 above: Some(0.0),
1132 up_to: Some(11_000.0),
1133 rate: 0.10,
1134 },
1135 TaxBracket {
1136 above: Some(11_000.0),
1137 up_to: Some(44_725.0),
1138 rate: 0.12,
1139 },
1140 TaxBracket {
1141 above: Some(44_725.0),
1142 up_to: None,
1143 rate: 0.22,
1144 },
1145 ];
1146 let pack = CountryPack {
1147 country_code: "US".to_string(),
1148 payroll: PayrollCountryConfig {
1149 statutory_deductions: vec![
1150 PayrollDeduction {
1151 code: "FIT".to_string(),
1152 name_en: "Federal Income Tax".to_string(),
1153 deduction_type: "progressive".to_string(),
1154 rate: 0.0,
1155 ..Default::default()
1156 },
1157 PayrollDeduction {
1158 code: "FICA".to_string(),
1159 name_en: "Social Security".to_string(),
1160 deduction_type: "percentage".to_string(),
1161 rate: 0.0765,
1162 ..Default::default()
1163 },
1164 ],
1165 ..Default::default()
1166 },
1167 tax: CountryTaxConfig {
1168 payroll_tax: PayrollTaxBracketsConfig {
1169 income_tax_brackets: brackets,
1170 ..Default::default()
1171 },
1172 ..Default::default()
1173 },
1174 ..Default::default()
1175 };
1176
1177 let mut gen = PayrollGenerator::new(42);
1178 gen.set_country_pack(pack);
1179
1180 let low_earner = vec![("LOW".to_string(), Decimal::from(30_000), None, None)];
1182 let high_earner = vec![("HIGH".to_string(), Decimal::from(200_000), None, None)];
1183
1184 let start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
1185 let end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
1186
1187 let (_, low_items) = gen.generate("C001", &low_earner, start, end, "USD");
1188 let mut gen2 = PayrollGenerator::new(42);
1189 gen2.set_country_pack(CountryPack {
1190 country_code: "US".to_string(),
1191 payroll: PayrollCountryConfig {
1192 statutory_deductions: vec![
1193 PayrollDeduction {
1194 code: "FIT".to_string(),
1195 name_en: "Federal Income Tax".to_string(),
1196 deduction_type: "progressive".to_string(),
1197 rate: 0.0,
1198 ..Default::default()
1199 },
1200 PayrollDeduction {
1201 code: "FICA".to_string(),
1202 name_en: "Social Security".to_string(),
1203 deduction_type: "percentage".to_string(),
1204 rate: 0.0765,
1205 ..Default::default()
1206 },
1207 ],
1208 ..Default::default()
1209 },
1210 tax: CountryTaxConfig {
1211 payroll_tax: PayrollTaxBracketsConfig {
1212 income_tax_brackets: vec![
1213 TaxBracket {
1214 above: Some(0.0),
1215 up_to: Some(11_000.0),
1216 rate: 0.10,
1217 },
1218 TaxBracket {
1219 above: Some(11_000.0),
1220 up_to: Some(44_725.0),
1221 rate: 0.12,
1222 },
1223 TaxBracket {
1224 above: Some(44_725.0),
1225 up_to: None,
1226 rate: 0.22,
1227 },
1228 ],
1229 ..Default::default()
1230 },
1231 ..Default::default()
1232 },
1233 ..Default::default()
1234 });
1235 let (_, high_items) = gen2.generate("C001", &high_earner, start, end, "USD");
1236
1237 let low_eff = low_items[0].tax_withholding / low_items[0].gross_pay;
1238 let high_eff = high_items[0].tax_withholding / high_items[0].gross_pay;
1239
1240 assert!(
1242 high_eff > low_eff,
1243 "High earner effective rate ({high_eff}) should exceed low earner ({low_eff})"
1244 );
1245 }
1246
1247 #[test]
1248 fn test_empty_pack_labels_are_none() {
1249 let pack = CountryPack::default();
1250 let labels = PayrollGenerator::labels_from_country_pack(&pack);
1251 assert!(labels.tax_withholding.is_none());
1252 assert!(labels.social_security.is_none());
1253 assert!(labels.health_insurance.is_none());
1254 assert!(labels.retirement_contribution.is_none());
1255 assert!(labels.employer_contribution.is_none());
1256 }
1257
1258 #[test]
1259 fn test_us_pack_labels() {
1260 let pack = us_country_pack();
1261 let labels = PayrollGenerator::labels_from_country_pack(&pack);
1262 assert!(labels.tax_withholding.is_some());
1264 let tw = labels.tax_withholding.unwrap();
1265 assert!(tw.contains("Federal Income Tax"));
1266 assert!(tw.contains("State Income Tax"));
1267 assert!(labels.social_security.is_some());
1269 assert!(labels
1270 .social_security
1271 .unwrap()
1272 .contains("Federal Insurance Contributions Act"));
1273 }
1274}