1use chrono::NaiveDate;
9use datasynth_core::utils::seeded_rng;
10use rand::prelude::*;
11use rand_chacha::ChaCha8Rng;
12use rust_decimal::Decimal;
13use std::collections::{HashMap, HashSet};
14
15use datasynth_core::models::{TaxCode, TaxLine, TaxableDocumentType};
16
17#[derive(Debug, Clone)]
23pub struct TaxLineGeneratorConfig {
24 pub exempt_categories: Vec<String>,
26 pub eu_countries: HashSet<String>,
28}
29
30impl Default for TaxLineGeneratorConfig {
31 fn default() -> Self {
32 Self {
33 exempt_categories: Vec::new(),
34 eu_countries: HashSet::from([
35 "DE".into(),
36 "FR".into(),
37 "IT".into(),
38 "ES".into(),
39 "NL".into(),
40 "BE".into(),
41 "AT".into(),
42 "PT".into(),
43 "IE".into(),
44 "FI".into(),
45 "SE".into(),
46 "DK".into(),
47 "PL".into(),
48 "CZ".into(),
49 "RO".into(),
50 "HU".into(),
51 "BG".into(),
52 "HR".into(),
53 "SK".into(),
54 "SI".into(),
55 "LT".into(),
56 "LV".into(),
57 "EE".into(),
58 "CY".into(),
59 "LU".into(),
60 "MT".into(),
61 "GR".into(),
62 ]),
63 }
64 }
65}
66
67pub struct TaxLineGenerator {
102 rng: ChaCha8Rng,
103 tax_codes_by_jurisdiction: HashMap<String, Vec<TaxCode>>,
105 config: TaxLineGeneratorConfig,
106 counter: u64,
107}
108
109impl TaxLineGenerator {
110 pub fn new(config: TaxLineGeneratorConfig, tax_codes: Vec<TaxCode>, seed: u64) -> Self {
114 let mut tax_codes_by_jurisdiction: HashMap<String, Vec<TaxCode>> = HashMap::new();
115 for code in tax_codes {
116 tax_codes_by_jurisdiction
117 .entry(code.jurisdiction_id.clone())
118 .or_default()
119 .push(code);
120 }
121
122 Self {
123 rng: seeded_rng(seed, 0),
124 tax_codes_by_jurisdiction,
125 config,
126 counter: 0,
127 }
128 }
129
130 pub fn generate_for_document(
138 &mut self,
139 doc_type: TaxableDocumentType,
140 doc_id: &str,
141 seller_country: &str,
142 buyer_country: &str,
143 taxable_amount: Decimal,
144 date: NaiveDate,
145 product_category: Option<&str>,
146 ) -> Vec<TaxLine> {
147 if let Some(cat) = product_category {
149 if self
150 .config
151 .exempt_categories
152 .iter()
153 .any(|e| e.eq_ignore_ascii_case(cat))
154 {
155 return Vec::new();
156 }
157 }
158
159 let jurisdiction_country = match doc_type {
161 TaxableDocumentType::VendorInvoice => seller_country,
162 TaxableDocumentType::CustomerInvoice => {
163 buyer_country
166 }
167 TaxableDocumentType::JournalEntry => seller_country,
168 _ => seller_country,
170 };
171
172 let is_eu_cross_border = seller_country != buyer_country
174 && self.config.eu_countries.contains(seller_country)
175 && self.config.eu_countries.contains(buyer_country);
176
177 if is_eu_cross_border {
178 return self.generate_reverse_charge_line(
179 doc_type,
180 doc_id,
181 buyer_country,
182 taxable_amount,
183 date,
184 );
185 }
186
187 let jurisdiction_id = self.resolve_jurisdiction_id(jurisdiction_country);
189
190 let tax_code = match self.find_standard_code(&jurisdiction_id, date) {
192 Some(code) => code,
193 None => return Vec::new(), };
195
196 let tax_amount = tax_code.tax_amount(taxable_amount);
198 let is_deductible = matches!(doc_type, TaxableDocumentType::VendorInvoice);
199
200 let line = self.build_tax_line(
201 doc_type,
202 doc_id,
203 &tax_code.id,
204 &jurisdiction_id,
205 taxable_amount,
206 tax_amount,
207 is_deductible,
208 false, false, );
211
212 vec![line]
213 }
214
215 pub fn generate_batch(
219 &mut self,
220 doc_type: TaxableDocumentType,
221 documents: &[(String, String, String, Decimal, NaiveDate, Option<String>)],
222 ) -> Vec<TaxLine> {
223 let mut result = Vec::new();
224 for (doc_id, seller, buyer, amount, date, category) in documents {
225 let lines = self.generate_for_document(
226 doc_type,
227 doc_id,
228 seller,
229 buyer,
230 *amount,
231 *date,
232 category.as_deref(),
233 );
234 result.extend(lines);
235 }
236 result
237 }
238
239 fn generate_reverse_charge_line(
247 &mut self,
248 doc_type: TaxableDocumentType,
249 doc_id: &str,
250 buyer_country: &str,
251 taxable_amount: Decimal,
252 date: NaiveDate,
253 ) -> Vec<TaxLine> {
254 let buyer_jurisdiction_id = self.resolve_jurisdiction_id(buyer_country);
255
256 let tax_code = match self.find_standard_code(&buyer_jurisdiction_id, date) {
257 Some(code) => code,
258 None => return Vec::new(),
259 };
260
261 let tax_amount = tax_code.tax_amount(taxable_amount);
262 let is_deductible = matches!(doc_type, TaxableDocumentType::VendorInvoice);
263
264 let line = self.build_tax_line(
265 doc_type,
266 doc_id,
267 &tax_code.id,
268 &buyer_jurisdiction_id,
269 taxable_amount,
270 tax_amount,
271 is_deductible,
272 true, true, );
275
276 vec![line]
277 }
278
279 fn resolve_jurisdiction_id(&self, country_or_state: &str) -> String {
284 if let Some(state_code) = country_or_state.strip_prefix("US-") {
285 format!("JUR-US-{state_code}")
287 } else {
288 format!("JUR-{country_or_state}")
289 }
290 }
291
292 fn find_standard_code(&self, jurisdiction_id: &str, date: NaiveDate) -> Option<TaxCode> {
299 let codes = self.tax_codes_by_jurisdiction.get(jurisdiction_id)?;
300
301 let mut candidates: Vec<&TaxCode> = codes
303 .iter()
304 .filter(|c| c.is_active(date) && !c.is_exempt)
305 .collect();
306
307 if candidates.is_empty() {
308 return None;
309 }
310
311 candidates.sort_by(|a, b| b.rate.cmp(&a.rate));
313
314 Some(candidates[0].clone())
315 }
316
317 #[allow(clippy::too_many_arguments)]
319 fn build_tax_line(
320 &mut self,
321 doc_type: TaxableDocumentType,
322 doc_id: &str,
323 tax_code_id: &str,
324 jurisdiction_id: &str,
325 taxable_amount: Decimal,
326 tax_amount: Decimal,
327 is_deductible: bool,
328 is_reverse_charge: bool,
329 is_self_assessed: bool,
330 ) -> TaxLine {
331 self.counter += 1;
332 let line_id = format!("TXLN-{:06}", self.counter);
333
334 let _noise: f64 = self.rng.random();
336
337 TaxLine::new(
338 line_id,
339 doc_type,
340 doc_id,
341 1, tax_code_id,
343 jurisdiction_id,
344 taxable_amount,
345 tax_amount,
346 )
347 .with_deductible(is_deductible)
348 .with_reverse_charge(is_reverse_charge)
349 .with_self_assessed(is_self_assessed)
350 }
351}
352
353#[cfg(test)]
358#[allow(clippy::unwrap_used)]
359mod tests {
360 use super::*;
361 use crate::tax::TaxCodeGenerator;
362 use datasynth_config::schema::TaxConfig;
363 use rust_decimal_macros::dec;
364
365 fn make_tax_codes() -> Vec<TaxCode> {
367 let mut config = TaxConfig::default();
368 config.jurisdictions.countries = vec!["DE".into(), "FR".into(), "GB".into(), "US".into()];
369 config.jurisdictions.include_subnational = true;
370
371 let mut gen = TaxCodeGenerator::with_config(42, config);
372 let (_jurisdictions, codes) = gen.generate();
373 codes
374 }
375
376 fn test_date() -> NaiveDate {
377 NaiveDate::from_ymd_opt(2024, 6, 15).unwrap()
378 }
379
380 #[test]
381 fn test_domestic_vendor_invoice() {
382 let codes = make_tax_codes();
383 let config = TaxLineGeneratorConfig::default();
384 let mut gen = TaxLineGenerator::new(config, codes, 42);
385
386 let lines = gen.generate_for_document(
387 TaxableDocumentType::VendorInvoice,
388 "INV-001",
389 "DE", "DE", dec!(10000),
392 test_date(),
393 None,
394 );
395
396 assert_eq!(lines.len(), 1, "Should produce one tax line");
397 let line = &lines[0];
398 assert_eq!(line.document_id, "INV-001");
399 assert_eq!(line.jurisdiction_id, "JUR-DE");
400 assert_eq!(line.tax_amount, dec!(1900.00));
402 assert_eq!(line.taxable_amount, dec!(10000));
403 assert!(line.is_deductible, "Vendor invoice input VAT is deductible");
404 assert!(!line.is_reverse_charge);
405 assert!(!line.is_self_assessed);
406 }
407
408 #[test]
409 fn test_domestic_customer_invoice() {
410 let codes = make_tax_codes();
411 let config = TaxLineGeneratorConfig::default();
412 let mut gen = TaxLineGenerator::new(config, codes, 42);
413
414 let lines = gen.generate_for_document(
415 TaxableDocumentType::CustomerInvoice,
416 "CINV-001",
417 "DE", "DE", dec!(5000),
420 test_date(),
421 None,
422 );
423
424 assert_eq!(lines.len(), 1);
425 let line = &lines[0];
426 assert_eq!(line.document_id, "CINV-001");
427 assert_eq!(line.jurisdiction_id, "JUR-DE");
428 assert_eq!(line.tax_amount, dec!(950.00));
430 assert!(
431 !line.is_deductible,
432 "Customer invoice output VAT is not deductible"
433 );
434 assert!(!line.is_reverse_charge);
435 }
436
437 #[test]
438 fn test_eu_cross_border_reverse_charge() {
439 let codes = make_tax_codes();
440 let config = TaxLineGeneratorConfig::default();
441 let mut gen = TaxLineGenerator::new(config, codes, 42);
442
443 let lines = gen.generate_for_document(
444 TaxableDocumentType::VendorInvoice,
445 "INV-EU-001",
446 "DE", "FR", dec!(20000),
449 test_date(),
450 None,
451 );
452
453 assert_eq!(lines.len(), 1, "Should produce one reverse-charge line");
454 let line = &lines[0];
455 assert_eq!(line.document_id, "INV-EU-001");
456 assert_eq!(line.jurisdiction_id, "JUR-FR");
458 assert_eq!(line.tax_amount, dec!(4000.00));
459 assert!(line.is_reverse_charge, "Should be reverse charge");
460 assert!(line.is_self_assessed, "Buyer should self-assess");
461 assert!(
462 line.is_deductible,
463 "Vendor invoice reverse charge is still deductible"
464 );
465 }
466
467 #[test]
468 fn test_exempt_category() {
469 let codes = make_tax_codes();
470 let config = TaxLineGeneratorConfig {
471 exempt_categories: vec!["financial_services".into(), "education".into()],
472 ..Default::default()
473 };
474 let mut gen = TaxLineGenerator::new(config, codes, 42);
475
476 let lines = gen.generate_for_document(
477 TaxableDocumentType::VendorInvoice,
478 "INV-EXEMPT",
479 "DE",
480 "DE",
481 dec!(50000),
482 test_date(),
483 Some("financial_services"),
484 );
485
486 assert!(
487 lines.is_empty(),
488 "Exempt category should produce no tax lines"
489 );
490
491 let lines2 = gen.generate_for_document(
493 TaxableDocumentType::VendorInvoice,
494 "INV-EXEMPT-2",
495 "DE",
496 "DE",
497 dec!(50000),
498 test_date(),
499 Some("FINANCIAL_SERVICES"),
500 );
501 assert!(
502 lines2.is_empty(),
503 "Exempt category check should be case-insensitive"
504 );
505 }
506
507 #[test]
508 fn test_non_eu_cross_border() {
509 let codes = make_tax_codes();
510 let config = TaxLineGeneratorConfig::default();
511 let mut gen = TaxLineGenerator::new(config, codes, 42);
512
513 let lines = gen.generate_for_document(
516 TaxableDocumentType::VendorInvoice,
517 "INV-XBORDER",
518 "US", "DE", dec!(10000),
521 test_date(),
522 None,
523 );
524
525 assert!(
530 lines.is_empty() || lines.iter().all(|l| !l.is_reverse_charge),
531 "Non-EU cross-border should NOT use reverse charge"
532 );
533 }
534
535 #[test]
536 fn test_us_sales_tax() {
537 let codes = make_tax_codes();
538 let config = TaxLineGeneratorConfig::default();
539 let mut gen = TaxLineGenerator::new(config, codes, 42);
540
541 let lines = gen.generate_for_document(
543 TaxableDocumentType::CustomerInvoice,
544 "CINV-US-001",
545 "US", "US-CA", dec!(1000),
548 test_date(),
549 None,
550 );
551
552 assert_eq!(lines.len(), 1, "Should produce one sales tax line");
553 let line = &lines[0];
554 assert_eq!(line.jurisdiction_id, "JUR-US-CA");
555 assert_eq!(line.tax_amount, dec!(72.50));
557 assert!(!line.is_deductible, "Customer invoice not deductible");
558 }
559
560 #[test]
561 fn test_no_matching_code() {
562 let codes = make_tax_codes();
563 let config = TaxLineGeneratorConfig::default();
564 let mut gen = TaxLineGenerator::new(config, codes, 42);
565
566 let lines = gen.generate_for_document(
568 TaxableDocumentType::VendorInvoice,
569 "INV-UNKNOWN",
570 "ZZ", "ZZ",
572 dec!(10000),
573 test_date(),
574 None,
575 );
576
577 assert!(
578 lines.is_empty(),
579 "Unknown jurisdiction should produce no tax lines"
580 );
581 }
582
583 #[test]
584 fn test_batch_generation() {
585 let codes = make_tax_codes();
586 let config = TaxLineGeneratorConfig::default();
587 let mut gen = TaxLineGenerator::new(config, codes, 42);
588 let date = test_date();
589
590 let documents = vec![
591 (
592 "INV-B1".into(),
593 "DE".into(),
594 "DE".into(),
595 dec!(1000),
596 date,
597 None,
598 ),
599 (
600 "INV-B2".into(),
601 "FR".into(),
602 "FR".into(),
603 dec!(2000),
604 date,
605 None,
606 ),
607 (
608 "INV-B3".into(),
609 "GB".into(),
610 "GB".into(),
611 dec!(3000),
612 date,
613 None,
614 ),
615 ];
616
617 let lines = gen.generate_batch(TaxableDocumentType::VendorInvoice, &documents);
618
619 assert_eq!(lines.len(), 3, "Should produce one line per document");
620
621 let doc_ids: Vec<&str> = lines.iter().map(|l| l.document_id.as_str()).collect();
623 assert!(doc_ids.contains(&"INV-B1"));
624 assert!(doc_ids.contains(&"INV-B2"));
625 assert!(doc_ids.contains(&"INV-B3"));
626
627 let de_line = lines.iter().find(|l| l.document_id == "INV-B1").unwrap();
629 assert_eq!(de_line.tax_amount, dec!(190.00));
630
631 let fr_line = lines.iter().find(|l| l.document_id == "INV-B2").unwrap();
632 assert_eq!(fr_line.tax_amount, dec!(400.00));
633
634 let gb_line = lines.iter().find(|l| l.document_id == "INV-B3").unwrap();
635 assert_eq!(gb_line.tax_amount, dec!(600.00));
636 }
637
638 #[test]
639 fn test_deterministic() {
640 let codes1 = make_tax_codes();
641 let codes2 = make_tax_codes();
642 let config1 = TaxLineGeneratorConfig::default();
643 let config2 = TaxLineGeneratorConfig::default();
644 let date = test_date();
645
646 let mut gen1 = TaxLineGenerator::new(config1, codes1, 999);
647 let mut gen2 = TaxLineGenerator::new(config2, codes2, 999);
648
649 let lines1 = gen1.generate_for_document(
650 TaxableDocumentType::VendorInvoice,
651 "INV-DET",
652 "DE",
653 "DE",
654 dec!(5000),
655 date,
656 None,
657 );
658 let lines2 = gen2.generate_for_document(
659 TaxableDocumentType::VendorInvoice,
660 "INV-DET",
661 "DE",
662 "DE",
663 dec!(5000),
664 date,
665 None,
666 );
667
668 assert_eq!(lines1.len(), lines2.len());
669 for (l1, l2) in lines1.iter().zip(lines2.iter()) {
670 assert_eq!(l1.id, l2.id);
671 assert_eq!(l1.tax_code_id, l2.tax_code_id);
672 assert_eq!(l1.tax_amount, l2.tax_amount);
673 assert_eq!(l1.jurisdiction_id, l2.jurisdiction_id);
674 assert_eq!(l1.is_deductible, l2.is_deductible);
675 assert_eq!(l1.is_reverse_charge, l2.is_reverse_charge);
676 }
677 }
678
679 #[test]
680 fn test_line_counter_increments() {
681 let codes = make_tax_codes();
682 let config = TaxLineGeneratorConfig::default();
683 let mut gen = TaxLineGenerator::new(config, codes, 42);
684 let date = test_date();
685
686 let lines1 = gen.generate_for_document(
687 TaxableDocumentType::VendorInvoice,
688 "INV-C1",
689 "DE",
690 "DE",
691 dec!(1000),
692 date,
693 None,
694 );
695 let lines2 = gen.generate_for_document(
696 TaxableDocumentType::VendorInvoice,
697 "INV-C2",
698 "DE",
699 "DE",
700 dec!(2000),
701 date,
702 None,
703 );
704 let lines3 = gen.generate_for_document(
705 TaxableDocumentType::VendorInvoice,
706 "INV-C3",
707 "DE",
708 "DE",
709 dec!(3000),
710 date,
711 None,
712 );
713
714 assert_eq!(lines1[0].id, "TXLN-000001");
715 assert_eq!(lines2[0].id, "TXLN-000002");
716 assert_eq!(lines3[0].id, "TXLN-000003");
717 }
718}