1use chrono::{Datelike, NaiveDate};
8use datasynth_config::schema::ExpenseConfig;
9use datasynth_core::models::{ExpenseCategory, ExpenseLineItem, ExpenseReport, ExpenseStatus};
10use datasynth_core::utils::{sample_decimal_range, seeded_rng};
11use datasynth_core::uuid_factory::{DeterministicUuidFactory, GeneratorType};
12use rand::prelude::*;
13use rand_chacha::ChaCha8Rng;
14use rust_decimal::Decimal;
15use tracing::debug;
16
17pub struct ExpenseReportGenerator {
19 rng: ChaCha8Rng,
20 uuid_factory: DeterministicUuidFactory,
21 item_uuid_factory: DeterministicUuidFactory,
22 #[allow(dead_code)]
24 config: ExpenseConfig,
25 employee_ids_pool: Vec<String>,
27 cost_center_ids_pool: Vec<String>,
29}
30
31impl ExpenseReportGenerator {
32 pub fn new(seed: u64) -> Self {
34 Self {
35 rng: seeded_rng(seed, 0),
36 uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::ExpenseReport),
37 item_uuid_factory: DeterministicUuidFactory::with_sub_discriminator(
38 seed,
39 GeneratorType::ExpenseReport,
40 1,
41 ),
42 config: ExpenseConfig::default(),
43 employee_ids_pool: Vec::new(),
44 cost_center_ids_pool: Vec::new(),
45 }
46 }
47
48 pub fn with_config(seed: u64, config: ExpenseConfig) -> Self {
50 Self {
51 rng: seeded_rng(seed, 0),
52 uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::ExpenseReport),
53 item_uuid_factory: DeterministicUuidFactory::with_sub_discriminator(
54 seed,
55 GeneratorType::ExpenseReport,
56 1,
57 ),
58 config,
59 employee_ids_pool: Vec::new(),
60 cost_center_ids_pool: Vec::new(),
61 }
62 }
63
64 pub fn with_pools(mut self, employee_ids: Vec<String>, cost_center_ids: Vec<String>) -> Self {
70 self.employee_ids_pool = employee_ids;
71 self.cost_center_ids_pool = cost_center_ids;
72 self
73 }
74
75 pub fn generate(
87 &mut self,
88 employee_ids: &[String],
89 period_start: NaiveDate,
90 period_end: NaiveDate,
91 config: &ExpenseConfig,
92 ) -> Vec<ExpenseReport> {
93 self.generate_with_currency(employee_ids, period_start, period_end, config, "USD")
94 }
95
96 pub fn generate_with_currency(
98 &mut self,
99 employee_ids: &[String],
100 period_start: NaiveDate,
101 period_end: NaiveDate,
102 config: &ExpenseConfig,
103 currency: &str,
104 ) -> Vec<ExpenseReport> {
105 debug!(employee_count = employee_ids.len(), %period_start, %period_end, currency, "Generating expense reports");
106 let mut reports = Vec::new();
107
108 let mut current_month_start = period_start;
110 while current_month_start <= period_end {
111 let month_end = self.month_end(current_month_start).min(period_end);
112
113 for employee_id in employee_ids {
114 if self.rng.gen_bool(config.submission_rate.min(1.0)) {
116 let report = self.generate_report(
117 employee_id,
118 current_month_start,
119 month_end,
120 config,
121 currency,
122 );
123 reports.push(report);
124 }
125 }
126
127 current_month_start = self.next_month_start(current_month_start);
129 }
130
131 reports
132 }
133
134 fn generate_report(
136 &mut self,
137 employee_id: &str,
138 period_start: NaiveDate,
139 period_end: NaiveDate,
140 config: &ExpenseConfig,
141 currency: &str,
142 ) -> ExpenseReport {
143 let report_id = self.uuid_factory.next().to_string();
144
145 let item_count = self.rng.gen_range(1..=5);
147 let mut line_items = Vec::with_capacity(item_count);
148 let mut total_amount = Decimal::ZERO;
149
150 for _ in 0..item_count {
151 let item = self.generate_line_item(period_start, period_end, currency);
152 total_amount += item.amount;
153 line_items.push(item);
154 }
155
156 let max_expense_date = line_items
158 .iter()
159 .map(|li| li.date)
160 .max()
161 .unwrap_or(period_end);
162 let submission_lag = self.rng.gen_range(0..=5);
163 let submission_date = max_expense_date + chrono::Duration::days(submission_lag);
164
165 let descriptions = [
167 "Client site visit",
168 "Conference attendance",
169 "Team offsite meeting",
170 "Customer presentation",
171 "Training workshop",
172 "Quarterly review travel",
173 "Sales meeting",
174 "Project kickoff",
175 ];
176 let description = descriptions[self.rng.gen_range(0..descriptions.len())].to_string();
177
178 let status_roll: f64 = self.rng.gen();
180 let status = if status_roll < 0.70 {
181 ExpenseStatus::Approved
182 } else if status_roll < 0.80 {
183 ExpenseStatus::Paid
184 } else if status_roll < 0.90 {
185 ExpenseStatus::Submitted
186 } else if status_roll < 0.95 {
187 ExpenseStatus::Rejected
188 } else {
189 ExpenseStatus::Draft
190 };
191
192 let approved_by = if matches!(status, ExpenseStatus::Approved | ExpenseStatus::Paid) {
193 if !self.employee_ids_pool.is_empty() {
194 let idx = self.rng.gen_range(0..self.employee_ids_pool.len());
195 Some(self.employee_ids_pool[idx].clone())
196 } else {
197 Some(format!("MGR-{:04}", self.rng.gen_range(1..=100)))
198 }
199 } else {
200 None
201 };
202
203 let approved_date = if matches!(status, ExpenseStatus::Approved | ExpenseStatus::Paid) {
204 let approval_lag = self.rng.gen_range(1..=7);
205 Some(submission_date + chrono::Duration::days(approval_lag))
206 } else {
207 None
208 };
209
210 let paid_date = if status == ExpenseStatus::Paid {
211 approved_date.map(|ad| ad + chrono::Duration::days(self.rng.gen_range(3..=14)))
212 } else {
213 None
214 };
215
216 let cost_center = if self.rng.gen_bool(0.70) {
218 if !self.cost_center_ids_pool.is_empty() {
219 let idx = self.rng.gen_range(0..self.cost_center_ids_pool.len());
220 Some(self.cost_center_ids_pool[idx].clone())
221 } else {
222 Some(format!("CC-{:03}", self.rng.gen_range(100..=500)))
223 }
224 } else {
225 None
226 };
227
228 let department = if self.rng.gen_bool(0.80) {
229 let departments = [
230 "Engineering",
231 "Sales",
232 "Marketing",
233 "Finance",
234 "HR",
235 "Operations",
236 "Legal",
237 "IT",
238 "Executive",
239 ];
240 Some(departments[self.rng.gen_range(0..departments.len())].to_string())
241 } else {
242 None
243 };
244
245 let policy_violation_rate = config.policy_violation_rate;
247 let mut policy_violations = Vec::new();
248 for item in &line_items {
249 if self.rng.gen_bool(policy_violation_rate.min(1.0)) {
250 let violation = self.pick_violation(item);
251 policy_violations.push(violation);
252 }
253 }
254
255 ExpenseReport {
256 report_id,
257 employee_id: employee_id.to_string(),
258 submission_date,
259 description,
260 status,
261 total_amount,
262 currency: currency.to_string(),
263 line_items,
264 approved_by,
265 approved_date,
266 paid_date,
267 cost_center,
268 department,
269 policy_violations,
270 }
271 }
272
273 fn generate_line_item(
275 &mut self,
276 period_start: NaiveDate,
277 period_end: NaiveDate,
278 currency: &str,
279 ) -> ExpenseLineItem {
280 let item_id = self.item_uuid_factory.next().to_string();
281
282 let (category, amount_min, amount_max, desc, merchant) = self.pick_category();
284
285 let amount = sample_decimal_range(
286 &mut self.rng,
287 Decimal::from_f64_retain(amount_min).unwrap_or(Decimal::ONE),
288 Decimal::from_f64_retain(amount_max).unwrap_or(Decimal::ONE),
289 )
290 .round_dp(2);
291
292 let days_in_period = (period_end - period_start).num_days().max(1);
294 let offset = self.rng.gen_range(0..=days_in_period);
295 let date = period_start + chrono::Duration::days(offset);
296
297 let receipt_attached = self.rng.gen_bool(0.85);
299
300 ExpenseLineItem {
301 item_id,
302 category,
303 date,
304 amount,
305 currency: currency.to_string(),
306 description: desc,
307 receipt_attached,
308 merchant,
309 }
310 }
311
312 fn pick_category(&mut self) -> (ExpenseCategory, f64, f64, String, Option<String>) {
314 let roll: f64 = self.rng.gen();
315
316 if roll < 0.20 {
317 let merchants = [
318 "Delta Airlines",
319 "United Airlines",
320 "American Airlines",
321 "Southwest",
322 ];
323 let merchant = merchants[self.rng.gen_range(0..merchants.len())].to_string();
324 (
325 ExpenseCategory::Travel,
326 200.0,
327 2000.0,
328 "Airfare - business travel".to_string(),
329 Some(merchant),
330 )
331 } else if roll < 0.40 {
332 let merchants = [
333 "Restaurant ABC",
334 "Cafe Express",
335 "Business Lunch Co",
336 "Steakhouse Prime",
337 "Sushi Palace",
338 ];
339 let merchant = merchants[self.rng.gen_range(0..merchants.len())].to_string();
340 (
341 ExpenseCategory::Meals,
342 20.0,
343 100.0,
344 "Business meal".to_string(),
345 Some(merchant),
346 )
347 } else if roll < 0.55 {
348 let merchants = ["Marriott", "Hilton", "Hyatt", "Holiday Inn", "Best Western"];
349 let merchant = merchants[self.rng.gen_range(0..merchants.len())].to_string();
350 (
351 ExpenseCategory::Lodging,
352 100.0,
353 500.0,
354 "Hotel accommodation".to_string(),
355 Some(merchant),
356 )
357 } else if roll < 0.70 {
358 let merchants = ["Uber", "Lyft", "Hertz", "Enterprise", "Airport Parking"];
359 let merchant = merchants[self.rng.gen_range(0..merchants.len())].to_string();
360 (
361 ExpenseCategory::Transportation,
362 10.0,
363 200.0,
364 "Ground transportation".to_string(),
365 Some(merchant),
366 )
367 } else if roll < 0.80 {
368 (
369 ExpenseCategory::Office,
370 15.0,
371 300.0,
372 "Office supplies".to_string(),
373 Some("Office Depot".to_string()),
374 )
375 } else if roll < 0.88 {
376 (
377 ExpenseCategory::Entertainment,
378 50.0,
379 500.0,
380 "Client entertainment".to_string(),
381 None,
382 )
383 } else if roll < 0.95 {
384 (
385 ExpenseCategory::Training,
386 100.0,
387 1500.0,
388 "Professional development".to_string(),
389 None,
390 )
391 } else {
392 (
393 ExpenseCategory::Other,
394 10.0,
395 200.0,
396 "Miscellaneous expense".to_string(),
397 None,
398 )
399 }
400 }
401
402 fn pick_violation(&mut self, item: &ExpenseLineItem) -> String {
404 let violations = match item.category {
405 ExpenseCategory::Meals => vec![
406 "Exceeds daily meal limit",
407 "Alcohol included without approval",
408 "Missing itemized receipt",
409 ],
410 ExpenseCategory::Travel => vec![
411 "Booked outside preferred vendor",
412 "Class upgrade not pre-approved",
413 "Booking made less than 7 days in advance",
414 ],
415 ExpenseCategory::Lodging => vec![
416 "Exceeds nightly rate limit",
417 "Extended stay without approval",
418 "Non-preferred hotel chain",
419 ],
420 _ => vec![
421 "Missing receipt",
422 "Insufficient business justification",
423 "Exceeds category spending limit",
424 ],
425 };
426
427 violations[self.rng.gen_range(0..violations.len())].to_string()
428 }
429
430 fn month_end(&self, date: NaiveDate) -> NaiveDate {
432 let (year, month) = if date.month() == 12 {
433 (date.year() + 1, 1)
434 } else {
435 (date.year(), date.month() + 1)
436 };
437 NaiveDate::from_ymd_opt(year, month, 1)
438 .unwrap_or(date)
439 .pred_opt()
440 .unwrap_or(date)
441 }
442
443 fn next_month_start(&self, date: NaiveDate) -> NaiveDate {
445 let (year, month) = if date.month() == 12 {
446 (date.year() + 1, 1)
447 } else {
448 (date.year(), date.month() + 1)
449 };
450 NaiveDate::from_ymd_opt(year, month, 1).unwrap_or(date)
451 }
452}
453
454#[cfg(test)]
455#[allow(clippy::unwrap_used)]
456mod tests {
457 use super::*;
458
459 fn test_employee_ids() -> Vec<String> {
460 (1..=10).map(|i| format!("EMP-{:04}", i)).collect()
461 }
462
463 #[test]
464 fn test_basic_expense_generation() {
465 let mut gen = ExpenseReportGenerator::new(42);
466 let employees = test_employee_ids();
467 let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
468 let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
469 let config = ExpenseConfig::default();
470
471 let reports = gen.generate(&employees, period_start, period_end, &config);
472
473 assert!(!reports.is_empty());
475 assert!(
476 reports.len() <= employees.len(),
477 "Should not exceed employee count for a single month"
478 );
479
480 for report in &reports {
481 assert!(!report.report_id.is_empty());
482 assert!(!report.employee_id.is_empty());
483 assert!(report.total_amount > Decimal::ZERO);
484 assert!(!report.line_items.is_empty());
485 assert!(report.line_items.len() <= 5);
486
487 let line_sum: Decimal = report.line_items.iter().map(|li| li.amount).sum();
489 assert_eq!(report.total_amount, line_sum);
490
491 for item in &report.line_items {
492 assert!(!item.item_id.is_empty());
493 assert!(item.amount > Decimal::ZERO);
494 }
495 }
496 }
497
498 #[test]
499 fn test_deterministic_expenses() {
500 let employees = test_employee_ids();
501 let period_start = NaiveDate::from_ymd_opt(2024, 3, 1).unwrap();
502 let period_end = NaiveDate::from_ymd_opt(2024, 3, 31).unwrap();
503 let config = ExpenseConfig::default();
504
505 let mut gen1 = ExpenseReportGenerator::new(42);
506 let reports1 = gen1.generate(&employees, period_start, period_end, &config);
507
508 let mut gen2 = ExpenseReportGenerator::new(42);
509 let reports2 = gen2.generate(&employees, period_start, period_end, &config);
510
511 assert_eq!(reports1.len(), reports2.len());
512 for (a, b) in reports1.iter().zip(reports2.iter()) {
513 assert_eq!(a.report_id, b.report_id);
514 assert_eq!(a.employee_id, b.employee_id);
515 assert_eq!(a.total_amount, b.total_amount);
516 assert_eq!(a.status, b.status);
517 assert_eq!(a.line_items.len(), b.line_items.len());
518 }
519 }
520
521 #[test]
522 fn test_expense_status_and_violations() {
523 let mut gen = ExpenseReportGenerator::new(99);
524 let employees: Vec<String> = (1..=30).map(|i| format!("EMP-{:04}", i)).collect();
526 let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
527 let period_end = NaiveDate::from_ymd_opt(2024, 6, 30).unwrap();
528 let config = ExpenseConfig::default();
529
530 let reports = gen.generate(&employees, period_start, period_end, &config);
531
532 assert!(
534 reports.len() > 10,
535 "Expected multiple reports, got {}",
536 reports.len()
537 );
538
539 let approved = reports
540 .iter()
541 .filter(|r| r.status == ExpenseStatus::Approved)
542 .count();
543 let paid = reports
544 .iter()
545 .filter(|r| r.status == ExpenseStatus::Paid)
546 .count();
547 let submitted = reports
548 .iter()
549 .filter(|r| r.status == ExpenseStatus::Submitted)
550 .count();
551 let rejected = reports
552 .iter()
553 .filter(|r| r.status == ExpenseStatus::Rejected)
554 .count();
555 let draft = reports
556 .iter()
557 .filter(|r| r.status == ExpenseStatus::Draft)
558 .count();
559
560 assert!(approved > 0, "Expected at least some approved reports");
562 assert!(
564 paid + submitted + rejected + draft > 0,
565 "Expected a mix of statuses beyond approved"
566 );
567
568 let total_violations: usize = reports.iter().map(|r| r.policy_violations.len()).sum();
570 assert!(
571 total_violations > 0,
572 "Expected at least some policy violations across {} reports",
573 reports.len()
574 );
575 }
576}