1use chrono::{Datelike, NaiveDate};
8use datasynth_config::schema::ExpenseConfig;
9use datasynth_core::models::{ExpenseCategory, ExpenseLineItem, ExpenseReport, ExpenseStatus};
10use datasynth_core::uuid_factory::{DeterministicUuidFactory, GeneratorType};
11use rand::prelude::*;
12use rand_chacha::ChaCha8Rng;
13use rust_decimal::Decimal;
14
15pub struct ExpenseReportGenerator {
17 rng: ChaCha8Rng,
18 uuid_factory: DeterministicUuidFactory,
19 item_uuid_factory: DeterministicUuidFactory,
20 config: ExpenseConfig,
21}
22
23impl ExpenseReportGenerator {
24 pub fn new(seed: u64) -> Self {
26 Self {
27 rng: ChaCha8Rng::seed_from_u64(seed),
28 uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::ExpenseReport),
29 item_uuid_factory: DeterministicUuidFactory::with_sub_discriminator(
30 seed,
31 GeneratorType::ExpenseReport,
32 1,
33 ),
34 config: ExpenseConfig::default(),
35 }
36 }
37
38 pub fn with_config(seed: u64, config: ExpenseConfig) -> Self {
40 Self {
41 rng: ChaCha8Rng::seed_from_u64(seed),
42 uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::ExpenseReport),
43 item_uuid_factory: DeterministicUuidFactory::with_sub_discriminator(
44 seed,
45 GeneratorType::ExpenseReport,
46 1,
47 ),
48 config,
49 }
50 }
51
52 pub fn generate(
64 &mut self,
65 employee_ids: &[String],
66 period_start: NaiveDate,
67 period_end: NaiveDate,
68 config: &ExpenseConfig,
69 ) -> Vec<ExpenseReport> {
70 let mut reports = Vec::new();
71
72 let mut current_month_start = period_start;
74 while current_month_start <= period_end {
75 let month_end = self.month_end(current_month_start).min(period_end);
76
77 for employee_id in employee_ids {
78 if self.rng.gen_bool(config.submission_rate.min(1.0)) {
80 let report =
81 self.generate_report(employee_id, current_month_start, month_end, config);
82 reports.push(report);
83 }
84 }
85
86 current_month_start = self.next_month_start(current_month_start);
88 }
89
90 reports
91 }
92
93 fn generate_report(
95 &mut self,
96 employee_id: &str,
97 period_start: NaiveDate,
98 period_end: NaiveDate,
99 config: &ExpenseConfig,
100 ) -> ExpenseReport {
101 let report_id = self.uuid_factory.next().to_string();
102
103 let item_count = self.rng.gen_range(1..=5);
105 let mut line_items = Vec::with_capacity(item_count);
106 let mut total_amount = Decimal::ZERO;
107
108 for _ in 0..item_count {
109 let item = self.generate_line_item(period_start, period_end);
110 total_amount += item.amount;
111 line_items.push(item);
112 }
113
114 let max_expense_date = line_items
116 .iter()
117 .map(|li| li.date)
118 .max()
119 .unwrap_or(period_end);
120 let submission_lag = self.rng.gen_range(0..=5);
121 let submission_date = max_expense_date + chrono::Duration::days(submission_lag);
122
123 let descriptions = [
125 "Client site visit",
126 "Conference attendance",
127 "Team offsite meeting",
128 "Customer presentation",
129 "Training workshop",
130 "Quarterly review travel",
131 "Sales meeting",
132 "Project kickoff",
133 ];
134 let description = descriptions[self.rng.gen_range(0..descriptions.len())].to_string();
135
136 let status_roll: f64 = self.rng.gen();
138 let status = if status_roll < 0.70 {
139 ExpenseStatus::Approved
140 } else if status_roll < 0.80 {
141 ExpenseStatus::Paid
142 } else if status_roll < 0.90 {
143 ExpenseStatus::Submitted
144 } else if status_roll < 0.95 {
145 ExpenseStatus::Rejected
146 } else {
147 ExpenseStatus::Draft
148 };
149
150 let approved_by = if matches!(status, ExpenseStatus::Approved | ExpenseStatus::Paid) {
151 Some(format!("MGR-{:04}", self.rng.gen_range(1..=100)))
152 } else {
153 None
154 };
155
156 let approved_date = if matches!(status, ExpenseStatus::Approved | ExpenseStatus::Paid) {
157 let approval_lag = self.rng.gen_range(1..=7);
158 Some(submission_date + chrono::Duration::days(approval_lag))
159 } else {
160 None
161 };
162
163 let paid_date = if status == ExpenseStatus::Paid {
164 approved_date.map(|ad| ad + chrono::Duration::days(self.rng.gen_range(3..=14)))
165 } else {
166 None
167 };
168
169 let cost_center = if self.rng.gen_bool(0.70) {
171 Some(format!("CC-{:03}", self.rng.gen_range(100..=500)))
172 } else {
173 None
174 };
175
176 let department = if self.rng.gen_bool(0.80) {
177 let departments = [
178 "Engineering",
179 "Sales",
180 "Marketing",
181 "Finance",
182 "HR",
183 "Operations",
184 "Legal",
185 "IT",
186 "Executive",
187 ];
188 Some(departments[self.rng.gen_range(0..departments.len())].to_string())
189 } else {
190 None
191 };
192
193 let policy_violation_rate = config.policy_violation_rate;
195 let mut policy_violations = Vec::new();
196 for item in &line_items {
197 if self.rng.gen_bool(policy_violation_rate.min(1.0)) {
198 let violation = self.pick_violation(item);
199 policy_violations.push(violation);
200 }
201 }
202
203 ExpenseReport {
204 report_id,
205 employee_id: employee_id.to_string(),
206 submission_date,
207 description,
208 status,
209 total_amount,
210 currency: "USD".to_string(),
211 line_items,
212 approved_by,
213 approved_date,
214 paid_date,
215 cost_center,
216 department,
217 policy_violations,
218 }
219 }
220
221 fn generate_line_item(
223 &mut self,
224 period_start: NaiveDate,
225 period_end: NaiveDate,
226 ) -> ExpenseLineItem {
227 let item_id = self.item_uuid_factory.next().to_string();
228
229 let (category, amount_min, amount_max, desc, merchant) = self.pick_category();
231
232 let raw_amount = self.rng.gen_range(amount_min..=amount_max);
233 let amount = Decimal::from_f64_retain(raw_amount)
234 .unwrap_or(Decimal::ONE)
235 .round_dp(2);
236
237 let days_in_period = (period_end - period_start).num_days().max(1);
239 let offset = self.rng.gen_range(0..=days_in_period);
240 let date = period_start + chrono::Duration::days(offset);
241
242 let receipt_attached = self.rng.gen_bool(0.85);
244
245 ExpenseLineItem {
246 item_id,
247 category,
248 date,
249 amount,
250 currency: "USD".to_string(),
251 description: desc,
252 receipt_attached,
253 merchant,
254 }
255 }
256
257 fn pick_category(&mut self) -> (ExpenseCategory, f64, f64, String, Option<String>) {
259 let roll: f64 = self.rng.gen();
260
261 if roll < 0.20 {
262 let merchants = [
263 "Delta Airlines",
264 "United Airlines",
265 "American Airlines",
266 "Southwest",
267 ];
268 let merchant = merchants[self.rng.gen_range(0..merchants.len())].to_string();
269 (
270 ExpenseCategory::Travel,
271 200.0,
272 2000.0,
273 "Airfare - business travel".to_string(),
274 Some(merchant),
275 )
276 } else if roll < 0.40 {
277 let merchants = [
278 "Restaurant ABC",
279 "Cafe Express",
280 "Business Lunch Co",
281 "Steakhouse Prime",
282 "Sushi Palace",
283 ];
284 let merchant = merchants[self.rng.gen_range(0..merchants.len())].to_string();
285 (
286 ExpenseCategory::Meals,
287 20.0,
288 100.0,
289 "Business meal".to_string(),
290 Some(merchant),
291 )
292 } else if roll < 0.55 {
293 let merchants = ["Marriott", "Hilton", "Hyatt", "Holiday Inn", "Best Western"];
294 let merchant = merchants[self.rng.gen_range(0..merchants.len())].to_string();
295 (
296 ExpenseCategory::Lodging,
297 100.0,
298 500.0,
299 "Hotel accommodation".to_string(),
300 Some(merchant),
301 )
302 } else if roll < 0.70 {
303 let merchants = ["Uber", "Lyft", "Hertz", "Enterprise", "Airport Parking"];
304 let merchant = merchants[self.rng.gen_range(0..merchants.len())].to_string();
305 (
306 ExpenseCategory::Transportation,
307 10.0,
308 200.0,
309 "Ground transportation".to_string(),
310 Some(merchant),
311 )
312 } else if roll < 0.80 {
313 (
314 ExpenseCategory::Office,
315 15.0,
316 300.0,
317 "Office supplies".to_string(),
318 Some("Office Depot".to_string()),
319 )
320 } else if roll < 0.88 {
321 (
322 ExpenseCategory::Entertainment,
323 50.0,
324 500.0,
325 "Client entertainment".to_string(),
326 None,
327 )
328 } else if roll < 0.95 {
329 (
330 ExpenseCategory::Training,
331 100.0,
332 1500.0,
333 "Professional development".to_string(),
334 None,
335 )
336 } else {
337 (
338 ExpenseCategory::Other,
339 10.0,
340 200.0,
341 "Miscellaneous expense".to_string(),
342 None,
343 )
344 }
345 }
346
347 fn pick_violation(&mut self, item: &ExpenseLineItem) -> String {
349 let violations = match item.category {
350 ExpenseCategory::Meals => vec![
351 "Exceeds daily meal limit",
352 "Alcohol included without approval",
353 "Missing itemized receipt",
354 ],
355 ExpenseCategory::Travel => vec![
356 "Booked outside preferred vendor",
357 "Class upgrade not pre-approved",
358 "Booking made less than 7 days in advance",
359 ],
360 ExpenseCategory::Lodging => vec![
361 "Exceeds nightly rate limit",
362 "Extended stay without approval",
363 "Non-preferred hotel chain",
364 ],
365 _ => vec![
366 "Missing receipt",
367 "Insufficient business justification",
368 "Exceeds category spending limit",
369 ],
370 };
371
372 violations[self.rng.gen_range(0..violations.len())].to_string()
373 }
374
375 fn month_end(&self, date: NaiveDate) -> NaiveDate {
377 let (year, month) = if date.month() == 12 {
378 (date.year() + 1, 1)
379 } else {
380 (date.year(), date.month() + 1)
381 };
382 NaiveDate::from_ymd_opt(year, month, 1)
383 .unwrap_or(date)
384 .pred_opt()
385 .unwrap_or(date)
386 }
387
388 fn next_month_start(&self, date: NaiveDate) -> NaiveDate {
390 let (year, month) = if date.month() == 12 {
391 (date.year() + 1, 1)
392 } else {
393 (date.year(), date.month() + 1)
394 };
395 NaiveDate::from_ymd_opt(year, month, 1).unwrap_or(date)
396 }
397}
398
399#[cfg(test)]
400#[allow(clippy::unwrap_used)]
401mod tests {
402 use super::*;
403
404 fn test_employee_ids() -> Vec<String> {
405 (1..=10).map(|i| format!("EMP-{:04}", i)).collect()
406 }
407
408 #[test]
409 fn test_basic_expense_generation() {
410 let mut gen = ExpenseReportGenerator::new(42);
411 let employees = test_employee_ids();
412 let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
413 let period_end = NaiveDate::from_ymd_opt(2024, 1, 31).unwrap();
414 let config = ExpenseConfig::default();
415
416 let reports = gen.generate(&employees, period_start, period_end, &config);
417
418 assert!(!reports.is_empty());
420 assert!(
421 reports.len() <= employees.len(),
422 "Should not exceed employee count for a single month"
423 );
424
425 for report in &reports {
426 assert!(!report.report_id.is_empty());
427 assert!(!report.employee_id.is_empty());
428 assert!(report.total_amount > Decimal::ZERO);
429 assert!(!report.line_items.is_empty());
430 assert!(report.line_items.len() <= 5);
431
432 let line_sum: Decimal = report.line_items.iter().map(|li| li.amount).sum();
434 assert_eq!(report.total_amount, line_sum);
435
436 for item in &report.line_items {
437 assert!(!item.item_id.is_empty());
438 assert!(item.amount > Decimal::ZERO);
439 }
440 }
441 }
442
443 #[test]
444 fn test_deterministic_expenses() {
445 let employees = test_employee_ids();
446 let period_start = NaiveDate::from_ymd_opt(2024, 3, 1).unwrap();
447 let period_end = NaiveDate::from_ymd_opt(2024, 3, 31).unwrap();
448 let config = ExpenseConfig::default();
449
450 let mut gen1 = ExpenseReportGenerator::new(42);
451 let reports1 = gen1.generate(&employees, period_start, period_end, &config);
452
453 let mut gen2 = ExpenseReportGenerator::new(42);
454 let reports2 = gen2.generate(&employees, period_start, period_end, &config);
455
456 assert_eq!(reports1.len(), reports2.len());
457 for (a, b) in reports1.iter().zip(reports2.iter()) {
458 assert_eq!(a.report_id, b.report_id);
459 assert_eq!(a.employee_id, b.employee_id);
460 assert_eq!(a.total_amount, b.total_amount);
461 assert_eq!(a.status, b.status);
462 assert_eq!(a.line_items.len(), b.line_items.len());
463 }
464 }
465
466 #[test]
467 fn test_expense_status_and_violations() {
468 let mut gen = ExpenseReportGenerator::new(99);
469 let employees: Vec<String> = (1..=30).map(|i| format!("EMP-{:04}", i)).collect();
471 let period_start = NaiveDate::from_ymd_opt(2024, 1, 1).unwrap();
472 let period_end = NaiveDate::from_ymd_opt(2024, 6, 30).unwrap();
473 let config = ExpenseConfig::default();
474
475 let reports = gen.generate(&employees, period_start, period_end, &config);
476
477 assert!(
479 reports.len() > 10,
480 "Expected multiple reports, got {}",
481 reports.len()
482 );
483
484 let approved = reports
485 .iter()
486 .filter(|r| r.status == ExpenseStatus::Approved)
487 .count();
488 let paid = reports
489 .iter()
490 .filter(|r| r.status == ExpenseStatus::Paid)
491 .count();
492 let submitted = reports
493 .iter()
494 .filter(|r| r.status == ExpenseStatus::Submitted)
495 .count();
496 let rejected = reports
497 .iter()
498 .filter(|r| r.status == ExpenseStatus::Rejected)
499 .count();
500 let draft = reports
501 .iter()
502 .filter(|r| r.status == ExpenseStatus::Draft)
503 .count();
504
505 assert!(approved > 0, "Expected at least some approved reports");
507 assert!(
509 paid + submitted + rejected + draft > 0,
510 "Expected a mix of statuses beyond approved"
511 );
512
513 let total_violations: usize = reports.iter().map(|r| r.policy_violations.len()).sum();
515 assert!(
516 total_violations > 0,
517 "Expected at least some policy violations across {} reports",
518 reports.len()
519 );
520 }
521}