Skip to main content

datasynth_generators/hr/
expense_report_generator.rs

1//! Expense report generator for the Hire-to-Retire (H2R) process.
2//!
3//! Generates employee expense reports with realistic line items across categories
4//! (travel, meals, lodging, transportation, etc.), policy violation detection,
5//! and approval workflow statuses.
6
7use chrono::{Datelike, NaiveDate};
8use datasynth_config::schema::ExpenseConfig;
9use datasynth_core::models::{ExpenseCategory, ExpenseLineItem, ExpenseReport, ExpenseStatus};
10use datasynth_core::utils::{sample_decimal_range, seeded_rng};
11use datasynth_core::uuid_factory::{DeterministicUuidFactory, GeneratorType};
12use rand::prelude::*;
13use rand_chacha::ChaCha8Rng;
14use rust_decimal::Decimal;
15use smallvec::SmallVec;
16use std::collections::HashMap;
17use tracing::debug;
18
19/// Generates [`ExpenseReport`] records for employees over a period.
20pub struct ExpenseReportGenerator {
21    rng: ChaCha8Rng,
22    uuid_factory: DeterministicUuidFactory,
23    item_uuid_factory: DeterministicUuidFactory,
24    /// Per-generator config drives expense submission rate and policy thresholds.
25    config: ExpenseConfig,
26    /// Pool of real employee IDs for approved_by references.
27    employee_ids_pool: Vec<String>,
28    /// Pool of real cost center IDs.
29    cost_center_ids_pool: Vec<String>,
30    /// Mapping of employee_id → employee_name for denormalization (DS-011).
31    employee_names: HashMap<String, String>,
32    /// Optional country pack for locale-aware generation (set via
33    /// `set_country_pack`); drives locale-specific currencies.
34    country_pack: Option<datasynth_core::CountryPack>,
35}
36
37impl ExpenseReportGenerator {
38    /// Create a new expense report generator with default configuration.
39    pub fn new(seed: u64) -> Self {
40        Self {
41            rng: seeded_rng(seed, 0),
42            uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::ExpenseReport),
43            item_uuid_factory: DeterministicUuidFactory::with_sub_discriminator(
44                seed,
45                GeneratorType::ExpenseReport,
46                1,
47            ),
48            config: ExpenseConfig::default(),
49            employee_ids_pool: Vec::new(),
50            cost_center_ids_pool: Vec::new(),
51            employee_names: HashMap::new(),
52            country_pack: None,
53        }
54    }
55
56    /// Create an expense report generator with custom configuration.
57    pub fn with_config(seed: u64, config: ExpenseConfig) -> Self {
58        Self {
59            rng: seeded_rng(seed, 0),
60            uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::ExpenseReport),
61            item_uuid_factory: DeterministicUuidFactory::with_sub_discriminator(
62                seed,
63                GeneratorType::ExpenseReport,
64                1,
65            ),
66            config,
67            employee_ids_pool: Vec::new(),
68            cost_center_ids_pool: Vec::new(),
69            employee_names: HashMap::new(),
70            country_pack: None,
71        }
72    }
73
74    /// Set the country pack for locale-aware generation.
75    ///
76    /// When set, the generator can use locale-specific currencies and
77    /// business rules from the country pack.  Currently the pack is
78    /// stored for future expansion; existing behaviour is unchanged
79    /// when no pack is provided.
80    pub fn set_country_pack(&mut self, pack: datasynth_core::CountryPack) {
81        self.country_pack = Some(pack);
82    }
83
84    /// Set ID pools for cross-reference coherence.
85    ///
86    /// When pools are non-empty, the generator selects `approved_by` from
87    /// `employee_ids` and `cost_center` from `cost_center_ids` instead of
88    /// fabricating placeholder IDs.
89    pub fn with_pools(mut self, employee_ids: Vec<String>, cost_center_ids: Vec<String>) -> Self {
90        self.employee_ids_pool = employee_ids;
91        self.cost_center_ids_pool = cost_center_ids;
92        self
93    }
94
95    /// Set the employee name mapping for denormalization (DS-011).
96    ///
97    /// Maps employee IDs to their display names so that generated expense
98    /// reports include the employee name for graph export convenience.
99    pub fn with_employee_names(mut self, names: HashMap<String, String>) -> Self {
100        self.employee_names = names;
101        self
102    }
103
104    /// Generate expense reports using the stored config and country pack.
105    ///
106    /// Uses `self.config` for submission rate and policy thresholds, and
107    /// `self.country_pack` for locale-specific currency.
108    pub fn generate_from_config(
109        &mut self,
110        employee_ids: &[String],
111        period_start: NaiveDate,
112        period_end: NaiveDate,
113    ) -> Vec<ExpenseReport> {
114        let config = self.config.clone();
115        self.generate(employee_ids, period_start, period_end, &config)
116    }
117
118    /// Generate expense reports for employees over the given period.
119    ///
120    /// Only `config.submission_rate` fraction of employees submit reports each
121    /// month within the period.
122    ///
123    /// # Arguments
124    ///
125    /// * `employee_ids` - Slice of employee identifiers
126    /// * `period_start` - Start of the period (inclusive)
127    /// * `period_end` - End of the period (inclusive)
128    /// * `config` - Expense management configuration (overrides stored config)
129    pub fn generate(
130        &mut self,
131        employee_ids: &[String],
132        period_start: NaiveDate,
133        period_end: NaiveDate,
134        config: &ExpenseConfig,
135    ) -> Vec<ExpenseReport> {
136        let currency = self
137            .country_pack
138            .as_ref()
139            .map(|cp| cp.locale.default_currency.clone())
140            .filter(|c| !c.is_empty())
141            .unwrap_or_else(|| "USD".to_string());
142        self.generate_with_currency(employee_ids, period_start, period_end, config, &currency)
143    }
144
145    /// Generate expense reports with a specific company currency.
146    pub fn generate_with_currency(
147        &mut self,
148        employee_ids: &[String],
149        period_start: NaiveDate,
150        period_end: NaiveDate,
151        config: &ExpenseConfig,
152        currency: &str,
153    ) -> Vec<ExpenseReport> {
154        debug!(employee_count = employee_ids.len(), %period_start, %period_end, currency, "Generating expense reports");
155        let mut reports = Vec::new();
156
157        // Iterate over each month in the period
158        let mut current_month_start = period_start;
159        while current_month_start <= period_end {
160            let month_end = self.month_end(current_month_start).min(period_end);
161
162            for employee_id in employee_ids {
163                // Only submission_rate fraction of employees submit per month
164                if self.rng.random_bool(config.submission_rate.min(1.0)) {
165                    let report = self.generate_report(
166                        employee_id,
167                        current_month_start,
168                        month_end,
169                        config,
170                        currency,
171                    );
172                    reports.push(report);
173                }
174            }
175
176            // Advance to next month
177            current_month_start = self.next_month_start(current_month_start);
178        }
179
180        reports
181    }
182
183    /// Generate a single expense report for an employee within a date range.
184    fn generate_report(
185        &mut self,
186        employee_id: &str,
187        period_start: NaiveDate,
188        period_end: NaiveDate,
189        config: &ExpenseConfig,
190        currency: &str,
191    ) -> ExpenseReport {
192        let report_id = self.uuid_factory.next().to_string();
193
194        // 1-5 line items per report
195        let item_count = self.rng.random_range(1..=5);
196        let mut line_items = SmallVec::with_capacity(item_count);
197        let mut total_amount = Decimal::ZERO;
198
199        for _ in 0..item_count {
200            let item = self.generate_line_item(period_start, period_end, currency);
201            total_amount += item.amount;
202            line_items.push(item);
203        }
204
205        // Submission date: usually within a few days after the last expense
206        let max_expense_date = line_items
207            .iter()
208            .map(|li: &ExpenseLineItem| li.date)
209            .max()
210            .unwrap_or(period_end);
211        let submission_lag = self.rng.random_range(0..=5);
212        let submission_date = max_expense_date + chrono::Duration::days(submission_lag);
213
214        // Trip/purpose descriptions
215        let descriptions = [
216            "Client site visit",
217            "Conference attendance",
218            "Team offsite meeting",
219            "Customer presentation",
220            "Training workshop",
221            "Quarterly review travel",
222            "Sales meeting",
223            "Project kickoff",
224        ];
225        let description = descriptions[self.rng.random_range(0..descriptions.len())].to_string();
226
227        // Status distribution: 70% Approved, 10% Paid, 10% Submitted, 5% Rejected, 5% Draft
228        let status_roll: f64 = self.rng.random();
229        let status = if status_roll < 0.70 {
230            ExpenseStatus::Approved
231        } else if status_roll < 0.80 {
232            ExpenseStatus::Paid
233        } else if status_roll < 0.90 {
234            ExpenseStatus::Submitted
235        } else if status_roll < 0.95 {
236            ExpenseStatus::Rejected
237        } else {
238            ExpenseStatus::Draft
239        };
240
241        let approved_by = if matches!(status, ExpenseStatus::Approved | ExpenseStatus::Paid) {
242            if !self.employee_ids_pool.is_empty() {
243                let idx = self.rng.random_range(0..self.employee_ids_pool.len());
244                Some(self.employee_ids_pool[idx].clone())
245            } else {
246                Some(format!("MGR-{:04}", self.rng.random_range(1..=100)))
247            }
248        } else {
249            None
250        };
251
252        let approved_date = if matches!(status, ExpenseStatus::Approved | ExpenseStatus::Paid) {
253            let approval_lag = self.rng.random_range(1..=7);
254            Some(submission_date + chrono::Duration::days(approval_lag))
255        } else {
256            None
257        };
258
259        let paid_date = if status == ExpenseStatus::Paid {
260            approved_date.map(|ad| ad + chrono::Duration::days(self.rng.random_range(3..=14)))
261        } else {
262            None
263        };
264
265        // Cost center and department
266        let cost_center = if self.rng.random_bool(0.70) {
267            if !self.cost_center_ids_pool.is_empty() {
268                let idx = self.rng.random_range(0..self.cost_center_ids_pool.len());
269                Some(self.cost_center_ids_pool[idx].clone())
270            } else {
271                Some(format!("CC-{:03}", self.rng.random_range(100..=500)))
272            }
273        } else {
274            None
275        };
276
277        let department = if self.rng.random_bool(0.80) {
278            let departments = [
279                "Engineering",
280                "Sales",
281                "Marketing",
282                "Finance",
283                "HR",
284                "Operations",
285                "Legal",
286                "IT",
287                "Executive",
288            ];
289            Some(departments[self.rng.random_range(0..departments.len())].to_string())
290        } else {
291            None
292        };
293
294        // Policy violations: based on config.policy_violation_rate per line item
295        let policy_violation_rate = config.policy_violation_rate;
296        let mut policy_violations = Vec::new();
297        for item in &line_items {
298            if self.rng.random_bool(policy_violation_rate.min(1.0)) {
299                let violation = self.pick_violation(item);
300                policy_violations.push(violation);
301            }
302        }
303
304        ExpenseReport {
305            report_id,
306            employee_id: employee_id.to_string(),
307            submission_date,
308            description,
309            status,
310            total_amount,
311            currency: currency.to_string(),
312            line_items,
313            approved_by,
314            approved_date,
315            paid_date,
316            cost_center,
317            department,
318            policy_violations,
319            employee_name: self.employee_names.get(employee_id).cloned(),
320        }
321    }
322
323    /// Generate a single expense line item with a random category and amount.
324    fn generate_line_item(
325        &mut self,
326        period_start: NaiveDate,
327        period_end: NaiveDate,
328        currency: &str,
329    ) -> ExpenseLineItem {
330        let item_id = self.item_uuid_factory.next().to_string();
331
332        // Pick a category and generate an appropriate amount range
333        let (category, amount_min, amount_max, desc, merchant) = self.pick_category();
334
335        let amount = sample_decimal_range(
336            &mut self.rng,
337            Decimal::from_f64_retain(amount_min).unwrap_or(Decimal::ONE),
338            Decimal::from_f64_retain(amount_max).unwrap_or(Decimal::ONE),
339        )
340        .round_dp(2);
341
342        // Date within the period
343        let days_in_period = (period_end - period_start).num_days().max(1);
344        let offset = self.rng.random_range(0..=days_in_period);
345        let date = period_start + chrono::Duration::days(offset);
346
347        // Receipt attached: 85% of the time
348        let receipt_attached = self.rng.random_bool(0.85);
349
350        ExpenseLineItem {
351            item_id,
352            category,
353            date,
354            amount,
355            currency: currency.to_string(),
356            description: desc,
357            receipt_attached,
358            merchant,
359        }
360    }
361
362    /// Pick an expense category with corresponding amount range, description, and merchant.
363    fn pick_category(&mut self) -> (ExpenseCategory, f64, f64, String, Option<String>) {
364        let roll: f64 = self.rng.random();
365
366        if roll < 0.20 {
367            let merchants = [
368                "Delta Airlines",
369                "United Airlines",
370                "American Airlines",
371                "Southwest",
372            ];
373            let merchant = merchants[self.rng.random_range(0..merchants.len())].to_string();
374            (
375                ExpenseCategory::Travel,
376                200.0,
377                2000.0,
378                "Airfare - business travel".to_string(),
379                Some(merchant),
380            )
381        } else if roll < 0.40 {
382            let merchants = [
383                "Restaurant ABC",
384                "Cafe Express",
385                "Business Lunch Co",
386                "Steakhouse Prime",
387                "Sushi Palace",
388            ];
389            let merchant = merchants[self.rng.random_range(0..merchants.len())].to_string();
390            (
391                ExpenseCategory::Meals,
392                20.0,
393                100.0,
394                "Business meal".to_string(),
395                Some(merchant),
396            )
397        } else if roll < 0.55 {
398            let merchants = ["Marriott", "Hilton", "Hyatt", "Holiday Inn", "Best Western"];
399            let merchant = merchants[self.rng.random_range(0..merchants.len())].to_string();
400            (
401                ExpenseCategory::Lodging,
402                100.0,
403                500.0,
404                "Hotel accommodation".to_string(),
405                Some(merchant),
406            )
407        } else if roll < 0.70 {
408            let merchants = ["Uber", "Lyft", "Hertz", "Enterprise", "Airport Parking"];
409            let merchant = merchants[self.rng.random_range(0..merchants.len())].to_string();
410            (
411                ExpenseCategory::Transportation,
412                10.0,
413                200.0,
414                "Ground transportation".to_string(),
415                Some(merchant),
416            )
417        } else if roll < 0.80 {
418            (
419                ExpenseCategory::Office,
420                15.0,
421                300.0,
422                "Office supplies".to_string(),
423                Some("Office Depot".to_string()),
424            )
425        } else if roll < 0.88 {
426            (
427                ExpenseCategory::Entertainment,
428                50.0,
429                500.0,
430                "Client entertainment".to_string(),
431                None,
432            )
433        } else if roll < 0.95 {
434            (
435                ExpenseCategory::Training,
436                100.0,
437                1500.0,
438                "Professional development".to_string(),
439                None,
440            )
441        } else {
442            (
443                ExpenseCategory::Other,
444                10.0,
445                200.0,
446                "Miscellaneous expense".to_string(),
447                None,
448            )
449        }
450    }
451
452    /// Generate a policy violation description for a given line item.
453    fn pick_violation(&mut self, item: &ExpenseLineItem) -> String {
454        let violations = match item.category {
455            ExpenseCategory::Meals => vec![
456                "Exceeds daily meal limit",
457                "Alcohol included without approval",
458                "Missing itemized receipt",
459            ],
460            ExpenseCategory::Travel => vec![
461                "Booked outside preferred vendor",
462                "Class upgrade not pre-approved",
463                "Booking made less than 7 days in advance",
464            ],
465            ExpenseCategory::Lodging => vec![
466                "Exceeds nightly rate limit",
467                "Extended stay without approval",
468                "Non-preferred hotel chain",
469            ],
470            _ => vec![
471                "Missing receipt",
472                "Insufficient business justification",
473                "Exceeds category spending limit",
474            ],
475        };
476
477        violations[self.rng.random_range(0..violations.len())].to_string()
478    }
479
480    /// Get the last day of the month for a given date.
481    fn month_end(&self, date: NaiveDate) -> NaiveDate {
482        let (year, month) = if date.month() == 12 {
483            (date.year() + 1, 1)
484        } else {
485            (date.year(), date.month() + 1)
486        };
487        NaiveDate::from_ymd_opt(year, month, 1)
488            .unwrap_or(date)
489            .pred_opt()
490            .unwrap_or(date)
491    }
492
493    /// Get the first day of the next month.
494    fn next_month_start(&self, date: NaiveDate) -> NaiveDate {
495        let (year, month) = if date.month() == 12 {
496            (date.year() + 1, 1)
497        } else {
498            (date.year(), date.month() + 1)
499        };
500        NaiveDate::from_ymd_opt(year, month, 1).unwrap_or(date)
501    }
502}
503
504#[cfg(test)]
505#[allow(clippy::unwrap_used)]
506mod tests {
507    use super::*;
508
509    fn test_employee_ids() -> Vec<String> {
510        (1..=10).map(|i| format!("EMP-{:04}", i)).collect()
511    }
512
513    #[test]
514    fn test_basic_expense_generation() {
515        let mut gen = ExpenseReportGenerator::new(42);
516        let employees = test_employee_ids();
517        let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
518        let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
519        let config = ExpenseConfig::default();
520
521        let reports = gen.generate(&employees, period_start, period_end, &config);
522
523        // With 30% submission rate and 10 employees, expect ~3 reports per month
524        assert!(!reports.is_empty());
525        assert!(
526            reports.len() <= employees.len(),
527            "Should not exceed employee count for a single month"
528        );
529
530        for report in &reports {
531            assert!(!report.report_id.is_empty());
532            assert!(!report.employee_id.is_empty());
533            assert!(report.total_amount > Decimal::ZERO);
534            assert!(!report.line_items.is_empty());
535            assert!(report.line_items.len() <= 5);
536
537            // Total should equal sum of line items
538            let line_sum: Decimal = report.line_items.iter().map(|li| li.amount).sum();
539            assert_eq!(report.total_amount, line_sum);
540
541            for item in &report.line_items {
542                assert!(!item.item_id.is_empty());
543                assert!(item.amount > Decimal::ZERO);
544            }
545        }
546    }
547
548    #[test]
549    fn test_deterministic_expenses() {
550        let employees = test_employee_ids();
551        let period_start = NaiveDate::from_ymd_opt(2024, 3, 1).unwrap();
552        let period_end = NaiveDate::from_ymd_opt(2024, 3, 31).unwrap();
553        let config = ExpenseConfig::default();
554
555        let mut gen1 = ExpenseReportGenerator::new(42);
556        let reports1 = gen1.generate(&employees, period_start, period_end, &config);
557
558        let mut gen2 = ExpenseReportGenerator::new(42);
559        let reports2 = gen2.generate(&employees, period_start, period_end, &config);
560
561        assert_eq!(reports1.len(), reports2.len());
562        for (a, b) in reports1.iter().zip(reports2.iter()) {
563            assert_eq!(a.report_id, b.report_id);
564            assert_eq!(a.employee_id, b.employee_id);
565            assert_eq!(a.total_amount, b.total_amount);
566            assert_eq!(a.status, b.status);
567            assert_eq!(a.line_items.len(), b.line_items.len());
568        }
569    }
570
571    #[test]
572    fn test_expense_status_and_violations() {
573        let mut gen = ExpenseReportGenerator::new(99);
574        // Use more employees for a broader sample
575        let employees: Vec<String> = (1..=30).map(|i| format!("EMP-{:04}", i)).collect();
576        let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
577        let period_end = NaiveDate::from_ymd_opt(2024, 6, 30).unwrap();
578        let config = ExpenseConfig::default();
579
580        let reports = gen.generate(&employees, period_start, period_end, &config);
581
582        // With 30 employees over 6 months, we should have a decent sample
583        assert!(
584            reports.len() > 10,
585            "Expected multiple reports, got {}",
586            reports.len()
587        );
588
589        let approved = reports
590            .iter()
591            .filter(|r| r.status == ExpenseStatus::Approved)
592            .count();
593        let paid = reports
594            .iter()
595            .filter(|r| r.status == ExpenseStatus::Paid)
596            .count();
597        let submitted = reports
598            .iter()
599            .filter(|r| r.status == ExpenseStatus::Submitted)
600            .count();
601        let rejected = reports
602            .iter()
603            .filter(|r| r.status == ExpenseStatus::Rejected)
604            .count();
605        let draft = reports
606            .iter()
607            .filter(|r| r.status == ExpenseStatus::Draft)
608            .count();
609
610        // Approved should be the majority
611        assert!(approved > 0, "Expected at least some approved reports");
612        // Check that we have a mix of statuses
613        assert!(
614            paid + submitted + rejected + draft > 0,
615            "Expected a mix of statuses beyond approved"
616        );
617
618        // Check policy violations exist somewhere
619        let total_violations: usize = reports.iter().map(|r| r.policy_violations.len()).sum();
620        assert!(
621            total_violations > 0,
622            "Expected at least some policy violations across {} reports",
623            reports.len()
624        );
625    }
626
627    #[test]
628    fn test_country_pack_does_not_break_generation() {
629        let mut gen = ExpenseReportGenerator::new(42);
630        // Setting a default country pack should not alter basic generation behaviour.
631        gen.set_country_pack(datasynth_core::CountryPack::default());
632
633        let employees = test_employee_ids();
634        let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
635        let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
636        let config = ExpenseConfig::default();
637
638        let reports = gen.generate(&employees, period_start, period_end, &config);
639
640        assert!(!reports.is_empty());
641        for report in &reports {
642            assert!(!report.report_id.is_empty());
643            assert!(report.total_amount > Decimal::ZERO);
644        }
645    }
646}