Skip to main content

datasynth_generators/hr/
payroll_generator.rs

1//! Payroll generator for the Hire-to-Retire (H2R) process.
2//!
3//! Generates payroll runs with individual employee line items, computing
4//! gross pay (base salary + overtime + bonus), deductions (tax, social security,
5//! health insurance, retirement), and net pay.
6
7use chrono::NaiveDate;
8use datasynth_config::schema::PayrollConfig;
9use datasynth_core::models::{PayrollLineItem, PayrollRun, PayrollRunStatus};
10use datasynth_core::utils::{sample_decimal_range, seeded_rng};
11use datasynth_core::uuid_factory::{DeterministicUuidFactory, GeneratorType};
12use datasynth_core::CountryPack;
13use rand::prelude::*;
14use rand_chacha::ChaCha8Rng;
15use rust_decimal::Decimal;
16use tracing::debug;
17
18/// Resolved payroll deduction rates used during generation.
19#[derive(Debug, Clone)]
20struct PayrollRates {
21    /// Combined income tax rate (federal + state, or equivalent).
22    income_tax_rate: Decimal,
23    /// Social security / FICA rate.
24    fica_rate: Decimal,
25    /// Employee health insurance rate.
26    health_rate: Decimal,
27    /// Employee retirement / pension rate.
28    retirement_rate: Decimal,
29    /// Employer-side social security matching rate.
30    employer_fica_rate: Decimal,
31}
32
33/// Country-pack-derived deduction labels applied to every line item in a run.
34#[derive(Debug, Clone, Default)]
35struct DeductionLabels {
36    tax_withholding: Option<String>,
37    social_security: Option<String>,
38    health_insurance: Option<String>,
39    retirement_contribution: Option<String>,
40    employer_contribution: Option<String>,
41}
42
43/// Generates [`PayrollRun`] and [`PayrollLineItem`] records from employee data.
44pub struct PayrollGenerator {
45    rng: ChaCha8Rng,
46    uuid_factory: DeterministicUuidFactory,
47    line_uuid_factory: DeterministicUuidFactory,
48    config: PayrollConfig,
49    country_pack: Option<CountryPack>,
50    /// Pool of real employee IDs for approved_by / posted_by references.
51    employee_ids_pool: Vec<String>,
52    /// Pool of real cost center IDs (unused directly here since cost_center
53    /// comes from the employee tuple, but kept for consistency).
54    cost_center_ids_pool: Vec<String>,
55}
56
57impl PayrollGenerator {
58    /// Create a new payroll generator with default configuration.
59    pub fn new(seed: u64) -> Self {
60        Self {
61            rng: seeded_rng(seed, 0),
62            uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::PayrollRun),
63            line_uuid_factory: DeterministicUuidFactory::with_sub_discriminator(
64                seed,
65                GeneratorType::PayrollRun,
66                1,
67            ),
68            config: PayrollConfig::default(),
69            country_pack: None,
70            employee_ids_pool: Vec::new(),
71            cost_center_ids_pool: Vec::new(),
72        }
73    }
74
75    /// Create a payroll generator with custom configuration.
76    pub fn with_config(seed: u64, config: PayrollConfig) -> Self {
77        Self {
78            rng: seeded_rng(seed, 0),
79            uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::PayrollRun),
80            line_uuid_factory: DeterministicUuidFactory::with_sub_discriminator(
81                seed,
82                GeneratorType::PayrollRun,
83                1,
84            ),
85            config,
86            country_pack: None,
87            employee_ids_pool: Vec::new(),
88            cost_center_ids_pool: Vec::new(),
89        }
90    }
91
92    /// Set ID pools for cross-reference coherence.
93    ///
94    /// When pools are non-empty, the generator selects `approved_by` and
95    /// `posted_by` from `employee_ids` instead of fabricating placeholder IDs.
96    pub fn with_pools(mut self, employee_ids: Vec<String>, cost_center_ids: Vec<String>) -> Self {
97        self.employee_ids_pool = employee_ids;
98        self.cost_center_ids_pool = cost_center_ids;
99        self
100    }
101
102    /// Set the country pack for localized deduction labels.
103    ///
104    /// When a country pack is set, generated [`PayrollLineItem`] records will
105    /// carry localized deduction labels derived from the pack's
106    /// `payroll.statutory_deductions` and `payroll.employer_contributions`.
107    /// The stored pack is also used by [`generate`] to resolve deduction rates,
108    /// so callers no longer need to pass the pack explicitly.
109    pub fn set_country_pack(&mut self, pack: CountryPack) {
110        self.country_pack = Some(pack);
111    }
112
113    /// Generate a payroll run and line items for the given employees and period.
114    ///
115    /// Uses tax rates from the [`PayrollConfig`] (defaults: 22% federal, 5% state,
116    /// 7.65% FICA, 3% health, 5% retirement).  If a country pack has been set via
117    /// [`set_country_pack`], the stored pack is used to resolve both rates and
118    /// localized deduction labels automatically.
119    ///
120    /// # Arguments
121    ///
122    /// * `company_code` - Company code owning the payroll
123    /// * `employees` - Slice of (employee_id, base_salary, cost_center, department)
124    /// * `period_start` - Start of the pay period (inclusive)
125    /// * `period_end` - End of the pay period (inclusive)
126    /// * `currency` - ISO 4217 currency code
127    pub fn generate(
128        &mut self,
129        company_code: &str,
130        employees: &[(String, Decimal, Option<String>, Option<String>)],
131        period_start: NaiveDate,
132        period_end: NaiveDate,
133        currency: &str,
134    ) -> (PayrollRun, Vec<PayrollLineItem>) {
135        debug!(company_code, employee_count = employees.len(), %period_start, %period_end, currency, "Generating payroll run");
136        if let Some(pack) = self.country_pack.as_ref() {
137            let rates = self.rates_from_country_pack(pack);
138            let labels = Self::labels_from_country_pack(pack);
139            self.generate_with_rates_and_labels(
140                company_code,
141                employees,
142                period_start,
143                period_end,
144                currency,
145                &rates,
146                &labels,
147            )
148        } else {
149            let rates = self.rates_from_config();
150            self.generate_with_rates_and_labels(
151                company_code,
152                employees,
153                period_start,
154                period_end,
155                currency,
156                &rates,
157                &DeductionLabels::default(),
158            )
159        }
160    }
161
162    /// Generate a payroll run using statutory deduction rates from a country pack.
163    ///
164    /// Iterates over `pack.payroll.statutory_deductions` to resolve rates by
165    /// deduction code / English name.  Any rate not found in the pack falls back
166    /// to the corresponding value from the generator's [`PayrollConfig`].
167    ///
168    /// # Deduction mapping
169    ///
170    /// | Pack code / `name_en` pattern              | Resolves to         |
171    /// |--------------------------------------------|---------------------|
172    /// | `FIT`, `LOHNST`, or `*Income Tax*` (not state) | federal income tax  |
173    /// | `SIT` or `*State Income Tax*`              | state income tax    |
174    /// | `FICA` or `*Social Security*`              | FICA / social security |
175    /// | `*Health Insurance*`                       | health insurance    |
176    /// | `*Pension*` or `*Retirement*`              | retirement / pension |
177    ///
178    /// For packs that have many small deductions (e.g. DE with pension, health,
179    /// unemployment, long-term care, solidarity surcharge, church tax), the rates
180    /// are summed into the closest category. Deductions not matching any category
181    /// above are accumulated into the FICA/social-security bucket.
182    pub fn generate_with_country_pack(
183        &mut self,
184        company_code: &str,
185        employees: &[(String, Decimal, Option<String>, Option<String>)],
186        period_start: NaiveDate,
187        period_end: NaiveDate,
188        currency: &str,
189        pack: &CountryPack,
190    ) -> (PayrollRun, Vec<PayrollLineItem>) {
191        let rates = self.rates_from_country_pack(pack);
192        let labels = Self::labels_from_country_pack(pack);
193        self.generate_with_rates_and_labels(
194            company_code,
195            employees,
196            period_start,
197            period_end,
198            currency,
199            &rates,
200            &labels,
201        )
202    }
203
204    // ------------------------------------------------------------------
205    // Private helpers
206    // ------------------------------------------------------------------
207
208    /// Build [`PayrollRates`] from the generator's config (original behaviour).
209    fn rates_from_config(&self) -> PayrollRates {
210        let federal_rate = Decimal::from_f64_retain(self.config.tax_rates.federal_effective)
211            .unwrap_or(Decimal::ZERO);
212        let state_rate = Decimal::from_f64_retain(self.config.tax_rates.state_effective)
213            .unwrap_or(Decimal::ZERO);
214        let fica_rate =
215            Decimal::from_f64_retain(self.config.tax_rates.fica).unwrap_or(Decimal::ZERO);
216
217        PayrollRates {
218            income_tax_rate: federal_rate + state_rate,
219            fica_rate,
220            health_rate: Decimal::from_f64_retain(0.03).unwrap_or(Decimal::ZERO),
221            retirement_rate: Decimal::from_f64_retain(0.05).unwrap_or(Decimal::ZERO),
222            employer_fica_rate: fica_rate,
223        }
224    }
225
226    /// Build [`PayrollRates`] from a [`CountryPack`], falling back to config
227    /// values for any category not found.
228    fn rates_from_country_pack(&self, pack: &CountryPack) -> PayrollRates {
229        let fallback = self.rates_from_config();
230
231        // Accumulators – start at zero; we only use the fallback when a
232        // category has *no* matching deduction in the pack at all.
233        let mut federal_tax = Decimal::ZERO;
234        let mut state_tax = Decimal::ZERO;
235        let mut fica = Decimal::ZERO;
236        let mut health = Decimal::ZERO;
237        let mut retirement = Decimal::ZERO;
238
239        // Track which categories were populated from the pack.
240        let mut found_federal = false;
241        let mut found_state = false;
242        let mut found_fica = false;
243        let mut found_health = false;
244        let mut found_retirement = false;
245
246        for ded in &pack.payroll.statutory_deductions {
247            let code_upper = ded.code.to_uppercase();
248            let name_en_lower = ded.name_en.to_lowercase();
249            let rate = Decimal::from_f64_retain(ded.rate).unwrap_or(Decimal::ZERO);
250
251            // Skip progressive (bracket-based) income taxes that have rate 0.0
252            // in the pack — these are placeholders indicating bracket lookup is
253            // needed. We will fall back to the config's effective rate instead.
254            if (ded.deduction_type == "progressive" || ded.type_field == "progressive")
255                && ded.rate == 0.0
256            {
257                continue;
258            }
259
260            if code_upper == "FIT"
261                || code_upper == "LOHNST"
262                || (name_en_lower.contains("income tax") && !name_en_lower.contains("state"))
263            {
264                federal_tax += rate;
265                found_federal = true;
266            } else if code_upper == "SIT" || name_en_lower.contains("state income tax") {
267                state_tax += rate;
268                found_state = true;
269            } else if code_upper == "FICA" || name_en_lower.contains("social security") {
270                fica += rate;
271                found_fica = true;
272            } else if name_en_lower.contains("health insurance") {
273                health += rate;
274                found_health = true;
275            } else if name_en_lower.contains("pension") || name_en_lower.contains("retirement") {
276                retirement += rate;
277                found_retirement = true;
278            } else {
279                // Unrecognised statutory deductions (solidarity surcharge,
280                // church tax, unemployment insurance, long-term care, etc.)
281                // are accumulated into the social-security / FICA bucket so
282                // that total deductions still reflect the country's burden.
283                fica += rate;
284                found_fica = true;
285            }
286        }
287
288        PayrollRates {
289            income_tax_rate: if found_federal || found_state {
290                let f = if found_federal {
291                    federal_tax
292                } else {
293                    fallback.income_tax_rate
294                        - Decimal::from_f64_retain(self.config.tax_rates.state_effective)
295                            .unwrap_or(Decimal::ZERO)
296                };
297                let s = if found_state {
298                    state_tax
299                } else {
300                    Decimal::from_f64_retain(self.config.tax_rates.state_effective)
301                        .unwrap_or(Decimal::ZERO)
302                };
303                f + s
304            } else {
305                fallback.income_tax_rate
306            },
307            fica_rate: if found_fica { fica } else { fallback.fica_rate },
308            health_rate: if found_health {
309                health
310            } else {
311                fallback.health_rate
312            },
313            retirement_rate: if found_retirement {
314                retirement
315            } else {
316                fallback.retirement_rate
317            },
318            employer_fica_rate: if found_fica {
319                fica
320            } else {
321                fallback.employer_fica_rate
322            },
323        }
324    }
325
326    /// Build [`DeductionLabels`] from a country pack.
327    ///
328    /// Walks the pack's `statutory_deductions` and `employer_contributions` and
329    /// picks the matching deduction's localized `name` (falling back to
330    /// `name_en` when `name` is empty) for each category.  The matching logic
331    /// mirrors [`rates_from_country_pack`] so labels and rates stay consistent.
332    fn labels_from_country_pack(pack: &CountryPack) -> DeductionLabels {
333        let mut labels = DeductionLabels::default();
334
335        for ded in &pack.payroll.statutory_deductions {
336            let code_upper = ded.code.to_uppercase();
337            let name_en_lower = ded.name_en.to_lowercase();
338
339            // Pick the best human-readable label: prefer localized `name`, fall
340            // back to `name_en`.
341            let label = if ded.name.is_empty() {
342                ded.name_en.clone()
343            } else {
344                ded.name.clone()
345            };
346            if label.is_empty() {
347                continue;
348            }
349
350            // For progressive placeholders (rate 0), still capture the label
351            // since the config-fallback rate will be used for the amount.
352            if (ded.deduction_type == "progressive" || ded.type_field == "progressive")
353                && ded.rate == 0.0
354            {
355                if code_upper == "FIT"
356                    || code_upper == "LOHNST"
357                    || (name_en_lower.contains("income tax") && !name_en_lower.contains("state"))
358                {
359                    if labels.tax_withholding.is_none() {
360                        labels.tax_withholding = Some(label);
361                    }
362                } else if code_upper == "SIT" || name_en_lower.contains("state income tax") {
363                    labels.tax_withholding = Some(match labels.tax_withholding.take() {
364                        Some(existing) => format!("{existing}; {label}"),
365                        None => label,
366                    });
367                }
368                continue;
369            }
370
371            if code_upper == "FIT"
372                || code_upper == "LOHNST"
373                || code_upper == "SIT"
374                || name_en_lower.contains("income tax")
375                || name_en_lower.contains("state income tax")
376            {
377                // All income-tax-related deductions (federal, state, combined)
378                // are grouped under the tax_withholding label.
379                labels.tax_withholding = Some(match labels.tax_withholding.take() {
380                    Some(existing) => format!("{existing}; {label}"),
381                    None => label,
382                });
383            } else if code_upper == "FICA" || name_en_lower.contains("social security") {
384                labels.social_security = Some(match labels.social_security.take() {
385                    Some(existing) => format!("{existing}; {label}"),
386                    None => label,
387                });
388            } else if name_en_lower.contains("health insurance") {
389                if labels.health_insurance.is_none() {
390                    labels.health_insurance = Some(label);
391                }
392            } else if name_en_lower.contains("pension") || name_en_lower.contains("retirement") {
393                if labels.retirement_contribution.is_none() {
394                    labels.retirement_contribution = Some(label);
395                }
396            } else {
397                // Misc deductions (unemployment, church tax, etc.) — append to
398                // social_security label since those rates go into that bucket.
399                labels.social_security = Some(match labels.social_security.take() {
400                    Some(existing) => format!("{existing}; {label}"),
401                    None => label,
402                });
403            }
404        }
405
406        // Employer contributions
407        let emp_labels: Vec<String> = pack
408            .payroll
409            .employer_contributions
410            .iter()
411            .filter_map(|c| {
412                let l = if c.name.is_empty() {
413                    c.name_en.clone()
414                } else {
415                    c.name.clone()
416                };
417                if l.is_empty() {
418                    None
419                } else {
420                    Some(l)
421                }
422            })
423            .collect();
424        if !emp_labels.is_empty() {
425            labels.employer_contribution = Some(emp_labels.join("; "));
426        }
427
428        labels
429    }
430
431    /// Core generation logic parameterised on resolved rates and labels.
432    fn generate_with_rates_and_labels(
433        &mut self,
434        company_code: &str,
435        employees: &[(String, Decimal, Option<String>, Option<String>)],
436        period_start: NaiveDate,
437        period_end: NaiveDate,
438        currency: &str,
439        rates: &PayrollRates,
440        labels: &DeductionLabels,
441    ) -> (PayrollRun, Vec<PayrollLineItem>) {
442        let payroll_id = self.uuid_factory.next().to_string();
443
444        let mut line_items = Vec::with_capacity(employees.len());
445        let mut total_gross = Decimal::ZERO;
446        let mut total_deductions = Decimal::ZERO;
447        let mut total_net = Decimal::ZERO;
448        let mut total_employer_cost = Decimal::ZERO;
449
450        let benefits_enrolled = self.config.benefits_enrollment_rate;
451        let retirement_participating = self.config.retirement_participation_rate;
452
453        for (employee_id, base_salary, cost_center, department) in employees {
454            let line_id = self.line_uuid_factory.next().to_string();
455
456            // Monthly base component (annual salary / 12)
457            let monthly_base = (*base_salary / Decimal::from(12)).round_dp(2);
458
459            // Overtime: 10% chance, 1-20 hours at 1.5x hourly rate
460            let (overtime_pay, overtime_hours) = if self.rng.gen_bool(0.10) {
461                let ot_hours = self.rng.gen_range(1.0..=20.0);
462                // Hourly rate = annual salary / (52 weeks * 40 hours)
463                let hourly_rate = *base_salary / Decimal::from(2080);
464                let ot_rate = hourly_rate * Decimal::from_f64_retain(1.5).unwrap_or(Decimal::ONE);
465                let ot_pay = (ot_rate
466                    * Decimal::from_f64_retain(ot_hours).unwrap_or(Decimal::ZERO))
467                .round_dp(2);
468                (ot_pay, ot_hours)
469            } else {
470                (Decimal::ZERO, 0.0)
471            };
472
473            // Bonus: 5% chance for a monthly bonus (1-10% of monthly base)
474            let bonus = if self.rng.gen_bool(0.05) {
475                let pct = self.rng.gen_range(0.01..=0.10);
476                (monthly_base * Decimal::from_f64_retain(pct).unwrap_or(Decimal::ZERO)).round_dp(2)
477            } else {
478                Decimal::ZERO
479            };
480
481            let gross_pay = monthly_base + overtime_pay + bonus;
482
483            // Deductions
484            let tax_withholding = (gross_pay * rates.income_tax_rate).round_dp(2);
485            let social_security = (gross_pay * rates.fica_rate).round_dp(2);
486
487            let health_insurance = if self.rng.gen_bool(benefits_enrolled) {
488                (gross_pay * rates.health_rate).round_dp(2)
489            } else {
490                Decimal::ZERO
491            };
492
493            let retirement_contribution = if self.rng.gen_bool(retirement_participating) {
494                (gross_pay * rates.retirement_rate).round_dp(2)
495            } else {
496                Decimal::ZERO
497            };
498
499            // Small random other deductions (garnishments, etc.): ~3% chance
500            let other_deductions = if self.rng.gen_bool(0.03) {
501                sample_decimal_range(&mut self.rng, Decimal::from(50), Decimal::from(500))
502                    .round_dp(2)
503            } else {
504                Decimal::ZERO
505            };
506
507            let total_ded = tax_withholding
508                + social_security
509                + health_insurance
510                + retirement_contribution
511                + other_deductions;
512            let net_pay = gross_pay - total_ded;
513
514            // Standard 160 regular hours per month (8h * 20 business days)
515            let hours_worked = 160.0;
516
517            // Employer-side cost: gross + employer contribution match
518            let employer_contrib = (gross_pay * rates.employer_fica_rate).round_dp(2);
519            let employer_cost = gross_pay + employer_contrib;
520
521            total_gross += gross_pay;
522            total_deductions += total_ded;
523            total_net += net_pay;
524            total_employer_cost += employer_cost;
525
526            line_items.push(PayrollLineItem {
527                payroll_id: payroll_id.clone(),
528                employee_id: employee_id.clone(),
529                line_id,
530                gross_pay,
531                base_salary: monthly_base,
532                overtime_pay,
533                bonus,
534                tax_withholding,
535                social_security,
536                health_insurance,
537                retirement_contribution,
538                other_deductions,
539                net_pay,
540                hours_worked,
541                overtime_hours,
542                pay_date: period_end,
543                cost_center: cost_center.clone(),
544                department: department.clone(),
545                tax_withholding_label: labels.tax_withholding.clone(),
546                social_security_label: labels.social_security.clone(),
547                health_insurance_label: labels.health_insurance.clone(),
548                retirement_contribution_label: labels.retirement_contribution.clone(),
549                employer_contribution_label: labels.employer_contribution.clone(),
550            });
551        }
552
553        // Determine status
554        let status_roll: f64 = self.rng.gen();
555        let status = if status_roll < 0.60 {
556            PayrollRunStatus::Posted
557        } else if status_roll < 0.85 {
558            PayrollRunStatus::Approved
559        } else if status_roll < 0.95 {
560            PayrollRunStatus::Calculated
561        } else {
562            PayrollRunStatus::Draft
563        };
564
565        let approved_by = if matches!(
566            status,
567            PayrollRunStatus::Approved | PayrollRunStatus::Posted
568        ) {
569            if !self.employee_ids_pool.is_empty() {
570                let idx = self.rng.gen_range(0..self.employee_ids_pool.len());
571                Some(self.employee_ids_pool[idx].clone())
572            } else {
573                Some(format!("USR-{:04}", self.rng.gen_range(201..=400)))
574            }
575        } else {
576            None
577        };
578
579        let posted_by = if status == PayrollRunStatus::Posted {
580            if !self.employee_ids_pool.is_empty() {
581                let idx = self.rng.gen_range(0..self.employee_ids_pool.len());
582                Some(self.employee_ids_pool[idx].clone())
583            } else {
584                Some(format!("USR-{:04}", self.rng.gen_range(401..=500)))
585            }
586        } else {
587            None
588        };
589
590        let run = PayrollRun {
591            company_code: company_code.to_string(),
592            payroll_id: payroll_id.clone(),
593            pay_period_start: period_start,
594            pay_period_end: period_end,
595            run_date: period_end,
596            status,
597            total_gross,
598            total_deductions,
599            total_net,
600            total_employer_cost,
601            employee_count: employees.len() as u32,
602            currency: currency.to_string(),
603            posted_by,
604            approved_by,
605        };
606
607        (run, line_items)
608    }
609}
610
611#[cfg(test)]
612#[allow(clippy::unwrap_used)]
613mod tests {
614    use super::*;
615
616    fn test_employees() -> Vec<(String, Decimal, Option<String>, Option<String>)> {
617        vec![
618            (
619                "EMP-001".to_string(),
620                Decimal::from(60_000),
621                Some("CC-100".to_string()),
622                Some("Engineering".to_string()),
623            ),
624            (
625                "EMP-002".to_string(),
626                Decimal::from(85_000),
627                Some("CC-200".to_string()),
628                Some("Finance".to_string()),
629            ),
630            (
631                "EMP-003".to_string(),
632                Decimal::from(120_000),
633                None,
634                Some("Sales".to_string()),
635            ),
636        ]
637    }
638
639    #[test]
640    fn test_basic_payroll_generation() {
641        let mut gen = PayrollGenerator::new(42);
642        let employees = test_employees();
643        let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
644        let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
645
646        let (run, items) = gen.generate("C001", &employees, period_start, period_end, "USD");
647
648        assert_eq!(run.company_code, "C001");
649        assert_eq!(run.currency, "USD");
650        assert_eq!(run.employee_count, 3);
651        assert_eq!(items.len(), 3);
652        assert!(run.total_gross > Decimal::ZERO);
653        assert!(run.total_deductions > Decimal::ZERO);
654        assert!(run.total_net > Decimal::ZERO);
655        assert!(run.total_employer_cost > run.total_gross);
656        // net = gross - deductions
657        assert_eq!(run.total_net, run.total_gross - run.total_deductions);
658
659        for item in &items {
660            assert_eq!(item.payroll_id, run.payroll_id);
661            assert!(item.gross_pay > Decimal::ZERO);
662            assert!(item.net_pay > Decimal::ZERO);
663            assert!(item.net_pay < item.gross_pay);
664            assert!(item.base_salary > Decimal::ZERO);
665            assert_eq!(item.pay_date, period_end);
666            // Without country pack, labels should be None
667            assert!(item.tax_withholding_label.is_none());
668            assert!(item.social_security_label.is_none());
669        }
670    }
671
672    #[test]
673    fn test_deterministic_payroll() {
674        let employees = test_employees();
675        let period_start = NaiveDate::from_ymd_opt(2024, 3, 1).unwrap();
676        let period_end = NaiveDate::from_ymd_opt(2024, 3, 31).unwrap();
677
678        let mut gen1 = PayrollGenerator::new(42);
679        let (run1, items1) = gen1.generate("C001", &employees, period_start, period_end, "USD");
680
681        let mut gen2 = PayrollGenerator::new(42);
682        let (run2, items2) = gen2.generate("C001", &employees, period_start, period_end, "USD");
683
684        assert_eq!(run1.payroll_id, run2.payroll_id);
685        assert_eq!(run1.total_gross, run2.total_gross);
686        assert_eq!(run1.total_net, run2.total_net);
687        assert_eq!(run1.status, run2.status);
688        assert_eq!(items1.len(), items2.len());
689        for (a, b) in items1.iter().zip(items2.iter()) {
690            assert_eq!(a.line_id, b.line_id);
691            assert_eq!(a.gross_pay, b.gross_pay);
692            assert_eq!(a.net_pay, b.net_pay);
693        }
694    }
695
696    #[test]
697    fn test_payroll_deduction_components() {
698        let mut gen = PayrollGenerator::new(99);
699        let employees = vec![(
700            "EMP-010".to_string(),
701            Decimal::from(100_000),
702            Some("CC-300".to_string()),
703            Some("HR".to_string()),
704        )];
705        let period_start = NaiveDate::from_ymd_opt(2024, 6, 1).unwrap();
706        let period_end = NaiveDate::from_ymd_opt(2024, 6, 30).unwrap();
707
708        let (_run, items) = gen.generate("C001", &employees, period_start, period_end, "USD");
709        assert_eq!(items.len(), 1);
710
711        let item = &items[0];
712        // base_salary should be approximately 100000/12 = 8333.33
713        let expected_monthly = (Decimal::from(100_000) / Decimal::from(12)).round_dp(2);
714        assert_eq!(item.base_salary, expected_monthly);
715
716        // Deductions should sum correctly
717        let deduction_sum = item.tax_withholding
718            + item.social_security
719            + item.health_insurance
720            + item.retirement_contribution
721            + item.other_deductions;
722        let expected_net = item.gross_pay - deduction_sum;
723        assert_eq!(item.net_pay, expected_net);
724
725        // Tax withholding should be reasonable (22% federal + 5% state = 27% of gross)
726        assert!(item.tax_withholding > Decimal::ZERO);
727        assert!(item.social_security > Decimal::ZERO);
728    }
729
730    // ---------------------------------------------------------------
731    // Country-pack tests
732    // ---------------------------------------------------------------
733
734    /// Helper: build a US-like country pack with explicit statutory deductions.
735    fn us_country_pack() -> CountryPack {
736        use datasynth_core::country::schema::{PayrollCountryConfig, PayrollDeduction};
737        CountryPack {
738            country_code: "US".to_string(),
739            payroll: PayrollCountryConfig {
740                statutory_deductions: vec![
741                    PayrollDeduction {
742                        code: "FICA".to_string(),
743                        name_en: "Federal Insurance Contributions Act".to_string(),
744                        deduction_type: "percentage".to_string(),
745                        rate: 0.0765,
746                        ..Default::default()
747                    },
748                    PayrollDeduction {
749                        code: "FIT".to_string(),
750                        name_en: "Federal Income Tax".to_string(),
751                        deduction_type: "progressive".to_string(),
752                        rate: 0.0, // progressive placeholder
753                        ..Default::default()
754                    },
755                    PayrollDeduction {
756                        code: "SIT".to_string(),
757                        name_en: "State Income Tax".to_string(),
758                        deduction_type: "percentage".to_string(),
759                        rate: 0.05,
760                        ..Default::default()
761                    },
762                ],
763                ..Default::default()
764            },
765            ..Default::default()
766        }
767    }
768
769    /// Helper: build a DE-like country pack.
770    fn de_country_pack() -> CountryPack {
771        use datasynth_core::country::schema::{PayrollCountryConfig, PayrollDeduction};
772        CountryPack {
773            country_code: "DE".to_string(),
774            payroll: PayrollCountryConfig {
775                pay_frequency: "monthly".to_string(),
776                currency: "EUR".to_string(),
777                statutory_deductions: vec![
778                    PayrollDeduction {
779                        code: "LOHNST".to_string(),
780                        name_en: "Income Tax".to_string(),
781                        type_field: "progressive".to_string(),
782                        rate: 0.0, // progressive placeholder
783                        ..Default::default()
784                    },
785                    PayrollDeduction {
786                        code: "SOLI".to_string(),
787                        name_en: "Solidarity Surcharge".to_string(),
788                        type_field: "percentage".to_string(),
789                        rate: 0.055,
790                        ..Default::default()
791                    },
792                    PayrollDeduction {
793                        code: "KiSt".to_string(),
794                        name_en: "Church Tax".to_string(),
795                        type_field: "percentage".to_string(),
796                        rate: 0.08,
797                        optional: true,
798                        ..Default::default()
799                    },
800                    PayrollDeduction {
801                        code: "RV".to_string(),
802                        name_en: "Pension Insurance".to_string(),
803                        type_field: "percentage".to_string(),
804                        rate: 0.093,
805                        ..Default::default()
806                    },
807                    PayrollDeduction {
808                        code: "KV".to_string(),
809                        name_en: "Health Insurance".to_string(),
810                        type_field: "percentage".to_string(),
811                        rate: 0.073,
812                        ..Default::default()
813                    },
814                    PayrollDeduction {
815                        code: "AV".to_string(),
816                        name_en: "Unemployment Insurance".to_string(),
817                        type_field: "percentage".to_string(),
818                        rate: 0.013,
819                        ..Default::default()
820                    },
821                    PayrollDeduction {
822                        code: "PV".to_string(),
823                        name_en: "Long-Term Care Insurance".to_string(),
824                        type_field: "percentage".to_string(),
825                        rate: 0.017,
826                        ..Default::default()
827                    },
828                ],
829                employer_contributions: vec![
830                    PayrollDeduction {
831                        code: "AG-RV".to_string(),
832                        name_en: "Employer Pension Insurance".to_string(),
833                        type_field: "percentage".to_string(),
834                        rate: 0.093,
835                        ..Default::default()
836                    },
837                    PayrollDeduction {
838                        code: "AG-KV".to_string(),
839                        name_en: "Employer Health Insurance".to_string(),
840                        type_field: "percentage".to_string(),
841                        rate: 0.073,
842                        ..Default::default()
843                    },
844                ],
845                ..Default::default()
846            },
847            ..Default::default()
848        }
849    }
850
851    #[test]
852    fn test_generate_with_us_country_pack() {
853        let mut gen = PayrollGenerator::new(42);
854        let employees = test_employees();
855        let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
856        let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
857        let pack = us_country_pack();
858
859        let (run, items) = gen.generate_with_country_pack(
860            "C001",
861            &employees,
862            period_start,
863            period_end,
864            "USD",
865            &pack,
866        );
867
868        assert_eq!(run.company_code, "C001");
869        assert_eq!(run.employee_count, 3);
870        assert_eq!(items.len(), 3);
871        assert_eq!(run.total_net, run.total_gross - run.total_deductions);
872
873        for item in &items {
874            assert!(item.gross_pay > Decimal::ZERO);
875            assert!(item.net_pay > Decimal::ZERO);
876            assert!(item.net_pay < item.gross_pay);
877            // FICA deduction should be present
878            assert!(item.social_security > Decimal::ZERO);
879            // US pack should produce labels
880            assert!(item.tax_withholding_label.is_some());
881            assert!(item.social_security_label.is_some());
882        }
883    }
884
885    #[test]
886    fn test_generate_with_de_country_pack() {
887        let mut gen = PayrollGenerator::new(42);
888        let employees = test_employees();
889        let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
890        let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
891        let pack = de_country_pack();
892
893        let (run, items) = gen.generate_with_country_pack(
894            "DE01",
895            &employees,
896            period_start,
897            period_end,
898            "EUR",
899            &pack,
900        );
901
902        assert_eq!(run.company_code, "DE01");
903        assert_eq!(items.len(), 3);
904        assert_eq!(run.total_net, run.total_gross - run.total_deductions);
905
906        // DE pack should use pension rate 0.093 for retirement
907        // and health insurance rate 0.073
908        let rates = gen.rates_from_country_pack(&pack);
909        assert_eq!(
910            rates.retirement_rate,
911            Decimal::from_f64_retain(0.093).unwrap()
912        );
913        assert_eq!(rates.health_rate, Decimal::from_f64_retain(0.073).unwrap());
914
915        // Check DE labels are populated
916        let item = &items[0];
917        assert_eq!(
918            item.health_insurance_label.as_deref(),
919            Some("Health Insurance")
920        );
921        assert_eq!(
922            item.retirement_contribution_label.as_deref(),
923            Some("Pension Insurance")
924        );
925        // Employer contribution labels should include both AG-RV and AG-KV
926        assert!(item.employer_contribution_label.is_some());
927        let ec = item.employer_contribution_label.as_ref().unwrap();
928        assert!(ec.contains("Employer Pension Insurance"));
929        assert!(ec.contains("Employer Health Insurance"));
930    }
931
932    #[test]
933    fn test_country_pack_falls_back_to_config_for_missing_categories() {
934        // Empty pack: no statutory deductions => all rates fall back to config
935        let pack = CountryPack::default();
936        let gen = PayrollGenerator::new(42);
937        let rates_pack = gen.rates_from_country_pack(&pack);
938        let rates_cfg = gen.rates_from_config();
939
940        assert_eq!(rates_pack.income_tax_rate, rates_cfg.income_tax_rate);
941        assert_eq!(rates_pack.fica_rate, rates_cfg.fica_rate);
942        assert_eq!(rates_pack.health_rate, rates_cfg.health_rate);
943        assert_eq!(rates_pack.retirement_rate, rates_cfg.retirement_rate);
944    }
945
946    #[test]
947    fn test_country_pack_deterministic() {
948        let employees = test_employees();
949        let period_start = NaiveDate::from_ymd_opt(2024, 3, 1).unwrap();
950        let period_end = NaiveDate::from_ymd_opt(2024, 3, 31).unwrap();
951        let pack = de_country_pack();
952
953        let mut gen1 = PayrollGenerator::new(42);
954        let (run1, items1) = gen1.generate_with_country_pack(
955            "DE01",
956            &employees,
957            period_start,
958            period_end,
959            "EUR",
960            &pack,
961        );
962
963        let mut gen2 = PayrollGenerator::new(42);
964        let (run2, items2) = gen2.generate_with_country_pack(
965            "DE01",
966            &employees,
967            period_start,
968            period_end,
969            "EUR",
970            &pack,
971        );
972
973        assert_eq!(run1.payroll_id, run2.payroll_id);
974        assert_eq!(run1.total_gross, run2.total_gross);
975        assert_eq!(run1.total_net, run2.total_net);
976        for (a, b) in items1.iter().zip(items2.iter()) {
977            assert_eq!(a.net_pay, b.net_pay);
978        }
979    }
980
981    #[test]
982    fn test_de_rates_differ_from_default() {
983        // With the DE pack, the resolved rates should differ from config defaults
984        let gen = PayrollGenerator::new(42);
985        let pack = de_country_pack();
986        let rates_cfg = gen.rates_from_config();
987        let rates_de = gen.rates_from_country_pack(&pack);
988
989        // DE has no non-progressive income tax in pack -> income_tax_rate falls
990        // back to config default for federal part.
991        // But health (0.073 vs 0.03) and retirement (0.093 vs 0.05) should differ.
992        assert_ne!(rates_de.health_rate, rates_cfg.health_rate);
993        assert_ne!(rates_de.retirement_rate, rates_cfg.retirement_rate);
994    }
995
996    #[test]
997    fn test_set_country_pack_uses_labels() {
998        let mut gen = PayrollGenerator::new(42);
999        let pack = de_country_pack();
1000        gen.set_country_pack(pack);
1001
1002        let employees = test_employees();
1003        let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
1004        let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
1005
1006        // generate() should now use the stored pack for rates + labels
1007        let (_run, items) = gen.generate("DE01", &employees, period_start, period_end, "EUR");
1008
1009        let item = &items[0];
1010        // Labels should be populated from the DE pack
1011        assert!(item.tax_withholding_label.is_some());
1012        assert!(item.health_insurance_label.is_some());
1013        assert!(item.retirement_contribution_label.is_some());
1014        assert!(item.employer_contribution_label.is_some());
1015    }
1016
1017    #[test]
1018    fn test_empty_pack_labels_are_none() {
1019        let pack = CountryPack::default();
1020        let labels = PayrollGenerator::labels_from_country_pack(&pack);
1021        assert!(labels.tax_withholding.is_none());
1022        assert!(labels.social_security.is_none());
1023        assert!(labels.health_insurance.is_none());
1024        assert!(labels.retirement_contribution.is_none());
1025        assert!(labels.employer_contribution.is_none());
1026    }
1027
1028    #[test]
1029    fn test_us_pack_labels() {
1030        let pack = us_country_pack();
1031        let labels = PayrollGenerator::labels_from_country_pack(&pack);
1032        // FIT is a progressive placeholder but label is still captured
1033        assert!(labels.tax_withholding.is_some());
1034        let tw = labels.tax_withholding.unwrap();
1035        assert!(tw.contains("Federal Income Tax"));
1036        assert!(tw.contains("State Income Tax"));
1037        // FICA label
1038        assert!(labels.social_security.is_some());
1039        assert!(labels
1040            .social_security
1041            .unwrap()
1042            .contains("Federal Insurance Contributions Act"));
1043    }
1044}