1use chrono::{Datelike, NaiveDate};
8use datasynth_config::schema::ExpenseConfig;
9use datasynth_core::distributions::TemporalContext;
10use datasynth_core::models::{ExpenseCategory, ExpenseLineItem, ExpenseReport, ExpenseStatus};
11use datasynth_core::utils::{sample_decimal_range, seeded_rng};
12use datasynth_core::uuid_factory::{DeterministicUuidFactory, GeneratorType};
13use rand::prelude::*;
14use rand_chacha::ChaCha8Rng;
15use rust_decimal::Decimal;
16use smallvec::SmallVec;
17use std::collections::HashMap;
18use std::sync::Arc;
19use tracing::debug;
20
21pub struct ExpenseReportGenerator {
23 rng: ChaCha8Rng,
24 uuid_factory: DeterministicUuidFactory,
25 item_uuid_factory: DeterministicUuidFactory,
26 config: ExpenseConfig,
28 employee_ids_pool: Vec<String>,
30 cost_center_ids_pool: Vec<String>,
32 employee_names: HashMap<String, String>,
34 country_pack: Option<datasynth_core::CountryPack>,
37 temporal_context: Option<Arc<TemporalContext>>,
41}
42
43impl ExpenseReportGenerator {
44 pub fn new(seed: u64) -> Self {
46 Self {
47 rng: seeded_rng(seed, 0),
48 uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::ExpenseReport),
49 item_uuid_factory: DeterministicUuidFactory::with_sub_discriminator(
50 seed,
51 GeneratorType::ExpenseReport,
52 1,
53 ),
54 config: ExpenseConfig::default(),
55 employee_ids_pool: Vec::new(),
56 cost_center_ids_pool: Vec::new(),
57 employee_names: HashMap::new(),
58 country_pack: None,
59 temporal_context: None,
60 }
61 }
62
63 pub fn with_config(seed: u64, config: ExpenseConfig) -> Self {
65 Self {
66 rng: seeded_rng(seed, 0),
67 uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::ExpenseReport),
68 item_uuid_factory: DeterministicUuidFactory::with_sub_discriminator(
69 seed,
70 GeneratorType::ExpenseReport,
71 1,
72 ),
73 config,
74 employee_ids_pool: Vec::new(),
75 cost_center_ids_pool: Vec::new(),
76 employee_names: HashMap::new(),
77 country_pack: None,
78 temporal_context: None,
79 }
80 }
81
82 pub fn set_temporal_context(&mut self, ctx: Arc<TemporalContext>) {
85 self.temporal_context = Some(ctx);
86 }
87
88 pub fn with_temporal_context(mut self, ctx: Arc<TemporalContext>) -> Self {
90 self.temporal_context = Some(ctx);
91 self
92 }
93
94 fn snap_to_business_day(&self, date: NaiveDate) -> NaiveDate {
97 match &self.temporal_context {
98 Some(ctx) => ctx.adjust_to_business_day(date),
99 None => date,
100 }
101 }
102
103 pub fn set_country_pack(&mut self, pack: datasynth_core::CountryPack) {
110 self.country_pack = Some(pack);
111 }
112
113 pub fn with_pools(mut self, employee_ids: Vec<String>, cost_center_ids: Vec<String>) -> Self {
119 self.employee_ids_pool = employee_ids;
120 self.cost_center_ids_pool = cost_center_ids;
121 self
122 }
123
124 pub fn with_employee_names(mut self, names: HashMap<String, String>) -> Self {
129 self.employee_names = names;
130 self
131 }
132
133 pub fn generate_from_config(
138 &mut self,
139 employee_ids: &[String],
140 period_start: NaiveDate,
141 period_end: NaiveDate,
142 ) -> Vec<ExpenseReport> {
143 let config = self.config.clone();
144 self.generate(employee_ids, period_start, period_end, &config)
145 }
146
147 pub fn generate(
159 &mut self,
160 employee_ids: &[String],
161 period_start: NaiveDate,
162 period_end: NaiveDate,
163 config: &ExpenseConfig,
164 ) -> Vec<ExpenseReport> {
165 let currency = self
166 .country_pack
167 .as_ref()
168 .map(|cp| cp.locale.default_currency.clone())
169 .filter(|c| !c.is_empty())
170 .unwrap_or_else(|| "USD".to_string());
171 self.generate_with_currency(employee_ids, period_start, period_end, config, ¤cy)
172 }
173
174 pub fn generate_with_currency(
176 &mut self,
177 employee_ids: &[String],
178 period_start: NaiveDate,
179 period_end: NaiveDate,
180 config: &ExpenseConfig,
181 currency: &str,
182 ) -> Vec<ExpenseReport> {
183 debug!(employee_count = employee_ids.len(), %period_start, %period_end, currency, "Generating expense reports");
184 let mut reports = Vec::new();
185
186 let mut current_month_start = period_start;
188 while current_month_start <= period_end {
189 let month_end = self.month_end(current_month_start).min(period_end);
190
191 for employee_id in employee_ids {
192 if self.rng.random_bool(config.submission_rate.min(1.0)) {
194 let report = self.generate_report(
195 employee_id,
196 current_month_start,
197 month_end,
198 config,
199 currency,
200 );
201 reports.push(report);
202 }
203 }
204
205 current_month_start = self.next_month_start(current_month_start);
207 }
208
209 reports
210 }
211
212 fn generate_report(
214 &mut self,
215 employee_id: &str,
216 period_start: NaiveDate,
217 period_end: NaiveDate,
218 config: &ExpenseConfig,
219 currency: &str,
220 ) -> ExpenseReport {
221 let report_id = self.uuid_factory.next().to_string();
222
223 let item_count = self.rng.random_range(1..=5);
225 let mut line_items = SmallVec::with_capacity(item_count);
226 let mut total_amount = Decimal::ZERO;
227
228 for _ in 0..item_count {
229 let item = self.generate_line_item(period_start, period_end, currency);
230 total_amount += item.amount;
231 line_items.push(item);
232 }
233
234 let max_expense_date = line_items
236 .iter()
237 .map(|li: &ExpenseLineItem| li.date)
238 .max()
239 .unwrap_or(period_end);
240 let submission_lag = self.rng.random_range(0..=5);
241 let raw_submission = max_expense_date + chrono::Duration::days(submission_lag);
242 let submission_date = self.snap_to_business_day(raw_submission);
243
244 let descriptions = [
246 "Client site visit",
247 "Conference attendance",
248 "Team offsite meeting",
249 "Customer presentation",
250 "Training workshop",
251 "Quarterly review travel",
252 "Sales meeting",
253 "Project kickoff",
254 ];
255 let description = descriptions[self.rng.random_range(0..descriptions.len())].to_string();
256
257 let status_roll: f64 = self.rng.random();
259 let status = if status_roll < 0.70 {
260 ExpenseStatus::Approved
261 } else if status_roll < 0.80 {
262 ExpenseStatus::Paid
263 } else if status_roll < 0.90 {
264 ExpenseStatus::Submitted
265 } else if status_roll < 0.95 {
266 ExpenseStatus::Rejected
267 } else {
268 ExpenseStatus::Draft
269 };
270
271 let approved_by = if matches!(status, ExpenseStatus::Approved | ExpenseStatus::Paid) {
272 if !self.employee_ids_pool.is_empty() {
273 let idx = self.rng.random_range(0..self.employee_ids_pool.len());
274 Some(self.employee_ids_pool[idx].clone())
275 } else {
276 Some(format!("MGR-{:04}", self.rng.random_range(1..=100)))
277 }
278 } else {
279 None
280 };
281
282 let approved_date = if matches!(status, ExpenseStatus::Approved | ExpenseStatus::Paid) {
283 let approval_lag = self.rng.random_range(1..=7);
284 let raw = submission_date + chrono::Duration::days(approval_lag);
285 Some(self.snap_to_business_day(raw))
286 } else {
287 None
288 };
289
290 let paid_date = if status == ExpenseStatus::Paid {
291 approved_date.map(|ad| {
292 let raw = ad + chrono::Duration::days(self.rng.random_range(3..=14));
293 self.snap_to_business_day(raw)
294 })
295 } else {
296 None
297 };
298
299 let cost_center = if self.rng.random_bool(0.70) {
301 if !self.cost_center_ids_pool.is_empty() {
302 let idx = self.rng.random_range(0..self.cost_center_ids_pool.len());
303 Some(self.cost_center_ids_pool[idx].clone())
304 } else {
305 Some(format!("CC-{:03}", self.rng.random_range(100..=500)))
306 }
307 } else {
308 None
309 };
310
311 let department = if self.rng.random_bool(0.80) {
312 let departments = [
313 "Engineering",
314 "Sales",
315 "Marketing",
316 "Finance",
317 "HR",
318 "Operations",
319 "Legal",
320 "IT",
321 "Executive",
322 ];
323 Some(departments[self.rng.random_range(0..departments.len())].to_string())
324 } else {
325 None
326 };
327
328 let policy_violation_rate = config.policy_violation_rate;
330 let mut policy_violations = Vec::new();
331 for item in &line_items {
332 if self.rng.random_bool(policy_violation_rate.min(1.0)) {
333 let violation = self.pick_violation(item);
334 policy_violations.push(violation);
335 }
336 }
337
338 ExpenseReport {
339 report_id,
340 employee_id: employee_id.to_string(),
341 submission_date,
342 description,
343 status,
344 total_amount,
345 currency: currency.to_string(),
346 line_items,
347 approved_by,
348 approved_date,
349 paid_date,
350 cost_center,
351 department,
352 policy_violations,
353 employee_name: self.employee_names.get(employee_id).cloned(),
354 }
355 }
356
357 fn generate_line_item(
359 &mut self,
360 period_start: NaiveDate,
361 period_end: NaiveDate,
362 currency: &str,
363 ) -> ExpenseLineItem {
364 let item_id = self.item_uuid_factory.next().to_string();
365
366 let (category, amount_min, amount_max, desc, merchant) = self.pick_category();
368
369 let amount = sample_decimal_range(
370 &mut self.rng,
371 Decimal::from_f64_retain(amount_min).unwrap_or(Decimal::ONE),
372 Decimal::from_f64_retain(amount_max).unwrap_or(Decimal::ONE),
373 )
374 .round_dp(2);
375
376 let days_in_period = (period_end - period_start).num_days().max(1);
379 let offset = self.rng.random_range(0..=days_in_period);
380 let raw_date = period_start + chrono::Duration::days(offset);
381 let date = match &self.temporal_context {
382 Some(ctx) => {
383 let snapped = ctx.adjust_to_business_day(raw_date);
384 if snapped > period_end {
385 ctx.adjust_to_previous_business_day(period_end)
386 } else {
387 snapped
388 }
389 }
390 None => raw_date,
391 };
392
393 let receipt_attached = self.rng.random_bool(0.85);
395
396 ExpenseLineItem {
397 item_id,
398 category,
399 date,
400 amount,
401 currency: currency.to_string(),
402 description: desc,
403 receipt_attached,
404 merchant,
405 }
406 }
407
408 fn pick_category(&mut self) -> (ExpenseCategory, f64, f64, String, Option<String>) {
410 let roll: f64 = self.rng.random();
411
412 if roll < 0.20 {
413 let merchants = [
414 "Delta Airlines",
415 "United Airlines",
416 "American Airlines",
417 "Southwest",
418 ];
419 let merchant = merchants[self.rng.random_range(0..merchants.len())].to_string();
420 (
421 ExpenseCategory::Travel,
422 200.0,
423 2000.0,
424 "Airfare - business travel".to_string(),
425 Some(merchant),
426 )
427 } else if roll < 0.40 {
428 let merchants = [
429 "Restaurant ABC",
430 "Cafe Express",
431 "Business Lunch Co",
432 "Steakhouse Prime",
433 "Sushi Palace",
434 ];
435 let merchant = merchants[self.rng.random_range(0..merchants.len())].to_string();
436 (
437 ExpenseCategory::Meals,
438 20.0,
439 100.0,
440 "Business meal".to_string(),
441 Some(merchant),
442 )
443 } else if roll < 0.55 {
444 let merchants = ["Marriott", "Hilton", "Hyatt", "Holiday Inn", "Best Western"];
445 let merchant = merchants[self.rng.random_range(0..merchants.len())].to_string();
446 (
447 ExpenseCategory::Lodging,
448 100.0,
449 500.0,
450 "Hotel accommodation".to_string(),
451 Some(merchant),
452 )
453 } else if roll < 0.70 {
454 let merchants = ["Uber", "Lyft", "Hertz", "Enterprise", "Airport Parking"];
455 let merchant = merchants[self.rng.random_range(0..merchants.len())].to_string();
456 (
457 ExpenseCategory::Transportation,
458 10.0,
459 200.0,
460 "Ground transportation".to_string(),
461 Some(merchant),
462 )
463 } else if roll < 0.80 {
464 (
465 ExpenseCategory::Office,
466 15.0,
467 300.0,
468 "Office supplies".to_string(),
469 Some("Office Depot".to_string()),
470 )
471 } else if roll < 0.88 {
472 (
473 ExpenseCategory::Entertainment,
474 50.0,
475 500.0,
476 "Client entertainment".to_string(),
477 None,
478 )
479 } else if roll < 0.95 {
480 (
481 ExpenseCategory::Training,
482 100.0,
483 1500.0,
484 "Professional development".to_string(),
485 None,
486 )
487 } else {
488 (
489 ExpenseCategory::Other,
490 10.0,
491 200.0,
492 "Miscellaneous expense".to_string(),
493 None,
494 )
495 }
496 }
497
498 fn pick_violation(&mut self, item: &ExpenseLineItem) -> String {
500 let violations = match item.category {
501 ExpenseCategory::Meals => vec![
502 "Exceeds daily meal limit",
503 "Alcohol included without approval",
504 "Missing itemized receipt",
505 ],
506 ExpenseCategory::Travel => vec![
507 "Booked outside preferred vendor",
508 "Class upgrade not pre-approved",
509 "Booking made less than 7 days in advance",
510 ],
511 ExpenseCategory::Lodging => vec![
512 "Exceeds nightly rate limit",
513 "Extended stay without approval",
514 "Non-preferred hotel chain",
515 ],
516 _ => vec![
517 "Missing receipt",
518 "Insufficient business justification",
519 "Exceeds category spending limit",
520 ],
521 };
522
523 violations[self.rng.random_range(0..violations.len())].to_string()
524 }
525
526 fn month_end(&self, date: NaiveDate) -> NaiveDate {
528 let (year, month) = if date.month() == 12 {
529 (date.year() + 1, 1)
530 } else {
531 (date.year(), date.month() + 1)
532 };
533 NaiveDate::from_ymd_opt(year, month, 1)
534 .unwrap_or(date)
535 .pred_opt()
536 .unwrap_or(date)
537 }
538
539 fn next_month_start(&self, date: NaiveDate) -> NaiveDate {
541 let (year, month) = if date.month() == 12 {
542 (date.year() + 1, 1)
543 } else {
544 (date.year(), date.month() + 1)
545 };
546 NaiveDate::from_ymd_opt(year, month, 1).unwrap_or(date)
547 }
548}
549
550#[cfg(test)]
551#[allow(clippy::unwrap_used)]
552mod tests {
553 use super::*;
554
555 fn test_employee_ids() -> Vec<String> {
556 (1..=10).map(|i| format!("EMP-{:04}", i)).collect()
557 }
558
559 #[test]
560 fn test_basic_expense_generation() {
561 let mut gen = ExpenseReportGenerator::new(42);
562 let employees = test_employee_ids();
563 let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
564 let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
565 let config = ExpenseConfig::default();
566
567 let reports = gen.generate(&employees, period_start, period_end, &config);
568
569 assert!(!reports.is_empty());
571 assert!(
572 reports.len() <= employees.len(),
573 "Should not exceed employee count for a single month"
574 );
575
576 for report in &reports {
577 assert!(!report.report_id.is_empty());
578 assert!(!report.employee_id.is_empty());
579 assert!(report.total_amount > Decimal::ZERO);
580 assert!(!report.line_items.is_empty());
581 assert!(report.line_items.len() <= 5);
582
583 let line_sum: Decimal = report.line_items.iter().map(|li| li.amount).sum();
585 assert_eq!(report.total_amount, line_sum);
586
587 for item in &report.line_items {
588 assert!(!item.item_id.is_empty());
589 assert!(item.amount > Decimal::ZERO);
590 }
591 }
592 }
593
594 #[test]
595 fn test_deterministic_expenses() {
596 let employees = test_employee_ids();
597 let period_start = NaiveDate::from_ymd_opt(2024, 3, 1).unwrap();
598 let period_end = NaiveDate::from_ymd_opt(2024, 3, 31).unwrap();
599 let config = ExpenseConfig::default();
600
601 let mut gen1 = ExpenseReportGenerator::new(42);
602 let reports1 = gen1.generate(&employees, period_start, period_end, &config);
603
604 let mut gen2 = ExpenseReportGenerator::new(42);
605 let reports2 = gen2.generate(&employees, period_start, period_end, &config);
606
607 assert_eq!(reports1.len(), reports2.len());
608 for (a, b) in reports1.iter().zip(reports2.iter()) {
609 assert_eq!(a.report_id, b.report_id);
610 assert_eq!(a.employee_id, b.employee_id);
611 assert_eq!(a.total_amount, b.total_amount);
612 assert_eq!(a.status, b.status);
613 assert_eq!(a.line_items.len(), b.line_items.len());
614 }
615 }
616
617 #[test]
618 fn test_expense_status_and_violations() {
619 let mut gen = ExpenseReportGenerator::new(99);
620 let employees: Vec<String> = (1..=30).map(|i| format!("EMP-{:04}", i)).collect();
622 let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
623 let period_end = NaiveDate::from_ymd_opt(2024, 6, 30).unwrap();
624 let config = ExpenseConfig::default();
625
626 let reports = gen.generate(&employees, period_start, period_end, &config);
627
628 assert!(
630 reports.len() > 10,
631 "Expected multiple reports, got {}",
632 reports.len()
633 );
634
635 let approved = reports
636 .iter()
637 .filter(|r| r.status == ExpenseStatus::Approved)
638 .count();
639 let paid = reports
640 .iter()
641 .filter(|r| r.status == ExpenseStatus::Paid)
642 .count();
643 let submitted = reports
644 .iter()
645 .filter(|r| r.status == ExpenseStatus::Submitted)
646 .count();
647 let rejected = reports
648 .iter()
649 .filter(|r| r.status == ExpenseStatus::Rejected)
650 .count();
651 let draft = reports
652 .iter()
653 .filter(|r| r.status == ExpenseStatus::Draft)
654 .count();
655
656 assert!(approved > 0, "Expected at least some approved reports");
658 assert!(
660 paid + submitted + rejected + draft > 0,
661 "Expected a mix of statuses beyond approved"
662 );
663
664 let total_violations: usize = reports.iter().map(|r| r.policy_violations.len()).sum();
666 assert!(
667 total_violations > 0,
668 "Expected at least some policy violations across {} reports",
669 reports.len()
670 );
671 }
672
673 #[test]
674 fn test_country_pack_does_not_break_generation() {
675 let mut gen = ExpenseReportGenerator::new(42);
676 gen.set_country_pack(datasynth_core::CountryPack::default());
678
679 let employees = test_employee_ids();
680 let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
681 let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
682 let config = ExpenseConfig::default();
683
684 let reports = gen.generate(&employees, period_start, period_end, &config);
685
686 assert!(!reports.is_empty());
687 for report in &reports {
688 assert!(!report.report_id.is_empty());
689 assert!(report.total_amount > Decimal::ZERO);
690 }
691 }
692}