Skip to main content

datasynth_generators/hr/
expense_report_generator.rs

1//! Expense report generator for the Hire-to-Retire (H2R) process.
2//!
3//! Generates employee expense reports with realistic line items across categories
4//! (travel, meals, lodging, transportation, etc.), policy violation detection,
5//! and approval workflow statuses.
6
7use chrono::{Datelike, NaiveDate};
8use datasynth_config::schema::ExpenseConfig;
9use datasynth_core::models::{ExpenseCategory, ExpenseLineItem, ExpenseReport, ExpenseStatus};
10use datasynth_core::utils::{sample_decimal_range, seeded_rng};
11use datasynth_core::uuid_factory::{DeterministicUuidFactory, GeneratorType};
12use rand::prelude::*;
13use rand_chacha::ChaCha8Rng;
14use rust_decimal::Decimal;
15use tracing::debug;
16
17/// Generates [`ExpenseReport`] records for employees over a period.
18pub struct ExpenseReportGenerator {
19    rng: ChaCha8Rng,
20    uuid_factory: DeterministicUuidFactory,
21    item_uuid_factory: DeterministicUuidFactory,
22    // Stored for future use by with_config(); generate() currently takes config as a parameter.
23    #[allow(dead_code)]
24    config: ExpenseConfig,
25    /// Pool of real employee IDs for approved_by references.
26    employee_ids_pool: Vec<String>,
27    /// Pool of real cost center IDs.
28    cost_center_ids_pool: Vec<String>,
29}
30
31impl ExpenseReportGenerator {
32    /// Create a new expense report generator with default configuration.
33    pub fn new(seed: u64) -> Self {
34        Self {
35            rng: seeded_rng(seed, 0),
36            uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::ExpenseReport),
37            item_uuid_factory: DeterministicUuidFactory::with_sub_discriminator(
38                seed,
39                GeneratorType::ExpenseReport,
40                1,
41            ),
42            config: ExpenseConfig::default(),
43            employee_ids_pool: Vec::new(),
44            cost_center_ids_pool: Vec::new(),
45        }
46    }
47
48    /// Create an expense report generator with custom configuration.
49    pub fn with_config(seed: u64, config: ExpenseConfig) -> Self {
50        Self {
51            rng: seeded_rng(seed, 0),
52            uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::ExpenseReport),
53            item_uuid_factory: DeterministicUuidFactory::with_sub_discriminator(
54                seed,
55                GeneratorType::ExpenseReport,
56                1,
57            ),
58            config,
59            employee_ids_pool: Vec::new(),
60            cost_center_ids_pool: Vec::new(),
61        }
62    }
63
64    /// Set ID pools for cross-reference coherence.
65    ///
66    /// When pools are non-empty, the generator selects `approved_by` from
67    /// `employee_ids` and `cost_center` from `cost_center_ids` instead of
68    /// fabricating placeholder IDs.
69    pub fn with_pools(mut self, employee_ids: Vec<String>, cost_center_ids: Vec<String>) -> Self {
70        self.employee_ids_pool = employee_ids;
71        self.cost_center_ids_pool = cost_center_ids;
72        self
73    }
74
75    /// Generate expense reports for employees over the given period.
76    ///
77    /// Only `config.submission_rate` fraction of employees submit reports each
78    /// month within the period.
79    ///
80    /// # Arguments
81    ///
82    /// * `employee_ids` - Slice of employee identifiers
83    /// * `period_start` - Start of the period (inclusive)
84    /// * `period_end` - End of the period (inclusive)
85    /// * `config` - Expense management configuration
86    pub fn generate(
87        &mut self,
88        employee_ids: &[String],
89        period_start: NaiveDate,
90        period_end: NaiveDate,
91        config: &ExpenseConfig,
92    ) -> Vec<ExpenseReport> {
93        self.generate_with_currency(employee_ids, period_start, period_end, config, "USD")
94    }
95
96    /// Generate expense reports with a specific company currency.
97    pub fn generate_with_currency(
98        &mut self,
99        employee_ids: &[String],
100        period_start: NaiveDate,
101        period_end: NaiveDate,
102        config: &ExpenseConfig,
103        currency: &str,
104    ) -> Vec<ExpenseReport> {
105        debug!(employee_count = employee_ids.len(), %period_start, %period_end, currency, "Generating expense reports");
106        let mut reports = Vec::new();
107
108        // Iterate over each month in the period
109        let mut current_month_start = period_start;
110        while current_month_start <= period_end {
111            let month_end = self.month_end(current_month_start).min(period_end);
112
113            for employee_id in employee_ids {
114                // Only submission_rate fraction of employees submit per month
115                if self.rng.gen_bool(config.submission_rate.min(1.0)) {
116                    let report = self.generate_report(
117                        employee_id,
118                        current_month_start,
119                        month_end,
120                        config,
121                        currency,
122                    );
123                    reports.push(report);
124                }
125            }
126
127            // Advance to next month
128            current_month_start = self.next_month_start(current_month_start);
129        }
130
131        reports
132    }
133
134    /// Generate a single expense report for an employee within a date range.
135    fn generate_report(
136        &mut self,
137        employee_id: &str,
138        period_start: NaiveDate,
139        period_end: NaiveDate,
140        config: &ExpenseConfig,
141        currency: &str,
142    ) -> ExpenseReport {
143        let report_id = self.uuid_factory.next().to_string();
144
145        // 1-5 line items per report
146        let item_count = self.rng.gen_range(1..=5);
147        let mut line_items = Vec::with_capacity(item_count);
148        let mut total_amount = Decimal::ZERO;
149
150        for _ in 0..item_count {
151            let item = self.generate_line_item(period_start, period_end, currency);
152            total_amount += item.amount;
153            line_items.push(item);
154        }
155
156        // Submission date: usually within a few days after the last expense
157        let max_expense_date = line_items
158            .iter()
159            .map(|li| li.date)
160            .max()
161            .unwrap_or(period_end);
162        let submission_lag = self.rng.gen_range(0..=5);
163        let submission_date = max_expense_date + chrono::Duration::days(submission_lag);
164
165        // Trip/purpose descriptions
166        let descriptions = [
167            "Client site visit",
168            "Conference attendance",
169            "Team offsite meeting",
170            "Customer presentation",
171            "Training workshop",
172            "Quarterly review travel",
173            "Sales meeting",
174            "Project kickoff",
175        ];
176        let description = descriptions[self.rng.gen_range(0..descriptions.len())].to_string();
177
178        // Status distribution: 70% Approved, 10% Paid, 10% Submitted, 5% Rejected, 5% Draft
179        let status_roll: f64 = self.rng.gen();
180        let status = if status_roll < 0.70 {
181            ExpenseStatus::Approved
182        } else if status_roll < 0.80 {
183            ExpenseStatus::Paid
184        } else if status_roll < 0.90 {
185            ExpenseStatus::Submitted
186        } else if status_roll < 0.95 {
187            ExpenseStatus::Rejected
188        } else {
189            ExpenseStatus::Draft
190        };
191
192        let approved_by = if matches!(status, ExpenseStatus::Approved | ExpenseStatus::Paid) {
193            if !self.employee_ids_pool.is_empty() {
194                let idx = self.rng.gen_range(0..self.employee_ids_pool.len());
195                Some(self.employee_ids_pool[idx].clone())
196            } else {
197                Some(format!("MGR-{:04}", self.rng.gen_range(1..=100)))
198            }
199        } else {
200            None
201        };
202
203        let approved_date = if matches!(status, ExpenseStatus::Approved | ExpenseStatus::Paid) {
204            let approval_lag = self.rng.gen_range(1..=7);
205            Some(submission_date + chrono::Duration::days(approval_lag))
206        } else {
207            None
208        };
209
210        let paid_date = if status == ExpenseStatus::Paid {
211            approved_date.map(|ad| ad + chrono::Duration::days(self.rng.gen_range(3..=14)))
212        } else {
213            None
214        };
215
216        // Cost center and department
217        let cost_center = if self.rng.gen_bool(0.70) {
218            if !self.cost_center_ids_pool.is_empty() {
219                let idx = self.rng.gen_range(0..self.cost_center_ids_pool.len());
220                Some(self.cost_center_ids_pool[idx].clone())
221            } else {
222                Some(format!("CC-{:03}", self.rng.gen_range(100..=500)))
223            }
224        } else {
225            None
226        };
227
228        let department = if self.rng.gen_bool(0.80) {
229            let departments = [
230                "Engineering",
231                "Sales",
232                "Marketing",
233                "Finance",
234                "HR",
235                "Operations",
236                "Legal",
237                "IT",
238                "Executive",
239            ];
240            Some(departments[self.rng.gen_range(0..departments.len())].to_string())
241        } else {
242            None
243        };
244
245        // Policy violations: based on config.policy_violation_rate per line item
246        let policy_violation_rate = config.policy_violation_rate;
247        let mut policy_violations = Vec::new();
248        for item in &line_items {
249            if self.rng.gen_bool(policy_violation_rate.min(1.0)) {
250                let violation = self.pick_violation(item);
251                policy_violations.push(violation);
252            }
253        }
254
255        ExpenseReport {
256            report_id,
257            employee_id: employee_id.to_string(),
258            submission_date,
259            description,
260            status,
261            total_amount,
262            currency: currency.to_string(),
263            line_items,
264            approved_by,
265            approved_date,
266            paid_date,
267            cost_center,
268            department,
269            policy_violations,
270        }
271    }
272
273    /// Generate a single expense line item with a random category and amount.
274    fn generate_line_item(
275        &mut self,
276        period_start: NaiveDate,
277        period_end: NaiveDate,
278        currency: &str,
279    ) -> ExpenseLineItem {
280        let item_id = self.item_uuid_factory.next().to_string();
281
282        // Pick a category and generate an appropriate amount range
283        let (category, amount_min, amount_max, desc, merchant) = self.pick_category();
284
285        let amount = sample_decimal_range(
286            &mut self.rng,
287            Decimal::from_f64_retain(amount_min).unwrap_or(Decimal::ONE),
288            Decimal::from_f64_retain(amount_max).unwrap_or(Decimal::ONE),
289        )
290        .round_dp(2);
291
292        // Date within the period
293        let days_in_period = (period_end - period_start).num_days().max(1);
294        let offset = self.rng.gen_range(0..=days_in_period);
295        let date = period_start + chrono::Duration::days(offset);
296
297        // Receipt attached: 85% of the time
298        let receipt_attached = self.rng.gen_bool(0.85);
299
300        ExpenseLineItem {
301            item_id,
302            category,
303            date,
304            amount,
305            currency: currency.to_string(),
306            description: desc,
307            receipt_attached,
308            merchant,
309        }
310    }
311
312    /// Pick an expense category with corresponding amount range, description, and merchant.
313    fn pick_category(&mut self) -> (ExpenseCategory, f64, f64, String, Option<String>) {
314        let roll: f64 = self.rng.gen();
315
316        if roll < 0.20 {
317            let merchants = [
318                "Delta Airlines",
319                "United Airlines",
320                "American Airlines",
321                "Southwest",
322            ];
323            let merchant = merchants[self.rng.gen_range(0..merchants.len())].to_string();
324            (
325                ExpenseCategory::Travel,
326                200.0,
327                2000.0,
328                "Airfare - business travel".to_string(),
329                Some(merchant),
330            )
331        } else if roll < 0.40 {
332            let merchants = [
333                "Restaurant ABC",
334                "Cafe Express",
335                "Business Lunch Co",
336                "Steakhouse Prime",
337                "Sushi Palace",
338            ];
339            let merchant = merchants[self.rng.gen_range(0..merchants.len())].to_string();
340            (
341                ExpenseCategory::Meals,
342                20.0,
343                100.0,
344                "Business meal".to_string(),
345                Some(merchant),
346            )
347        } else if roll < 0.55 {
348            let merchants = ["Marriott", "Hilton", "Hyatt", "Holiday Inn", "Best Western"];
349            let merchant = merchants[self.rng.gen_range(0..merchants.len())].to_string();
350            (
351                ExpenseCategory::Lodging,
352                100.0,
353                500.0,
354                "Hotel accommodation".to_string(),
355                Some(merchant),
356            )
357        } else if roll < 0.70 {
358            let merchants = ["Uber", "Lyft", "Hertz", "Enterprise", "Airport Parking"];
359            let merchant = merchants[self.rng.gen_range(0..merchants.len())].to_string();
360            (
361                ExpenseCategory::Transportation,
362                10.0,
363                200.0,
364                "Ground transportation".to_string(),
365                Some(merchant),
366            )
367        } else if roll < 0.80 {
368            (
369                ExpenseCategory::Office,
370                15.0,
371                300.0,
372                "Office supplies".to_string(),
373                Some("Office Depot".to_string()),
374            )
375        } else if roll < 0.88 {
376            (
377                ExpenseCategory::Entertainment,
378                50.0,
379                500.0,
380                "Client entertainment".to_string(),
381                None,
382            )
383        } else if roll < 0.95 {
384            (
385                ExpenseCategory::Training,
386                100.0,
387                1500.0,
388                "Professional development".to_string(),
389                None,
390            )
391        } else {
392            (
393                ExpenseCategory::Other,
394                10.0,
395                200.0,
396                "Miscellaneous expense".to_string(),
397                None,
398            )
399        }
400    }
401
402    /// Generate a policy violation description for a given line item.
403    fn pick_violation(&mut self, item: &ExpenseLineItem) -> String {
404        let violations = match item.category {
405            ExpenseCategory::Meals => vec![
406                "Exceeds daily meal limit",
407                "Alcohol included without approval",
408                "Missing itemized receipt",
409            ],
410            ExpenseCategory::Travel => vec![
411                "Booked outside preferred vendor",
412                "Class upgrade not pre-approved",
413                "Booking made less than 7 days in advance",
414            ],
415            ExpenseCategory::Lodging => vec![
416                "Exceeds nightly rate limit",
417                "Extended stay without approval",
418                "Non-preferred hotel chain",
419            ],
420            _ => vec![
421                "Missing receipt",
422                "Insufficient business justification",
423                "Exceeds category spending limit",
424            ],
425        };
426
427        violations[self.rng.gen_range(0..violations.len())].to_string()
428    }
429
430    /// Get the last day of the month for a given date.
431    fn month_end(&self, date: NaiveDate) -> NaiveDate {
432        let (year, month) = if date.month() == 12 {
433            (date.year() + 1, 1)
434        } else {
435            (date.year(), date.month() + 1)
436        };
437        NaiveDate::from_ymd_opt(year, month, 1)
438            .unwrap_or(date)
439            .pred_opt()
440            .unwrap_or(date)
441    }
442
443    /// Get the first day of the next month.
444    fn next_month_start(&self, date: NaiveDate) -> NaiveDate {
445        let (year, month) = if date.month() == 12 {
446            (date.year() + 1, 1)
447        } else {
448            (date.year(), date.month() + 1)
449        };
450        NaiveDate::from_ymd_opt(year, month, 1).unwrap_or(date)
451    }
452}
453
454#[cfg(test)]
455#[allow(clippy::unwrap_used)]
456mod tests {
457    use super::*;
458
459    fn test_employee_ids() -> Vec<String> {
460        (1..=10).map(|i| format!("EMP-{:04}", i)).collect()
461    }
462
463    #[test]
464    fn test_basic_expense_generation() {
465        let mut gen = ExpenseReportGenerator::new(42);
466        let employees = test_employee_ids();
467        let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
468        let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
469        let config = ExpenseConfig::default();
470
471        let reports = gen.generate(&employees, period_start, period_end, &config);
472
473        // With 30% submission rate and 10 employees, expect ~3 reports per month
474        assert!(!reports.is_empty());
475        assert!(
476            reports.len() <= employees.len(),
477            "Should not exceed employee count for a single month"
478        );
479
480        for report in &reports {
481            assert!(!report.report_id.is_empty());
482            assert!(!report.employee_id.is_empty());
483            assert!(report.total_amount > Decimal::ZERO);
484            assert!(!report.line_items.is_empty());
485            assert!(report.line_items.len() <= 5);
486
487            // Total should equal sum of line items
488            let line_sum: Decimal = report.line_items.iter().map(|li| li.amount).sum();
489            assert_eq!(report.total_amount, line_sum);
490
491            for item in &report.line_items {
492                assert!(!item.item_id.is_empty());
493                assert!(item.amount > Decimal::ZERO);
494            }
495        }
496    }
497
498    #[test]
499    fn test_deterministic_expenses() {
500        let employees = test_employee_ids();
501        let period_start = NaiveDate::from_ymd_opt(2024, 3, 1).unwrap();
502        let period_end = NaiveDate::from_ymd_opt(2024, 3, 31).unwrap();
503        let config = ExpenseConfig::default();
504
505        let mut gen1 = ExpenseReportGenerator::new(42);
506        let reports1 = gen1.generate(&employees, period_start, period_end, &config);
507
508        let mut gen2 = ExpenseReportGenerator::new(42);
509        let reports2 = gen2.generate(&employees, period_start, period_end, &config);
510
511        assert_eq!(reports1.len(), reports2.len());
512        for (a, b) in reports1.iter().zip(reports2.iter()) {
513            assert_eq!(a.report_id, b.report_id);
514            assert_eq!(a.employee_id, b.employee_id);
515            assert_eq!(a.total_amount, b.total_amount);
516            assert_eq!(a.status, b.status);
517            assert_eq!(a.line_items.len(), b.line_items.len());
518        }
519    }
520
521    #[test]
522    fn test_expense_status_and_violations() {
523        let mut gen = ExpenseReportGenerator::new(99);
524        // Use more employees for a broader sample
525        let employees: Vec<String> = (1..=30).map(|i| format!("EMP-{:04}", i)).collect();
526        let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
527        let period_end = NaiveDate::from_ymd_opt(2024, 6, 30).unwrap();
528        let config = ExpenseConfig::default();
529
530        let reports = gen.generate(&employees, period_start, period_end, &config);
531
532        // With 30 employees over 6 months, we should have a decent sample
533        assert!(
534            reports.len() > 10,
535            "Expected multiple reports, got {}",
536            reports.len()
537        );
538
539        let approved = reports
540            .iter()
541            .filter(|r| r.status == ExpenseStatus::Approved)
542            .count();
543        let paid = reports
544            .iter()
545            .filter(|r| r.status == ExpenseStatus::Paid)
546            .count();
547        let submitted = reports
548            .iter()
549            .filter(|r| r.status == ExpenseStatus::Submitted)
550            .count();
551        let rejected = reports
552            .iter()
553            .filter(|r| r.status == ExpenseStatus::Rejected)
554            .count();
555        let draft = reports
556            .iter()
557            .filter(|r| r.status == ExpenseStatus::Draft)
558            .count();
559
560        // Approved should be the majority
561        assert!(approved > 0, "Expected at least some approved reports");
562        // Check that we have a mix of statuses
563        assert!(
564            paid + submitted + rejected + draft > 0,
565            "Expected a mix of statuses beyond approved"
566        );
567
568        // Check policy violations exist somewhere
569        let total_violations: usize = reports.iter().map(|r| r.policy_violations.len()).sum();
570        assert!(
571            total_violations > 0,
572            "Expected at least some policy violations across {} reports",
573            reports.len()
574        );
575    }
576}