Skip to main content

datasynth_generators/hr/
payroll_generator.rs

1//! Payroll generator for the Hire-to-Retire (H2R) process.
2//!
3//! Generates payroll runs with individual employee line items, computing
4//! gross pay (base salary + overtime + bonus), deductions (tax, social security,
5//! health insurance, retirement), and net pay.
6
7use chrono::NaiveDate;
8use datasynth_config::schema::PayrollConfig;
9use datasynth_core::models::{PayrollLineItem, PayrollRun, PayrollRunStatus};
10use datasynth_core::uuid_factory::{DeterministicUuidFactory, GeneratorType};
11use rand::prelude::*;
12use rand_chacha::ChaCha8Rng;
13use rust_decimal::Decimal;
14
15/// Generates [`PayrollRun`] and [`PayrollLineItem`] records from employee data.
16pub struct PayrollGenerator {
17    rng: ChaCha8Rng,
18    uuid_factory: DeterministicUuidFactory,
19    line_uuid_factory: DeterministicUuidFactory,
20    config: PayrollConfig,
21}
22
23impl PayrollGenerator {
24    /// Create a new payroll generator with default configuration.
25    pub fn new(seed: u64) -> Self {
26        Self {
27            rng: ChaCha8Rng::seed_from_u64(seed),
28            uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::PayrollRun),
29            line_uuid_factory: DeterministicUuidFactory::with_sub_discriminator(
30                seed,
31                GeneratorType::PayrollRun,
32                1,
33            ),
34            config: PayrollConfig::default(),
35        }
36    }
37
38    /// Create a payroll generator with custom configuration.
39    pub fn with_config(seed: u64, config: PayrollConfig) -> Self {
40        Self {
41            rng: ChaCha8Rng::seed_from_u64(seed),
42            uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::PayrollRun),
43            line_uuid_factory: DeterministicUuidFactory::with_sub_discriminator(
44                seed,
45                GeneratorType::PayrollRun,
46                1,
47            ),
48            config,
49        }
50    }
51
52    /// Generate a payroll run and line items for the given employees and period.
53    ///
54    /// # Arguments
55    ///
56    /// * `company_code` - Company code owning the payroll
57    /// * `employees` - Slice of (employee_id, base_salary, cost_center, department)
58    /// * `period_start` - Start of the pay period (inclusive)
59    /// * `period_end` - End of the pay period (inclusive)
60    /// * `currency` - ISO 4217 currency code
61    pub fn generate(
62        &mut self,
63        company_code: &str,
64        employees: &[(String, Decimal, Option<String>, Option<String>)],
65        period_start: NaiveDate,
66        period_end: NaiveDate,
67        currency: &str,
68    ) -> (PayrollRun, Vec<PayrollLineItem>) {
69        let payroll_id = self.uuid_factory.next().to_string();
70
71        let mut line_items = Vec::with_capacity(employees.len());
72        let mut total_gross = Decimal::ZERO;
73        let mut total_deductions = Decimal::ZERO;
74        let mut total_net = Decimal::ZERO;
75        let mut total_employer_cost = Decimal::ZERO;
76
77        // Tax rates from config
78        let federal_rate = Decimal::from_f64_retain(self.config.tax_rates.federal_effective)
79            .unwrap_or(Decimal::ZERO);
80        let state_rate = Decimal::from_f64_retain(self.config.tax_rates.state_effective)
81            .unwrap_or(Decimal::ZERO);
82        let fica_rate =
83            Decimal::from_f64_retain(self.config.tax_rates.fica).unwrap_or(Decimal::ZERO);
84
85        // Combined income tax rate (federal + state)
86        let income_tax_rate = federal_rate + state_rate;
87
88        // Standard deduction rates for health and retirement
89        let health_rate = Decimal::from_f64_retain(0.03).unwrap_or(Decimal::ZERO);
90        let retirement_rate = Decimal::from_f64_retain(0.05).unwrap_or(Decimal::ZERO);
91
92        let benefits_enrolled = self.config.benefits_enrollment_rate;
93        let retirement_participating = self.config.retirement_participation_rate;
94
95        for (employee_id, base_salary, cost_center, department) in employees {
96            let line_id = self.line_uuid_factory.next().to_string();
97
98            // Monthly base component (annual salary / 12)
99            let monthly_base = (*base_salary / Decimal::from(12)).round_dp(2);
100
101            // Overtime: 10% chance, 1-20 hours at 1.5x hourly rate
102            let (overtime_pay, overtime_hours) = if self.rng.gen_bool(0.10) {
103                let ot_hours = self.rng.gen_range(1.0..=20.0);
104                // Hourly rate = annual salary / (52 weeks * 40 hours)
105                let hourly_rate = *base_salary / Decimal::from(2080);
106                let ot_rate = hourly_rate * Decimal::from_f64_retain(1.5).unwrap_or(Decimal::ONE);
107                let ot_pay = (ot_rate
108                    * Decimal::from_f64_retain(ot_hours).unwrap_or(Decimal::ZERO))
109                .round_dp(2);
110                (ot_pay, ot_hours)
111            } else {
112                (Decimal::ZERO, 0.0)
113            };
114
115            // Bonus: 5% chance for a monthly bonus (1-10% of monthly base)
116            let bonus = if self.rng.gen_bool(0.05) {
117                let pct = self.rng.gen_range(0.01..=0.10);
118                (monthly_base * Decimal::from_f64_retain(pct).unwrap_or(Decimal::ZERO)).round_dp(2)
119            } else {
120                Decimal::ZERO
121            };
122
123            let gross_pay = monthly_base + overtime_pay + bonus;
124
125            // Deductions
126            let tax_withholding = (gross_pay * income_tax_rate).round_dp(2);
127            let social_security = (gross_pay * fica_rate).round_dp(2);
128
129            let health_insurance = if self.rng.gen_bool(benefits_enrolled) {
130                (gross_pay * health_rate).round_dp(2)
131            } else {
132                Decimal::ZERO
133            };
134
135            let retirement_contribution = if self.rng.gen_bool(retirement_participating) {
136                (gross_pay * retirement_rate).round_dp(2)
137            } else {
138                Decimal::ZERO
139            };
140
141            // Small random other deductions (garnishments, etc.): ~3% chance
142            let other_deductions = if self.rng.gen_bool(0.03) {
143                let raw = self.rng.gen_range(50.0..=500.0);
144                Decimal::from_f64_retain(raw)
145                    .unwrap_or(Decimal::ZERO)
146                    .round_dp(2)
147            } else {
148                Decimal::ZERO
149            };
150
151            let total_ded = tax_withholding
152                + social_security
153                + health_insurance
154                + retirement_contribution
155                + other_deductions;
156            let net_pay = gross_pay - total_ded;
157
158            // Standard 160 regular hours per month (8h * 20 business days)
159            let hours_worked = 160.0;
160
161            // Employer-side cost: gross + employer FICA match
162            let employer_fica = (gross_pay * fica_rate).round_dp(2);
163            let employer_cost = gross_pay + employer_fica;
164
165            total_gross += gross_pay;
166            total_deductions += total_ded;
167            total_net += net_pay;
168            total_employer_cost += employer_cost;
169
170            line_items.push(PayrollLineItem {
171                payroll_id: payroll_id.clone(),
172                employee_id: employee_id.clone(),
173                line_id,
174                gross_pay,
175                base_salary: monthly_base,
176                overtime_pay,
177                bonus,
178                tax_withholding,
179                social_security,
180                health_insurance,
181                retirement_contribution,
182                other_deductions,
183                net_pay,
184                hours_worked,
185                overtime_hours,
186                pay_date: period_end,
187                cost_center: cost_center.clone(),
188                department: department.clone(),
189            });
190        }
191
192        // Determine status
193        let status_roll: f64 = self.rng.gen();
194        let status = if status_roll < 0.60 {
195            PayrollRunStatus::Posted
196        } else if status_roll < 0.85 {
197            PayrollRunStatus::Approved
198        } else if status_roll < 0.95 {
199            PayrollRunStatus::Calculated
200        } else {
201            PayrollRunStatus::Draft
202        };
203
204        let approved_by = if matches!(
205            status,
206            PayrollRunStatus::Approved | PayrollRunStatus::Posted
207        ) {
208            Some(format!("USR-{:04}", self.rng.gen_range(201..=400)))
209        } else {
210            None
211        };
212
213        let posted_by = if status == PayrollRunStatus::Posted {
214            Some(format!("USR-{:04}", self.rng.gen_range(401..=500)))
215        } else {
216            None
217        };
218
219        let run = PayrollRun {
220            company_code: company_code.to_string(),
221            payroll_id: payroll_id.clone(),
222            pay_period_start: period_start,
223            pay_period_end: period_end,
224            run_date: period_end,
225            status,
226            total_gross,
227            total_deductions,
228            total_net,
229            total_employer_cost,
230            employee_count: employees.len() as u32,
231            currency: currency.to_string(),
232            posted_by,
233            approved_by,
234        };
235
236        (run, line_items)
237    }
238}
239
240#[cfg(test)]
241#[allow(clippy::unwrap_used)]
242mod tests {
243    use super::*;
244
245    fn test_employees() -> Vec<(String, Decimal, Option<String>, Option<String>)> {
246        vec![
247            (
248                "EMP-001".to_string(),
249                Decimal::from(60_000),
250                Some("CC-100".to_string()),
251                Some("Engineering".to_string()),
252            ),
253            (
254                "EMP-002".to_string(),
255                Decimal::from(85_000),
256                Some("CC-200".to_string()),
257                Some("Finance".to_string()),
258            ),
259            (
260                "EMP-003".to_string(),
261                Decimal::from(120_000),
262                None,
263                Some("Sales".to_string()),
264            ),
265        ]
266    }
267
268    #[test]
269    fn test_basic_payroll_generation() {
270        let mut gen = PayrollGenerator::new(42);
271        let employees = test_employees();
272        let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
273        let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
274
275        let (run, items) = gen.generate("C001", &employees, period_start, period_end, "USD");
276
277        assert_eq!(run.company_code, "C001");
278        assert_eq!(run.currency, "USD");
279        assert_eq!(run.employee_count, 3);
280        assert_eq!(items.len(), 3);
281        assert!(run.total_gross > Decimal::ZERO);
282        assert!(run.total_deductions > Decimal::ZERO);
283        assert!(run.total_net > Decimal::ZERO);
284        assert!(run.total_employer_cost > run.total_gross);
285        // net = gross - deductions
286        assert_eq!(run.total_net, run.total_gross - run.total_deductions);
287
288        for item in &items {
289            assert_eq!(item.payroll_id, run.payroll_id);
290            assert!(item.gross_pay > Decimal::ZERO);
291            assert!(item.net_pay > Decimal::ZERO);
292            assert!(item.net_pay < item.gross_pay);
293            assert!(item.base_salary > Decimal::ZERO);
294            assert_eq!(item.pay_date, period_end);
295        }
296    }
297
298    #[test]
299    fn test_deterministic_payroll() {
300        let employees = test_employees();
301        let period_start = NaiveDate::from_ymd_opt(2024, 3, 1).unwrap();
302        let period_end = NaiveDate::from_ymd_opt(2024, 3, 31).unwrap();
303
304        let mut gen1 = PayrollGenerator::new(42);
305        let (run1, items1) = gen1.generate("C001", &employees, period_start, period_end, "USD");
306
307        let mut gen2 = PayrollGenerator::new(42);
308        let (run2, items2) = gen2.generate("C001", &employees, period_start, period_end, "USD");
309
310        assert_eq!(run1.payroll_id, run2.payroll_id);
311        assert_eq!(run1.total_gross, run2.total_gross);
312        assert_eq!(run1.total_net, run2.total_net);
313        assert_eq!(run1.status, run2.status);
314        assert_eq!(items1.len(), items2.len());
315        for (a, b) in items1.iter().zip(items2.iter()) {
316            assert_eq!(a.line_id, b.line_id);
317            assert_eq!(a.gross_pay, b.gross_pay);
318            assert_eq!(a.net_pay, b.net_pay);
319        }
320    }
321
322    #[test]
323    fn test_payroll_deduction_components() {
324        let mut gen = PayrollGenerator::new(99);
325        let employees = vec![(
326            "EMP-010".to_string(),
327            Decimal::from(100_000),
328            Some("CC-300".to_string()),
329            Some("HR".to_string()),
330        )];
331        let period_start = NaiveDate::from_ymd_opt(2024, 6, 1).unwrap();
332        let period_end = NaiveDate::from_ymd_opt(2024, 6, 30).unwrap();
333
334        let (_run, items) = gen.generate("C001", &employees, period_start, period_end, "USD");
335        assert_eq!(items.len(), 1);
336
337        let item = &items[0];
338        // base_salary should be approximately 100000/12 = 8333.33
339        let expected_monthly = (Decimal::from(100_000) / Decimal::from(12)).round_dp(2);
340        assert_eq!(item.base_salary, expected_monthly);
341
342        // Deductions should sum correctly
343        let deduction_sum = item.tax_withholding
344            + item.social_security
345            + item.health_insurance
346            + item.retirement_contribution
347            + item.other_deductions;
348        let expected_net = item.gross_pay - deduction_sum;
349        assert_eq!(item.net_pay, expected_net);
350
351        // Tax withholding should be reasonable (22% federal + 5% state = 27% of gross)
352        assert!(item.tax_withholding > Decimal::ZERO);
353        assert!(item.social_security > Decimal::ZERO);
354    }
355}