1use chrono::{Datelike, NaiveDate};
8use datasynth_config::schema::ExpenseConfig;
9use datasynth_core::models::{ExpenseCategory, ExpenseLineItem, ExpenseReport, ExpenseStatus};
10use datasynth_core::utils::{sample_decimal_range, seeded_rng};
11use datasynth_core::uuid_factory::{DeterministicUuidFactory, GeneratorType};
12use rand::prelude::*;
13use rand_chacha::ChaCha8Rng;
14use rust_decimal::Decimal;
15use smallvec::SmallVec;
16use std::collections::HashMap;
17use tracing::debug;
18
19pub struct ExpenseReportGenerator {
21 rng: ChaCha8Rng,
22 uuid_factory: DeterministicUuidFactory,
23 item_uuid_factory: DeterministicUuidFactory,
24 config: ExpenseConfig,
26 employee_ids_pool: Vec<String>,
28 cost_center_ids_pool: Vec<String>,
30 employee_names: HashMap<String, String>,
32 country_pack: Option<datasynth_core::CountryPack>,
35}
36
37impl ExpenseReportGenerator {
38 pub fn new(seed: u64) -> Self {
40 Self {
41 rng: seeded_rng(seed, 0),
42 uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::ExpenseReport),
43 item_uuid_factory: DeterministicUuidFactory::with_sub_discriminator(
44 seed,
45 GeneratorType::ExpenseReport,
46 1,
47 ),
48 config: ExpenseConfig::default(),
49 employee_ids_pool: Vec::new(),
50 cost_center_ids_pool: Vec::new(),
51 employee_names: HashMap::new(),
52 country_pack: None,
53 }
54 }
55
56 pub fn with_config(seed: u64, config: ExpenseConfig) -> Self {
58 Self {
59 rng: seeded_rng(seed, 0),
60 uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::ExpenseReport),
61 item_uuid_factory: DeterministicUuidFactory::with_sub_discriminator(
62 seed,
63 GeneratorType::ExpenseReport,
64 1,
65 ),
66 config,
67 employee_ids_pool: Vec::new(),
68 cost_center_ids_pool: Vec::new(),
69 employee_names: HashMap::new(),
70 country_pack: None,
71 }
72 }
73
74 pub fn set_country_pack(&mut self, pack: datasynth_core::CountryPack) {
81 self.country_pack = Some(pack);
82 }
83
84 pub fn with_pools(mut self, employee_ids: Vec<String>, cost_center_ids: Vec<String>) -> Self {
90 self.employee_ids_pool = employee_ids;
91 self.cost_center_ids_pool = cost_center_ids;
92 self
93 }
94
95 pub fn with_employee_names(mut self, names: HashMap<String, String>) -> Self {
100 self.employee_names = names;
101 self
102 }
103
104 pub fn generate_from_config(
109 &mut self,
110 employee_ids: &[String],
111 period_start: NaiveDate,
112 period_end: NaiveDate,
113 ) -> Vec<ExpenseReport> {
114 let config = self.config.clone();
115 self.generate(employee_ids, period_start, period_end, &config)
116 }
117
118 pub fn generate(
130 &mut self,
131 employee_ids: &[String],
132 period_start: NaiveDate,
133 period_end: NaiveDate,
134 config: &ExpenseConfig,
135 ) -> Vec<ExpenseReport> {
136 let currency = self
137 .country_pack
138 .as_ref()
139 .map(|cp| cp.locale.default_currency.clone())
140 .filter(|c| !c.is_empty())
141 .unwrap_or_else(|| "USD".to_string());
142 self.generate_with_currency(employee_ids, period_start, period_end, config, ¤cy)
143 }
144
145 pub fn generate_with_currency(
147 &mut self,
148 employee_ids: &[String],
149 period_start: NaiveDate,
150 period_end: NaiveDate,
151 config: &ExpenseConfig,
152 currency: &str,
153 ) -> Vec<ExpenseReport> {
154 debug!(employee_count = employee_ids.len(), %period_start, %period_end, currency, "Generating expense reports");
155 let mut reports = Vec::new();
156
157 let mut current_month_start = period_start;
159 while current_month_start <= period_end {
160 let month_end = self.month_end(current_month_start).min(period_end);
161
162 for employee_id in employee_ids {
163 if self.rng.random_bool(config.submission_rate.min(1.0)) {
165 let report = self.generate_report(
166 employee_id,
167 current_month_start,
168 month_end,
169 config,
170 currency,
171 );
172 reports.push(report);
173 }
174 }
175
176 current_month_start = self.next_month_start(current_month_start);
178 }
179
180 reports
181 }
182
183 fn generate_report(
185 &mut self,
186 employee_id: &str,
187 period_start: NaiveDate,
188 period_end: NaiveDate,
189 config: &ExpenseConfig,
190 currency: &str,
191 ) -> ExpenseReport {
192 let report_id = self.uuid_factory.next().to_string();
193
194 let item_count = self.rng.random_range(1..=5);
196 let mut line_items = SmallVec::with_capacity(item_count);
197 let mut total_amount = Decimal::ZERO;
198
199 for _ in 0..item_count {
200 let item = self.generate_line_item(period_start, period_end, currency);
201 total_amount += item.amount;
202 line_items.push(item);
203 }
204
205 let max_expense_date = line_items
207 .iter()
208 .map(|li: &ExpenseLineItem| li.date)
209 .max()
210 .unwrap_or(period_end);
211 let submission_lag = self.rng.random_range(0..=5);
212 let submission_date = max_expense_date + chrono::Duration::days(submission_lag);
213
214 let descriptions = [
216 "Client site visit",
217 "Conference attendance",
218 "Team offsite meeting",
219 "Customer presentation",
220 "Training workshop",
221 "Quarterly review travel",
222 "Sales meeting",
223 "Project kickoff",
224 ];
225 let description = descriptions[self.rng.random_range(0..descriptions.len())].to_string();
226
227 let status_roll: f64 = self.rng.random();
229 let status = if status_roll < 0.70 {
230 ExpenseStatus::Approved
231 } else if status_roll < 0.80 {
232 ExpenseStatus::Paid
233 } else if status_roll < 0.90 {
234 ExpenseStatus::Submitted
235 } else if status_roll < 0.95 {
236 ExpenseStatus::Rejected
237 } else {
238 ExpenseStatus::Draft
239 };
240
241 let approved_by = if matches!(status, ExpenseStatus::Approved | ExpenseStatus::Paid) {
242 if !self.employee_ids_pool.is_empty() {
243 let idx = self.rng.random_range(0..self.employee_ids_pool.len());
244 Some(self.employee_ids_pool[idx].clone())
245 } else {
246 Some(format!("MGR-{:04}", self.rng.random_range(1..=100)))
247 }
248 } else {
249 None
250 };
251
252 let approved_date = if matches!(status, ExpenseStatus::Approved | ExpenseStatus::Paid) {
253 let approval_lag = self.rng.random_range(1..=7);
254 Some(submission_date + chrono::Duration::days(approval_lag))
255 } else {
256 None
257 };
258
259 let paid_date = if status == ExpenseStatus::Paid {
260 approved_date.map(|ad| ad + chrono::Duration::days(self.rng.random_range(3..=14)))
261 } else {
262 None
263 };
264
265 let cost_center = if self.rng.random_bool(0.70) {
267 if !self.cost_center_ids_pool.is_empty() {
268 let idx = self.rng.random_range(0..self.cost_center_ids_pool.len());
269 Some(self.cost_center_ids_pool[idx].clone())
270 } else {
271 Some(format!("CC-{:03}", self.rng.random_range(100..=500)))
272 }
273 } else {
274 None
275 };
276
277 let department = if self.rng.random_bool(0.80) {
278 let departments = [
279 "Engineering",
280 "Sales",
281 "Marketing",
282 "Finance",
283 "HR",
284 "Operations",
285 "Legal",
286 "IT",
287 "Executive",
288 ];
289 Some(departments[self.rng.random_range(0..departments.len())].to_string())
290 } else {
291 None
292 };
293
294 let policy_violation_rate = config.policy_violation_rate;
296 let mut policy_violations = Vec::new();
297 for item in &line_items {
298 if self.rng.random_bool(policy_violation_rate.min(1.0)) {
299 let violation = self.pick_violation(item);
300 policy_violations.push(violation);
301 }
302 }
303
304 ExpenseReport {
305 report_id,
306 employee_id: employee_id.to_string(),
307 submission_date,
308 description,
309 status,
310 total_amount,
311 currency: currency.to_string(),
312 line_items,
313 approved_by,
314 approved_date,
315 paid_date,
316 cost_center,
317 department,
318 policy_violations,
319 employee_name: self.employee_names.get(employee_id).cloned(),
320 }
321 }
322
323 fn generate_line_item(
325 &mut self,
326 period_start: NaiveDate,
327 period_end: NaiveDate,
328 currency: &str,
329 ) -> ExpenseLineItem {
330 let item_id = self.item_uuid_factory.next().to_string();
331
332 let (category, amount_min, amount_max, desc, merchant) = self.pick_category();
334
335 let amount = sample_decimal_range(
336 &mut self.rng,
337 Decimal::from_f64_retain(amount_min).unwrap_or(Decimal::ONE),
338 Decimal::from_f64_retain(amount_max).unwrap_or(Decimal::ONE),
339 )
340 .round_dp(2);
341
342 let days_in_period = (period_end - period_start).num_days().max(1);
344 let offset = self.rng.random_range(0..=days_in_period);
345 let date = period_start + chrono::Duration::days(offset);
346
347 let receipt_attached = self.rng.random_bool(0.85);
349
350 ExpenseLineItem {
351 item_id,
352 category,
353 date,
354 amount,
355 currency: currency.to_string(),
356 description: desc,
357 receipt_attached,
358 merchant,
359 }
360 }
361
362 fn pick_category(&mut self) -> (ExpenseCategory, f64, f64, String, Option<String>) {
364 let roll: f64 = self.rng.random();
365
366 if roll < 0.20 {
367 let merchants = [
368 "Delta Airlines",
369 "United Airlines",
370 "American Airlines",
371 "Southwest",
372 ];
373 let merchant = merchants[self.rng.random_range(0..merchants.len())].to_string();
374 (
375 ExpenseCategory::Travel,
376 200.0,
377 2000.0,
378 "Airfare - business travel".to_string(),
379 Some(merchant),
380 )
381 } else if roll < 0.40 {
382 let merchants = [
383 "Restaurant ABC",
384 "Cafe Express",
385 "Business Lunch Co",
386 "Steakhouse Prime",
387 "Sushi Palace",
388 ];
389 let merchant = merchants[self.rng.random_range(0..merchants.len())].to_string();
390 (
391 ExpenseCategory::Meals,
392 20.0,
393 100.0,
394 "Business meal".to_string(),
395 Some(merchant),
396 )
397 } else if roll < 0.55 {
398 let merchants = ["Marriott", "Hilton", "Hyatt", "Holiday Inn", "Best Western"];
399 let merchant = merchants[self.rng.random_range(0..merchants.len())].to_string();
400 (
401 ExpenseCategory::Lodging,
402 100.0,
403 500.0,
404 "Hotel accommodation".to_string(),
405 Some(merchant),
406 )
407 } else if roll < 0.70 {
408 let merchants = ["Uber", "Lyft", "Hertz", "Enterprise", "Airport Parking"];
409 let merchant = merchants[self.rng.random_range(0..merchants.len())].to_string();
410 (
411 ExpenseCategory::Transportation,
412 10.0,
413 200.0,
414 "Ground transportation".to_string(),
415 Some(merchant),
416 )
417 } else if roll < 0.80 {
418 (
419 ExpenseCategory::Office,
420 15.0,
421 300.0,
422 "Office supplies".to_string(),
423 Some("Office Depot".to_string()),
424 )
425 } else if roll < 0.88 {
426 (
427 ExpenseCategory::Entertainment,
428 50.0,
429 500.0,
430 "Client entertainment".to_string(),
431 None,
432 )
433 } else if roll < 0.95 {
434 (
435 ExpenseCategory::Training,
436 100.0,
437 1500.0,
438 "Professional development".to_string(),
439 None,
440 )
441 } else {
442 (
443 ExpenseCategory::Other,
444 10.0,
445 200.0,
446 "Miscellaneous expense".to_string(),
447 None,
448 )
449 }
450 }
451
452 fn pick_violation(&mut self, item: &ExpenseLineItem) -> String {
454 let violations = match item.category {
455 ExpenseCategory::Meals => vec![
456 "Exceeds daily meal limit",
457 "Alcohol included without approval",
458 "Missing itemized receipt",
459 ],
460 ExpenseCategory::Travel => vec![
461 "Booked outside preferred vendor",
462 "Class upgrade not pre-approved",
463 "Booking made less than 7 days in advance",
464 ],
465 ExpenseCategory::Lodging => vec![
466 "Exceeds nightly rate limit",
467 "Extended stay without approval",
468 "Non-preferred hotel chain",
469 ],
470 _ => vec![
471 "Missing receipt",
472 "Insufficient business justification",
473 "Exceeds category spending limit",
474 ],
475 };
476
477 violations[self.rng.random_range(0..violations.len())].to_string()
478 }
479
480 fn month_end(&self, date: NaiveDate) -> NaiveDate {
482 let (year, month) = if date.month() == 12 {
483 (date.year() + 1, 1)
484 } else {
485 (date.year(), date.month() + 1)
486 };
487 NaiveDate::from_ymd_opt(year, month, 1)
488 .unwrap_or(date)
489 .pred_opt()
490 .unwrap_or(date)
491 }
492
493 fn next_month_start(&self, date: NaiveDate) -> NaiveDate {
495 let (year, month) = if date.month() == 12 {
496 (date.year() + 1, 1)
497 } else {
498 (date.year(), date.month() + 1)
499 };
500 NaiveDate::from_ymd_opt(year, month, 1).unwrap_or(date)
501 }
502}
503
504#[cfg(test)]
505#[allow(clippy::unwrap_used)]
506mod tests {
507 use super::*;
508
509 fn test_employee_ids() -> Vec<String> {
510 (1..=10).map(|i| format!("EMP-{:04}", i)).collect()
511 }
512
513 #[test]
514 fn test_basic_expense_generation() {
515 let mut gen = ExpenseReportGenerator::new(42);
516 let employees = test_employee_ids();
517 let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
518 let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
519 let config = ExpenseConfig::default();
520
521 let reports = gen.generate(&employees, period_start, period_end, &config);
522
523 assert!(!reports.is_empty());
525 assert!(
526 reports.len() <= employees.len(),
527 "Should not exceed employee count for a single month"
528 );
529
530 for report in &reports {
531 assert!(!report.report_id.is_empty());
532 assert!(!report.employee_id.is_empty());
533 assert!(report.total_amount > Decimal::ZERO);
534 assert!(!report.line_items.is_empty());
535 assert!(report.line_items.len() <= 5);
536
537 let line_sum: Decimal = report.line_items.iter().map(|li| li.amount).sum();
539 assert_eq!(report.total_amount, line_sum);
540
541 for item in &report.line_items {
542 assert!(!item.item_id.is_empty());
543 assert!(item.amount > Decimal::ZERO);
544 }
545 }
546 }
547
548 #[test]
549 fn test_deterministic_expenses() {
550 let employees = test_employee_ids();
551 let period_start = NaiveDate::from_ymd_opt(2024, 3, 1).unwrap();
552 let period_end = NaiveDate::from_ymd_opt(2024, 3, 31).unwrap();
553 let config = ExpenseConfig::default();
554
555 let mut gen1 = ExpenseReportGenerator::new(42);
556 let reports1 = gen1.generate(&employees, period_start, period_end, &config);
557
558 let mut gen2 = ExpenseReportGenerator::new(42);
559 let reports2 = gen2.generate(&employees, period_start, period_end, &config);
560
561 assert_eq!(reports1.len(), reports2.len());
562 for (a, b) in reports1.iter().zip(reports2.iter()) {
563 assert_eq!(a.report_id, b.report_id);
564 assert_eq!(a.employee_id, b.employee_id);
565 assert_eq!(a.total_amount, b.total_amount);
566 assert_eq!(a.status, b.status);
567 assert_eq!(a.line_items.len(), b.line_items.len());
568 }
569 }
570
571 #[test]
572 fn test_expense_status_and_violations() {
573 let mut gen = ExpenseReportGenerator::new(99);
574 let employees: Vec<String> = (1..=30).map(|i| format!("EMP-{:04}", i)).collect();
576 let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
577 let period_end = NaiveDate::from_ymd_opt(2024, 6, 30).unwrap();
578 let config = ExpenseConfig::default();
579
580 let reports = gen.generate(&employees, period_start, period_end, &config);
581
582 assert!(
584 reports.len() > 10,
585 "Expected multiple reports, got {}",
586 reports.len()
587 );
588
589 let approved = reports
590 .iter()
591 .filter(|r| r.status == ExpenseStatus::Approved)
592 .count();
593 let paid = reports
594 .iter()
595 .filter(|r| r.status == ExpenseStatus::Paid)
596 .count();
597 let submitted = reports
598 .iter()
599 .filter(|r| r.status == ExpenseStatus::Submitted)
600 .count();
601 let rejected = reports
602 .iter()
603 .filter(|r| r.status == ExpenseStatus::Rejected)
604 .count();
605 let draft = reports
606 .iter()
607 .filter(|r| r.status == ExpenseStatus::Draft)
608 .count();
609
610 assert!(approved > 0, "Expected at least some approved reports");
612 assert!(
614 paid + submitted + rejected + draft > 0,
615 "Expected a mix of statuses beyond approved"
616 );
617
618 let total_violations: usize = reports.iter().map(|r| r.policy_violations.len()).sum();
620 assert!(
621 total_violations > 0,
622 "Expected at least some policy violations across {} reports",
623 reports.len()
624 );
625 }
626
627 #[test]
628 fn test_country_pack_does_not_break_generation() {
629 let mut gen = ExpenseReportGenerator::new(42);
630 gen.set_country_pack(datasynth_core::CountryPack::default());
632
633 let employees = test_employee_ids();
634 let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
635 let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
636 let config = ExpenseConfig::default();
637
638 let reports = gen.generate(&employees, period_start, period_end, &config);
639
640 assert!(!reports.is_empty());
641 for report in &reports {
642 assert!(!report.report_id.is_empty());
643 assert!(report.total_amount > Decimal::ZERO);
644 }
645 }
646}