Skip to main content

datasynth_generators/hr/
payroll_generator.rs

1//! Payroll generator for the Hire-to-Retire (H2R) process.
2//!
3//! Generates payroll runs with individual employee line items, computing
4//! gross pay (base salary + overtime + bonus), deductions (tax, social security,
5//! health insurance, retirement), and net pay.
6
7use chrono::NaiveDate;
8use datasynth_config::schema::PayrollConfig;
9use datasynth_core::country::schema::TaxBracket;
10use datasynth_core::models::{PayrollLineItem, PayrollRun, PayrollRunStatus};
11use datasynth_core::utils::{sample_decimal_range, seeded_rng};
12use datasynth_core::uuid_factory::{DeterministicUuidFactory, GeneratorType};
13use datasynth_core::CountryPack;
14use rand::prelude::*;
15use rand_chacha::ChaCha8Rng;
16use rust_decimal::Decimal;
17use tracing::debug;
18
19/// Resolved payroll deduction rates used during generation.
20#[derive(Debug, Clone)]
21struct PayrollRates {
22    /// Combined income tax rate (federal + state, or equivalent).
23    income_tax_rate: Decimal,
24    /// Progressive income tax brackets from country pack (if available).
25    income_tax_brackets: Vec<TaxBracket>,
26    /// Social security / FICA rate.
27    fica_rate: Decimal,
28    /// Employee health insurance rate.
29    health_rate: Decimal,
30    /// Employee retirement / pension rate.
31    retirement_rate: Decimal,
32    /// Employer-side social security matching rate.
33    employer_fica_rate: Decimal,
34}
35
36/// Country-pack-derived deduction labels applied to every line item in a run.
37#[derive(Debug, Clone, Default)]
38struct DeductionLabels {
39    tax_withholding: Option<String>,
40    social_security: Option<String>,
41    health_insurance: Option<String>,
42    retirement_contribution: Option<String>,
43    employer_contribution: Option<String>,
44}
45
46/// Generates [`PayrollRun`] and [`PayrollLineItem`] records from employee data.
47pub struct PayrollGenerator {
48    rng: ChaCha8Rng,
49    uuid_factory: DeterministicUuidFactory,
50    line_uuid_factory: DeterministicUuidFactory,
51    config: PayrollConfig,
52    country_pack: Option<CountryPack>,
53    /// Pool of real employee IDs for approved_by / posted_by references.
54    employee_ids_pool: Vec<String>,
55    /// Pool of real cost center IDs (unused directly here since cost_center
56    /// comes from the employee tuple, but kept for consistency).
57    cost_center_ids_pool: Vec<String>,
58}
59
60impl PayrollGenerator {
61    /// Create a new payroll generator with default configuration.
62    pub fn new(seed: u64) -> Self {
63        Self {
64            rng: seeded_rng(seed, 0),
65            uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::PayrollRun),
66            line_uuid_factory: DeterministicUuidFactory::with_sub_discriminator(
67                seed,
68                GeneratorType::PayrollRun,
69                1,
70            ),
71            config: PayrollConfig::default(),
72            country_pack: None,
73            employee_ids_pool: Vec::new(),
74            cost_center_ids_pool: Vec::new(),
75        }
76    }
77
78    /// Create a payroll generator with custom configuration.
79    pub fn with_config(seed: u64, config: PayrollConfig) -> Self {
80        Self {
81            rng: seeded_rng(seed, 0),
82            uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::PayrollRun),
83            line_uuid_factory: DeterministicUuidFactory::with_sub_discriminator(
84                seed,
85                GeneratorType::PayrollRun,
86                1,
87            ),
88            config,
89            country_pack: None,
90            employee_ids_pool: Vec::new(),
91            cost_center_ids_pool: Vec::new(),
92        }
93    }
94
95    /// Set ID pools for cross-reference coherence.
96    ///
97    /// When pools are non-empty, the generator selects `approved_by` and
98    /// `posted_by` from `employee_ids` instead of fabricating placeholder IDs.
99    pub fn with_pools(mut self, employee_ids: Vec<String>, cost_center_ids: Vec<String>) -> Self {
100        self.employee_ids_pool = employee_ids;
101        self.cost_center_ids_pool = cost_center_ids;
102        self
103    }
104
105    /// Set the country pack for localized deduction labels.
106    ///
107    /// When a country pack is set, generated [`PayrollLineItem`] records will
108    /// carry localized deduction labels derived from the pack's
109    /// `payroll.statutory_deductions` and `payroll.employer_contributions`.
110    /// The stored pack is also used by [`generate`] to resolve deduction rates,
111    /// so callers no longer need to pass the pack explicitly.
112    pub fn set_country_pack(&mut self, pack: CountryPack) {
113        self.country_pack = Some(pack);
114    }
115
116    /// Generate a payroll run and line items for the given employees and period.
117    ///
118    /// Uses tax rates from the [`PayrollConfig`] (defaults: 22% federal, 5% state,
119    /// 7.65% FICA, 3% health, 5% retirement).  If a country pack has been set via
120    /// [`set_country_pack`], the stored pack is used to resolve both rates and
121    /// localized deduction labels automatically.
122    ///
123    /// # Arguments
124    ///
125    /// * `company_code` - Company code owning the payroll
126    /// * `employees` - Slice of (employee_id, base_salary, cost_center, department)
127    /// * `period_start` - Start of the pay period (inclusive)
128    /// * `period_end` - End of the pay period (inclusive)
129    /// * `currency` - ISO 4217 currency code
130    pub fn generate(
131        &mut self,
132        company_code: &str,
133        employees: &[(String, Decimal, Option<String>, Option<String>)],
134        period_start: NaiveDate,
135        period_end: NaiveDate,
136        currency: &str,
137    ) -> (PayrollRun, Vec<PayrollLineItem>) {
138        debug!(company_code, employee_count = employees.len(), %period_start, %period_end, currency, "Generating payroll run");
139        if let Some(pack) = self.country_pack.as_ref() {
140            let rates = self.rates_from_country_pack(pack);
141            let labels = Self::labels_from_country_pack(pack);
142            self.generate_with_rates_and_labels(
143                company_code,
144                employees,
145                period_start,
146                period_end,
147                currency,
148                &rates,
149                &labels,
150            )
151        } else {
152            let rates = self.rates_from_config();
153            self.generate_with_rates_and_labels(
154                company_code,
155                employees,
156                period_start,
157                period_end,
158                currency,
159                &rates,
160                &DeductionLabels::default(),
161            )
162        }
163    }
164
165    /// Generate a payroll run using statutory deduction rates from a country pack.
166    ///
167    /// Iterates over `pack.payroll.statutory_deductions` to resolve rates by
168    /// deduction code / English name.  Any rate not found in the pack falls back
169    /// to the corresponding value from the generator's [`PayrollConfig`].
170    ///
171    /// # Deduction mapping
172    ///
173    /// | Pack code / `name_en` pattern              | Resolves to         |
174    /// |--------------------------------------------|---------------------|
175    /// | `FIT`, `LOHNST`, or `*Income Tax*` (not state) | federal income tax  |
176    /// | `SIT` or `*State Income Tax*`              | state income tax    |
177    /// | `FICA` or `*Social Security*`              | FICA / social security |
178    /// | `*Health Insurance*`                       | health insurance    |
179    /// | `*Pension*` or `*Retirement*`              | retirement / pension |
180    ///
181    /// For packs that have many small deductions (e.g. DE with pension, health,
182    /// unemployment, long-term care, solidarity surcharge, church tax), the rates
183    /// are summed into the closest category. Deductions not matching any category
184    /// above are accumulated into the FICA/social-security bucket.
185    pub fn generate_with_country_pack(
186        &mut self,
187        company_code: &str,
188        employees: &[(String, Decimal, Option<String>, Option<String>)],
189        period_start: NaiveDate,
190        period_end: NaiveDate,
191        currency: &str,
192        pack: &CountryPack,
193    ) -> (PayrollRun, Vec<PayrollLineItem>) {
194        let rates = self.rates_from_country_pack(pack);
195        let labels = Self::labels_from_country_pack(pack);
196        self.generate_with_rates_and_labels(
197            company_code,
198            employees,
199            period_start,
200            period_end,
201            currency,
202            &rates,
203            &labels,
204        )
205    }
206
207    // ------------------------------------------------------------------
208    // Private helpers
209    // ------------------------------------------------------------------
210
211    /// Build [`PayrollRates`] from the generator's config (original behaviour).
212    fn rates_from_config(&self) -> PayrollRates {
213        let federal_rate = Decimal::from_f64_retain(self.config.tax_rates.federal_effective)
214            .unwrap_or(Decimal::ZERO);
215        let state_rate = Decimal::from_f64_retain(self.config.tax_rates.state_effective)
216            .unwrap_or(Decimal::ZERO);
217        let fica_rate =
218            Decimal::from_f64_retain(self.config.tax_rates.fica).unwrap_or(Decimal::ZERO);
219
220        PayrollRates {
221            income_tax_rate: federal_rate + state_rate,
222            income_tax_brackets: Vec::new(),
223            fica_rate,
224            health_rate: Decimal::from_f64_retain(0.03).unwrap_or(Decimal::ZERO),
225            retirement_rate: Decimal::from_f64_retain(0.05).unwrap_or(Decimal::ZERO),
226            employer_fica_rate: fica_rate,
227        }
228    }
229
230    /// Compute progressive tax using marginal brackets.
231    ///
232    /// Iterates brackets in ascending order. Each bracket taxes only the
233    /// portion of income within that bracket at the bracket's rate. The
234    /// terminal bracket (no `up_to`) taxes all remaining income.
235    fn compute_progressive_tax(annual_income: Decimal, brackets: &[TaxBracket]) -> Decimal {
236        let mut total_tax = Decimal::ZERO;
237        let mut taxed_up_to = Decimal::ZERO;
238
239        for bracket in brackets {
240            let bracket_floor = bracket
241                .above
242                .and_then(Decimal::from_f64_retain)
243                .unwrap_or(taxed_up_to);
244            let bracket_rate = Decimal::from_f64_retain(bracket.rate).unwrap_or(Decimal::ZERO);
245
246            if annual_income <= bracket_floor {
247                break;
248            }
249
250            let taxable_in_bracket = if let Some(ceiling) = bracket.up_to {
251                let ceiling = Decimal::from_f64_retain(ceiling).unwrap_or(Decimal::ZERO);
252                (annual_income.min(ceiling) - bracket_floor).max(Decimal::ZERO)
253            } else {
254                // Terminal bracket — tax all remaining income
255                (annual_income - bracket_floor).max(Decimal::ZERO)
256            };
257
258            total_tax += (taxable_in_bracket * bracket_rate).round_dp(2);
259            taxed_up_to = bracket
260                .up_to
261                .and_then(Decimal::from_f64_retain)
262                .unwrap_or(annual_income);
263        }
264
265        total_tax.round_dp(2)
266    }
267
268    /// Build [`PayrollRates`] from a [`CountryPack`], falling back to config
269    /// values for any category not found.
270    fn rates_from_country_pack(&self, pack: &CountryPack) -> PayrollRates {
271        let fallback = self.rates_from_config();
272
273        // Accumulators – start at zero; we only use the fallback when a
274        // category has *no* matching deduction in the pack at all.
275        let mut federal_tax = Decimal::ZERO;
276        let mut state_tax = Decimal::ZERO;
277        let mut fica = Decimal::ZERO;
278        let mut health = Decimal::ZERO;
279        let mut retirement = Decimal::ZERO;
280
281        // Track which categories were populated from the pack.
282        let mut found_federal = false;
283        let mut found_state = false;
284        let mut found_fica = false;
285        let mut found_health = false;
286        let mut found_retirement = false;
287
288        for ded in &pack.payroll.statutory_deductions {
289            let code_upper = ded.code.to_uppercase();
290            let name_en_lower = ded.name_en.to_lowercase();
291            let rate = Decimal::from_f64_retain(ded.rate).unwrap_or(Decimal::ZERO);
292
293            // Progressive (bracket-based) income taxes have rate 0.0 as a
294            // placeholder. Mark the category as found so the config fallback
295            // is skipped — the actual tax will be computed per-employee from
296            // the bracket table in generate_with_rates_and_labels().
297            if (ded.deduction_type == "progressive" || ded.type_field == "progressive")
298                && ded.rate == 0.0
299            {
300                if code_upper == "FIT"
301                    || code_upper == "LOHNST"
302                    || (name_en_lower.contains("income tax") && !name_en_lower.contains("state"))
303                {
304                    found_federal = true;
305                }
306                continue;
307            }
308
309            if code_upper == "FIT"
310                || code_upper == "LOHNST"
311                || (name_en_lower.contains("income tax") && !name_en_lower.contains("state"))
312            {
313                federal_tax += rate;
314                found_federal = true;
315            } else if code_upper == "SIT" || name_en_lower.contains("state income tax") {
316                state_tax += rate;
317                found_state = true;
318            } else if code_upper == "FICA" || name_en_lower.contains("social security") {
319                fica += rate;
320                found_fica = true;
321            } else if name_en_lower.contains("health insurance") {
322                health += rate;
323                found_health = true;
324            } else if name_en_lower.contains("pension") || name_en_lower.contains("retirement") {
325                retirement += rate;
326                found_retirement = true;
327            } else {
328                // Unrecognised statutory deductions (solidarity surcharge,
329                // church tax, unemployment insurance, long-term care, etc.)
330                // are accumulated into the social-security / FICA bucket so
331                // that total deductions still reflect the country's burden.
332                fica += rate;
333                found_fica = true;
334            }
335        }
336
337        PayrollRates {
338            income_tax_rate: if found_federal || found_state {
339                let f = if found_federal {
340                    federal_tax
341                } else {
342                    fallback.income_tax_rate
343                        - Decimal::from_f64_retain(self.config.tax_rates.state_effective)
344                            .unwrap_or(Decimal::ZERO)
345                };
346                let s = if found_state {
347                    state_tax
348                } else {
349                    Decimal::from_f64_retain(self.config.tax_rates.state_effective)
350                        .unwrap_or(Decimal::ZERO)
351                };
352                f + s
353            } else {
354                fallback.income_tax_rate
355            },
356            income_tax_brackets: pack.tax.payroll_tax.income_tax_brackets.clone(),
357            fica_rate: if found_fica { fica } else { fallback.fica_rate },
358            health_rate: if found_health {
359                health
360            } else {
361                fallback.health_rate
362            },
363            retirement_rate: if found_retirement {
364                retirement
365            } else {
366                fallback.retirement_rate
367            },
368            employer_fica_rate: if found_fica {
369                fica
370            } else {
371                fallback.employer_fica_rate
372            },
373        }
374    }
375
376    /// Build [`DeductionLabels`] from a country pack.
377    ///
378    /// Walks the pack's `statutory_deductions` and `employer_contributions` and
379    /// picks the matching deduction's localized `name` (falling back to
380    /// `name_en` when `name` is empty) for each category.  The matching logic
381    /// mirrors [`rates_from_country_pack`] so labels and rates stay consistent.
382    fn labels_from_country_pack(pack: &CountryPack) -> DeductionLabels {
383        let mut labels = DeductionLabels::default();
384
385        for ded in &pack.payroll.statutory_deductions {
386            let code_upper = ded.code.to_uppercase();
387            let name_en_lower = ded.name_en.to_lowercase();
388
389            // Pick the best human-readable label: prefer localized `name`, fall
390            // back to `name_en`.
391            let label = if ded.name.is_empty() {
392                ded.name_en.clone()
393            } else {
394                ded.name.clone()
395            };
396            if label.is_empty() {
397                continue;
398            }
399
400            // For progressive placeholders (rate 0), still capture the label
401            // since the config-fallback rate will be used for the amount.
402            if (ded.deduction_type == "progressive" || ded.type_field == "progressive")
403                && ded.rate == 0.0
404            {
405                if code_upper == "FIT"
406                    || code_upper == "LOHNST"
407                    || (name_en_lower.contains("income tax") && !name_en_lower.contains("state"))
408                {
409                    if labels.tax_withholding.is_none() {
410                        labels.tax_withholding = Some(label);
411                    }
412                } else if code_upper == "SIT" || name_en_lower.contains("state income tax") {
413                    labels.tax_withholding = Some(match labels.tax_withholding.take() {
414                        Some(existing) => format!("{existing}; {label}"),
415                        None => label,
416                    });
417                }
418                continue;
419            }
420
421            if code_upper == "FIT"
422                || code_upper == "LOHNST"
423                || code_upper == "SIT"
424                || name_en_lower.contains("income tax")
425                || name_en_lower.contains("state income tax")
426            {
427                // All income-tax-related deductions (federal, state, combined)
428                // are grouped under the tax_withholding label.
429                labels.tax_withholding = Some(match labels.tax_withholding.take() {
430                    Some(existing) => format!("{existing}; {label}"),
431                    None => label,
432                });
433            } else if code_upper == "FICA" || name_en_lower.contains("social security") {
434                labels.social_security = Some(match labels.social_security.take() {
435                    Some(existing) => format!("{existing}; {label}"),
436                    None => label,
437                });
438            } else if name_en_lower.contains("health insurance") {
439                if labels.health_insurance.is_none() {
440                    labels.health_insurance = Some(label);
441                }
442            } else if name_en_lower.contains("pension") || name_en_lower.contains("retirement") {
443                if labels.retirement_contribution.is_none() {
444                    labels.retirement_contribution = Some(label);
445                }
446            } else {
447                // Misc deductions (unemployment, church tax, etc.) — append to
448                // social_security label since those rates go into that bucket.
449                labels.social_security = Some(match labels.social_security.take() {
450                    Some(existing) => format!("{existing}; {label}"),
451                    None => label,
452                });
453            }
454        }
455
456        // Employer contributions
457        let emp_labels: Vec<String> = pack
458            .payroll
459            .employer_contributions
460            .iter()
461            .filter_map(|c| {
462                let l = if c.name.is_empty() {
463                    c.name_en.clone()
464                } else {
465                    c.name.clone()
466                };
467                if l.is_empty() {
468                    None
469                } else {
470                    Some(l)
471                }
472            })
473            .collect();
474        if !emp_labels.is_empty() {
475            labels.employer_contribution = Some(emp_labels.join("; "));
476        }
477
478        labels
479    }
480
481    /// Core generation logic parameterised on resolved rates and labels.
482    fn generate_with_rates_and_labels(
483        &mut self,
484        company_code: &str,
485        employees: &[(String, Decimal, Option<String>, Option<String>)],
486        period_start: NaiveDate,
487        period_end: NaiveDate,
488        currency: &str,
489        rates: &PayrollRates,
490        labels: &DeductionLabels,
491    ) -> (PayrollRun, Vec<PayrollLineItem>) {
492        let payroll_id = self.uuid_factory.next().to_string();
493
494        let mut line_items = Vec::with_capacity(employees.len());
495        let mut total_gross = Decimal::ZERO;
496        let mut total_deductions = Decimal::ZERO;
497        let mut total_net = Decimal::ZERO;
498        let mut total_employer_cost = Decimal::ZERO;
499
500        let benefits_enrolled = self.config.benefits_enrollment_rate;
501        let retirement_participating = self.config.retirement_participation_rate;
502
503        for (employee_id, base_salary, cost_center, department) in employees {
504            let line_id = self.line_uuid_factory.next().to_string();
505
506            // Monthly base component (annual salary / 12)
507            let monthly_base = (*base_salary / Decimal::from(12)).round_dp(2);
508
509            // Overtime: 10% chance, 1-20 hours at 1.5x hourly rate
510            let (overtime_pay, overtime_hours) = if self.rng.random_bool(0.10) {
511                let ot_hours = self.rng.random_range(1.0..=20.0);
512                // Hourly rate = annual salary / (52 weeks * 40 hours)
513                let hourly_rate = *base_salary / Decimal::from(2080);
514                let ot_rate = hourly_rate * Decimal::from_f64_retain(1.5).unwrap_or(Decimal::ONE);
515                let ot_pay = (ot_rate
516                    * Decimal::from_f64_retain(ot_hours).unwrap_or(Decimal::ZERO))
517                .round_dp(2);
518                (ot_pay, ot_hours)
519            } else {
520                (Decimal::ZERO, 0.0)
521            };
522
523            // Bonus: 5% chance for a monthly bonus (1-10% of monthly base)
524            let bonus = if self.rng.random_bool(0.05) {
525                let pct = self.rng.random_range(0.01..=0.10);
526                (monthly_base * Decimal::from_f64_retain(pct).unwrap_or(Decimal::ZERO)).round_dp(2)
527            } else {
528                Decimal::ZERO
529            };
530
531            let gross_pay = monthly_base + overtime_pay + bonus;
532
533            // Deductions — use progressive brackets when available
534            let tax_withholding = if !rates.income_tax_brackets.is_empty() {
535                let annual = gross_pay * Decimal::from(12);
536                Self::compute_progressive_tax(annual, &rates.income_tax_brackets)
537                    / Decimal::from(12)
538            } else {
539                (gross_pay * rates.income_tax_rate).round_dp(2)
540            };
541            let social_security = (gross_pay * rates.fica_rate).round_dp(2);
542
543            let health_insurance = if self.rng.random_bool(benefits_enrolled) {
544                (gross_pay * rates.health_rate).round_dp(2)
545            } else {
546                Decimal::ZERO
547            };
548
549            let retirement_contribution = if self.rng.random_bool(retirement_participating) {
550                (gross_pay * rates.retirement_rate).round_dp(2)
551            } else {
552                Decimal::ZERO
553            };
554
555            // Small random other deductions (garnishments, etc.): ~3% chance
556            let other_deductions = if self.rng.random_bool(0.03) {
557                sample_decimal_range(&mut self.rng, Decimal::from(50), Decimal::from(500))
558                    .round_dp(2)
559            } else {
560                Decimal::ZERO
561            };
562
563            let total_ded = tax_withholding
564                + social_security
565                + health_insurance
566                + retirement_contribution
567                + other_deductions;
568            let net_pay = gross_pay - total_ded;
569
570            // Standard 160 regular hours per month (8h * 20 business days)
571            let hours_worked = 160.0;
572
573            // Employer-side cost: gross + employer contribution match
574            let employer_contrib = (gross_pay * rates.employer_fica_rate).round_dp(2);
575            let employer_cost = gross_pay + employer_contrib;
576
577            total_gross += gross_pay;
578            total_deductions += total_ded;
579            total_net += net_pay;
580            total_employer_cost += employer_cost;
581
582            line_items.push(PayrollLineItem {
583                payroll_id: payroll_id.clone(),
584                employee_id: employee_id.clone(),
585                line_id,
586                gross_pay,
587                base_salary: monthly_base,
588                overtime_pay,
589                bonus,
590                tax_withholding,
591                social_security,
592                health_insurance,
593                retirement_contribution,
594                other_deductions,
595                net_pay,
596                hours_worked,
597                overtime_hours,
598                pay_date: period_end,
599                cost_center: cost_center.clone(),
600                department: department.clone(),
601                tax_withholding_label: labels.tax_withholding.clone(),
602                social_security_label: labels.social_security.clone(),
603                health_insurance_label: labels.health_insurance.clone(),
604                retirement_contribution_label: labels.retirement_contribution.clone(),
605                employer_contribution_label: labels.employer_contribution.clone(),
606            });
607        }
608
609        // Determine status
610        let status_roll: f64 = self.rng.random();
611        let status = if status_roll < 0.60 {
612            PayrollRunStatus::Posted
613        } else if status_roll < 0.85 {
614            PayrollRunStatus::Approved
615        } else if status_roll < 0.95 {
616            PayrollRunStatus::Calculated
617        } else {
618            PayrollRunStatus::Draft
619        };
620
621        let approved_by = if matches!(
622            status,
623            PayrollRunStatus::Approved | PayrollRunStatus::Posted
624        ) {
625            if !self.employee_ids_pool.is_empty() {
626                let idx = self.rng.random_range(0..self.employee_ids_pool.len());
627                Some(self.employee_ids_pool[idx].clone())
628            } else {
629                Some(format!("USR-{:04}", self.rng.random_range(201..=400)))
630            }
631        } else {
632            None
633        };
634
635        let posted_by = if status == PayrollRunStatus::Posted {
636            if !self.employee_ids_pool.is_empty() {
637                let idx = self.rng.random_range(0..self.employee_ids_pool.len());
638                Some(self.employee_ids_pool[idx].clone())
639            } else {
640                Some(format!("USR-{:04}", self.rng.random_range(401..=500)))
641            }
642        } else {
643            None
644        };
645
646        let run = PayrollRun {
647            company_code: company_code.to_string(),
648            payroll_id: payroll_id.clone(),
649            pay_period_start: period_start,
650            pay_period_end: period_end,
651            run_date: period_end,
652            status,
653            total_gross,
654            total_deductions,
655            total_net,
656            total_employer_cost,
657            employee_count: employees.len() as u32,
658            currency: currency.to_string(),
659            posted_by,
660            approved_by,
661        };
662
663        (run, line_items)
664    }
665}
666
667#[cfg(test)]
668#[allow(clippy::unwrap_used)]
669mod tests {
670    use super::*;
671
672    fn test_employees() -> Vec<(String, Decimal, Option<String>, Option<String>)> {
673        vec![
674            (
675                "EMP-001".to_string(),
676                Decimal::from(60_000),
677                Some("CC-100".to_string()),
678                Some("Engineering".to_string()),
679            ),
680            (
681                "EMP-002".to_string(),
682                Decimal::from(85_000),
683                Some("CC-200".to_string()),
684                Some("Finance".to_string()),
685            ),
686            (
687                "EMP-003".to_string(),
688                Decimal::from(120_000),
689                None,
690                Some("Sales".to_string()),
691            ),
692        ]
693    }
694
695    #[test]
696    fn test_basic_payroll_generation() {
697        let mut gen = PayrollGenerator::new(42);
698        let employees = test_employees();
699        let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
700        let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
701
702        let (run, items) = gen.generate("C001", &employees, period_start, period_end, "USD");
703
704        assert_eq!(run.company_code, "C001");
705        assert_eq!(run.currency, "USD");
706        assert_eq!(run.employee_count, 3);
707        assert_eq!(items.len(), 3);
708        assert!(run.total_gross > Decimal::ZERO);
709        assert!(run.total_deductions > Decimal::ZERO);
710        assert!(run.total_net > Decimal::ZERO);
711        assert!(run.total_employer_cost > run.total_gross);
712        // net = gross - deductions
713        assert_eq!(run.total_net, run.total_gross - run.total_deductions);
714
715        for item in &items {
716            assert_eq!(item.payroll_id, run.payroll_id);
717            assert!(item.gross_pay > Decimal::ZERO);
718            assert!(item.net_pay > Decimal::ZERO);
719            assert!(item.net_pay < item.gross_pay);
720            assert!(item.base_salary > Decimal::ZERO);
721            assert_eq!(item.pay_date, period_end);
722            // Without country pack, labels should be None
723            assert!(item.tax_withholding_label.is_none());
724            assert!(item.social_security_label.is_none());
725        }
726    }
727
728    #[test]
729    fn test_deterministic_payroll() {
730        let employees = test_employees();
731        let period_start = NaiveDate::from_ymd_opt(2024, 3, 1).unwrap();
732        let period_end = NaiveDate::from_ymd_opt(2024, 3, 31).unwrap();
733
734        let mut gen1 = PayrollGenerator::new(42);
735        let (run1, items1) = gen1.generate("C001", &employees, period_start, period_end, "USD");
736
737        let mut gen2 = PayrollGenerator::new(42);
738        let (run2, items2) = gen2.generate("C001", &employees, period_start, period_end, "USD");
739
740        assert_eq!(run1.payroll_id, run2.payroll_id);
741        assert_eq!(run1.total_gross, run2.total_gross);
742        assert_eq!(run1.total_net, run2.total_net);
743        assert_eq!(run1.status, run2.status);
744        assert_eq!(items1.len(), items2.len());
745        for (a, b) in items1.iter().zip(items2.iter()) {
746            assert_eq!(a.line_id, b.line_id);
747            assert_eq!(a.gross_pay, b.gross_pay);
748            assert_eq!(a.net_pay, b.net_pay);
749        }
750    }
751
752    #[test]
753    fn test_payroll_deduction_components() {
754        let mut gen = PayrollGenerator::new(99);
755        let employees = vec![(
756            "EMP-010".to_string(),
757            Decimal::from(100_000),
758            Some("CC-300".to_string()),
759            Some("HR".to_string()),
760        )];
761        let period_start = NaiveDate::from_ymd_opt(2024, 6, 1).unwrap();
762        let period_end = NaiveDate::from_ymd_opt(2024, 6, 30).unwrap();
763
764        let (_run, items) = gen.generate("C001", &employees, period_start, period_end, "USD");
765        assert_eq!(items.len(), 1);
766
767        let item = &items[0];
768        // base_salary should be approximately 100000/12 = 8333.33
769        let expected_monthly = (Decimal::from(100_000) / Decimal::from(12)).round_dp(2);
770        assert_eq!(item.base_salary, expected_monthly);
771
772        // Deductions should sum correctly
773        let deduction_sum = item.tax_withholding
774            + item.social_security
775            + item.health_insurance
776            + item.retirement_contribution
777            + item.other_deductions;
778        let expected_net = item.gross_pay - deduction_sum;
779        assert_eq!(item.net_pay, expected_net);
780
781        // Tax withholding should be reasonable (22% federal + 5% state = 27% of gross)
782        assert!(item.tax_withholding > Decimal::ZERO);
783        assert!(item.social_security > Decimal::ZERO);
784    }
785
786    // ---------------------------------------------------------------
787    // Country-pack tests
788    // ---------------------------------------------------------------
789
790    /// Helper: build a US-like country pack with explicit statutory deductions.
791    fn us_country_pack() -> CountryPack {
792        use datasynth_core::country::schema::{PayrollCountryConfig, PayrollDeduction};
793        CountryPack {
794            country_code: "US".to_string(),
795            payroll: PayrollCountryConfig {
796                statutory_deductions: vec![
797                    PayrollDeduction {
798                        code: "FICA".to_string(),
799                        name_en: "Federal Insurance Contributions Act".to_string(),
800                        deduction_type: "percentage".to_string(),
801                        rate: 0.0765,
802                        ..Default::default()
803                    },
804                    PayrollDeduction {
805                        code: "FIT".to_string(),
806                        name_en: "Federal Income Tax".to_string(),
807                        deduction_type: "progressive".to_string(),
808                        rate: 0.0, // progressive placeholder
809                        ..Default::default()
810                    },
811                    PayrollDeduction {
812                        code: "SIT".to_string(),
813                        name_en: "State Income Tax".to_string(),
814                        deduction_type: "percentage".to_string(),
815                        rate: 0.05,
816                        ..Default::default()
817                    },
818                ],
819                ..Default::default()
820            },
821            ..Default::default()
822        }
823    }
824
825    /// Helper: build a DE-like country pack.
826    fn de_country_pack() -> CountryPack {
827        use datasynth_core::country::schema::{PayrollCountryConfig, PayrollDeduction};
828        CountryPack {
829            country_code: "DE".to_string(),
830            payroll: PayrollCountryConfig {
831                pay_frequency: "monthly".to_string(),
832                currency: "EUR".to_string(),
833                statutory_deductions: vec![
834                    PayrollDeduction {
835                        code: "LOHNST".to_string(),
836                        name_en: "Income Tax".to_string(),
837                        type_field: "progressive".to_string(),
838                        rate: 0.0, // progressive placeholder
839                        ..Default::default()
840                    },
841                    PayrollDeduction {
842                        code: "SOLI".to_string(),
843                        name_en: "Solidarity Surcharge".to_string(),
844                        type_field: "percentage".to_string(),
845                        rate: 0.055,
846                        ..Default::default()
847                    },
848                    PayrollDeduction {
849                        code: "KiSt".to_string(),
850                        name_en: "Church Tax".to_string(),
851                        type_field: "percentage".to_string(),
852                        rate: 0.08,
853                        optional: true,
854                        ..Default::default()
855                    },
856                    PayrollDeduction {
857                        code: "RV".to_string(),
858                        name_en: "Pension Insurance".to_string(),
859                        type_field: "percentage".to_string(),
860                        rate: 0.093,
861                        ..Default::default()
862                    },
863                    PayrollDeduction {
864                        code: "KV".to_string(),
865                        name_en: "Health Insurance".to_string(),
866                        type_field: "percentage".to_string(),
867                        rate: 0.073,
868                        ..Default::default()
869                    },
870                    PayrollDeduction {
871                        code: "AV".to_string(),
872                        name_en: "Unemployment Insurance".to_string(),
873                        type_field: "percentage".to_string(),
874                        rate: 0.013,
875                        ..Default::default()
876                    },
877                    PayrollDeduction {
878                        code: "PV".to_string(),
879                        name_en: "Long-Term Care Insurance".to_string(),
880                        type_field: "percentage".to_string(),
881                        rate: 0.017,
882                        ..Default::default()
883                    },
884                ],
885                employer_contributions: vec![
886                    PayrollDeduction {
887                        code: "AG-RV".to_string(),
888                        name_en: "Employer Pension Insurance".to_string(),
889                        type_field: "percentage".to_string(),
890                        rate: 0.093,
891                        ..Default::default()
892                    },
893                    PayrollDeduction {
894                        code: "AG-KV".to_string(),
895                        name_en: "Employer Health Insurance".to_string(),
896                        type_field: "percentage".to_string(),
897                        rate: 0.073,
898                        ..Default::default()
899                    },
900                ],
901                ..Default::default()
902            },
903            ..Default::default()
904        }
905    }
906
907    #[test]
908    fn test_generate_with_us_country_pack() {
909        let mut gen = PayrollGenerator::new(42);
910        let employees = test_employees();
911        let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
912        let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
913        let pack = us_country_pack();
914
915        let (run, items) = gen.generate_with_country_pack(
916            "C001",
917            &employees,
918            period_start,
919            period_end,
920            "USD",
921            &pack,
922        );
923
924        assert_eq!(run.company_code, "C001");
925        assert_eq!(run.employee_count, 3);
926        assert_eq!(items.len(), 3);
927        assert_eq!(run.total_net, run.total_gross - run.total_deductions);
928
929        for item in &items {
930            assert!(item.gross_pay > Decimal::ZERO);
931            assert!(item.net_pay > Decimal::ZERO);
932            assert!(item.net_pay < item.gross_pay);
933            // FICA deduction should be present
934            assert!(item.social_security > Decimal::ZERO);
935            // US pack should produce labels
936            assert!(item.tax_withholding_label.is_some());
937            assert!(item.social_security_label.is_some());
938        }
939    }
940
941    #[test]
942    fn test_generate_with_de_country_pack() {
943        let mut gen = PayrollGenerator::new(42);
944        let employees = test_employees();
945        let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
946        let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
947        let pack = de_country_pack();
948
949        let (run, items) = gen.generate_with_country_pack(
950            "DE01",
951            &employees,
952            period_start,
953            period_end,
954            "EUR",
955            &pack,
956        );
957
958        assert_eq!(run.company_code, "DE01");
959        assert_eq!(items.len(), 3);
960        assert_eq!(run.total_net, run.total_gross - run.total_deductions);
961
962        // DE pack should use pension rate 0.093 for retirement
963        // and health insurance rate 0.073
964        let rates = gen.rates_from_country_pack(&pack);
965        assert_eq!(
966            rates.retirement_rate,
967            Decimal::from_f64_retain(0.093).unwrap()
968        );
969        assert_eq!(rates.health_rate, Decimal::from_f64_retain(0.073).unwrap());
970
971        // Check DE labels are populated
972        let item = &items[0];
973        assert_eq!(
974            item.health_insurance_label.as_deref(),
975            Some("Health Insurance")
976        );
977        assert_eq!(
978            item.retirement_contribution_label.as_deref(),
979            Some("Pension Insurance")
980        );
981        // Employer contribution labels should include both AG-RV and AG-KV
982        assert!(item.employer_contribution_label.is_some());
983        let ec = item.employer_contribution_label.as_ref().unwrap();
984        assert!(ec.contains("Employer Pension Insurance"));
985        assert!(ec.contains("Employer Health Insurance"));
986    }
987
988    #[test]
989    fn test_country_pack_falls_back_to_config_for_missing_categories() {
990        // Empty pack: no statutory deductions => all rates fall back to config
991        let pack = CountryPack::default();
992        let gen = PayrollGenerator::new(42);
993        let rates_pack = gen.rates_from_country_pack(&pack);
994        let rates_cfg = gen.rates_from_config();
995
996        assert_eq!(rates_pack.income_tax_rate, rates_cfg.income_tax_rate);
997        assert_eq!(rates_pack.fica_rate, rates_cfg.fica_rate);
998        assert_eq!(rates_pack.health_rate, rates_cfg.health_rate);
999        assert_eq!(rates_pack.retirement_rate, rates_cfg.retirement_rate);
1000    }
1001
1002    #[test]
1003    fn test_country_pack_deterministic() {
1004        let employees = test_employees();
1005        let period_start = NaiveDate::from_ymd_opt(2024, 3, 1).unwrap();
1006        let period_end = NaiveDate::from_ymd_opt(2024, 3, 31).unwrap();
1007        let pack = de_country_pack();
1008
1009        let mut gen1 = PayrollGenerator::new(42);
1010        let (run1, items1) = gen1.generate_with_country_pack(
1011            "DE01",
1012            &employees,
1013            period_start,
1014            period_end,
1015            "EUR",
1016            &pack,
1017        );
1018
1019        let mut gen2 = PayrollGenerator::new(42);
1020        let (run2, items2) = gen2.generate_with_country_pack(
1021            "DE01",
1022            &employees,
1023            period_start,
1024            period_end,
1025            "EUR",
1026            &pack,
1027        );
1028
1029        assert_eq!(run1.payroll_id, run2.payroll_id);
1030        assert_eq!(run1.total_gross, run2.total_gross);
1031        assert_eq!(run1.total_net, run2.total_net);
1032        for (a, b) in items1.iter().zip(items2.iter()) {
1033            assert_eq!(a.net_pay, b.net_pay);
1034        }
1035    }
1036
1037    #[test]
1038    fn test_de_rates_differ_from_default() {
1039        // With the DE pack, the resolved rates should differ from config defaults
1040        let gen = PayrollGenerator::new(42);
1041        let pack = de_country_pack();
1042        let rates_cfg = gen.rates_from_config();
1043        let rates_de = gen.rates_from_country_pack(&pack);
1044
1045        // DE has no non-progressive income tax in pack -> income_tax_rate falls
1046        // back to config default for federal part.
1047        // But health (0.073 vs 0.03) and retirement (0.093 vs 0.05) should differ.
1048        assert_ne!(rates_de.health_rate, rates_cfg.health_rate);
1049        assert_ne!(rates_de.retirement_rate, rates_cfg.retirement_rate);
1050    }
1051
1052    #[test]
1053    fn test_set_country_pack_uses_labels() {
1054        let mut gen = PayrollGenerator::new(42);
1055        let pack = de_country_pack();
1056        gen.set_country_pack(pack);
1057
1058        let employees = test_employees();
1059        let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
1060        let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
1061
1062        // generate() should now use the stored pack for rates + labels
1063        let (_run, items) = gen.generate("DE01", &employees, period_start, period_end, "EUR");
1064
1065        let item = &items[0];
1066        // Labels should be populated from the DE pack
1067        assert!(item.tax_withholding_label.is_some());
1068        assert!(item.health_insurance_label.is_some());
1069        assert!(item.retirement_contribution_label.is_some());
1070        assert!(item.employer_contribution_label.is_some());
1071    }
1072
1073    #[test]
1074    fn test_compute_progressive_tax_us_brackets() {
1075        // Simplified US-style brackets for testing
1076        let brackets = vec![
1077            TaxBracket {
1078                above: Some(0.0),
1079                up_to: Some(11_000.0),
1080                rate: 0.10,
1081            },
1082            TaxBracket {
1083                above: Some(11_000.0),
1084                up_to: Some(44_725.0),
1085                rate: 0.12,
1086            },
1087            TaxBracket {
1088                above: Some(44_725.0),
1089                up_to: Some(95_375.0),
1090                rate: 0.22,
1091            },
1092            TaxBracket {
1093                above: Some(95_375.0),
1094                up_to: None,
1095                rate: 0.24,
1096            },
1097        ];
1098
1099        // $60,000 income
1100        let tax = PayrollGenerator::compute_progressive_tax(Decimal::from(60_000), &brackets);
1101        // 11,000 * 0.10 = 1,100
1102        // (44,725 - 11,000) * 0.12 = 4,047
1103        // (60,000 - 44,725) * 0.22 = 3,360.50
1104        // Total = 8,507.50
1105        assert_eq!(tax, Decimal::from_f64_retain(8507.50).unwrap());
1106
1107        // $11,000 income — only first bracket
1108        let tax = PayrollGenerator::compute_progressive_tax(Decimal::from(11_000), &brackets);
1109        assert_eq!(tax, Decimal::from_f64_retain(1100.0).unwrap());
1110    }
1111
1112    #[test]
1113    fn test_progressive_tax_zero_income() {
1114        let brackets = vec![TaxBracket {
1115            above: Some(0.0),
1116            up_to: Some(10_000.0),
1117            rate: 0.10,
1118        }];
1119        let tax = PayrollGenerator::compute_progressive_tax(Decimal::ZERO, &brackets);
1120        assert_eq!(tax, Decimal::ZERO);
1121    }
1122
1123    #[test]
1124    fn test_us_pack_employees_have_varying_rates() {
1125        use datasynth_core::country::schema::{
1126            CountryTaxConfig, PayrollCountryConfig, PayrollDeduction, PayrollTaxBracketsConfig,
1127        };
1128
1129        let brackets = vec![
1130            TaxBracket {
1131                above: Some(0.0),
1132                up_to: Some(11_000.0),
1133                rate: 0.10,
1134            },
1135            TaxBracket {
1136                above: Some(11_000.0),
1137                up_to: Some(44_725.0),
1138                rate: 0.12,
1139            },
1140            TaxBracket {
1141                above: Some(44_725.0),
1142                up_to: None,
1143                rate: 0.22,
1144            },
1145        ];
1146        let pack = CountryPack {
1147            country_code: "US".to_string(),
1148            payroll: PayrollCountryConfig {
1149                statutory_deductions: vec![
1150                    PayrollDeduction {
1151                        code: "FIT".to_string(),
1152                        name_en: "Federal Income Tax".to_string(),
1153                        deduction_type: "progressive".to_string(),
1154                        rate: 0.0,
1155                        ..Default::default()
1156                    },
1157                    PayrollDeduction {
1158                        code: "FICA".to_string(),
1159                        name_en: "Social Security".to_string(),
1160                        deduction_type: "percentage".to_string(),
1161                        rate: 0.0765,
1162                        ..Default::default()
1163                    },
1164                ],
1165                ..Default::default()
1166            },
1167            tax: CountryTaxConfig {
1168                payroll_tax: PayrollTaxBracketsConfig {
1169                    income_tax_brackets: brackets,
1170                    ..Default::default()
1171                },
1172                ..Default::default()
1173            },
1174            ..Default::default()
1175        };
1176
1177        let mut gen = PayrollGenerator::new(42);
1178        gen.set_country_pack(pack);
1179
1180        // Low earner vs high earner
1181        let low_earner = vec![("LOW".to_string(), Decimal::from(30_000), None, None)];
1182        let high_earner = vec![("HIGH".to_string(), Decimal::from(200_000), None, None)];
1183
1184        let start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
1185        let end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
1186
1187        let (_, low_items) = gen.generate("C001", &low_earner, start, end, "USD");
1188        let mut gen2 = PayrollGenerator::new(42);
1189        gen2.set_country_pack(CountryPack {
1190            country_code: "US".to_string(),
1191            payroll: PayrollCountryConfig {
1192                statutory_deductions: vec![
1193                    PayrollDeduction {
1194                        code: "FIT".to_string(),
1195                        name_en: "Federal Income Tax".to_string(),
1196                        deduction_type: "progressive".to_string(),
1197                        rate: 0.0,
1198                        ..Default::default()
1199                    },
1200                    PayrollDeduction {
1201                        code: "FICA".to_string(),
1202                        name_en: "Social Security".to_string(),
1203                        deduction_type: "percentage".to_string(),
1204                        rate: 0.0765,
1205                        ..Default::default()
1206                    },
1207                ],
1208                ..Default::default()
1209            },
1210            tax: CountryTaxConfig {
1211                payroll_tax: PayrollTaxBracketsConfig {
1212                    income_tax_brackets: vec![
1213                        TaxBracket {
1214                            above: Some(0.0),
1215                            up_to: Some(11_000.0),
1216                            rate: 0.10,
1217                        },
1218                        TaxBracket {
1219                            above: Some(11_000.0),
1220                            up_to: Some(44_725.0),
1221                            rate: 0.12,
1222                        },
1223                        TaxBracket {
1224                            above: Some(44_725.0),
1225                            up_to: None,
1226                            rate: 0.22,
1227                        },
1228                    ],
1229                    ..Default::default()
1230                },
1231                ..Default::default()
1232            },
1233            ..Default::default()
1234        });
1235        let (_, high_items) = gen2.generate("C001", &high_earner, start, end, "USD");
1236
1237        let low_eff = low_items[0].tax_withholding / low_items[0].gross_pay;
1238        let high_eff = high_items[0].tax_withholding / high_items[0].gross_pay;
1239
1240        // High earner should have a higher effective tax rate
1241        assert!(
1242            high_eff > low_eff,
1243            "High earner effective rate ({high_eff}) should exceed low earner ({low_eff})"
1244        );
1245    }
1246
1247    #[test]
1248    fn test_empty_pack_labels_are_none() {
1249        let pack = CountryPack::default();
1250        let labels = PayrollGenerator::labels_from_country_pack(&pack);
1251        assert!(labels.tax_withholding.is_none());
1252        assert!(labels.social_security.is_none());
1253        assert!(labels.health_insurance.is_none());
1254        assert!(labels.retirement_contribution.is_none());
1255        assert!(labels.employer_contribution.is_none());
1256    }
1257
1258    #[test]
1259    fn test_us_pack_labels() {
1260        let pack = us_country_pack();
1261        let labels = PayrollGenerator::labels_from_country_pack(&pack);
1262        // FIT is a progressive placeholder but label is still captured
1263        assert!(labels.tax_withholding.is_some());
1264        let tw = labels.tax_withholding.unwrap();
1265        assert!(tw.contains("Federal Income Tax"));
1266        assert!(tw.contains("State Income Tax"));
1267        // FICA label
1268        assert!(labels.social_security.is_some());
1269        assert!(labels
1270            .social_security
1271            .unwrap()
1272            .contains("Federal Insurance Contributions Act"));
1273    }
1274}