1use chrono::{Datelike, NaiveDate};
8use datasynth_config::schema::ExpenseConfig;
9use datasynth_core::models::{ExpenseCategory, ExpenseLineItem, ExpenseReport, ExpenseStatus};
10use datasynth_core::utils::{sample_decimal_range, seeded_rng};
11use datasynth_core::uuid_factory::{DeterministicUuidFactory, GeneratorType};
12use rand::prelude::*;
13use rand_chacha::ChaCha8Rng;
14use rust_decimal::Decimal;
15use smallvec::SmallVec;
16use std::collections::HashMap;
17use tracing::debug;
18
19pub struct ExpenseReportGenerator {
21 rng: ChaCha8Rng,
22 uuid_factory: DeterministicUuidFactory,
23 item_uuid_factory: DeterministicUuidFactory,
24 #[allow(dead_code)]
27 config: ExpenseConfig,
28 employee_ids_pool: Vec<String>,
30 cost_center_ids_pool: Vec<String>,
32 employee_names: HashMap<String, String>,
34 #[allow(dead_code)]
38 country_pack: Option<datasynth_core::CountryPack>,
39}
40
41impl ExpenseReportGenerator {
42 pub fn new(seed: u64) -> Self {
44 Self {
45 rng: seeded_rng(seed, 0),
46 uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::ExpenseReport),
47 item_uuid_factory: DeterministicUuidFactory::with_sub_discriminator(
48 seed,
49 GeneratorType::ExpenseReport,
50 1,
51 ),
52 config: ExpenseConfig::default(),
53 employee_ids_pool: Vec::new(),
54 cost_center_ids_pool: Vec::new(),
55 employee_names: HashMap::new(),
56 country_pack: None,
57 }
58 }
59
60 pub fn with_config(seed: u64, config: ExpenseConfig) -> Self {
62 Self {
63 rng: seeded_rng(seed, 0),
64 uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::ExpenseReport),
65 item_uuid_factory: DeterministicUuidFactory::with_sub_discriminator(
66 seed,
67 GeneratorType::ExpenseReport,
68 1,
69 ),
70 config,
71 employee_ids_pool: Vec::new(),
72 cost_center_ids_pool: Vec::new(),
73 employee_names: HashMap::new(),
74 country_pack: None,
75 }
76 }
77
78 pub fn set_country_pack(&mut self, pack: datasynth_core::CountryPack) {
85 self.country_pack = Some(pack);
86 }
87
88 pub fn with_pools(mut self, employee_ids: Vec<String>, cost_center_ids: Vec<String>) -> Self {
94 self.employee_ids_pool = employee_ids;
95 self.cost_center_ids_pool = cost_center_ids;
96 self
97 }
98
99 pub fn with_employee_names(mut self, names: HashMap<String, String>) -> Self {
104 self.employee_names = names;
105 self
106 }
107
108 pub fn generate(
120 &mut self,
121 employee_ids: &[String],
122 period_start: NaiveDate,
123 period_end: NaiveDate,
124 config: &ExpenseConfig,
125 ) -> Vec<ExpenseReport> {
126 self.generate_with_currency(employee_ids, period_start, period_end, config, "USD")
127 }
128
129 pub fn generate_with_currency(
131 &mut self,
132 employee_ids: &[String],
133 period_start: NaiveDate,
134 period_end: NaiveDate,
135 config: &ExpenseConfig,
136 currency: &str,
137 ) -> Vec<ExpenseReport> {
138 debug!(employee_count = employee_ids.len(), %period_start, %period_end, currency, "Generating expense reports");
139 let mut reports = Vec::new();
140
141 let mut current_month_start = period_start;
143 while current_month_start <= period_end {
144 let month_end = self.month_end(current_month_start).min(period_end);
145
146 for employee_id in employee_ids {
147 if self.rng.random_bool(config.submission_rate.min(1.0)) {
149 let report = self.generate_report(
150 employee_id,
151 current_month_start,
152 month_end,
153 config,
154 currency,
155 );
156 reports.push(report);
157 }
158 }
159
160 current_month_start = self.next_month_start(current_month_start);
162 }
163
164 reports
165 }
166
167 fn generate_report(
169 &mut self,
170 employee_id: &str,
171 period_start: NaiveDate,
172 period_end: NaiveDate,
173 config: &ExpenseConfig,
174 currency: &str,
175 ) -> ExpenseReport {
176 let report_id = self.uuid_factory.next().to_string();
177
178 let item_count = self.rng.random_range(1..=5);
180 let mut line_items = SmallVec::with_capacity(item_count);
181 let mut total_amount = Decimal::ZERO;
182
183 for _ in 0..item_count {
184 let item = self.generate_line_item(period_start, period_end, currency);
185 total_amount += item.amount;
186 line_items.push(item);
187 }
188
189 let max_expense_date = line_items
191 .iter()
192 .map(|li: &ExpenseLineItem| li.date)
193 .max()
194 .unwrap_or(period_end);
195 let submission_lag = self.rng.random_range(0..=5);
196 let submission_date = max_expense_date + chrono::Duration::days(submission_lag);
197
198 let descriptions = [
200 "Client site visit",
201 "Conference attendance",
202 "Team offsite meeting",
203 "Customer presentation",
204 "Training workshop",
205 "Quarterly review travel",
206 "Sales meeting",
207 "Project kickoff",
208 ];
209 let description = descriptions[self.rng.random_range(0..descriptions.len())].to_string();
210
211 let status_roll: f64 = self.rng.random();
213 let status = if status_roll < 0.70 {
214 ExpenseStatus::Approved
215 } else if status_roll < 0.80 {
216 ExpenseStatus::Paid
217 } else if status_roll < 0.90 {
218 ExpenseStatus::Submitted
219 } else if status_roll < 0.95 {
220 ExpenseStatus::Rejected
221 } else {
222 ExpenseStatus::Draft
223 };
224
225 let approved_by = if matches!(status, ExpenseStatus::Approved | ExpenseStatus::Paid) {
226 if !self.employee_ids_pool.is_empty() {
227 let idx = self.rng.random_range(0..self.employee_ids_pool.len());
228 Some(self.employee_ids_pool[idx].clone())
229 } else {
230 Some(format!("MGR-{:04}", self.rng.random_range(1..=100)))
231 }
232 } else {
233 None
234 };
235
236 let approved_date = if matches!(status, ExpenseStatus::Approved | ExpenseStatus::Paid) {
237 let approval_lag = self.rng.random_range(1..=7);
238 Some(submission_date + chrono::Duration::days(approval_lag))
239 } else {
240 None
241 };
242
243 let paid_date = if status == ExpenseStatus::Paid {
244 approved_date.map(|ad| ad + chrono::Duration::days(self.rng.random_range(3..=14)))
245 } else {
246 None
247 };
248
249 let cost_center = if self.rng.random_bool(0.70) {
251 if !self.cost_center_ids_pool.is_empty() {
252 let idx = self.rng.random_range(0..self.cost_center_ids_pool.len());
253 Some(self.cost_center_ids_pool[idx].clone())
254 } else {
255 Some(format!("CC-{:03}", self.rng.random_range(100..=500)))
256 }
257 } else {
258 None
259 };
260
261 let department = if self.rng.random_bool(0.80) {
262 let departments = [
263 "Engineering",
264 "Sales",
265 "Marketing",
266 "Finance",
267 "HR",
268 "Operations",
269 "Legal",
270 "IT",
271 "Executive",
272 ];
273 Some(departments[self.rng.random_range(0..departments.len())].to_string())
274 } else {
275 None
276 };
277
278 let policy_violation_rate = config.policy_violation_rate;
280 let mut policy_violations = Vec::new();
281 for item in &line_items {
282 if self.rng.random_bool(policy_violation_rate.min(1.0)) {
283 let violation = self.pick_violation(item);
284 policy_violations.push(violation);
285 }
286 }
287
288 ExpenseReport {
289 report_id,
290 employee_id: employee_id.to_string(),
291 submission_date,
292 description,
293 status,
294 total_amount,
295 currency: currency.to_string(),
296 line_items,
297 approved_by,
298 approved_date,
299 paid_date,
300 cost_center,
301 department,
302 policy_violations,
303 employee_name: self.employee_names.get(employee_id).cloned(),
304 }
305 }
306
307 fn generate_line_item(
309 &mut self,
310 period_start: NaiveDate,
311 period_end: NaiveDate,
312 currency: &str,
313 ) -> ExpenseLineItem {
314 let item_id = self.item_uuid_factory.next().to_string();
315
316 let (category, amount_min, amount_max, desc, merchant) = self.pick_category();
318
319 let amount = sample_decimal_range(
320 &mut self.rng,
321 Decimal::from_f64_retain(amount_min).unwrap_or(Decimal::ONE),
322 Decimal::from_f64_retain(amount_max).unwrap_or(Decimal::ONE),
323 )
324 .round_dp(2);
325
326 let days_in_period = (period_end - period_start).num_days().max(1);
328 let offset = self.rng.random_range(0..=days_in_period);
329 let date = period_start + chrono::Duration::days(offset);
330
331 let receipt_attached = self.rng.random_bool(0.85);
333
334 ExpenseLineItem {
335 item_id,
336 category,
337 date,
338 amount,
339 currency: currency.to_string(),
340 description: desc,
341 receipt_attached,
342 merchant,
343 }
344 }
345
346 fn pick_category(&mut self) -> (ExpenseCategory, f64, f64, String, Option<String>) {
348 let roll: f64 = self.rng.random();
349
350 if roll < 0.20 {
351 let merchants = [
352 "Delta Airlines",
353 "United Airlines",
354 "American Airlines",
355 "Southwest",
356 ];
357 let merchant = merchants[self.rng.random_range(0..merchants.len())].to_string();
358 (
359 ExpenseCategory::Travel,
360 200.0,
361 2000.0,
362 "Airfare - business travel".to_string(),
363 Some(merchant),
364 )
365 } else if roll < 0.40 {
366 let merchants = [
367 "Restaurant ABC",
368 "Cafe Express",
369 "Business Lunch Co",
370 "Steakhouse Prime",
371 "Sushi Palace",
372 ];
373 let merchant = merchants[self.rng.random_range(0..merchants.len())].to_string();
374 (
375 ExpenseCategory::Meals,
376 20.0,
377 100.0,
378 "Business meal".to_string(),
379 Some(merchant),
380 )
381 } else if roll < 0.55 {
382 let merchants = ["Marriott", "Hilton", "Hyatt", "Holiday Inn", "Best Western"];
383 let merchant = merchants[self.rng.random_range(0..merchants.len())].to_string();
384 (
385 ExpenseCategory::Lodging,
386 100.0,
387 500.0,
388 "Hotel accommodation".to_string(),
389 Some(merchant),
390 )
391 } else if roll < 0.70 {
392 let merchants = ["Uber", "Lyft", "Hertz", "Enterprise", "Airport Parking"];
393 let merchant = merchants[self.rng.random_range(0..merchants.len())].to_string();
394 (
395 ExpenseCategory::Transportation,
396 10.0,
397 200.0,
398 "Ground transportation".to_string(),
399 Some(merchant),
400 )
401 } else if roll < 0.80 {
402 (
403 ExpenseCategory::Office,
404 15.0,
405 300.0,
406 "Office supplies".to_string(),
407 Some("Office Depot".to_string()),
408 )
409 } else if roll < 0.88 {
410 (
411 ExpenseCategory::Entertainment,
412 50.0,
413 500.0,
414 "Client entertainment".to_string(),
415 None,
416 )
417 } else if roll < 0.95 {
418 (
419 ExpenseCategory::Training,
420 100.0,
421 1500.0,
422 "Professional development".to_string(),
423 None,
424 )
425 } else {
426 (
427 ExpenseCategory::Other,
428 10.0,
429 200.0,
430 "Miscellaneous expense".to_string(),
431 None,
432 )
433 }
434 }
435
436 fn pick_violation(&mut self, item: &ExpenseLineItem) -> String {
438 let violations = match item.category {
439 ExpenseCategory::Meals => vec![
440 "Exceeds daily meal limit",
441 "Alcohol included without approval",
442 "Missing itemized receipt",
443 ],
444 ExpenseCategory::Travel => vec![
445 "Booked outside preferred vendor",
446 "Class upgrade not pre-approved",
447 "Booking made less than 7 days in advance",
448 ],
449 ExpenseCategory::Lodging => vec![
450 "Exceeds nightly rate limit",
451 "Extended stay without approval",
452 "Non-preferred hotel chain",
453 ],
454 _ => vec![
455 "Missing receipt",
456 "Insufficient business justification",
457 "Exceeds category spending limit",
458 ],
459 };
460
461 violations[self.rng.random_range(0..violations.len())].to_string()
462 }
463
464 fn month_end(&self, date: NaiveDate) -> NaiveDate {
466 let (year, month) = if date.month() == 12 {
467 (date.year() + 1, 1)
468 } else {
469 (date.year(), date.month() + 1)
470 };
471 NaiveDate::from_ymd_opt(year, month, 1)
472 .unwrap_or(date)
473 .pred_opt()
474 .unwrap_or(date)
475 }
476
477 fn next_month_start(&self, date: NaiveDate) -> NaiveDate {
479 let (year, month) = if date.month() == 12 {
480 (date.year() + 1, 1)
481 } else {
482 (date.year(), date.month() + 1)
483 };
484 NaiveDate::from_ymd_opt(year, month, 1).unwrap_or(date)
485 }
486}
487
488#[cfg(test)]
489#[allow(clippy::unwrap_used)]
490mod tests {
491 use super::*;
492
493 fn test_employee_ids() -> Vec<String> {
494 (1..=10).map(|i| format!("EMP-{:04}", i)).collect()
495 }
496
497 #[test]
498 fn test_basic_expense_generation() {
499 let mut gen = ExpenseReportGenerator::new(42);
500 let employees = test_employee_ids();
501 let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
502 let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
503 let config = ExpenseConfig::default();
504
505 let reports = gen.generate(&employees, period_start, period_end, &config);
506
507 assert!(!reports.is_empty());
509 assert!(
510 reports.len() <= employees.len(),
511 "Should not exceed employee count for a single month"
512 );
513
514 for report in &reports {
515 assert!(!report.report_id.is_empty());
516 assert!(!report.employee_id.is_empty());
517 assert!(report.total_amount > Decimal::ZERO);
518 assert!(!report.line_items.is_empty());
519 assert!(report.line_items.len() <= 5);
520
521 let line_sum: Decimal = report.line_items.iter().map(|li| li.amount).sum();
523 assert_eq!(report.total_amount, line_sum);
524
525 for item in &report.line_items {
526 assert!(!item.item_id.is_empty());
527 assert!(item.amount > Decimal::ZERO);
528 }
529 }
530 }
531
532 #[test]
533 fn test_deterministic_expenses() {
534 let employees = test_employee_ids();
535 let period_start = NaiveDate::from_ymd_opt(2024, 3, 1).unwrap();
536 let period_end = NaiveDate::from_ymd_opt(2024, 3, 31).unwrap();
537 let config = ExpenseConfig::default();
538
539 let mut gen1 = ExpenseReportGenerator::new(42);
540 let reports1 = gen1.generate(&employees, period_start, period_end, &config);
541
542 let mut gen2 = ExpenseReportGenerator::new(42);
543 let reports2 = gen2.generate(&employees, period_start, period_end, &config);
544
545 assert_eq!(reports1.len(), reports2.len());
546 for (a, b) in reports1.iter().zip(reports2.iter()) {
547 assert_eq!(a.report_id, b.report_id);
548 assert_eq!(a.employee_id, b.employee_id);
549 assert_eq!(a.total_amount, b.total_amount);
550 assert_eq!(a.status, b.status);
551 assert_eq!(a.line_items.len(), b.line_items.len());
552 }
553 }
554
555 #[test]
556 fn test_expense_status_and_violations() {
557 let mut gen = ExpenseReportGenerator::new(99);
558 let employees: Vec<String> = (1..=30).map(|i| format!("EMP-{:04}", i)).collect();
560 let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
561 let period_end = NaiveDate::from_ymd_opt(2024, 6, 30).unwrap();
562 let config = ExpenseConfig::default();
563
564 let reports = gen.generate(&employees, period_start, period_end, &config);
565
566 assert!(
568 reports.len() > 10,
569 "Expected multiple reports, got {}",
570 reports.len()
571 );
572
573 let approved = reports
574 .iter()
575 .filter(|r| r.status == ExpenseStatus::Approved)
576 .count();
577 let paid = reports
578 .iter()
579 .filter(|r| r.status == ExpenseStatus::Paid)
580 .count();
581 let submitted = reports
582 .iter()
583 .filter(|r| r.status == ExpenseStatus::Submitted)
584 .count();
585 let rejected = reports
586 .iter()
587 .filter(|r| r.status == ExpenseStatus::Rejected)
588 .count();
589 let draft = reports
590 .iter()
591 .filter(|r| r.status == ExpenseStatus::Draft)
592 .count();
593
594 assert!(approved > 0, "Expected at least some approved reports");
596 assert!(
598 paid + submitted + rejected + draft > 0,
599 "Expected a mix of statuses beyond approved"
600 );
601
602 let total_violations: usize = reports.iter().map(|r| r.policy_violations.len()).sum();
604 assert!(
605 total_violations > 0,
606 "Expected at least some policy violations across {} reports",
607 reports.len()
608 );
609 }
610
611 #[test]
612 fn test_country_pack_does_not_break_generation() {
613 let mut gen = ExpenseReportGenerator::new(42);
614 gen.set_country_pack(datasynth_core::CountryPack::default());
616
617 let employees = test_employee_ids();
618 let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
619 let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
620 let config = ExpenseConfig::default();
621
622 let reports = gen.generate(&employees, period_start, period_end, &config);
623
624 assert!(!reports.is_empty());
625 for report in &reports {
626 assert!(!report.report_id.is_empty());
627 assert!(report.total_amount > Decimal::ZERO);
628 }
629 }
630}