1use chrono::{Datelike, NaiveDate};
8use datasynth_config::schema::ExpenseConfig;
9use datasynth_core::models::{ExpenseCategory, ExpenseLineItem, ExpenseReport, ExpenseStatus};
10use datasynth_core::utils::{sample_decimal_range, seeded_rng};
11use datasynth_core::uuid_factory::{DeterministicUuidFactory, GeneratorType};
12use rand::prelude::*;
13use rand_chacha::ChaCha8Rng;
14use rust_decimal::Decimal;
15use std::collections::HashMap;
16use tracing::debug;
17
18pub struct ExpenseReportGenerator {
20 rng: ChaCha8Rng,
21 uuid_factory: DeterministicUuidFactory,
22 item_uuid_factory: DeterministicUuidFactory,
23 #[allow(dead_code)]
25 config: ExpenseConfig,
26 employee_ids_pool: Vec<String>,
28 cost_center_ids_pool: Vec<String>,
30 employee_names: HashMap<String, String>,
32}
33
34impl ExpenseReportGenerator {
35 pub fn new(seed: u64) -> Self {
37 Self {
38 rng: seeded_rng(seed, 0),
39 uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::ExpenseReport),
40 item_uuid_factory: DeterministicUuidFactory::with_sub_discriminator(
41 seed,
42 GeneratorType::ExpenseReport,
43 1,
44 ),
45 config: ExpenseConfig::default(),
46 employee_ids_pool: Vec::new(),
47 cost_center_ids_pool: Vec::new(),
48 employee_names: HashMap::new(),
49 }
50 }
51
52 pub fn with_config(seed: u64, config: ExpenseConfig) -> Self {
54 Self {
55 rng: seeded_rng(seed, 0),
56 uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::ExpenseReport),
57 item_uuid_factory: DeterministicUuidFactory::with_sub_discriminator(
58 seed,
59 GeneratorType::ExpenseReport,
60 1,
61 ),
62 config,
63 employee_ids_pool: Vec::new(),
64 cost_center_ids_pool: Vec::new(),
65 employee_names: HashMap::new(),
66 }
67 }
68
69 pub fn with_pools(mut self, employee_ids: Vec<String>, cost_center_ids: Vec<String>) -> Self {
75 self.employee_ids_pool = employee_ids;
76 self.cost_center_ids_pool = cost_center_ids;
77 self
78 }
79
80 pub fn with_employee_names(mut self, names: HashMap<String, String>) -> Self {
85 self.employee_names = names;
86 self
87 }
88
89 pub fn generate(
101 &mut self,
102 employee_ids: &[String],
103 period_start: NaiveDate,
104 period_end: NaiveDate,
105 config: &ExpenseConfig,
106 ) -> Vec<ExpenseReport> {
107 self.generate_with_currency(employee_ids, period_start, period_end, config, "USD")
108 }
109
110 pub fn generate_with_currency(
112 &mut self,
113 employee_ids: &[String],
114 period_start: NaiveDate,
115 period_end: NaiveDate,
116 config: &ExpenseConfig,
117 currency: &str,
118 ) -> Vec<ExpenseReport> {
119 debug!(employee_count = employee_ids.len(), %period_start, %period_end, currency, "Generating expense reports");
120 let mut reports = Vec::new();
121
122 let mut current_month_start = period_start;
124 while current_month_start <= period_end {
125 let month_end = self.month_end(current_month_start).min(period_end);
126
127 for employee_id in employee_ids {
128 if self.rng.random_bool(config.submission_rate.min(1.0)) {
130 let report = self.generate_report(
131 employee_id,
132 current_month_start,
133 month_end,
134 config,
135 currency,
136 );
137 reports.push(report);
138 }
139 }
140
141 current_month_start = self.next_month_start(current_month_start);
143 }
144
145 reports
146 }
147
148 fn generate_report(
150 &mut self,
151 employee_id: &str,
152 period_start: NaiveDate,
153 period_end: NaiveDate,
154 config: &ExpenseConfig,
155 currency: &str,
156 ) -> ExpenseReport {
157 let report_id = self.uuid_factory.next().to_string();
158
159 let item_count = self.rng.random_range(1..=5);
161 let mut line_items = Vec::with_capacity(item_count);
162 let mut total_amount = Decimal::ZERO;
163
164 for _ in 0..item_count {
165 let item = self.generate_line_item(period_start, period_end, currency);
166 total_amount += item.amount;
167 line_items.push(item);
168 }
169
170 let max_expense_date = line_items
172 .iter()
173 .map(|li| li.date)
174 .max()
175 .unwrap_or(period_end);
176 let submission_lag = self.rng.random_range(0..=5);
177 let submission_date = max_expense_date + chrono::Duration::days(submission_lag);
178
179 let descriptions = [
181 "Client site visit",
182 "Conference attendance",
183 "Team offsite meeting",
184 "Customer presentation",
185 "Training workshop",
186 "Quarterly review travel",
187 "Sales meeting",
188 "Project kickoff",
189 ];
190 let description = descriptions[self.rng.random_range(0..descriptions.len())].to_string();
191
192 let status_roll: f64 = self.rng.random();
194 let status = if status_roll < 0.70 {
195 ExpenseStatus::Approved
196 } else if status_roll < 0.80 {
197 ExpenseStatus::Paid
198 } else if status_roll < 0.90 {
199 ExpenseStatus::Submitted
200 } else if status_roll < 0.95 {
201 ExpenseStatus::Rejected
202 } else {
203 ExpenseStatus::Draft
204 };
205
206 let approved_by = if matches!(status, ExpenseStatus::Approved | ExpenseStatus::Paid) {
207 if !self.employee_ids_pool.is_empty() {
208 let idx = self.rng.random_range(0..self.employee_ids_pool.len());
209 Some(self.employee_ids_pool[idx].clone())
210 } else {
211 Some(format!("MGR-{:04}", self.rng.random_range(1..=100)))
212 }
213 } else {
214 None
215 };
216
217 let approved_date = if matches!(status, ExpenseStatus::Approved | ExpenseStatus::Paid) {
218 let approval_lag = self.rng.random_range(1..=7);
219 Some(submission_date + chrono::Duration::days(approval_lag))
220 } else {
221 None
222 };
223
224 let paid_date = if status == ExpenseStatus::Paid {
225 approved_date.map(|ad| ad + chrono::Duration::days(self.rng.random_range(3..=14)))
226 } else {
227 None
228 };
229
230 let cost_center = if self.rng.random_bool(0.70) {
232 if !self.cost_center_ids_pool.is_empty() {
233 let idx = self.rng.random_range(0..self.cost_center_ids_pool.len());
234 Some(self.cost_center_ids_pool[idx].clone())
235 } else {
236 Some(format!("CC-{:03}", self.rng.random_range(100..=500)))
237 }
238 } else {
239 None
240 };
241
242 let department = if self.rng.random_bool(0.80) {
243 let departments = [
244 "Engineering",
245 "Sales",
246 "Marketing",
247 "Finance",
248 "HR",
249 "Operations",
250 "Legal",
251 "IT",
252 "Executive",
253 ];
254 Some(departments[self.rng.random_range(0..departments.len())].to_string())
255 } else {
256 None
257 };
258
259 let policy_violation_rate = config.policy_violation_rate;
261 let mut policy_violations = Vec::new();
262 for item in &line_items {
263 if self.rng.random_bool(policy_violation_rate.min(1.0)) {
264 let violation = self.pick_violation(item);
265 policy_violations.push(violation);
266 }
267 }
268
269 ExpenseReport {
270 report_id,
271 employee_id: employee_id.to_string(),
272 submission_date,
273 description,
274 status,
275 total_amount,
276 currency: currency.to_string(),
277 line_items,
278 approved_by,
279 approved_date,
280 paid_date,
281 cost_center,
282 department,
283 policy_violations,
284 employee_name: self.employee_names.get(employee_id).cloned(),
285 }
286 }
287
288 fn generate_line_item(
290 &mut self,
291 period_start: NaiveDate,
292 period_end: NaiveDate,
293 currency: &str,
294 ) -> ExpenseLineItem {
295 let item_id = self.item_uuid_factory.next().to_string();
296
297 let (category, amount_min, amount_max, desc, merchant) = self.pick_category();
299
300 let amount = sample_decimal_range(
301 &mut self.rng,
302 Decimal::from_f64_retain(amount_min).unwrap_or(Decimal::ONE),
303 Decimal::from_f64_retain(amount_max).unwrap_or(Decimal::ONE),
304 )
305 .round_dp(2);
306
307 let days_in_period = (period_end - period_start).num_days().max(1);
309 let offset = self.rng.random_range(0..=days_in_period);
310 let date = period_start + chrono::Duration::days(offset);
311
312 let receipt_attached = self.rng.random_bool(0.85);
314
315 ExpenseLineItem {
316 item_id,
317 category,
318 date,
319 amount,
320 currency: currency.to_string(),
321 description: desc,
322 receipt_attached,
323 merchant,
324 }
325 }
326
327 fn pick_category(&mut self) -> (ExpenseCategory, f64, f64, String, Option<String>) {
329 let roll: f64 = self.rng.random();
330
331 if roll < 0.20 {
332 let merchants = [
333 "Delta Airlines",
334 "United Airlines",
335 "American Airlines",
336 "Southwest",
337 ];
338 let merchant = merchants[self.rng.random_range(0..merchants.len())].to_string();
339 (
340 ExpenseCategory::Travel,
341 200.0,
342 2000.0,
343 "Airfare - business travel".to_string(),
344 Some(merchant),
345 )
346 } else if roll < 0.40 {
347 let merchants = [
348 "Restaurant ABC",
349 "Cafe Express",
350 "Business Lunch Co",
351 "Steakhouse Prime",
352 "Sushi Palace",
353 ];
354 let merchant = merchants[self.rng.random_range(0..merchants.len())].to_string();
355 (
356 ExpenseCategory::Meals,
357 20.0,
358 100.0,
359 "Business meal".to_string(),
360 Some(merchant),
361 )
362 } else if roll < 0.55 {
363 let merchants = ["Marriott", "Hilton", "Hyatt", "Holiday Inn", "Best Western"];
364 let merchant = merchants[self.rng.random_range(0..merchants.len())].to_string();
365 (
366 ExpenseCategory::Lodging,
367 100.0,
368 500.0,
369 "Hotel accommodation".to_string(),
370 Some(merchant),
371 )
372 } else if roll < 0.70 {
373 let merchants = ["Uber", "Lyft", "Hertz", "Enterprise", "Airport Parking"];
374 let merchant = merchants[self.rng.random_range(0..merchants.len())].to_string();
375 (
376 ExpenseCategory::Transportation,
377 10.0,
378 200.0,
379 "Ground transportation".to_string(),
380 Some(merchant),
381 )
382 } else if roll < 0.80 {
383 (
384 ExpenseCategory::Office,
385 15.0,
386 300.0,
387 "Office supplies".to_string(),
388 Some("Office Depot".to_string()),
389 )
390 } else if roll < 0.88 {
391 (
392 ExpenseCategory::Entertainment,
393 50.0,
394 500.0,
395 "Client entertainment".to_string(),
396 None,
397 )
398 } else if roll < 0.95 {
399 (
400 ExpenseCategory::Training,
401 100.0,
402 1500.0,
403 "Professional development".to_string(),
404 None,
405 )
406 } else {
407 (
408 ExpenseCategory::Other,
409 10.0,
410 200.0,
411 "Miscellaneous expense".to_string(),
412 None,
413 )
414 }
415 }
416
417 fn pick_violation(&mut self, item: &ExpenseLineItem) -> String {
419 let violations = match item.category {
420 ExpenseCategory::Meals => vec![
421 "Exceeds daily meal limit",
422 "Alcohol included without approval",
423 "Missing itemized receipt",
424 ],
425 ExpenseCategory::Travel => vec![
426 "Booked outside preferred vendor",
427 "Class upgrade not pre-approved",
428 "Booking made less than 7 days in advance",
429 ],
430 ExpenseCategory::Lodging => vec![
431 "Exceeds nightly rate limit",
432 "Extended stay without approval",
433 "Non-preferred hotel chain",
434 ],
435 _ => vec![
436 "Missing receipt",
437 "Insufficient business justification",
438 "Exceeds category spending limit",
439 ],
440 };
441
442 violations[self.rng.random_range(0..violations.len())].to_string()
443 }
444
445 fn month_end(&self, date: NaiveDate) -> NaiveDate {
447 let (year, month) = if date.month() == 12 {
448 (date.year() + 1, 1)
449 } else {
450 (date.year(), date.month() + 1)
451 };
452 NaiveDate::from_ymd_opt(year, month, 1)
453 .unwrap_or(date)
454 .pred_opt()
455 .unwrap_or(date)
456 }
457
458 fn next_month_start(&self, date: NaiveDate) -> NaiveDate {
460 let (year, month) = if date.month() == 12 {
461 (date.year() + 1, 1)
462 } else {
463 (date.year(), date.month() + 1)
464 };
465 NaiveDate::from_ymd_opt(year, month, 1).unwrap_or(date)
466 }
467}
468
469#[cfg(test)]
470#[allow(clippy::unwrap_used)]
471mod tests {
472 use super::*;
473
474 fn test_employee_ids() -> Vec<String> {
475 (1..=10).map(|i| format!("EMP-{:04}", i)).collect()
476 }
477
478 #[test]
479 fn test_basic_expense_generation() {
480 let mut gen = ExpenseReportGenerator::new(42);
481 let employees = test_employee_ids();
482 let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
483 let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
484 let config = ExpenseConfig::default();
485
486 let reports = gen.generate(&employees, period_start, period_end, &config);
487
488 assert!(!reports.is_empty());
490 assert!(
491 reports.len() <= employees.len(),
492 "Should not exceed employee count for a single month"
493 );
494
495 for report in &reports {
496 assert!(!report.report_id.is_empty());
497 assert!(!report.employee_id.is_empty());
498 assert!(report.total_amount > Decimal::ZERO);
499 assert!(!report.line_items.is_empty());
500 assert!(report.line_items.len() <= 5);
501
502 let line_sum: Decimal = report.line_items.iter().map(|li| li.amount).sum();
504 assert_eq!(report.total_amount, line_sum);
505
506 for item in &report.line_items {
507 assert!(!item.item_id.is_empty());
508 assert!(item.amount > Decimal::ZERO);
509 }
510 }
511 }
512
513 #[test]
514 fn test_deterministic_expenses() {
515 let employees = test_employee_ids();
516 let period_start = NaiveDate::from_ymd_opt(2024, 3, 1).unwrap();
517 let period_end = NaiveDate::from_ymd_opt(2024, 3, 31).unwrap();
518 let config = ExpenseConfig::default();
519
520 let mut gen1 = ExpenseReportGenerator::new(42);
521 let reports1 = gen1.generate(&employees, period_start, period_end, &config);
522
523 let mut gen2 = ExpenseReportGenerator::new(42);
524 let reports2 = gen2.generate(&employees, period_start, period_end, &config);
525
526 assert_eq!(reports1.len(), reports2.len());
527 for (a, b) in reports1.iter().zip(reports2.iter()) {
528 assert_eq!(a.report_id, b.report_id);
529 assert_eq!(a.employee_id, b.employee_id);
530 assert_eq!(a.total_amount, b.total_amount);
531 assert_eq!(a.status, b.status);
532 assert_eq!(a.line_items.len(), b.line_items.len());
533 }
534 }
535
536 #[test]
537 fn test_expense_status_and_violations() {
538 let mut gen = ExpenseReportGenerator::new(99);
539 let employees: Vec<String> = (1..=30).map(|i| format!("EMP-{:04}", i)).collect();
541 let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
542 let period_end = NaiveDate::from_ymd_opt(2024, 6, 30).unwrap();
543 let config = ExpenseConfig::default();
544
545 let reports = gen.generate(&employees, period_start, period_end, &config);
546
547 assert!(
549 reports.len() > 10,
550 "Expected multiple reports, got {}",
551 reports.len()
552 );
553
554 let approved = reports
555 .iter()
556 .filter(|r| r.status == ExpenseStatus::Approved)
557 .count();
558 let paid = reports
559 .iter()
560 .filter(|r| r.status == ExpenseStatus::Paid)
561 .count();
562 let submitted = reports
563 .iter()
564 .filter(|r| r.status == ExpenseStatus::Submitted)
565 .count();
566 let rejected = reports
567 .iter()
568 .filter(|r| r.status == ExpenseStatus::Rejected)
569 .count();
570 let draft = reports
571 .iter()
572 .filter(|r| r.status == ExpenseStatus::Draft)
573 .count();
574
575 assert!(approved > 0, "Expected at least some approved reports");
577 assert!(
579 paid + submitted + rejected + draft > 0,
580 "Expected a mix of statuses beyond approved"
581 );
582
583 let total_violations: usize = reports.iter().map(|r| r.policy_violations.len()).sum();
585 assert!(
586 total_violations > 0,
587 "Expected at least some policy violations across {} reports",
588 reports.len()
589 );
590 }
591}