Skip to main content

datasynth_generators/hr/
expense_report_generator.rs

1//! Expense report generator for the Hire-to-Retire (H2R) process.
2//!
3//! Generates employee expense reports with realistic line items across categories
4//! (travel, meals, lodging, transportation, etc.), policy violation detection,
5//! and approval workflow statuses.
6
7use chrono::{Datelike, NaiveDate};
8use datasynth_config::schema::ExpenseConfig;
9use datasynth_core::models::{ExpenseCategory, ExpenseLineItem, ExpenseReport, ExpenseStatus};
10use datasynth_core::utils::{sample_decimal_range, seeded_rng};
11use datasynth_core::uuid_factory::{DeterministicUuidFactory, GeneratorType};
12use rand::prelude::*;
13use rand_chacha::ChaCha8Rng;
14use rust_decimal::Decimal;
15use smallvec::SmallVec;
16use std::collections::HashMap;
17use tracing::debug;
18
19/// Generates [`ExpenseReport`] records for employees over a period.
20pub struct ExpenseReportGenerator {
21    rng: ChaCha8Rng,
22    uuid_factory: DeterministicUuidFactory,
23    item_uuid_factory: DeterministicUuidFactory,
24    /// Stored via `with_config()`; will be used when per-generator config
25    /// drives expense category weights and policy thresholds.
26    #[allow(dead_code)]
27    config: ExpenseConfig,
28    /// Pool of real employee IDs for approved_by references.
29    employee_ids_pool: Vec<String>,
30    /// Pool of real cost center IDs.
31    cost_center_ids_pool: Vec<String>,
32    /// Mapping of employee_id → employee_name for denormalization (DS-011).
33    employee_names: HashMap<String, String>,
34    /// Optional country pack for locale-aware generation (set via
35    /// `set_country_pack`); will drive locale-specific currencies and
36    /// business rules in a future release.
37    #[allow(dead_code)]
38    country_pack: Option<datasynth_core::CountryPack>,
39}
40
41impl ExpenseReportGenerator {
42    /// Create a new expense report generator with default configuration.
43    pub fn new(seed: u64) -> Self {
44        Self {
45            rng: seeded_rng(seed, 0),
46            uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::ExpenseReport),
47            item_uuid_factory: DeterministicUuidFactory::with_sub_discriminator(
48                seed,
49                GeneratorType::ExpenseReport,
50                1,
51            ),
52            config: ExpenseConfig::default(),
53            employee_ids_pool: Vec::new(),
54            cost_center_ids_pool: Vec::new(),
55            employee_names: HashMap::new(),
56            country_pack: None,
57        }
58    }
59
60    /// Create an expense report generator with custom configuration.
61    pub fn with_config(seed: u64, config: ExpenseConfig) -> Self {
62        Self {
63            rng: seeded_rng(seed, 0),
64            uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::ExpenseReport),
65            item_uuid_factory: DeterministicUuidFactory::with_sub_discriminator(
66                seed,
67                GeneratorType::ExpenseReport,
68                1,
69            ),
70            config,
71            employee_ids_pool: Vec::new(),
72            cost_center_ids_pool: Vec::new(),
73            employee_names: HashMap::new(),
74            country_pack: None,
75        }
76    }
77
78    /// Set the country pack for locale-aware generation.
79    ///
80    /// When set, the generator can use locale-specific currencies and
81    /// business rules from the country pack.  Currently the pack is
82    /// stored for future expansion; existing behaviour is unchanged
83    /// when no pack is provided.
84    pub fn set_country_pack(&mut self, pack: datasynth_core::CountryPack) {
85        self.country_pack = Some(pack);
86    }
87
88    /// Set ID pools for cross-reference coherence.
89    ///
90    /// When pools are non-empty, the generator selects `approved_by` from
91    /// `employee_ids` and `cost_center` from `cost_center_ids` instead of
92    /// fabricating placeholder IDs.
93    pub fn with_pools(mut self, employee_ids: Vec<String>, cost_center_ids: Vec<String>) -> Self {
94        self.employee_ids_pool = employee_ids;
95        self.cost_center_ids_pool = cost_center_ids;
96        self
97    }
98
99    /// Set the employee name mapping for denormalization (DS-011).
100    ///
101    /// Maps employee IDs to their display names so that generated expense
102    /// reports include the employee name for graph export convenience.
103    pub fn with_employee_names(mut self, names: HashMap<String, String>) -> Self {
104        self.employee_names = names;
105        self
106    }
107
108    /// Generate expense reports for employees over the given period.
109    ///
110    /// Only `config.submission_rate` fraction of employees submit reports each
111    /// month within the period.
112    ///
113    /// # Arguments
114    ///
115    /// * `employee_ids` - Slice of employee identifiers
116    /// * `period_start` - Start of the period (inclusive)
117    /// * `period_end` - End of the period (inclusive)
118    /// * `config` - Expense management configuration
119    pub fn generate(
120        &mut self,
121        employee_ids: &[String],
122        period_start: NaiveDate,
123        period_end: NaiveDate,
124        config: &ExpenseConfig,
125    ) -> Vec<ExpenseReport> {
126        self.generate_with_currency(employee_ids, period_start, period_end, config, "USD")
127    }
128
129    /// Generate expense reports with a specific company currency.
130    pub fn generate_with_currency(
131        &mut self,
132        employee_ids: &[String],
133        period_start: NaiveDate,
134        period_end: NaiveDate,
135        config: &ExpenseConfig,
136        currency: &str,
137    ) -> Vec<ExpenseReport> {
138        debug!(employee_count = employee_ids.len(), %period_start, %period_end, currency, "Generating expense reports");
139        let mut reports = Vec::new();
140
141        // Iterate over each month in the period
142        let mut current_month_start = period_start;
143        while current_month_start <= period_end {
144            let month_end = self.month_end(current_month_start).min(period_end);
145
146            for employee_id in employee_ids {
147                // Only submission_rate fraction of employees submit per month
148                if self.rng.random_bool(config.submission_rate.min(1.0)) {
149                    let report = self.generate_report(
150                        employee_id,
151                        current_month_start,
152                        month_end,
153                        config,
154                        currency,
155                    );
156                    reports.push(report);
157                }
158            }
159
160            // Advance to next month
161            current_month_start = self.next_month_start(current_month_start);
162        }
163
164        reports
165    }
166
167    /// Generate a single expense report for an employee within a date range.
168    fn generate_report(
169        &mut self,
170        employee_id: &str,
171        period_start: NaiveDate,
172        period_end: NaiveDate,
173        config: &ExpenseConfig,
174        currency: &str,
175    ) -> ExpenseReport {
176        let report_id = self.uuid_factory.next().to_string();
177
178        // 1-5 line items per report
179        let item_count = self.rng.random_range(1..=5);
180        let mut line_items = SmallVec::with_capacity(item_count);
181        let mut total_amount = Decimal::ZERO;
182
183        for _ in 0..item_count {
184            let item = self.generate_line_item(period_start, period_end, currency);
185            total_amount += item.amount;
186            line_items.push(item);
187        }
188
189        // Submission date: usually within a few days after the last expense
190        let max_expense_date = line_items
191            .iter()
192            .map(|li: &ExpenseLineItem| li.date)
193            .max()
194            .unwrap_or(period_end);
195        let submission_lag = self.rng.random_range(0..=5);
196        let submission_date = max_expense_date + chrono::Duration::days(submission_lag);
197
198        // Trip/purpose descriptions
199        let descriptions = [
200            "Client site visit",
201            "Conference attendance",
202            "Team offsite meeting",
203            "Customer presentation",
204            "Training workshop",
205            "Quarterly review travel",
206            "Sales meeting",
207            "Project kickoff",
208        ];
209        let description = descriptions[self.rng.random_range(0..descriptions.len())].to_string();
210
211        // Status distribution: 70% Approved, 10% Paid, 10% Submitted, 5% Rejected, 5% Draft
212        let status_roll: f64 = self.rng.random();
213        let status = if status_roll < 0.70 {
214            ExpenseStatus::Approved
215        } else if status_roll < 0.80 {
216            ExpenseStatus::Paid
217        } else if status_roll < 0.90 {
218            ExpenseStatus::Submitted
219        } else if status_roll < 0.95 {
220            ExpenseStatus::Rejected
221        } else {
222            ExpenseStatus::Draft
223        };
224
225        let approved_by = if matches!(status, ExpenseStatus::Approved | ExpenseStatus::Paid) {
226            if !self.employee_ids_pool.is_empty() {
227                let idx = self.rng.random_range(0..self.employee_ids_pool.len());
228                Some(self.employee_ids_pool[idx].clone())
229            } else {
230                Some(format!("MGR-{:04}", self.rng.random_range(1..=100)))
231            }
232        } else {
233            None
234        };
235
236        let approved_date = if matches!(status, ExpenseStatus::Approved | ExpenseStatus::Paid) {
237            let approval_lag = self.rng.random_range(1..=7);
238            Some(submission_date + chrono::Duration::days(approval_lag))
239        } else {
240            None
241        };
242
243        let paid_date = if status == ExpenseStatus::Paid {
244            approved_date.map(|ad| ad + chrono::Duration::days(self.rng.random_range(3..=14)))
245        } else {
246            None
247        };
248
249        // Cost center and department
250        let cost_center = if self.rng.random_bool(0.70) {
251            if !self.cost_center_ids_pool.is_empty() {
252                let idx = self.rng.random_range(0..self.cost_center_ids_pool.len());
253                Some(self.cost_center_ids_pool[idx].clone())
254            } else {
255                Some(format!("CC-{:03}", self.rng.random_range(100..=500)))
256            }
257        } else {
258            None
259        };
260
261        let department = if self.rng.random_bool(0.80) {
262            let departments = [
263                "Engineering",
264                "Sales",
265                "Marketing",
266                "Finance",
267                "HR",
268                "Operations",
269                "Legal",
270                "IT",
271                "Executive",
272            ];
273            Some(departments[self.rng.random_range(0..departments.len())].to_string())
274        } else {
275            None
276        };
277
278        // Policy violations: based on config.policy_violation_rate per line item
279        let policy_violation_rate = config.policy_violation_rate;
280        let mut policy_violations = Vec::new();
281        for item in &line_items {
282            if self.rng.random_bool(policy_violation_rate.min(1.0)) {
283                let violation = self.pick_violation(item);
284                policy_violations.push(violation);
285            }
286        }
287
288        ExpenseReport {
289            report_id,
290            employee_id: employee_id.to_string(),
291            submission_date,
292            description,
293            status,
294            total_amount,
295            currency: currency.to_string(),
296            line_items,
297            approved_by,
298            approved_date,
299            paid_date,
300            cost_center,
301            department,
302            policy_violations,
303            employee_name: self.employee_names.get(employee_id).cloned(),
304        }
305    }
306
307    /// Generate a single expense line item with a random category and amount.
308    fn generate_line_item(
309        &mut self,
310        period_start: NaiveDate,
311        period_end: NaiveDate,
312        currency: &str,
313    ) -> ExpenseLineItem {
314        let item_id = self.item_uuid_factory.next().to_string();
315
316        // Pick a category and generate an appropriate amount range
317        let (category, amount_min, amount_max, desc, merchant) = self.pick_category();
318
319        let amount = sample_decimal_range(
320            &mut self.rng,
321            Decimal::from_f64_retain(amount_min).unwrap_or(Decimal::ONE),
322            Decimal::from_f64_retain(amount_max).unwrap_or(Decimal::ONE),
323        )
324        .round_dp(2);
325
326        // Date within the period
327        let days_in_period = (period_end - period_start).num_days().max(1);
328        let offset = self.rng.random_range(0..=days_in_period);
329        let date = period_start + chrono::Duration::days(offset);
330
331        // Receipt attached: 85% of the time
332        let receipt_attached = self.rng.random_bool(0.85);
333
334        ExpenseLineItem {
335            item_id,
336            category,
337            date,
338            amount,
339            currency: currency.to_string(),
340            description: desc,
341            receipt_attached,
342            merchant,
343        }
344    }
345
346    /// Pick an expense category with corresponding amount range, description, and merchant.
347    fn pick_category(&mut self) -> (ExpenseCategory, f64, f64, String, Option<String>) {
348        let roll: f64 = self.rng.random();
349
350        if roll < 0.20 {
351            let merchants = [
352                "Delta Airlines",
353                "United Airlines",
354                "American Airlines",
355                "Southwest",
356            ];
357            let merchant = merchants[self.rng.random_range(0..merchants.len())].to_string();
358            (
359                ExpenseCategory::Travel,
360                200.0,
361                2000.0,
362                "Airfare - business travel".to_string(),
363                Some(merchant),
364            )
365        } else if roll < 0.40 {
366            let merchants = [
367                "Restaurant ABC",
368                "Cafe Express",
369                "Business Lunch Co",
370                "Steakhouse Prime",
371                "Sushi Palace",
372            ];
373            let merchant = merchants[self.rng.random_range(0..merchants.len())].to_string();
374            (
375                ExpenseCategory::Meals,
376                20.0,
377                100.0,
378                "Business meal".to_string(),
379                Some(merchant),
380            )
381        } else if roll < 0.55 {
382            let merchants = ["Marriott", "Hilton", "Hyatt", "Holiday Inn", "Best Western"];
383            let merchant = merchants[self.rng.random_range(0..merchants.len())].to_string();
384            (
385                ExpenseCategory::Lodging,
386                100.0,
387                500.0,
388                "Hotel accommodation".to_string(),
389                Some(merchant),
390            )
391        } else if roll < 0.70 {
392            let merchants = ["Uber", "Lyft", "Hertz", "Enterprise", "Airport Parking"];
393            let merchant = merchants[self.rng.random_range(0..merchants.len())].to_string();
394            (
395                ExpenseCategory::Transportation,
396                10.0,
397                200.0,
398                "Ground transportation".to_string(),
399                Some(merchant),
400            )
401        } else if roll < 0.80 {
402            (
403                ExpenseCategory::Office,
404                15.0,
405                300.0,
406                "Office supplies".to_string(),
407                Some("Office Depot".to_string()),
408            )
409        } else if roll < 0.88 {
410            (
411                ExpenseCategory::Entertainment,
412                50.0,
413                500.0,
414                "Client entertainment".to_string(),
415                None,
416            )
417        } else if roll < 0.95 {
418            (
419                ExpenseCategory::Training,
420                100.0,
421                1500.0,
422                "Professional development".to_string(),
423                None,
424            )
425        } else {
426            (
427                ExpenseCategory::Other,
428                10.0,
429                200.0,
430                "Miscellaneous expense".to_string(),
431                None,
432            )
433        }
434    }
435
436    /// Generate a policy violation description for a given line item.
437    fn pick_violation(&mut self, item: &ExpenseLineItem) -> String {
438        let violations = match item.category {
439            ExpenseCategory::Meals => vec![
440                "Exceeds daily meal limit",
441                "Alcohol included without approval",
442                "Missing itemized receipt",
443            ],
444            ExpenseCategory::Travel => vec![
445                "Booked outside preferred vendor",
446                "Class upgrade not pre-approved",
447                "Booking made less than 7 days in advance",
448            ],
449            ExpenseCategory::Lodging => vec![
450                "Exceeds nightly rate limit",
451                "Extended stay without approval",
452                "Non-preferred hotel chain",
453            ],
454            _ => vec![
455                "Missing receipt",
456                "Insufficient business justification",
457                "Exceeds category spending limit",
458            ],
459        };
460
461        violations[self.rng.random_range(0..violations.len())].to_string()
462    }
463
464    /// Get the last day of the month for a given date.
465    fn month_end(&self, date: NaiveDate) -> NaiveDate {
466        let (year, month) = if date.month() == 12 {
467            (date.year() + 1, 1)
468        } else {
469            (date.year(), date.month() + 1)
470        };
471        NaiveDate::from_ymd_opt(year, month, 1)
472            .unwrap_or(date)
473            .pred_opt()
474            .unwrap_or(date)
475    }
476
477    /// Get the first day of the next month.
478    fn next_month_start(&self, date: NaiveDate) -> NaiveDate {
479        let (year, month) = if date.month() == 12 {
480            (date.year() + 1, 1)
481        } else {
482            (date.year(), date.month() + 1)
483        };
484        NaiveDate::from_ymd_opt(year, month, 1).unwrap_or(date)
485    }
486}
487
488#[cfg(test)]
489#[allow(clippy::unwrap_used)]
490mod tests {
491    use super::*;
492
493    fn test_employee_ids() -> Vec<String> {
494        (1..=10).map(|i| format!("EMP-{:04}", i)).collect()
495    }
496
497    #[test]
498    fn test_basic_expense_generation() {
499        let mut gen = ExpenseReportGenerator::new(42);
500        let employees = test_employee_ids();
501        let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
502        let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
503        let config = ExpenseConfig::default();
504
505        let reports = gen.generate(&employees, period_start, period_end, &config);
506
507        // With 30% submission rate and 10 employees, expect ~3 reports per month
508        assert!(!reports.is_empty());
509        assert!(
510            reports.len() <= employees.len(),
511            "Should not exceed employee count for a single month"
512        );
513
514        for report in &reports {
515            assert!(!report.report_id.is_empty());
516            assert!(!report.employee_id.is_empty());
517            assert!(report.total_amount > Decimal::ZERO);
518            assert!(!report.line_items.is_empty());
519            assert!(report.line_items.len() <= 5);
520
521            // Total should equal sum of line items
522            let line_sum: Decimal = report.line_items.iter().map(|li| li.amount).sum();
523            assert_eq!(report.total_amount, line_sum);
524
525            for item in &report.line_items {
526                assert!(!item.item_id.is_empty());
527                assert!(item.amount > Decimal::ZERO);
528            }
529        }
530    }
531
532    #[test]
533    fn test_deterministic_expenses() {
534        let employees = test_employee_ids();
535        let period_start = NaiveDate::from_ymd_opt(2024, 3, 1).unwrap();
536        let period_end = NaiveDate::from_ymd_opt(2024, 3, 31).unwrap();
537        let config = ExpenseConfig::default();
538
539        let mut gen1 = ExpenseReportGenerator::new(42);
540        let reports1 = gen1.generate(&employees, period_start, period_end, &config);
541
542        let mut gen2 = ExpenseReportGenerator::new(42);
543        let reports2 = gen2.generate(&employees, period_start, period_end, &config);
544
545        assert_eq!(reports1.len(), reports2.len());
546        for (a, b) in reports1.iter().zip(reports2.iter()) {
547            assert_eq!(a.report_id, b.report_id);
548            assert_eq!(a.employee_id, b.employee_id);
549            assert_eq!(a.total_amount, b.total_amount);
550            assert_eq!(a.status, b.status);
551            assert_eq!(a.line_items.len(), b.line_items.len());
552        }
553    }
554
555    #[test]
556    fn test_expense_status_and_violations() {
557        let mut gen = ExpenseReportGenerator::new(99);
558        // Use more employees for a broader sample
559        let employees: Vec<String> = (1..=30).map(|i| format!("EMP-{:04}", i)).collect();
560        let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
561        let period_end = NaiveDate::from_ymd_opt(2024, 6, 30).unwrap();
562        let config = ExpenseConfig::default();
563
564        let reports = gen.generate(&employees, period_start, period_end, &config);
565
566        // With 30 employees over 6 months, we should have a decent sample
567        assert!(
568            reports.len() > 10,
569            "Expected multiple reports, got {}",
570            reports.len()
571        );
572
573        let approved = reports
574            .iter()
575            .filter(|r| r.status == ExpenseStatus::Approved)
576            .count();
577        let paid = reports
578            .iter()
579            .filter(|r| r.status == ExpenseStatus::Paid)
580            .count();
581        let submitted = reports
582            .iter()
583            .filter(|r| r.status == ExpenseStatus::Submitted)
584            .count();
585        let rejected = reports
586            .iter()
587            .filter(|r| r.status == ExpenseStatus::Rejected)
588            .count();
589        let draft = reports
590            .iter()
591            .filter(|r| r.status == ExpenseStatus::Draft)
592            .count();
593
594        // Approved should be the majority
595        assert!(approved > 0, "Expected at least some approved reports");
596        // Check that we have a mix of statuses
597        assert!(
598            paid + submitted + rejected + draft > 0,
599            "Expected a mix of statuses beyond approved"
600        );
601
602        // Check policy violations exist somewhere
603        let total_violations: usize = reports.iter().map(|r| r.policy_violations.len()).sum();
604        assert!(
605            total_violations > 0,
606            "Expected at least some policy violations across {} reports",
607            reports.len()
608        );
609    }
610
611    #[test]
612    fn test_country_pack_does_not_break_generation() {
613        let mut gen = ExpenseReportGenerator::new(42);
614        // Setting a default country pack should not alter basic generation behaviour.
615        gen.set_country_pack(datasynth_core::CountryPack::default());
616
617        let employees = test_employee_ids();
618        let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
619        let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
620        let config = ExpenseConfig::default();
621
622        let reports = gen.generate(&employees, period_start, period_end, &config);
623
624        assert!(!reports.is_empty());
625        for report in &reports {
626            assert!(!report.report_id.is_empty());
627            assert!(report.total_amount > Decimal::ZERO);
628        }
629    }
630}