1use chrono::NaiveDate;
9use rand::prelude::*;
10use rand_chacha::ChaCha8Rng;
11use rust_decimal::Decimal;
12use std::collections::{HashMap, HashSet};
13
14use datasynth_core::models::{TaxCode, TaxLine, TaxableDocumentType};
15
16#[derive(Debug, Clone)]
22pub struct TaxLineGeneratorConfig {
23 pub exempt_categories: Vec<String>,
25 pub eu_countries: HashSet<String>,
27}
28
29impl Default for TaxLineGeneratorConfig {
30 fn default() -> Self {
31 Self {
32 exempt_categories: Vec::new(),
33 eu_countries: HashSet::from([
34 "DE".into(),
35 "FR".into(),
36 "IT".into(),
37 "ES".into(),
38 "NL".into(),
39 "BE".into(),
40 "AT".into(),
41 "PT".into(),
42 "IE".into(),
43 "FI".into(),
44 "SE".into(),
45 "DK".into(),
46 "PL".into(),
47 "CZ".into(),
48 "RO".into(),
49 "HU".into(),
50 "BG".into(),
51 "HR".into(),
52 "SK".into(),
53 "SI".into(),
54 "LT".into(),
55 "LV".into(),
56 "EE".into(),
57 "CY".into(),
58 "LU".into(),
59 "MT".into(),
60 "GR".into(),
61 ]),
62 }
63 }
64}
65
66pub struct TaxLineGenerator {
101 rng: ChaCha8Rng,
102 tax_codes_by_jurisdiction: HashMap<String, Vec<TaxCode>>,
104 config: TaxLineGeneratorConfig,
105 counter: u64,
106}
107
108impl TaxLineGenerator {
109 pub fn new(seed: u64, tax_codes: Vec<TaxCode>, config: TaxLineGeneratorConfig) -> Self {
113 let mut tax_codes_by_jurisdiction: HashMap<String, Vec<TaxCode>> = HashMap::new();
114 for code in tax_codes {
115 tax_codes_by_jurisdiction
116 .entry(code.jurisdiction_id.clone())
117 .or_default()
118 .push(code);
119 }
120
121 Self {
122 rng: ChaCha8Rng::seed_from_u64(seed),
123 tax_codes_by_jurisdiction,
124 config,
125 counter: 0,
126 }
127 }
128
129 pub fn generate_for_document(
137 &mut self,
138 doc_type: TaxableDocumentType,
139 doc_id: &str,
140 seller_country: &str,
141 buyer_country: &str,
142 taxable_amount: Decimal,
143 date: NaiveDate,
144 product_category: Option<&str>,
145 ) -> Vec<TaxLine> {
146 if let Some(cat) = product_category {
148 if self
149 .config
150 .exempt_categories
151 .iter()
152 .any(|e| e.eq_ignore_ascii_case(cat))
153 {
154 return Vec::new();
155 }
156 }
157
158 let jurisdiction_country = match doc_type {
160 TaxableDocumentType::VendorInvoice => seller_country,
161 TaxableDocumentType::CustomerInvoice => {
162 buyer_country
165 }
166 TaxableDocumentType::JournalEntry => seller_country,
167 _ => seller_country,
169 };
170
171 let is_eu_cross_border = seller_country != buyer_country
173 && self.config.eu_countries.contains(seller_country)
174 && self.config.eu_countries.contains(buyer_country);
175
176 if is_eu_cross_border {
177 return self.generate_reverse_charge_line(
178 doc_type,
179 doc_id,
180 buyer_country,
181 taxable_amount,
182 date,
183 );
184 }
185
186 let jurisdiction_id = self.resolve_jurisdiction_id(jurisdiction_country);
188
189 let tax_code = match self.find_standard_code(&jurisdiction_id, date) {
191 Some(code) => code,
192 None => return Vec::new(), };
194
195 let tax_amount = tax_code.tax_amount(taxable_amount);
197 let is_deductible = matches!(doc_type, TaxableDocumentType::VendorInvoice);
198
199 let line = self.build_tax_line(
200 doc_type,
201 doc_id,
202 &tax_code.id,
203 &jurisdiction_id,
204 taxable_amount,
205 tax_amount,
206 is_deductible,
207 false, false, );
210
211 vec![line]
212 }
213
214 pub fn generate_batch(
218 &mut self,
219 doc_type: TaxableDocumentType,
220 documents: &[(String, String, String, Decimal, NaiveDate, Option<String>)],
221 ) -> Vec<TaxLine> {
222 let mut result = Vec::new();
223 for (doc_id, seller, buyer, amount, date, category) in documents {
224 let lines = self.generate_for_document(
225 doc_type,
226 doc_id,
227 seller,
228 buyer,
229 *amount,
230 *date,
231 category.as_deref(),
232 );
233 result.extend(lines);
234 }
235 result
236 }
237
238 fn generate_reverse_charge_line(
246 &mut self,
247 doc_type: TaxableDocumentType,
248 doc_id: &str,
249 buyer_country: &str,
250 taxable_amount: Decimal,
251 date: NaiveDate,
252 ) -> Vec<TaxLine> {
253 let buyer_jurisdiction_id = self.resolve_jurisdiction_id(buyer_country);
254
255 let tax_code = match self.find_standard_code(&buyer_jurisdiction_id, date) {
256 Some(code) => code,
257 None => return Vec::new(),
258 };
259
260 let tax_amount = tax_code.tax_amount(taxable_amount);
261 let is_deductible = matches!(doc_type, TaxableDocumentType::VendorInvoice);
262
263 let line = self.build_tax_line(
264 doc_type,
265 doc_id,
266 &tax_code.id,
267 &buyer_jurisdiction_id,
268 taxable_amount,
269 tax_amount,
270 is_deductible,
271 true, true, );
274
275 vec![line]
276 }
277
278 fn resolve_jurisdiction_id(&self, country_or_state: &str) -> String {
283 if let Some(state_code) = country_or_state.strip_prefix("US-") {
284 format!("JUR-US-{state_code}")
286 } else {
287 format!("JUR-{country_or_state}")
288 }
289 }
290
291 fn find_standard_code(&self, jurisdiction_id: &str, date: NaiveDate) -> Option<TaxCode> {
298 let codes = self.tax_codes_by_jurisdiction.get(jurisdiction_id)?;
299
300 let mut candidates: Vec<&TaxCode> = codes
302 .iter()
303 .filter(|c| c.is_active(date) && !c.is_exempt)
304 .collect();
305
306 if candidates.is_empty() {
307 return None;
308 }
309
310 candidates.sort_by(|a, b| b.rate.cmp(&a.rate));
312
313 Some(candidates[0].clone())
314 }
315
316 #[allow(clippy::too_many_arguments)]
318 fn build_tax_line(
319 &mut self,
320 doc_type: TaxableDocumentType,
321 doc_id: &str,
322 tax_code_id: &str,
323 jurisdiction_id: &str,
324 taxable_amount: Decimal,
325 tax_amount: Decimal,
326 is_deductible: bool,
327 is_reverse_charge: bool,
328 is_self_assessed: bool,
329 ) -> TaxLine {
330 self.counter += 1;
331 let line_id = format!("TXLN-{:06}", self.counter);
332
333 let _noise: f64 = self.rng.gen();
335
336 TaxLine::new(
337 line_id,
338 doc_type,
339 doc_id,
340 1, tax_code_id,
342 jurisdiction_id,
343 taxable_amount,
344 tax_amount,
345 )
346 .with_deductible(is_deductible)
347 .with_reverse_charge(is_reverse_charge)
348 .with_self_assessed(is_self_assessed)
349 }
350}
351
352#[cfg(test)]
357#[allow(clippy::unwrap_used)]
358mod tests {
359 use super::*;
360 use crate::tax::TaxCodeGenerator;
361 use datasynth_config::schema::TaxConfig;
362 use rust_decimal_macros::dec;
363
364 fn make_tax_codes() -> Vec<TaxCode> {
366 let mut config = TaxConfig::default();
367 config.jurisdictions.countries = vec!["DE".into(), "FR".into(), "GB".into(), "US".into()];
368 config.jurisdictions.include_subnational = true;
369
370 let mut gen = TaxCodeGenerator::with_config(42, config);
371 let (_jurisdictions, codes) = gen.generate();
372 codes
373 }
374
375 fn test_date() -> NaiveDate {
376 NaiveDate::from_ymd_opt(2024, 6, 15).unwrap()
377 }
378
379 #[test]
380 fn test_domestic_vendor_invoice() {
381 let codes = make_tax_codes();
382 let config = TaxLineGeneratorConfig::default();
383 let mut gen = TaxLineGenerator::new(42, codes, config);
384
385 let lines = gen.generate_for_document(
386 TaxableDocumentType::VendorInvoice,
387 "INV-001",
388 "DE", "DE", dec!(10000),
391 test_date(),
392 None,
393 );
394
395 assert_eq!(lines.len(), 1, "Should produce one tax line");
396 let line = &lines[0];
397 assert_eq!(line.document_id, "INV-001");
398 assert_eq!(line.jurisdiction_id, "JUR-DE");
399 assert_eq!(line.tax_amount, dec!(1900.00));
401 assert_eq!(line.taxable_amount, dec!(10000));
402 assert!(line.is_deductible, "Vendor invoice input VAT is deductible");
403 assert!(!line.is_reverse_charge);
404 assert!(!line.is_self_assessed);
405 }
406
407 #[test]
408 fn test_domestic_customer_invoice() {
409 let codes = make_tax_codes();
410 let config = TaxLineGeneratorConfig::default();
411 let mut gen = TaxLineGenerator::new(42, codes, config);
412
413 let lines = gen.generate_for_document(
414 TaxableDocumentType::CustomerInvoice,
415 "CINV-001",
416 "DE", "DE", dec!(5000),
419 test_date(),
420 None,
421 );
422
423 assert_eq!(lines.len(), 1);
424 let line = &lines[0];
425 assert_eq!(line.document_id, "CINV-001");
426 assert_eq!(line.jurisdiction_id, "JUR-DE");
427 assert_eq!(line.tax_amount, dec!(950.00));
429 assert!(
430 !line.is_deductible,
431 "Customer invoice output VAT is not deductible"
432 );
433 assert!(!line.is_reverse_charge);
434 }
435
436 #[test]
437 fn test_eu_cross_border_reverse_charge() {
438 let codes = make_tax_codes();
439 let config = TaxLineGeneratorConfig::default();
440 let mut gen = TaxLineGenerator::new(42, codes, config);
441
442 let lines = gen.generate_for_document(
443 TaxableDocumentType::VendorInvoice,
444 "INV-EU-001",
445 "DE", "FR", dec!(20000),
448 test_date(),
449 None,
450 );
451
452 assert_eq!(lines.len(), 1, "Should produce one reverse-charge line");
453 let line = &lines[0];
454 assert_eq!(line.document_id, "INV-EU-001");
455 assert_eq!(line.jurisdiction_id, "JUR-FR");
457 assert_eq!(line.tax_amount, dec!(4000.00));
458 assert!(line.is_reverse_charge, "Should be reverse charge");
459 assert!(line.is_self_assessed, "Buyer should self-assess");
460 assert!(
461 line.is_deductible,
462 "Vendor invoice reverse charge is still deductible"
463 );
464 }
465
466 #[test]
467 fn test_exempt_category() {
468 let codes = make_tax_codes();
469 let config = TaxLineGeneratorConfig {
470 exempt_categories: vec!["financial_services".into(), "education".into()],
471 ..Default::default()
472 };
473 let mut gen = TaxLineGenerator::new(42, codes, config);
474
475 let lines = gen.generate_for_document(
476 TaxableDocumentType::VendorInvoice,
477 "INV-EXEMPT",
478 "DE",
479 "DE",
480 dec!(50000),
481 test_date(),
482 Some("financial_services"),
483 );
484
485 assert!(
486 lines.is_empty(),
487 "Exempt category should produce no tax lines"
488 );
489
490 let lines2 = gen.generate_for_document(
492 TaxableDocumentType::VendorInvoice,
493 "INV-EXEMPT-2",
494 "DE",
495 "DE",
496 dec!(50000),
497 test_date(),
498 Some("FINANCIAL_SERVICES"),
499 );
500 assert!(
501 lines2.is_empty(),
502 "Exempt category check should be case-insensitive"
503 );
504 }
505
506 #[test]
507 fn test_non_eu_cross_border() {
508 let codes = make_tax_codes();
509 let config = TaxLineGeneratorConfig::default();
510 let mut gen = TaxLineGenerator::new(42, codes, config);
511
512 let lines = gen.generate_for_document(
515 TaxableDocumentType::VendorInvoice,
516 "INV-XBORDER",
517 "US", "DE", dec!(10000),
520 test_date(),
521 None,
522 );
523
524 assert!(
529 lines.is_empty() || lines.iter().all(|l| !l.is_reverse_charge),
530 "Non-EU cross-border should NOT use reverse charge"
531 );
532 }
533
534 #[test]
535 fn test_us_sales_tax() {
536 let codes = make_tax_codes();
537 let config = TaxLineGeneratorConfig::default();
538 let mut gen = TaxLineGenerator::new(42, codes, config);
539
540 let lines = gen.generate_for_document(
542 TaxableDocumentType::CustomerInvoice,
543 "CINV-US-001",
544 "US", "US-CA", dec!(1000),
547 test_date(),
548 None,
549 );
550
551 assert_eq!(lines.len(), 1, "Should produce one sales tax line");
552 let line = &lines[0];
553 assert_eq!(line.jurisdiction_id, "JUR-US-CA");
554 assert_eq!(line.tax_amount, dec!(72.50));
556 assert!(!line.is_deductible, "Customer invoice not deductible");
557 }
558
559 #[test]
560 fn test_no_matching_code() {
561 let codes = make_tax_codes();
562 let config = TaxLineGeneratorConfig::default();
563 let mut gen = TaxLineGenerator::new(42, codes, config);
564
565 let lines = gen.generate_for_document(
567 TaxableDocumentType::VendorInvoice,
568 "INV-UNKNOWN",
569 "ZZ", "ZZ",
571 dec!(10000),
572 test_date(),
573 None,
574 );
575
576 assert!(
577 lines.is_empty(),
578 "Unknown jurisdiction should produce no tax lines"
579 );
580 }
581
582 #[test]
583 fn test_batch_generation() {
584 let codes = make_tax_codes();
585 let config = TaxLineGeneratorConfig::default();
586 let mut gen = TaxLineGenerator::new(42, codes, config);
587 let date = test_date();
588
589 let documents = vec![
590 (
591 "INV-B1".into(),
592 "DE".into(),
593 "DE".into(),
594 dec!(1000),
595 date,
596 None,
597 ),
598 (
599 "INV-B2".into(),
600 "FR".into(),
601 "FR".into(),
602 dec!(2000),
603 date,
604 None,
605 ),
606 (
607 "INV-B3".into(),
608 "GB".into(),
609 "GB".into(),
610 dec!(3000),
611 date,
612 None,
613 ),
614 ];
615
616 let lines = gen.generate_batch(TaxableDocumentType::VendorInvoice, &documents);
617
618 assert_eq!(lines.len(), 3, "Should produce one line per document");
619
620 let doc_ids: Vec<&str> = lines.iter().map(|l| l.document_id.as_str()).collect();
622 assert!(doc_ids.contains(&"INV-B1"));
623 assert!(doc_ids.contains(&"INV-B2"));
624 assert!(doc_ids.contains(&"INV-B3"));
625
626 let de_line = lines.iter().find(|l| l.document_id == "INV-B1").unwrap();
628 assert_eq!(de_line.tax_amount, dec!(190.00));
629
630 let fr_line = lines.iter().find(|l| l.document_id == "INV-B2").unwrap();
631 assert_eq!(fr_line.tax_amount, dec!(400.00));
632
633 let gb_line = lines.iter().find(|l| l.document_id == "INV-B3").unwrap();
634 assert_eq!(gb_line.tax_amount, dec!(600.00));
635 }
636
637 #[test]
638 fn test_deterministic() {
639 let codes1 = make_tax_codes();
640 let codes2 = make_tax_codes();
641 let config1 = TaxLineGeneratorConfig::default();
642 let config2 = TaxLineGeneratorConfig::default();
643 let date = test_date();
644
645 let mut gen1 = TaxLineGenerator::new(999, codes1, config1);
646 let mut gen2 = TaxLineGenerator::new(999, codes2, config2);
647
648 let lines1 = gen1.generate_for_document(
649 TaxableDocumentType::VendorInvoice,
650 "INV-DET",
651 "DE",
652 "DE",
653 dec!(5000),
654 date,
655 None,
656 );
657 let lines2 = gen2.generate_for_document(
658 TaxableDocumentType::VendorInvoice,
659 "INV-DET",
660 "DE",
661 "DE",
662 dec!(5000),
663 date,
664 None,
665 );
666
667 assert_eq!(lines1.len(), lines2.len());
668 for (l1, l2) in lines1.iter().zip(lines2.iter()) {
669 assert_eq!(l1.id, l2.id);
670 assert_eq!(l1.tax_code_id, l2.tax_code_id);
671 assert_eq!(l1.tax_amount, l2.tax_amount);
672 assert_eq!(l1.jurisdiction_id, l2.jurisdiction_id);
673 assert_eq!(l1.is_deductible, l2.is_deductible);
674 assert_eq!(l1.is_reverse_charge, l2.is_reverse_charge);
675 }
676 }
677
678 #[test]
679 fn test_line_counter_increments() {
680 let codes = make_tax_codes();
681 let config = TaxLineGeneratorConfig::default();
682 let mut gen = TaxLineGenerator::new(42, codes, config);
683 let date = test_date();
684
685 let lines1 = gen.generate_for_document(
686 TaxableDocumentType::VendorInvoice,
687 "INV-C1",
688 "DE",
689 "DE",
690 dec!(1000),
691 date,
692 None,
693 );
694 let lines2 = gen.generate_for_document(
695 TaxableDocumentType::VendorInvoice,
696 "INV-C2",
697 "DE",
698 "DE",
699 dec!(2000),
700 date,
701 None,
702 );
703 let lines3 = gen.generate_for_document(
704 TaxableDocumentType::VendorInvoice,
705 "INV-C3",
706 "DE",
707 "DE",
708 dec!(3000),
709 date,
710 None,
711 );
712
713 assert_eq!(lines1[0].id, "TXLN-000001");
714 assert_eq!(lines2[0].id, "TXLN-000002");
715 assert_eq!(lines3[0].id, "TXLN-000003");
716 }
717}