1use chrono::NaiveDate;
8use datasynth_config::schema::PayrollConfig;
9use datasynth_core::models::{PayrollLineItem, PayrollRun, PayrollRunStatus};
10use datasynth_core::uuid_factory::{DeterministicUuidFactory, GeneratorType};
11use rand::prelude::*;
12use rand_chacha::ChaCha8Rng;
13use rust_decimal::Decimal;
14
15pub struct PayrollGenerator {
17 rng: ChaCha8Rng,
18 uuid_factory: DeterministicUuidFactory,
19 line_uuid_factory: DeterministicUuidFactory,
20 config: PayrollConfig,
21}
22
23impl PayrollGenerator {
24 pub fn new(seed: u64) -> Self {
26 Self {
27 rng: ChaCha8Rng::seed_from_u64(seed),
28 uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::PayrollRun),
29 line_uuid_factory: DeterministicUuidFactory::with_sub_discriminator(
30 seed,
31 GeneratorType::PayrollRun,
32 1,
33 ),
34 config: PayrollConfig::default(),
35 }
36 }
37
38 pub fn with_config(seed: u64, config: PayrollConfig) -> Self {
40 Self {
41 rng: ChaCha8Rng::seed_from_u64(seed),
42 uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::PayrollRun),
43 line_uuid_factory: DeterministicUuidFactory::with_sub_discriminator(
44 seed,
45 GeneratorType::PayrollRun,
46 1,
47 ),
48 config,
49 }
50 }
51
52 pub fn generate(
62 &mut self,
63 company_code: &str,
64 employees: &[(String, Decimal, Option<String>, Option<String>)],
65 period_start: NaiveDate,
66 period_end: NaiveDate,
67 currency: &str,
68 ) -> (PayrollRun, Vec<PayrollLineItem>) {
69 let payroll_id = self.uuid_factory.next().to_string();
70
71 let mut line_items = Vec::with_capacity(employees.len());
72 let mut total_gross = Decimal::ZERO;
73 let mut total_deductions = Decimal::ZERO;
74 let mut total_net = Decimal::ZERO;
75 let mut total_employer_cost = Decimal::ZERO;
76
77 let federal_rate = Decimal::from_f64_retain(self.config.tax_rates.federal_effective)
79 .unwrap_or(Decimal::ZERO);
80 let state_rate = Decimal::from_f64_retain(self.config.tax_rates.state_effective)
81 .unwrap_or(Decimal::ZERO);
82 let fica_rate =
83 Decimal::from_f64_retain(self.config.tax_rates.fica).unwrap_or(Decimal::ZERO);
84
85 let income_tax_rate = federal_rate + state_rate;
87
88 let health_rate = Decimal::from_f64_retain(0.03).unwrap_or(Decimal::ZERO);
90 let retirement_rate = Decimal::from_f64_retain(0.05).unwrap_or(Decimal::ZERO);
91
92 let benefits_enrolled = self.config.benefits_enrollment_rate;
93 let retirement_participating = self.config.retirement_participation_rate;
94
95 for (employee_id, base_salary, cost_center, department) in employees {
96 let line_id = self.line_uuid_factory.next().to_string();
97
98 let monthly_base = (*base_salary / Decimal::from(12)).round_dp(2);
100
101 let (overtime_pay, overtime_hours) = if self.rng.gen_bool(0.10) {
103 let ot_hours = self.rng.gen_range(1.0..=20.0);
104 let hourly_rate = *base_salary / Decimal::from(2080);
106 let ot_rate = hourly_rate * Decimal::from_f64_retain(1.5).unwrap_or(Decimal::ONE);
107 let ot_pay = (ot_rate
108 * Decimal::from_f64_retain(ot_hours).unwrap_or(Decimal::ZERO))
109 .round_dp(2);
110 (ot_pay, ot_hours)
111 } else {
112 (Decimal::ZERO, 0.0)
113 };
114
115 let bonus = if self.rng.gen_bool(0.05) {
117 let pct = self.rng.gen_range(0.01..=0.10);
118 (monthly_base * Decimal::from_f64_retain(pct).unwrap_or(Decimal::ZERO)).round_dp(2)
119 } else {
120 Decimal::ZERO
121 };
122
123 let gross_pay = monthly_base + overtime_pay + bonus;
124
125 let tax_withholding = (gross_pay * income_tax_rate).round_dp(2);
127 let social_security = (gross_pay * fica_rate).round_dp(2);
128
129 let health_insurance = if self.rng.gen_bool(benefits_enrolled) {
130 (gross_pay * health_rate).round_dp(2)
131 } else {
132 Decimal::ZERO
133 };
134
135 let retirement_contribution = if self.rng.gen_bool(retirement_participating) {
136 (gross_pay * retirement_rate).round_dp(2)
137 } else {
138 Decimal::ZERO
139 };
140
141 let other_deductions = if self.rng.gen_bool(0.03) {
143 let raw = self.rng.gen_range(50.0..=500.0);
144 Decimal::from_f64_retain(raw)
145 .unwrap_or(Decimal::ZERO)
146 .round_dp(2)
147 } else {
148 Decimal::ZERO
149 };
150
151 let total_ded = tax_withholding
152 + social_security
153 + health_insurance
154 + retirement_contribution
155 + other_deductions;
156 let net_pay = gross_pay - total_ded;
157
158 let hours_worked = 160.0;
160
161 let employer_fica = (gross_pay * fica_rate).round_dp(2);
163 let employer_cost = gross_pay + employer_fica;
164
165 total_gross += gross_pay;
166 total_deductions += total_ded;
167 total_net += net_pay;
168 total_employer_cost += employer_cost;
169
170 line_items.push(PayrollLineItem {
171 payroll_id: payroll_id.clone(),
172 employee_id: employee_id.clone(),
173 line_id,
174 gross_pay,
175 base_salary: monthly_base,
176 overtime_pay,
177 bonus,
178 tax_withholding,
179 social_security,
180 health_insurance,
181 retirement_contribution,
182 other_deductions,
183 net_pay,
184 hours_worked,
185 overtime_hours,
186 pay_date: period_end,
187 cost_center: cost_center.clone(),
188 department: department.clone(),
189 });
190 }
191
192 let status_roll: f64 = self.rng.gen();
194 let status = if status_roll < 0.60 {
195 PayrollRunStatus::Posted
196 } else if status_roll < 0.85 {
197 PayrollRunStatus::Approved
198 } else if status_roll < 0.95 {
199 PayrollRunStatus::Calculated
200 } else {
201 PayrollRunStatus::Draft
202 };
203
204 let approved_by = if matches!(
205 status,
206 PayrollRunStatus::Approved | PayrollRunStatus::Posted
207 ) {
208 Some(format!("USR-{:04}", self.rng.gen_range(201..=400)))
209 } else {
210 None
211 };
212
213 let posted_by = if status == PayrollRunStatus::Posted {
214 Some(format!("USR-{:04}", self.rng.gen_range(401..=500)))
215 } else {
216 None
217 };
218
219 let run = PayrollRun {
220 company_code: company_code.to_string(),
221 payroll_id: payroll_id.clone(),
222 pay_period_start: period_start,
223 pay_period_end: period_end,
224 run_date: period_end,
225 status,
226 total_gross,
227 total_deductions,
228 total_net,
229 total_employer_cost,
230 employee_count: employees.len() as u32,
231 currency: currency.to_string(),
232 posted_by,
233 approved_by,
234 };
235
236 (run, line_items)
237 }
238}
239
240#[cfg(test)]
241#[allow(clippy::unwrap_used)]
242mod tests {
243 use super::*;
244
245 fn test_employees() -> Vec<(String, Decimal, Option<String>, Option<String>)> {
246 vec![
247 (
248 "EMP-001".to_string(),
249 Decimal::from(60_000),
250 Some("CC-100".to_string()),
251 Some("Engineering".to_string()),
252 ),
253 (
254 "EMP-002".to_string(),
255 Decimal::from(85_000),
256 Some("CC-200".to_string()),
257 Some("Finance".to_string()),
258 ),
259 (
260 "EMP-003".to_string(),
261 Decimal::from(120_000),
262 None,
263 Some("Sales".to_string()),
264 ),
265 ]
266 }
267
268 #[test]
269 fn test_basic_payroll_generation() {
270 let mut gen = PayrollGenerator::new(42);
271 let employees = test_employees();
272 let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
273 let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
274
275 let (run, items) = gen.generate("C001", &employees, period_start, period_end, "USD");
276
277 assert_eq!(run.company_code, "C001");
278 assert_eq!(run.currency, "USD");
279 assert_eq!(run.employee_count, 3);
280 assert_eq!(items.len(), 3);
281 assert!(run.total_gross > Decimal::ZERO);
282 assert!(run.total_deductions > Decimal::ZERO);
283 assert!(run.total_net > Decimal::ZERO);
284 assert!(run.total_employer_cost > run.total_gross);
285 assert_eq!(run.total_net, run.total_gross - run.total_deductions);
287
288 for item in &items {
289 assert_eq!(item.payroll_id, run.payroll_id);
290 assert!(item.gross_pay > Decimal::ZERO);
291 assert!(item.net_pay > Decimal::ZERO);
292 assert!(item.net_pay < item.gross_pay);
293 assert!(item.base_salary > Decimal::ZERO);
294 assert_eq!(item.pay_date, period_end);
295 }
296 }
297
298 #[test]
299 fn test_deterministic_payroll() {
300 let employees = test_employees();
301 let period_start = NaiveDate::from_ymd_opt(2024, 3, 1).unwrap();
302 let period_end = NaiveDate::from_ymd_opt(2024, 3, 31).unwrap();
303
304 let mut gen1 = PayrollGenerator::new(42);
305 let (run1, items1) = gen1.generate("C001", &employees, period_start, period_end, "USD");
306
307 let mut gen2 = PayrollGenerator::new(42);
308 let (run2, items2) = gen2.generate("C001", &employees, period_start, period_end, "USD");
309
310 assert_eq!(run1.payroll_id, run2.payroll_id);
311 assert_eq!(run1.total_gross, run2.total_gross);
312 assert_eq!(run1.total_net, run2.total_net);
313 assert_eq!(run1.status, run2.status);
314 assert_eq!(items1.len(), items2.len());
315 for (a, b) in items1.iter().zip(items2.iter()) {
316 assert_eq!(a.line_id, b.line_id);
317 assert_eq!(a.gross_pay, b.gross_pay);
318 assert_eq!(a.net_pay, b.net_pay);
319 }
320 }
321
322 #[test]
323 fn test_payroll_deduction_components() {
324 let mut gen = PayrollGenerator::new(99);
325 let employees = vec![(
326 "EMP-010".to_string(),
327 Decimal::from(100_000),
328 Some("CC-300".to_string()),
329 Some("HR".to_string()),
330 )];
331 let period_start = NaiveDate::from_ymd_opt(2024, 6, 1).unwrap();
332 let period_end = NaiveDate::from_ymd_opt(2024, 6, 30).unwrap();
333
334 let (_run, items) = gen.generate("C001", &employees, period_start, period_end, "USD");
335 assert_eq!(items.len(), 1);
336
337 let item = &items[0];
338 let expected_monthly = (Decimal::from(100_000) / Decimal::from(12)).round_dp(2);
340 assert_eq!(item.base_salary, expected_monthly);
341
342 let deduction_sum = item.tax_withholding
344 + item.social_security
345 + item.health_insurance
346 + item.retirement_contribution
347 + item.other_deductions;
348 let expected_net = item.gross_pay - deduction_sum;
349 assert_eq!(item.net_pay, expected_net);
350
351 assert!(item.tax_withholding > Decimal::ZERO);
353 assert!(item.social_security > Decimal::ZERO);
354 }
355}