Skip to main content

datasynth_generators/hr/
expense_report_generator.rs

1//! Expense report generator for the Hire-to-Retire (H2R) process.
2//!
3//! Generates employee expense reports with realistic line items across categories
4//! (travel, meals, lodging, transportation, etc.), policy violation detection,
5//! and approval workflow statuses.
6
7use chrono::{Datelike, NaiveDate};
8use datasynth_config::schema::ExpenseConfig;
9use datasynth_core::models::{ExpenseCategory, ExpenseLineItem, ExpenseReport, ExpenseStatus};
10use datasynth_core::uuid_factory::{DeterministicUuidFactory, GeneratorType};
11use rand::prelude::*;
12use rand_chacha::ChaCha8Rng;
13use rust_decimal::Decimal;
14
15/// Generates [`ExpenseReport`] records for employees over a period.
16pub struct ExpenseReportGenerator {
17    rng: ChaCha8Rng,
18    uuid_factory: DeterministicUuidFactory,
19    item_uuid_factory: DeterministicUuidFactory,
20    config: ExpenseConfig,
21}
22
23impl ExpenseReportGenerator {
24    /// Create a new expense report generator with default configuration.
25    pub fn new(seed: u64) -> Self {
26        Self {
27            rng: ChaCha8Rng::seed_from_u64(seed),
28            uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::ExpenseReport),
29            item_uuid_factory: DeterministicUuidFactory::with_sub_discriminator(
30                seed,
31                GeneratorType::ExpenseReport,
32                1,
33            ),
34            config: ExpenseConfig::default(),
35        }
36    }
37
38    /// Create an expense report generator with custom configuration.
39    pub fn with_config(seed: u64, config: ExpenseConfig) -> Self {
40        Self {
41            rng: ChaCha8Rng::seed_from_u64(seed),
42            uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::ExpenseReport),
43            item_uuid_factory: DeterministicUuidFactory::with_sub_discriminator(
44                seed,
45                GeneratorType::ExpenseReport,
46                1,
47            ),
48            config,
49        }
50    }
51
52    /// Generate expense reports for employees over the given period.
53    ///
54    /// Only `config.submission_rate` fraction of employees submit reports each
55    /// month within the period.
56    ///
57    /// # Arguments
58    ///
59    /// * `employee_ids` - Slice of employee identifiers
60    /// * `period_start` - Start of the period (inclusive)
61    /// * `period_end` - End of the period (inclusive)
62    /// * `config` - Expense management configuration
63    pub fn generate(
64        &mut self,
65        employee_ids: &[String],
66        period_start: NaiveDate,
67        period_end: NaiveDate,
68        config: &ExpenseConfig,
69    ) -> Vec<ExpenseReport> {
70        let mut reports = Vec::new();
71
72        // Iterate over each month in the period
73        let mut current_month_start = period_start;
74        while current_month_start <= period_end {
75            let month_end = self.month_end(current_month_start).min(period_end);
76
77            for employee_id in employee_ids {
78                // Only submission_rate fraction of employees submit per month
79                if self.rng.gen_bool(config.submission_rate.min(1.0)) {
80                    let report =
81                        self.generate_report(employee_id, current_month_start, month_end, config);
82                    reports.push(report);
83                }
84            }
85
86            // Advance to next month
87            current_month_start = self.next_month_start(current_month_start);
88        }
89
90        reports
91    }
92
93    /// Generate a single expense report for an employee within a date range.
94    fn generate_report(
95        &mut self,
96        employee_id: &str,
97        period_start: NaiveDate,
98        period_end: NaiveDate,
99        config: &ExpenseConfig,
100    ) -> ExpenseReport {
101        let report_id = self.uuid_factory.next().to_string();
102
103        // 1-5 line items per report
104        let item_count = self.rng.gen_range(1..=5);
105        let mut line_items = Vec::with_capacity(item_count);
106        let mut total_amount = Decimal::ZERO;
107
108        for _ in 0..item_count {
109            let item = self.generate_line_item(period_start, period_end);
110            total_amount += item.amount;
111            line_items.push(item);
112        }
113
114        // Submission date: usually within a few days after the last expense
115        let max_expense_date = line_items
116            .iter()
117            .map(|li| li.date)
118            .max()
119            .unwrap_or(period_end);
120        let submission_lag = self.rng.gen_range(0..=5);
121        let submission_date = max_expense_date + chrono::Duration::days(submission_lag);
122
123        // Trip/purpose descriptions
124        let descriptions = [
125            "Client site visit",
126            "Conference attendance",
127            "Team offsite meeting",
128            "Customer presentation",
129            "Training workshop",
130            "Quarterly review travel",
131            "Sales meeting",
132            "Project kickoff",
133        ];
134        let description = descriptions[self.rng.gen_range(0..descriptions.len())].to_string();
135
136        // Status distribution: 70% Approved, 10% Paid, 10% Submitted, 5% Rejected, 5% Draft
137        let status_roll: f64 = self.rng.gen();
138        let status = if status_roll < 0.70 {
139            ExpenseStatus::Approved
140        } else if status_roll < 0.80 {
141            ExpenseStatus::Paid
142        } else if status_roll < 0.90 {
143            ExpenseStatus::Submitted
144        } else if status_roll < 0.95 {
145            ExpenseStatus::Rejected
146        } else {
147            ExpenseStatus::Draft
148        };
149
150        let approved_by = if matches!(status, ExpenseStatus::Approved | ExpenseStatus::Paid) {
151            Some(format!("MGR-{:04}", self.rng.gen_range(1..=100)))
152        } else {
153            None
154        };
155
156        let approved_date = if matches!(status, ExpenseStatus::Approved | ExpenseStatus::Paid) {
157            let approval_lag = self.rng.gen_range(1..=7);
158            Some(submission_date + chrono::Duration::days(approval_lag))
159        } else {
160            None
161        };
162
163        let paid_date = if status == ExpenseStatus::Paid {
164            approved_date.map(|ad| ad + chrono::Duration::days(self.rng.gen_range(3..=14)))
165        } else {
166            None
167        };
168
169        // Cost center and department
170        let cost_center = if self.rng.gen_bool(0.70) {
171            Some(format!("CC-{:03}", self.rng.gen_range(100..=500)))
172        } else {
173            None
174        };
175
176        let department = if self.rng.gen_bool(0.80) {
177            let departments = [
178                "Engineering",
179                "Sales",
180                "Marketing",
181                "Finance",
182                "HR",
183                "Operations",
184                "Legal",
185                "IT",
186                "Executive",
187            ];
188            Some(departments[self.rng.gen_range(0..departments.len())].to_string())
189        } else {
190            None
191        };
192
193        // Policy violations: based on config.policy_violation_rate per line item
194        let policy_violation_rate = config.policy_violation_rate;
195        let mut policy_violations = Vec::new();
196        for item in &line_items {
197            if self.rng.gen_bool(policy_violation_rate.min(1.0)) {
198                let violation = self.pick_violation(item);
199                policy_violations.push(violation);
200            }
201        }
202
203        ExpenseReport {
204            report_id,
205            employee_id: employee_id.to_string(),
206            submission_date,
207            description,
208            status,
209            total_amount,
210            currency: "USD".to_string(),
211            line_items,
212            approved_by,
213            approved_date,
214            paid_date,
215            cost_center,
216            department,
217            policy_violations,
218        }
219    }
220
221    /// Generate a single expense line item with a random category and amount.
222    fn generate_line_item(
223        &mut self,
224        period_start: NaiveDate,
225        period_end: NaiveDate,
226    ) -> ExpenseLineItem {
227        let item_id = self.item_uuid_factory.next().to_string();
228
229        // Pick a category and generate an appropriate amount range
230        let (category, amount_min, amount_max, desc, merchant) = self.pick_category();
231
232        let raw_amount = self.rng.gen_range(amount_min..=amount_max);
233        let amount = Decimal::from_f64_retain(raw_amount)
234            .unwrap_or(Decimal::ONE)
235            .round_dp(2);
236
237        // Date within the period
238        let days_in_period = (period_end - period_start).num_days().max(1);
239        let offset = self.rng.gen_range(0..=days_in_period);
240        let date = period_start + chrono::Duration::days(offset);
241
242        // Receipt attached: 85% of the time
243        let receipt_attached = self.rng.gen_bool(0.85);
244
245        ExpenseLineItem {
246            item_id,
247            category,
248            date,
249            amount,
250            currency: "USD".to_string(),
251            description: desc,
252            receipt_attached,
253            merchant,
254        }
255    }
256
257    /// Pick an expense category with corresponding amount range, description, and merchant.
258    fn pick_category(&mut self) -> (ExpenseCategory, f64, f64, String, Option<String>) {
259        let roll: f64 = self.rng.gen();
260
261        if roll < 0.20 {
262            let merchants = [
263                "Delta Airlines",
264                "United Airlines",
265                "American Airlines",
266                "Southwest",
267            ];
268            let merchant = merchants[self.rng.gen_range(0..merchants.len())].to_string();
269            (
270                ExpenseCategory::Travel,
271                200.0,
272                2000.0,
273                "Airfare - business travel".to_string(),
274                Some(merchant),
275            )
276        } else if roll < 0.40 {
277            let merchants = [
278                "Restaurant ABC",
279                "Cafe Express",
280                "Business Lunch Co",
281                "Steakhouse Prime",
282                "Sushi Palace",
283            ];
284            let merchant = merchants[self.rng.gen_range(0..merchants.len())].to_string();
285            (
286                ExpenseCategory::Meals,
287                20.0,
288                100.0,
289                "Business meal".to_string(),
290                Some(merchant),
291            )
292        } else if roll < 0.55 {
293            let merchants = ["Marriott", "Hilton", "Hyatt", "Holiday Inn", "Best Western"];
294            let merchant = merchants[self.rng.gen_range(0..merchants.len())].to_string();
295            (
296                ExpenseCategory::Lodging,
297                100.0,
298                500.0,
299                "Hotel accommodation".to_string(),
300                Some(merchant),
301            )
302        } else if roll < 0.70 {
303            let merchants = ["Uber", "Lyft", "Hertz", "Enterprise", "Airport Parking"];
304            let merchant = merchants[self.rng.gen_range(0..merchants.len())].to_string();
305            (
306                ExpenseCategory::Transportation,
307                10.0,
308                200.0,
309                "Ground transportation".to_string(),
310                Some(merchant),
311            )
312        } else if roll < 0.80 {
313            (
314                ExpenseCategory::Office,
315                15.0,
316                300.0,
317                "Office supplies".to_string(),
318                Some("Office Depot".to_string()),
319            )
320        } else if roll < 0.88 {
321            (
322                ExpenseCategory::Entertainment,
323                50.0,
324                500.0,
325                "Client entertainment".to_string(),
326                None,
327            )
328        } else if roll < 0.95 {
329            (
330                ExpenseCategory::Training,
331                100.0,
332                1500.0,
333                "Professional development".to_string(),
334                None,
335            )
336        } else {
337            (
338                ExpenseCategory::Other,
339                10.0,
340                200.0,
341                "Miscellaneous expense".to_string(),
342                None,
343            )
344        }
345    }
346
347    /// Generate a policy violation description for a given line item.
348    fn pick_violation(&mut self, item: &ExpenseLineItem) -> String {
349        let violations = match item.category {
350            ExpenseCategory::Meals => vec![
351                "Exceeds daily meal limit",
352                "Alcohol included without approval",
353                "Missing itemized receipt",
354            ],
355            ExpenseCategory::Travel => vec![
356                "Booked outside preferred vendor",
357                "Class upgrade not pre-approved",
358                "Booking made less than 7 days in advance",
359            ],
360            ExpenseCategory::Lodging => vec![
361                "Exceeds nightly rate limit",
362                "Extended stay without approval",
363                "Non-preferred hotel chain",
364            ],
365            _ => vec![
366                "Missing receipt",
367                "Insufficient business justification",
368                "Exceeds category spending limit",
369            ],
370        };
371
372        violations[self.rng.gen_range(0..violations.len())].to_string()
373    }
374
375    /// Get the last day of the month for a given date.
376    fn month_end(&self, date: NaiveDate) -> NaiveDate {
377        let (year, month) = if date.month() == 12 {
378            (date.year() + 1, 1)
379        } else {
380            (date.year(), date.month() + 1)
381        };
382        NaiveDate::from_ymd_opt(year, month, 1)
383            .unwrap_or(date)
384            .pred_opt()
385            .unwrap_or(date)
386    }
387
388    /// Get the first day of the next month.
389    fn next_month_start(&self, date: NaiveDate) -> NaiveDate {
390        let (year, month) = if date.month() == 12 {
391            (date.year() + 1, 1)
392        } else {
393            (date.year(), date.month() + 1)
394        };
395        NaiveDate::from_ymd_opt(year, month, 1).unwrap_or(date)
396    }
397}
398
399#[cfg(test)]
400#[allow(clippy::unwrap_used)]
401mod tests {
402    use super::*;
403
404    fn test_employee_ids() -> Vec<String> {
405        (1..=10).map(|i| format!("EMP-{:04}", i)).collect()
406    }
407
408    #[test]
409    fn test_basic_expense_generation() {
410        let mut gen = ExpenseReportGenerator::new(42);
411        let employees = test_employee_ids();
412        let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
413        let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
414        let config = ExpenseConfig::default();
415
416        let reports = gen.generate(&employees, period_start, period_end, &config);
417
418        // With 30% submission rate and 10 employees, expect ~3 reports per month
419        assert!(!reports.is_empty());
420        assert!(
421            reports.len() <= employees.len(),
422            "Should not exceed employee count for a single month"
423        );
424
425        for report in &reports {
426            assert!(!report.report_id.is_empty());
427            assert!(!report.employee_id.is_empty());
428            assert!(report.total_amount > Decimal::ZERO);
429            assert!(!report.line_items.is_empty());
430            assert!(report.line_items.len() <= 5);
431
432            // Total should equal sum of line items
433            let line_sum: Decimal = report.line_items.iter().map(|li| li.amount).sum();
434            assert_eq!(report.total_amount, line_sum);
435
436            for item in &report.line_items {
437                assert!(!item.item_id.is_empty());
438                assert!(item.amount > Decimal::ZERO);
439            }
440        }
441    }
442
443    #[test]
444    fn test_deterministic_expenses() {
445        let employees = test_employee_ids();
446        let period_start = NaiveDate::from_ymd_opt(2024, 3, 1).unwrap();
447        let period_end = NaiveDate::from_ymd_opt(2024, 3, 31).unwrap();
448        let config = ExpenseConfig::default();
449
450        let mut gen1 = ExpenseReportGenerator::new(42);
451        let reports1 = gen1.generate(&employees, period_start, period_end, &config);
452
453        let mut gen2 = ExpenseReportGenerator::new(42);
454        let reports2 = gen2.generate(&employees, period_start, period_end, &config);
455
456        assert_eq!(reports1.len(), reports2.len());
457        for (a, b) in reports1.iter().zip(reports2.iter()) {
458            assert_eq!(a.report_id, b.report_id);
459            assert_eq!(a.employee_id, b.employee_id);
460            assert_eq!(a.total_amount, b.total_amount);
461            assert_eq!(a.status, b.status);
462            assert_eq!(a.line_items.len(), b.line_items.len());
463        }
464    }
465
466    #[test]
467    fn test_expense_status_and_violations() {
468        let mut gen = ExpenseReportGenerator::new(99);
469        // Use more employees for a broader sample
470        let employees: Vec<String> = (1..=30).map(|i| format!("EMP-{:04}", i)).collect();
471        let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
472        let period_end = NaiveDate::from_ymd_opt(2024, 6, 30).unwrap();
473        let config = ExpenseConfig::default();
474
475        let reports = gen.generate(&employees, period_start, period_end, &config);
476
477        // With 30 employees over 6 months, we should have a decent sample
478        assert!(
479            reports.len() > 10,
480            "Expected multiple reports, got {}",
481            reports.len()
482        );
483
484        let approved = reports
485            .iter()
486            .filter(|r| r.status == ExpenseStatus::Approved)
487            .count();
488        let paid = reports
489            .iter()
490            .filter(|r| r.status == ExpenseStatus::Paid)
491            .count();
492        let submitted = reports
493            .iter()
494            .filter(|r| r.status == ExpenseStatus::Submitted)
495            .count();
496        let rejected = reports
497            .iter()
498            .filter(|r| r.status == ExpenseStatus::Rejected)
499            .count();
500        let draft = reports
501            .iter()
502            .filter(|r| r.status == ExpenseStatus::Draft)
503            .count();
504
505        // Approved should be the majority
506        assert!(approved > 0, "Expected at least some approved reports");
507        // Check that we have a mix of statuses
508        assert!(
509            paid + submitted + rejected + draft > 0,
510            "Expected a mix of statuses beyond approved"
511        );
512
513        // Check policy violations exist somewhere
514        let total_violations: usize = reports.iter().map(|r| r.policy_violations.len()).sum();
515        assert!(
516            total_violations > 0,
517            "Expected at least some policy violations across {} reports",
518            reports.len()
519        );
520    }
521}