Skip to main content

datasynth_generators/hr/
expense_report_generator.rs

1//! Expense report generator for the Hire-to-Retire (H2R) process.
2//!
3//! Generates employee expense reports with realistic line items across categories
4//! (travel, meals, lodging, transportation, etc.), policy violation detection,
5//! and approval workflow statuses.
6
7use chrono::{Datelike, NaiveDate};
8use datasynth_config::schema::ExpenseConfig;
9use datasynth_core::models::{ExpenseCategory, ExpenseLineItem, ExpenseReport, ExpenseStatus};
10use datasynth_core::utils::{sample_decimal_range, seeded_rng};
11use datasynth_core::uuid_factory::{DeterministicUuidFactory, GeneratorType};
12use rand::prelude::*;
13use rand_chacha::ChaCha8Rng;
14use rust_decimal::Decimal;
15use std::collections::HashMap;
16use tracing::debug;
17
18/// Generates [`ExpenseReport`] records for employees over a period.
19pub struct ExpenseReportGenerator {
20    rng: ChaCha8Rng,
21    uuid_factory: DeterministicUuidFactory,
22    item_uuid_factory: DeterministicUuidFactory,
23    // Stored for future use by with_config(); generate() currently takes config as a parameter.
24    #[allow(dead_code)]
25    config: ExpenseConfig,
26    /// Pool of real employee IDs for approved_by references.
27    employee_ids_pool: Vec<String>,
28    /// Pool of real cost center IDs.
29    cost_center_ids_pool: Vec<String>,
30    /// Mapping of employee_id → employee_name for denormalization (DS-011).
31    employee_names: HashMap<String, String>,
32}
33
34impl ExpenseReportGenerator {
35    /// Create a new expense report generator with default configuration.
36    pub fn new(seed: u64) -> Self {
37        Self {
38            rng: seeded_rng(seed, 0),
39            uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::ExpenseReport),
40            item_uuid_factory: DeterministicUuidFactory::with_sub_discriminator(
41                seed,
42                GeneratorType::ExpenseReport,
43                1,
44            ),
45            config: ExpenseConfig::default(),
46            employee_ids_pool: Vec::new(),
47            cost_center_ids_pool: Vec::new(),
48            employee_names: HashMap::new(),
49        }
50    }
51
52    /// Create an expense report generator with custom configuration.
53    pub fn with_config(seed: u64, config: ExpenseConfig) -> Self {
54        Self {
55            rng: seeded_rng(seed, 0),
56            uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::ExpenseReport),
57            item_uuid_factory: DeterministicUuidFactory::with_sub_discriminator(
58                seed,
59                GeneratorType::ExpenseReport,
60                1,
61            ),
62            config,
63            employee_ids_pool: Vec::new(),
64            cost_center_ids_pool: Vec::new(),
65            employee_names: HashMap::new(),
66        }
67    }
68
69    /// Set ID pools for cross-reference coherence.
70    ///
71    /// When pools are non-empty, the generator selects `approved_by` from
72    /// `employee_ids` and `cost_center` from `cost_center_ids` instead of
73    /// fabricating placeholder IDs.
74    pub fn with_pools(mut self, employee_ids: Vec<String>, cost_center_ids: Vec<String>) -> Self {
75        self.employee_ids_pool = employee_ids;
76        self.cost_center_ids_pool = cost_center_ids;
77        self
78    }
79
80    /// Set the employee name mapping for denormalization (DS-011).
81    ///
82    /// Maps employee IDs to their display names so that generated expense
83    /// reports include the employee name for graph export convenience.
84    pub fn with_employee_names(mut self, names: HashMap<String, String>) -> Self {
85        self.employee_names = names;
86        self
87    }
88
89    /// Generate expense reports for employees over the given period.
90    ///
91    /// Only `config.submission_rate` fraction of employees submit reports each
92    /// month within the period.
93    ///
94    /// # Arguments
95    ///
96    /// * `employee_ids` - Slice of employee identifiers
97    /// * `period_start` - Start of the period (inclusive)
98    /// * `period_end` - End of the period (inclusive)
99    /// * `config` - Expense management configuration
100    pub fn generate(
101        &mut self,
102        employee_ids: &[String],
103        period_start: NaiveDate,
104        period_end: NaiveDate,
105        config: &ExpenseConfig,
106    ) -> Vec<ExpenseReport> {
107        self.generate_with_currency(employee_ids, period_start, period_end, config, "USD")
108    }
109
110    /// Generate expense reports with a specific company currency.
111    pub fn generate_with_currency(
112        &mut self,
113        employee_ids: &[String],
114        period_start: NaiveDate,
115        period_end: NaiveDate,
116        config: &ExpenseConfig,
117        currency: &str,
118    ) -> Vec<ExpenseReport> {
119        debug!(employee_count = employee_ids.len(), %period_start, %period_end, currency, "Generating expense reports");
120        let mut reports = Vec::new();
121
122        // Iterate over each month in the period
123        let mut current_month_start = period_start;
124        while current_month_start <= period_end {
125            let month_end = self.month_end(current_month_start).min(period_end);
126
127            for employee_id in employee_ids {
128                // Only submission_rate fraction of employees submit per month
129                if self.rng.random_bool(config.submission_rate.min(1.0)) {
130                    let report = self.generate_report(
131                        employee_id,
132                        current_month_start,
133                        month_end,
134                        config,
135                        currency,
136                    );
137                    reports.push(report);
138                }
139            }
140
141            // Advance to next month
142            current_month_start = self.next_month_start(current_month_start);
143        }
144
145        reports
146    }
147
148    /// Generate a single expense report for an employee within a date range.
149    fn generate_report(
150        &mut self,
151        employee_id: &str,
152        period_start: NaiveDate,
153        period_end: NaiveDate,
154        config: &ExpenseConfig,
155        currency: &str,
156    ) -> ExpenseReport {
157        let report_id = self.uuid_factory.next().to_string();
158
159        // 1-5 line items per report
160        let item_count = self.rng.random_range(1..=5);
161        let mut line_items = Vec::with_capacity(item_count);
162        let mut total_amount = Decimal::ZERO;
163
164        for _ in 0..item_count {
165            let item = self.generate_line_item(period_start, period_end, currency);
166            total_amount += item.amount;
167            line_items.push(item);
168        }
169
170        // Submission date: usually within a few days after the last expense
171        let max_expense_date = line_items
172            .iter()
173            .map(|li| li.date)
174            .max()
175            .unwrap_or(period_end);
176        let submission_lag = self.rng.random_range(0..=5);
177        let submission_date = max_expense_date + chrono::Duration::days(submission_lag);
178
179        // Trip/purpose descriptions
180        let descriptions = [
181            "Client site visit",
182            "Conference attendance",
183            "Team offsite meeting",
184            "Customer presentation",
185            "Training workshop",
186            "Quarterly review travel",
187            "Sales meeting",
188            "Project kickoff",
189        ];
190        let description = descriptions[self.rng.random_range(0..descriptions.len())].to_string();
191
192        // Status distribution: 70% Approved, 10% Paid, 10% Submitted, 5% Rejected, 5% Draft
193        let status_roll: f64 = self.rng.random();
194        let status = if status_roll < 0.70 {
195            ExpenseStatus::Approved
196        } else if status_roll < 0.80 {
197            ExpenseStatus::Paid
198        } else if status_roll < 0.90 {
199            ExpenseStatus::Submitted
200        } else if status_roll < 0.95 {
201            ExpenseStatus::Rejected
202        } else {
203            ExpenseStatus::Draft
204        };
205
206        let approved_by = if matches!(status, ExpenseStatus::Approved | ExpenseStatus::Paid) {
207            if !self.employee_ids_pool.is_empty() {
208                let idx = self.rng.random_range(0..self.employee_ids_pool.len());
209                Some(self.employee_ids_pool[idx].clone())
210            } else {
211                Some(format!("MGR-{:04}", self.rng.random_range(1..=100)))
212            }
213        } else {
214            None
215        };
216
217        let approved_date = if matches!(status, ExpenseStatus::Approved | ExpenseStatus::Paid) {
218            let approval_lag = self.rng.random_range(1..=7);
219            Some(submission_date + chrono::Duration::days(approval_lag))
220        } else {
221            None
222        };
223
224        let paid_date = if status == ExpenseStatus::Paid {
225            approved_date.map(|ad| ad + chrono::Duration::days(self.rng.random_range(3..=14)))
226        } else {
227            None
228        };
229
230        // Cost center and department
231        let cost_center = if self.rng.random_bool(0.70) {
232            if !self.cost_center_ids_pool.is_empty() {
233                let idx = self.rng.random_range(0..self.cost_center_ids_pool.len());
234                Some(self.cost_center_ids_pool[idx].clone())
235            } else {
236                Some(format!("CC-{:03}", self.rng.random_range(100..=500)))
237            }
238        } else {
239            None
240        };
241
242        let department = if self.rng.random_bool(0.80) {
243            let departments = [
244                "Engineering",
245                "Sales",
246                "Marketing",
247                "Finance",
248                "HR",
249                "Operations",
250                "Legal",
251                "IT",
252                "Executive",
253            ];
254            Some(departments[self.rng.random_range(0..departments.len())].to_string())
255        } else {
256            None
257        };
258
259        // Policy violations: based on config.policy_violation_rate per line item
260        let policy_violation_rate = config.policy_violation_rate;
261        let mut policy_violations = Vec::new();
262        for item in &line_items {
263            if self.rng.random_bool(policy_violation_rate.min(1.0)) {
264                let violation = self.pick_violation(item);
265                policy_violations.push(violation);
266            }
267        }
268
269        ExpenseReport {
270            report_id,
271            employee_id: employee_id.to_string(),
272            submission_date,
273            description,
274            status,
275            total_amount,
276            currency: currency.to_string(),
277            line_items,
278            approved_by,
279            approved_date,
280            paid_date,
281            cost_center,
282            department,
283            policy_violations,
284            employee_name: self.employee_names.get(employee_id).cloned(),
285        }
286    }
287
288    /// Generate a single expense line item with a random category and amount.
289    fn generate_line_item(
290        &mut self,
291        period_start: NaiveDate,
292        period_end: NaiveDate,
293        currency: &str,
294    ) -> ExpenseLineItem {
295        let item_id = self.item_uuid_factory.next().to_string();
296
297        // Pick a category and generate an appropriate amount range
298        let (category, amount_min, amount_max, desc, merchant) = self.pick_category();
299
300        let amount = sample_decimal_range(
301            &mut self.rng,
302            Decimal::from_f64_retain(amount_min).unwrap_or(Decimal::ONE),
303            Decimal::from_f64_retain(amount_max).unwrap_or(Decimal::ONE),
304        )
305        .round_dp(2);
306
307        // Date within the period
308        let days_in_period = (period_end - period_start).num_days().max(1);
309        let offset = self.rng.random_range(0..=days_in_period);
310        let date = period_start + chrono::Duration::days(offset);
311
312        // Receipt attached: 85% of the time
313        let receipt_attached = self.rng.random_bool(0.85);
314
315        ExpenseLineItem {
316            item_id,
317            category,
318            date,
319            amount,
320            currency: currency.to_string(),
321            description: desc,
322            receipt_attached,
323            merchant,
324        }
325    }
326
327    /// Pick an expense category with corresponding amount range, description, and merchant.
328    fn pick_category(&mut self) -> (ExpenseCategory, f64, f64, String, Option<String>) {
329        let roll: f64 = self.rng.random();
330
331        if roll < 0.20 {
332            let merchants = [
333                "Delta Airlines",
334                "United Airlines",
335                "American Airlines",
336                "Southwest",
337            ];
338            let merchant = merchants[self.rng.random_range(0..merchants.len())].to_string();
339            (
340                ExpenseCategory::Travel,
341                200.0,
342                2000.0,
343                "Airfare - business travel".to_string(),
344                Some(merchant),
345            )
346        } else if roll < 0.40 {
347            let merchants = [
348                "Restaurant ABC",
349                "Cafe Express",
350                "Business Lunch Co",
351                "Steakhouse Prime",
352                "Sushi Palace",
353            ];
354            let merchant = merchants[self.rng.random_range(0..merchants.len())].to_string();
355            (
356                ExpenseCategory::Meals,
357                20.0,
358                100.0,
359                "Business meal".to_string(),
360                Some(merchant),
361            )
362        } else if roll < 0.55 {
363            let merchants = ["Marriott", "Hilton", "Hyatt", "Holiday Inn", "Best Western"];
364            let merchant = merchants[self.rng.random_range(0..merchants.len())].to_string();
365            (
366                ExpenseCategory::Lodging,
367                100.0,
368                500.0,
369                "Hotel accommodation".to_string(),
370                Some(merchant),
371            )
372        } else if roll < 0.70 {
373            let merchants = ["Uber", "Lyft", "Hertz", "Enterprise", "Airport Parking"];
374            let merchant = merchants[self.rng.random_range(0..merchants.len())].to_string();
375            (
376                ExpenseCategory::Transportation,
377                10.0,
378                200.0,
379                "Ground transportation".to_string(),
380                Some(merchant),
381            )
382        } else if roll < 0.80 {
383            (
384                ExpenseCategory::Office,
385                15.0,
386                300.0,
387                "Office supplies".to_string(),
388                Some("Office Depot".to_string()),
389            )
390        } else if roll < 0.88 {
391            (
392                ExpenseCategory::Entertainment,
393                50.0,
394                500.0,
395                "Client entertainment".to_string(),
396                None,
397            )
398        } else if roll < 0.95 {
399            (
400                ExpenseCategory::Training,
401                100.0,
402                1500.0,
403                "Professional development".to_string(),
404                None,
405            )
406        } else {
407            (
408                ExpenseCategory::Other,
409                10.0,
410                200.0,
411                "Miscellaneous expense".to_string(),
412                None,
413            )
414        }
415    }
416
417    /// Generate a policy violation description for a given line item.
418    fn pick_violation(&mut self, item: &ExpenseLineItem) -> String {
419        let violations = match item.category {
420            ExpenseCategory::Meals => vec![
421                "Exceeds daily meal limit",
422                "Alcohol included without approval",
423                "Missing itemized receipt",
424            ],
425            ExpenseCategory::Travel => vec![
426                "Booked outside preferred vendor",
427                "Class upgrade not pre-approved",
428                "Booking made less than 7 days in advance",
429            ],
430            ExpenseCategory::Lodging => vec![
431                "Exceeds nightly rate limit",
432                "Extended stay without approval",
433                "Non-preferred hotel chain",
434            ],
435            _ => vec![
436                "Missing receipt",
437                "Insufficient business justification",
438                "Exceeds category spending limit",
439            ],
440        };
441
442        violations[self.rng.random_range(0..violations.len())].to_string()
443    }
444
445    /// Get the last day of the month for a given date.
446    fn month_end(&self, date: NaiveDate) -> NaiveDate {
447        let (year, month) = if date.month() == 12 {
448            (date.year() + 1, 1)
449        } else {
450            (date.year(), date.month() + 1)
451        };
452        NaiveDate::from_ymd_opt(year, month, 1)
453            .unwrap_or(date)
454            .pred_opt()
455            .unwrap_or(date)
456    }
457
458    /// Get the first day of the next month.
459    fn next_month_start(&self, date: NaiveDate) -> NaiveDate {
460        let (year, month) = if date.month() == 12 {
461            (date.year() + 1, 1)
462        } else {
463            (date.year(), date.month() + 1)
464        };
465        NaiveDate::from_ymd_opt(year, month, 1).unwrap_or(date)
466    }
467}
468
469#[cfg(test)]
470#[allow(clippy::unwrap_used)]
471mod tests {
472    use super::*;
473
474    fn test_employee_ids() -> Vec<String> {
475        (1..=10).map(|i| format!("EMP-{:04}", i)).collect()
476    }
477
478    #[test]
479    fn test_basic_expense_generation() {
480        let mut gen = ExpenseReportGenerator::new(42);
481        let employees = test_employee_ids();
482        let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
483        let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
484        let config = ExpenseConfig::default();
485
486        let reports = gen.generate(&employees, period_start, period_end, &config);
487
488        // With 30% submission rate and 10 employees, expect ~3 reports per month
489        assert!(!reports.is_empty());
490        assert!(
491            reports.len() <= employees.len(),
492            "Should not exceed employee count for a single month"
493        );
494
495        for report in &reports {
496            assert!(!report.report_id.is_empty());
497            assert!(!report.employee_id.is_empty());
498            assert!(report.total_amount > Decimal::ZERO);
499            assert!(!report.line_items.is_empty());
500            assert!(report.line_items.len() <= 5);
501
502            // Total should equal sum of line items
503            let line_sum: Decimal = report.line_items.iter().map(|li| li.amount).sum();
504            assert_eq!(report.total_amount, line_sum);
505
506            for item in &report.line_items {
507                assert!(!item.item_id.is_empty());
508                assert!(item.amount > Decimal::ZERO);
509            }
510        }
511    }
512
513    #[test]
514    fn test_deterministic_expenses() {
515        let employees = test_employee_ids();
516        let period_start = NaiveDate::from_ymd_opt(2024, 3, 1).unwrap();
517        let period_end = NaiveDate::from_ymd_opt(2024, 3, 31).unwrap();
518        let config = ExpenseConfig::default();
519
520        let mut gen1 = ExpenseReportGenerator::new(42);
521        let reports1 = gen1.generate(&employees, period_start, period_end, &config);
522
523        let mut gen2 = ExpenseReportGenerator::new(42);
524        let reports2 = gen2.generate(&employees, period_start, period_end, &config);
525
526        assert_eq!(reports1.len(), reports2.len());
527        for (a, b) in reports1.iter().zip(reports2.iter()) {
528            assert_eq!(a.report_id, b.report_id);
529            assert_eq!(a.employee_id, b.employee_id);
530            assert_eq!(a.total_amount, b.total_amount);
531            assert_eq!(a.status, b.status);
532            assert_eq!(a.line_items.len(), b.line_items.len());
533        }
534    }
535
536    #[test]
537    fn test_expense_status_and_violations() {
538        let mut gen = ExpenseReportGenerator::new(99);
539        // Use more employees for a broader sample
540        let employees: Vec<String> = (1..=30).map(|i| format!("EMP-{:04}", i)).collect();
541        let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
542        let period_end = NaiveDate::from_ymd_opt(2024, 6, 30).unwrap();
543        let config = ExpenseConfig::default();
544
545        let reports = gen.generate(&employees, period_start, period_end, &config);
546
547        // With 30 employees over 6 months, we should have a decent sample
548        assert!(
549            reports.len() > 10,
550            "Expected multiple reports, got {}",
551            reports.len()
552        );
553
554        let approved = reports
555            .iter()
556            .filter(|r| r.status == ExpenseStatus::Approved)
557            .count();
558        let paid = reports
559            .iter()
560            .filter(|r| r.status == ExpenseStatus::Paid)
561            .count();
562        let submitted = reports
563            .iter()
564            .filter(|r| r.status == ExpenseStatus::Submitted)
565            .count();
566        let rejected = reports
567            .iter()
568            .filter(|r| r.status == ExpenseStatus::Rejected)
569            .count();
570        let draft = reports
571            .iter()
572            .filter(|r| r.status == ExpenseStatus::Draft)
573            .count();
574
575        // Approved should be the majority
576        assert!(approved > 0, "Expected at least some approved reports");
577        // Check that we have a mix of statuses
578        assert!(
579            paid + submitted + rejected + draft > 0,
580            "Expected a mix of statuses beyond approved"
581        );
582
583        // Check policy violations exist somewhere
584        let total_violations: usize = reports.iter().map(|r| r.policy_violations.len()).sum();
585        assert!(
586            total_violations > 0,
587            "Expected at least some policy violations across {} reports",
588            reports.len()
589        );
590    }
591}