Skip to main content

datasynth_generators/hr/
expense_report_generator.rs

1//! Expense report generator for the Hire-to-Retire (H2R) process.
2//!
3//! Generates employee expense reports with realistic line items across categories
4//! (travel, meals, lodging, transportation, etc.), policy violation detection,
5//! and approval workflow statuses.
6
7use chrono::{Datelike, NaiveDate};
8use datasynth_config::schema::ExpenseConfig;
9use datasynth_core::distributions::TemporalContext;
10use datasynth_core::models::{ExpenseCategory, ExpenseLineItem, ExpenseReport, ExpenseStatus};
11use datasynth_core::utils::{sample_decimal_range, seeded_rng};
12use datasynth_core::uuid_factory::{DeterministicUuidFactory, GeneratorType};
13use rand::prelude::*;
14use rand_chacha::ChaCha8Rng;
15use rust_decimal::Decimal;
16use smallvec::SmallVec;
17use std::collections::HashMap;
18use std::sync::Arc;
19use tracing::debug;
20
21/// Generates [`ExpenseReport`] records for employees over a period.
22pub struct ExpenseReportGenerator {
23    rng: ChaCha8Rng,
24    uuid_factory: DeterministicUuidFactory,
25    item_uuid_factory: DeterministicUuidFactory,
26    /// Per-generator config drives expense submission rate and policy thresholds.
27    config: ExpenseConfig,
28    /// Pool of real employee IDs for approved_by references.
29    employee_ids_pool: Vec<String>,
30    /// Pool of real cost center IDs.
31    cost_center_ids_pool: Vec<String>,
32    /// Mapping of employee_id → employee_name for denormalization (DS-011).
33    employee_names: HashMap<String, String>,
34    /// Optional country pack for locale-aware generation (set via
35    /// `set_country_pack`); drives locale-specific currencies.
36    country_pack: Option<datasynth_core::CountryPack>,
37    /// v3.4.2+ temporal context — when set, submission / approval / paid /
38    /// line-item dates snap to the next business day. `None` preserves
39    /// legacy raw-rng behavior (byte-identical to v3.4.1).
40    temporal_context: Option<Arc<TemporalContext>>,
41}
42
43impl ExpenseReportGenerator {
44    /// Create a new expense report generator with default configuration.
45    pub fn new(seed: u64) -> Self {
46        Self {
47            rng: seeded_rng(seed, 0),
48            uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::ExpenseReport),
49            item_uuid_factory: DeterministicUuidFactory::with_sub_discriminator(
50                seed,
51                GeneratorType::ExpenseReport,
52                1,
53            ),
54            config: ExpenseConfig::default(),
55            employee_ids_pool: Vec::new(),
56            cost_center_ids_pool: Vec::new(),
57            employee_names: HashMap::new(),
58            country_pack: None,
59            temporal_context: None,
60        }
61    }
62
63    /// Create an expense report generator with custom configuration.
64    pub fn with_config(seed: u64, config: ExpenseConfig) -> Self {
65        Self {
66            rng: seeded_rng(seed, 0),
67            uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::ExpenseReport),
68            item_uuid_factory: DeterministicUuidFactory::with_sub_discriminator(
69                seed,
70                GeneratorType::ExpenseReport,
71                1,
72            ),
73            config,
74            employee_ids_pool: Vec::new(),
75            cost_center_ids_pool: Vec::new(),
76            employee_names: HashMap::new(),
77            country_pack: None,
78            temporal_context: None,
79        }
80    }
81
82    /// Set the shared [`TemporalContext`] so submission / approval / paid /
83    /// line-item dates snap to the next business day.
84    pub fn set_temporal_context(&mut self, ctx: Arc<TemporalContext>) {
85        self.temporal_context = Some(ctx);
86    }
87
88    /// Builder variant of [`Self::set_temporal_context`].
89    pub fn with_temporal_context(mut self, ctx: Arc<TemporalContext>) -> Self {
90        self.temporal_context = Some(ctx);
91        self
92    }
93
94    /// Snap a date to the next business day when a [`TemporalContext`] is
95    /// present; otherwise return it unchanged.
96    fn snap_to_business_day(&self, date: NaiveDate) -> NaiveDate {
97        match &self.temporal_context {
98            Some(ctx) => ctx.adjust_to_business_day(date),
99            None => date,
100        }
101    }
102
103    /// Set the country pack for locale-aware generation.
104    ///
105    /// When set, the generator can use locale-specific currencies and
106    /// business rules from the country pack.  Currently the pack is
107    /// stored for future expansion; existing behaviour is unchanged
108    /// when no pack is provided.
109    pub fn set_country_pack(&mut self, pack: datasynth_core::CountryPack) {
110        self.country_pack = Some(pack);
111    }
112
113    /// Set ID pools for cross-reference coherence.
114    ///
115    /// When pools are non-empty, the generator selects `approved_by` from
116    /// `employee_ids` and `cost_center` from `cost_center_ids` instead of
117    /// fabricating placeholder IDs.
118    pub fn with_pools(mut self, employee_ids: Vec<String>, cost_center_ids: Vec<String>) -> Self {
119        self.employee_ids_pool = employee_ids;
120        self.cost_center_ids_pool = cost_center_ids;
121        self
122    }
123
124    /// Set the employee name mapping for denormalization (DS-011).
125    ///
126    /// Maps employee IDs to their display names so that generated expense
127    /// reports include the employee name for graph export convenience.
128    pub fn with_employee_names(mut self, names: HashMap<String, String>) -> Self {
129        self.employee_names = names;
130        self
131    }
132
133    /// Generate expense reports using the stored config and country pack.
134    ///
135    /// Uses `self.config` for submission rate and policy thresholds, and
136    /// `self.country_pack` for locale-specific currency.
137    pub fn generate_from_config(
138        &mut self,
139        employee_ids: &[String],
140        period_start: NaiveDate,
141        period_end: NaiveDate,
142    ) -> Vec<ExpenseReport> {
143        let config = self.config.clone();
144        self.generate(employee_ids, period_start, period_end, &config)
145    }
146
147    /// Generate expense reports for employees over the given period.
148    ///
149    /// Only `config.submission_rate` fraction of employees submit reports each
150    /// month within the period.
151    ///
152    /// # Arguments
153    ///
154    /// * `employee_ids` - Slice of employee identifiers
155    /// * `period_start` - Start of the period (inclusive)
156    /// * `period_end` - End of the period (inclusive)
157    /// * `config` - Expense management configuration (overrides stored config)
158    pub fn generate(
159        &mut self,
160        employee_ids: &[String],
161        period_start: NaiveDate,
162        period_end: NaiveDate,
163        config: &ExpenseConfig,
164    ) -> Vec<ExpenseReport> {
165        let currency = self
166            .country_pack
167            .as_ref()
168            .map(|cp| cp.locale.default_currency.clone())
169            .filter(|c| !c.is_empty())
170            .unwrap_or_else(|| "USD".to_string());
171        self.generate_with_currency(employee_ids, period_start, period_end, config, &currency)
172    }
173
174    /// Generate expense reports with a specific company currency.
175    pub fn generate_with_currency(
176        &mut self,
177        employee_ids: &[String],
178        period_start: NaiveDate,
179        period_end: NaiveDate,
180        config: &ExpenseConfig,
181        currency: &str,
182    ) -> Vec<ExpenseReport> {
183        debug!(employee_count = employee_ids.len(), %period_start, %period_end, currency, "Generating expense reports");
184        let mut reports = Vec::new();
185
186        // Iterate over each month in the period
187        let mut current_month_start = period_start;
188        while current_month_start <= period_end {
189            let month_end = self.month_end(current_month_start).min(period_end);
190
191            for employee_id in employee_ids {
192                // Only submission_rate fraction of employees submit per month
193                if self.rng.random_bool(config.submission_rate.min(1.0)) {
194                    let report = self.generate_report(
195                        employee_id,
196                        current_month_start,
197                        month_end,
198                        config,
199                        currency,
200                    );
201                    reports.push(report);
202                }
203            }
204
205            // Advance to next month
206            current_month_start = self.next_month_start(current_month_start);
207        }
208
209        reports
210    }
211
212    /// Generate a single expense report for an employee within a date range.
213    fn generate_report(
214        &mut self,
215        employee_id: &str,
216        period_start: NaiveDate,
217        period_end: NaiveDate,
218        config: &ExpenseConfig,
219        currency: &str,
220    ) -> ExpenseReport {
221        let report_id = self.uuid_factory.next().to_string();
222
223        // 1-5 line items per report
224        let item_count = self.rng.random_range(1..=5);
225        let mut line_items = SmallVec::with_capacity(item_count);
226        let mut total_amount = Decimal::ZERO;
227
228        for _ in 0..item_count {
229            let item = self.generate_line_item(period_start, period_end, currency);
230            total_amount += item.amount;
231            line_items.push(item);
232        }
233
234        // Submission date: usually within a few days after the last expense
235        let max_expense_date = line_items
236            .iter()
237            .map(|li: &ExpenseLineItem| li.date)
238            .max()
239            .unwrap_or(period_end);
240        let submission_lag = self.rng.random_range(0..=5);
241        let raw_submission = max_expense_date + chrono::Duration::days(submission_lag);
242        let submission_date = self.snap_to_business_day(raw_submission);
243
244        // Trip/purpose descriptions
245        let descriptions = [
246            "Client site visit",
247            "Conference attendance",
248            "Team offsite meeting",
249            "Customer presentation",
250            "Training workshop",
251            "Quarterly review travel",
252            "Sales meeting",
253            "Project kickoff",
254        ];
255        let description = descriptions[self.rng.random_range(0..descriptions.len())].to_string();
256
257        // Status distribution: 70% Approved, 10% Paid, 10% Submitted, 5% Rejected, 5% Draft
258        let status_roll: f64 = self.rng.random();
259        let status = if status_roll < 0.70 {
260            ExpenseStatus::Approved
261        } else if status_roll < 0.80 {
262            ExpenseStatus::Paid
263        } else if status_roll < 0.90 {
264            ExpenseStatus::Submitted
265        } else if status_roll < 0.95 {
266            ExpenseStatus::Rejected
267        } else {
268            ExpenseStatus::Draft
269        };
270
271        let approved_by = if matches!(status, ExpenseStatus::Approved | ExpenseStatus::Paid) {
272            if !self.employee_ids_pool.is_empty() {
273                let idx = self.rng.random_range(0..self.employee_ids_pool.len());
274                Some(self.employee_ids_pool[idx].clone())
275            } else {
276                Some(format!("MGR-{:04}", self.rng.random_range(1..=100)))
277            }
278        } else {
279            None
280        };
281
282        let approved_date = if matches!(status, ExpenseStatus::Approved | ExpenseStatus::Paid) {
283            let approval_lag = self.rng.random_range(1..=7);
284            let raw = submission_date + chrono::Duration::days(approval_lag);
285            Some(self.snap_to_business_day(raw))
286        } else {
287            None
288        };
289
290        let paid_date = if status == ExpenseStatus::Paid {
291            approved_date.map(|ad| {
292                let raw = ad + chrono::Duration::days(self.rng.random_range(3..=14));
293                self.snap_to_business_day(raw)
294            })
295        } else {
296            None
297        };
298
299        // Cost center and department
300        let cost_center = if self.rng.random_bool(0.70) {
301            if !self.cost_center_ids_pool.is_empty() {
302                let idx = self.rng.random_range(0..self.cost_center_ids_pool.len());
303                Some(self.cost_center_ids_pool[idx].clone())
304            } else {
305                Some(format!("CC-{:03}", self.rng.random_range(100..=500)))
306            }
307        } else {
308            None
309        };
310
311        let department = if self.rng.random_bool(0.80) {
312            let departments = [
313                "Engineering",
314                "Sales",
315                "Marketing",
316                "Finance",
317                "HR",
318                "Operations",
319                "Legal",
320                "IT",
321                "Executive",
322            ];
323            Some(departments[self.rng.random_range(0..departments.len())].to_string())
324        } else {
325            None
326        };
327
328        // Policy violations: based on config.policy_violation_rate per line item
329        let policy_violation_rate = config.policy_violation_rate;
330        let mut policy_violations = Vec::new();
331        for item in &line_items {
332            if self.rng.random_bool(policy_violation_rate.min(1.0)) {
333                let violation = self.pick_violation(item);
334                policy_violations.push(violation);
335            }
336        }
337
338        ExpenseReport {
339            report_id,
340            employee_id: employee_id.to_string(),
341            submission_date,
342            description,
343            status,
344            total_amount,
345            currency: currency.to_string(),
346            line_items,
347            approved_by,
348            approved_date,
349            paid_date,
350            cost_center,
351            department,
352            policy_violations,
353            employee_name: self.employee_names.get(employee_id).cloned(),
354        }
355    }
356
357    /// Generate a single expense line item with a random category and amount.
358    fn generate_line_item(
359        &mut self,
360        period_start: NaiveDate,
361        period_end: NaiveDate,
362        currency: &str,
363    ) -> ExpenseLineItem {
364        let item_id = self.item_uuid_factory.next().to_string();
365
366        // Pick a category and generate an appropriate amount range
367        let (category, amount_min, amount_max, desc, merchant) = self.pick_category();
368
369        let amount = sample_decimal_range(
370            &mut self.rng,
371            Decimal::from_f64_retain(amount_min).unwrap_or(Decimal::ONE),
372            Decimal::from_f64_retain(amount_max).unwrap_or(Decimal::ONE),
373        )
374        .round_dp(2);
375
376        // Date within the period. With a TemporalContext, snap to the next
377        // business day; otherwise keep the raw offset (legacy behavior).
378        let days_in_period = (period_end - period_start).num_days().max(1);
379        let offset = self.rng.random_range(0..=days_in_period);
380        let raw_date = period_start + chrono::Duration::days(offset);
381        let date = match &self.temporal_context {
382            Some(ctx) => {
383                let snapped = ctx.adjust_to_business_day(raw_date);
384                if snapped > period_end {
385                    ctx.adjust_to_previous_business_day(period_end)
386                } else {
387                    snapped
388                }
389            }
390            None => raw_date,
391        };
392
393        // Receipt attached: 85% of the time
394        let receipt_attached = self.rng.random_bool(0.85);
395
396        ExpenseLineItem {
397            item_id,
398            category,
399            date,
400            amount,
401            currency: currency.to_string(),
402            description: desc,
403            receipt_attached,
404            merchant,
405        }
406    }
407
408    /// Pick an expense category with corresponding amount range, description, and merchant.
409    fn pick_category(&mut self) -> (ExpenseCategory, f64, f64, String, Option<String>) {
410        let roll: f64 = self.rng.random();
411
412        if roll < 0.20 {
413            let merchants = [
414                "Delta Airlines",
415                "United Airlines",
416                "American Airlines",
417                "Southwest",
418            ];
419            let merchant = merchants[self.rng.random_range(0..merchants.len())].to_string();
420            (
421                ExpenseCategory::Travel,
422                200.0,
423                2000.0,
424                "Airfare - business travel".to_string(),
425                Some(merchant),
426            )
427        } else if roll < 0.40 {
428            let merchants = [
429                "Restaurant ABC",
430                "Cafe Express",
431                "Business Lunch Co",
432                "Steakhouse Prime",
433                "Sushi Palace",
434            ];
435            let merchant = merchants[self.rng.random_range(0..merchants.len())].to_string();
436            (
437                ExpenseCategory::Meals,
438                20.0,
439                100.0,
440                "Business meal".to_string(),
441                Some(merchant),
442            )
443        } else if roll < 0.55 {
444            let merchants = ["Marriott", "Hilton", "Hyatt", "Holiday Inn", "Best Western"];
445            let merchant = merchants[self.rng.random_range(0..merchants.len())].to_string();
446            (
447                ExpenseCategory::Lodging,
448                100.0,
449                500.0,
450                "Hotel accommodation".to_string(),
451                Some(merchant),
452            )
453        } else if roll < 0.70 {
454            let merchants = ["Uber", "Lyft", "Hertz", "Enterprise", "Airport Parking"];
455            let merchant = merchants[self.rng.random_range(0..merchants.len())].to_string();
456            (
457                ExpenseCategory::Transportation,
458                10.0,
459                200.0,
460                "Ground transportation".to_string(),
461                Some(merchant),
462            )
463        } else if roll < 0.80 {
464            (
465                ExpenseCategory::Office,
466                15.0,
467                300.0,
468                "Office supplies".to_string(),
469                Some("Office Depot".to_string()),
470            )
471        } else if roll < 0.88 {
472            (
473                ExpenseCategory::Entertainment,
474                50.0,
475                500.0,
476                "Client entertainment".to_string(),
477                None,
478            )
479        } else if roll < 0.95 {
480            (
481                ExpenseCategory::Training,
482                100.0,
483                1500.0,
484                "Professional development".to_string(),
485                None,
486            )
487        } else {
488            (
489                ExpenseCategory::Other,
490                10.0,
491                200.0,
492                "Miscellaneous expense".to_string(),
493                None,
494            )
495        }
496    }
497
498    /// Generate a policy violation description for a given line item.
499    fn pick_violation(&mut self, item: &ExpenseLineItem) -> String {
500        let violations = match item.category {
501            ExpenseCategory::Meals => vec![
502                "Exceeds daily meal limit",
503                "Alcohol included without approval",
504                "Missing itemized receipt",
505            ],
506            ExpenseCategory::Travel => vec![
507                "Booked outside preferred vendor",
508                "Class upgrade not pre-approved",
509                "Booking made less than 7 days in advance",
510            ],
511            ExpenseCategory::Lodging => vec![
512                "Exceeds nightly rate limit",
513                "Extended stay without approval",
514                "Non-preferred hotel chain",
515            ],
516            _ => vec![
517                "Missing receipt",
518                "Insufficient business justification",
519                "Exceeds category spending limit",
520            ],
521        };
522
523        violations[self.rng.random_range(0..violations.len())].to_string()
524    }
525
526    /// Get the last day of the month for a given date.
527    fn month_end(&self, date: NaiveDate) -> NaiveDate {
528        let (year, month) = if date.month() == 12 {
529            (date.year() + 1, 1)
530        } else {
531            (date.year(), date.month() + 1)
532        };
533        NaiveDate::from_ymd_opt(year, month, 1)
534            .unwrap_or(date)
535            .pred_opt()
536            .unwrap_or(date)
537    }
538
539    /// Get the first day of the next month.
540    fn next_month_start(&self, date: NaiveDate) -> NaiveDate {
541        let (year, month) = if date.month() == 12 {
542            (date.year() + 1, 1)
543        } else {
544            (date.year(), date.month() + 1)
545        };
546        NaiveDate::from_ymd_opt(year, month, 1).unwrap_or(date)
547    }
548}
549
550#[cfg(test)]
551#[allow(clippy::unwrap_used)]
552mod tests {
553    use super::*;
554
555    fn test_employee_ids() -> Vec<String> {
556        (1..=10).map(|i| format!("EMP-{:04}", i)).collect()
557    }
558
559    #[test]
560    fn test_basic_expense_generation() {
561        let mut gen = ExpenseReportGenerator::new(42);
562        let employees = test_employee_ids();
563        let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
564        let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
565        let config = ExpenseConfig::default();
566
567        let reports = gen.generate(&employees, period_start, period_end, &config);
568
569        // With 30% submission rate and 10 employees, expect ~3 reports per month
570        assert!(!reports.is_empty());
571        assert!(
572            reports.len() <= employees.len(),
573            "Should not exceed employee count for a single month"
574        );
575
576        for report in &reports {
577            assert!(!report.report_id.is_empty());
578            assert!(!report.employee_id.is_empty());
579            assert!(report.total_amount > Decimal::ZERO);
580            assert!(!report.line_items.is_empty());
581            assert!(report.line_items.len() <= 5);
582
583            // Total should equal sum of line items
584            let line_sum: Decimal = report.line_items.iter().map(|li| li.amount).sum();
585            assert_eq!(report.total_amount, line_sum);
586
587            for item in &report.line_items {
588                assert!(!item.item_id.is_empty());
589                assert!(item.amount > Decimal::ZERO);
590            }
591        }
592    }
593
594    #[test]
595    fn test_deterministic_expenses() {
596        let employees = test_employee_ids();
597        let period_start = NaiveDate::from_ymd_opt(2024, 3, 1).unwrap();
598        let period_end = NaiveDate::from_ymd_opt(2024, 3, 31).unwrap();
599        let config = ExpenseConfig::default();
600
601        let mut gen1 = ExpenseReportGenerator::new(42);
602        let reports1 = gen1.generate(&employees, period_start, period_end, &config);
603
604        let mut gen2 = ExpenseReportGenerator::new(42);
605        let reports2 = gen2.generate(&employees, period_start, period_end, &config);
606
607        assert_eq!(reports1.len(), reports2.len());
608        for (a, b) in reports1.iter().zip(reports2.iter()) {
609            assert_eq!(a.report_id, b.report_id);
610            assert_eq!(a.employee_id, b.employee_id);
611            assert_eq!(a.total_amount, b.total_amount);
612            assert_eq!(a.status, b.status);
613            assert_eq!(a.line_items.len(), b.line_items.len());
614        }
615    }
616
617    #[test]
618    fn test_expense_status_and_violations() {
619        let mut gen = ExpenseReportGenerator::new(99);
620        // Use more employees for a broader sample
621        let employees: Vec<String> = (1..=30).map(|i| format!("EMP-{:04}", i)).collect();
622        let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
623        let period_end = NaiveDate::from_ymd_opt(2024, 6, 30).unwrap();
624        let config = ExpenseConfig::default();
625
626        let reports = gen.generate(&employees, period_start, period_end, &config);
627
628        // With 30 employees over 6 months, we should have a decent sample
629        assert!(
630            reports.len() > 10,
631            "Expected multiple reports, got {}",
632            reports.len()
633        );
634
635        let approved = reports
636            .iter()
637            .filter(|r| r.status == ExpenseStatus::Approved)
638            .count();
639        let paid = reports
640            .iter()
641            .filter(|r| r.status == ExpenseStatus::Paid)
642            .count();
643        let submitted = reports
644            .iter()
645            .filter(|r| r.status == ExpenseStatus::Submitted)
646            .count();
647        let rejected = reports
648            .iter()
649            .filter(|r| r.status == ExpenseStatus::Rejected)
650            .count();
651        let draft = reports
652            .iter()
653            .filter(|r| r.status == ExpenseStatus::Draft)
654            .count();
655
656        // Approved should be the majority
657        assert!(approved > 0, "Expected at least some approved reports");
658        // Check that we have a mix of statuses
659        assert!(
660            paid + submitted + rejected + draft > 0,
661            "Expected a mix of statuses beyond approved"
662        );
663
664        // Check policy violations exist somewhere
665        let total_violations: usize = reports.iter().map(|r| r.policy_violations.len()).sum();
666        assert!(
667            total_violations > 0,
668            "Expected at least some policy violations across {} reports",
669            reports.len()
670        );
671    }
672
673    #[test]
674    fn test_country_pack_does_not_break_generation() {
675        let mut gen = ExpenseReportGenerator::new(42);
676        // Setting a default country pack should not alter basic generation behaviour.
677        gen.set_country_pack(datasynth_core::CountryPack::default());
678
679        let employees = test_employee_ids();
680        let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
681        let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
682        let config = ExpenseConfig::default();
683
684        let reports = gen.generate(&employees, period_start, period_end, &config);
685
686        assert!(!reports.is_empty());
687        for report in &reports {
688            assert!(!report.report_id.is_empty());
689            assert!(report.total_amount > Decimal::ZERO);
690        }
691    }
692}