1use chrono::NaiveDate;
9use datasynth_core::utils::seeded_rng;
10use rand::prelude::*;
11use rand_chacha::ChaCha8Rng;
12use rust_decimal::Decimal;
13use std::collections::{HashMap, HashSet};
14
15use datasynth_core::models::{TaxCode, TaxLine, TaxableDocumentType};
16
17#[derive(Debug, Clone)]
23pub struct TaxLineGeneratorConfig {
24 pub exempt_categories: Vec<String>,
26 pub eu_countries: HashSet<String>,
28}
29
30impl Default for TaxLineGeneratorConfig {
31 fn default() -> Self {
32 Self {
33 exempt_categories: Vec::new(),
34 eu_countries: HashSet::from([
35 "DE".into(),
36 "FR".into(),
37 "IT".into(),
38 "ES".into(),
39 "NL".into(),
40 "BE".into(),
41 "AT".into(),
42 "PT".into(),
43 "IE".into(),
44 "FI".into(),
45 "SE".into(),
46 "DK".into(),
47 "PL".into(),
48 "CZ".into(),
49 "RO".into(),
50 "HU".into(),
51 "BG".into(),
52 "HR".into(),
53 "SK".into(),
54 "SI".into(),
55 "LT".into(),
56 "LV".into(),
57 "EE".into(),
58 "CY".into(),
59 "LU".into(),
60 "MT".into(),
61 "GR".into(),
62 ]),
63 }
64 }
65}
66
67pub struct TaxLineGenerator {
102 rng: ChaCha8Rng,
103 tax_codes_by_jurisdiction: HashMap<String, Vec<TaxCode>>,
105 config: TaxLineGeneratorConfig,
106 counter: u64,
107}
108
109impl TaxLineGenerator {
110 pub fn new(config: TaxLineGeneratorConfig, tax_codes: Vec<TaxCode>, seed: u64) -> Self {
114 let mut tax_codes_by_jurisdiction: HashMap<String, Vec<TaxCode>> = HashMap::new();
115 for code in tax_codes {
116 tax_codes_by_jurisdiction
117 .entry(code.jurisdiction_id.clone())
118 .or_default()
119 .push(code);
120 }
121
122 Self {
123 rng: seeded_rng(seed, 0),
124 tax_codes_by_jurisdiction,
125 config,
126 counter: 0,
127 }
128 }
129
130 pub fn generate_for_document(
138 &mut self,
139 doc_type: TaxableDocumentType,
140 doc_id: &str,
141 seller_country: &str,
142 buyer_country: &str,
143 taxable_amount: Decimal,
144 date: NaiveDate,
145 product_category: Option<&str>,
146 ) -> Vec<TaxLine> {
147 if let Some(cat) = product_category {
149 if self
150 .config
151 .exempt_categories
152 .iter()
153 .any(|e| e.eq_ignore_ascii_case(cat))
154 {
155 return Vec::new();
156 }
157 }
158
159 let jurisdiction_country = match doc_type {
161 TaxableDocumentType::VendorInvoice => seller_country,
162 TaxableDocumentType::CustomerInvoice => {
163 buyer_country
166 }
167 TaxableDocumentType::JournalEntry => seller_country,
168 _ => seller_country,
170 };
171
172 let is_eu_cross_border = seller_country != buyer_country
174 && self.config.eu_countries.contains(seller_country)
175 && self.config.eu_countries.contains(buyer_country);
176
177 if is_eu_cross_border {
178 return self.generate_reverse_charge_line(
179 doc_type,
180 doc_id,
181 buyer_country,
182 taxable_amount,
183 date,
184 );
185 }
186
187 let jurisdiction_id = self.resolve_jurisdiction_id(jurisdiction_country);
189
190 let tax_code = match self.find_standard_code(&jurisdiction_id, date) {
192 Some(code) => code,
193 None => return Vec::new(), };
195
196 let tax_amount = tax_code.tax_amount(taxable_amount);
198 let is_deductible = matches!(doc_type, TaxableDocumentType::VendorInvoice);
199
200 let line = self.build_tax_line(
201 doc_type,
202 doc_id,
203 &tax_code.id,
204 &jurisdiction_id,
205 taxable_amount,
206 tax_amount,
207 is_deductible,
208 false, false, );
211
212 vec![line]
213 }
214
215 pub fn generate_batch(
219 &mut self,
220 doc_type: TaxableDocumentType,
221 documents: &[(String, String, String, Decimal, NaiveDate, Option<String>)],
222 ) -> Vec<TaxLine> {
223 let mut result = Vec::new();
224 for (doc_id, seller, buyer, amount, date, category) in documents {
225 let lines = self.generate_for_document(
226 doc_type,
227 doc_id,
228 seller,
229 buyer,
230 *amount,
231 *date,
232 category.as_deref(),
233 );
234 result.extend(lines);
235 }
236 result
237 }
238
239 fn generate_reverse_charge_line(
247 &mut self,
248 doc_type: TaxableDocumentType,
249 doc_id: &str,
250 buyer_country: &str,
251 taxable_amount: Decimal,
252 date: NaiveDate,
253 ) -> Vec<TaxLine> {
254 let buyer_jurisdiction_id = self.resolve_jurisdiction_id(buyer_country);
255
256 let tax_code = match self.find_standard_code(&buyer_jurisdiction_id, date) {
257 Some(code) => code,
258 None => return Vec::new(),
259 };
260
261 let tax_amount = tax_code.tax_amount(taxable_amount);
262 let is_deductible = matches!(doc_type, TaxableDocumentType::VendorInvoice);
263
264 let line = self.build_tax_line(
265 doc_type,
266 doc_id,
267 &tax_code.id,
268 &buyer_jurisdiction_id,
269 taxable_amount,
270 tax_amount,
271 is_deductible,
272 true, true, );
275
276 vec![line]
277 }
278
279 fn resolve_jurisdiction_id(&self, country_or_state: &str) -> String {
284 if let Some(state_code) = country_or_state.strip_prefix("US-") {
285 format!("JUR-US-{state_code}")
287 } else {
288 format!("JUR-{country_or_state}")
289 }
290 }
291
292 fn find_standard_code(&self, jurisdiction_id: &str, date: NaiveDate) -> Option<TaxCode> {
299 let codes = self.tax_codes_by_jurisdiction.get(jurisdiction_id)?;
300
301 let mut candidates: Vec<&TaxCode> = codes
303 .iter()
304 .filter(|c| c.is_active(date) && !c.is_exempt)
305 .collect();
306
307 if candidates.is_empty() {
308 return None;
309 }
310
311 candidates.sort_by_key(|b| std::cmp::Reverse(b.rate));
313
314 Some(candidates[0].clone())
315 }
316
317 #[allow(clippy::too_many_arguments)]
319 fn build_tax_line(
320 &mut self,
321 doc_type: TaxableDocumentType,
322 doc_id: &str,
323 tax_code_id: &str,
324 jurisdiction_id: &str,
325 taxable_amount: Decimal,
326 tax_amount: Decimal,
327 is_deductible: bool,
328 is_reverse_charge: bool,
329 is_self_assessed: bool,
330 ) -> TaxLine {
331 self.counter += 1;
332 let line_id = format!("TXLN-{:06}", self.counter);
333
334 let _noise: f64 = self.rng.random();
336
337 TaxLine::new(
338 line_id,
339 doc_type,
340 doc_id,
341 1, tax_code_id,
343 jurisdiction_id,
344 taxable_amount,
345 tax_amount,
346 )
347 .with_deductible(is_deductible)
348 .with_reverse_charge(is_reverse_charge)
349 .with_self_assessed(is_self_assessed)
350 }
351}
352
353#[cfg(test)]
358mod tests {
359 use super::*;
360 use crate::tax::TaxCodeGenerator;
361 use datasynth_config::schema::TaxConfig;
362 use rust_decimal_macros::dec;
363
364 fn make_tax_codes() -> Vec<TaxCode> {
366 let mut config = TaxConfig::default();
367 config.jurisdictions.countries = vec!["DE".into(), "FR".into(), "GB".into(), "US".into()];
368 config.jurisdictions.include_subnational = true;
369
370 let mut gen = TaxCodeGenerator::with_config(42, config);
371 let (_jurisdictions, codes) = gen.generate();
372 codes
373 }
374
375 fn test_date() -> NaiveDate {
376 NaiveDate::from_ymd_opt(2024, 6, 15).unwrap()
377 }
378
379 #[test]
380 fn test_domestic_vendor_invoice() {
381 let codes = make_tax_codes();
382 let config = TaxLineGeneratorConfig::default();
383 let mut gen = TaxLineGenerator::new(config, codes, 42);
384
385 let lines = gen.generate_for_document(
386 TaxableDocumentType::VendorInvoice,
387 "INV-001",
388 "DE", "DE", dec!(10000),
391 test_date(),
392 None,
393 );
394
395 assert_eq!(lines.len(), 1, "Should produce one tax line");
396 let line = &lines[0];
397 assert_eq!(line.document_id, "INV-001");
398 assert_eq!(line.jurisdiction_id, "JUR-DE");
399 assert_eq!(line.tax_amount, dec!(1900.00));
401 assert_eq!(line.taxable_amount, dec!(10000));
402 assert!(line.is_deductible, "Vendor invoice input VAT is deductible");
403 assert!(!line.is_reverse_charge);
404 assert!(!line.is_self_assessed);
405 }
406
407 #[test]
408 fn test_domestic_customer_invoice() {
409 let codes = make_tax_codes();
410 let config = TaxLineGeneratorConfig::default();
411 let mut gen = TaxLineGenerator::new(config, codes, 42);
412
413 let lines = gen.generate_for_document(
414 TaxableDocumentType::CustomerInvoice,
415 "CINV-001",
416 "DE", "DE", dec!(5000),
419 test_date(),
420 None,
421 );
422
423 assert_eq!(lines.len(), 1);
424 let line = &lines[0];
425 assert_eq!(line.document_id, "CINV-001");
426 assert_eq!(line.jurisdiction_id, "JUR-DE");
427 assert_eq!(line.tax_amount, dec!(950.00));
429 assert!(
430 !line.is_deductible,
431 "Customer invoice output VAT is not deductible"
432 );
433 assert!(!line.is_reverse_charge);
434 }
435
436 #[test]
437 fn test_eu_cross_border_reverse_charge() {
438 let codes = make_tax_codes();
439 let config = TaxLineGeneratorConfig::default();
440 let mut gen = TaxLineGenerator::new(config, codes, 42);
441
442 let lines = gen.generate_for_document(
443 TaxableDocumentType::VendorInvoice,
444 "INV-EU-001",
445 "DE", "FR", dec!(20000),
448 test_date(),
449 None,
450 );
451
452 assert_eq!(lines.len(), 1, "Should produce one reverse-charge line");
453 let line = &lines[0];
454 assert_eq!(line.document_id, "INV-EU-001");
455 assert_eq!(line.jurisdiction_id, "JUR-FR");
457 assert_eq!(line.tax_amount, dec!(4000.00));
458 assert!(line.is_reverse_charge, "Should be reverse charge");
459 assert!(line.is_self_assessed, "Buyer should self-assess");
460 assert!(
461 line.is_deductible,
462 "Vendor invoice reverse charge is still deductible"
463 );
464 }
465
466 #[test]
467 fn test_exempt_category() {
468 let codes = make_tax_codes();
469 let config = TaxLineGeneratorConfig {
470 exempt_categories: vec!["financial_services".into(), "education".into()],
471 ..Default::default()
472 };
473 let mut gen = TaxLineGenerator::new(config, codes, 42);
474
475 let lines = gen.generate_for_document(
476 TaxableDocumentType::VendorInvoice,
477 "INV-EXEMPT",
478 "DE",
479 "DE",
480 dec!(50000),
481 test_date(),
482 Some("financial_services"),
483 );
484
485 assert!(
486 lines.is_empty(),
487 "Exempt category should produce no tax lines"
488 );
489
490 let lines2 = gen.generate_for_document(
492 TaxableDocumentType::VendorInvoice,
493 "INV-EXEMPT-2",
494 "DE",
495 "DE",
496 dec!(50000),
497 test_date(),
498 Some("FINANCIAL_SERVICES"),
499 );
500 assert!(
501 lines2.is_empty(),
502 "Exempt category check should be case-insensitive"
503 );
504 }
505
506 #[test]
507 fn test_non_eu_cross_border() {
508 let codes = make_tax_codes();
509 let config = TaxLineGeneratorConfig::default();
510 let mut gen = TaxLineGenerator::new(config, codes, 42);
511
512 let lines = gen.generate_for_document(
515 TaxableDocumentType::VendorInvoice,
516 "INV-XBORDER",
517 "US", "DE", dec!(10000),
520 test_date(),
521 None,
522 );
523
524 assert!(
529 lines.is_empty() || lines.iter().all(|l| !l.is_reverse_charge),
530 "Non-EU cross-border should NOT use reverse charge"
531 );
532 }
533
534 #[test]
535 fn test_us_sales_tax() {
536 let codes = make_tax_codes();
537 let config = TaxLineGeneratorConfig::default();
538 let mut gen = TaxLineGenerator::new(config, codes, 42);
539
540 let lines = gen.generate_for_document(
542 TaxableDocumentType::CustomerInvoice,
543 "CINV-US-001",
544 "US", "US-CA", dec!(1000),
547 test_date(),
548 None,
549 );
550
551 assert_eq!(lines.len(), 1, "Should produce one sales tax line");
552 let line = &lines[0];
553 assert_eq!(line.jurisdiction_id, "JUR-US-CA");
554 assert_eq!(line.tax_amount, dec!(72.50));
556 assert!(!line.is_deductible, "Customer invoice not deductible");
557 }
558
559 #[test]
560 fn test_no_matching_code() {
561 let codes = make_tax_codes();
562 let config = TaxLineGeneratorConfig::default();
563 let mut gen = TaxLineGenerator::new(config, codes, 42);
564
565 let lines = gen.generate_for_document(
567 TaxableDocumentType::VendorInvoice,
568 "INV-UNKNOWN",
569 "ZZ", "ZZ",
571 dec!(10000),
572 test_date(),
573 None,
574 );
575
576 assert!(
577 lines.is_empty(),
578 "Unknown jurisdiction should produce no tax lines"
579 );
580 }
581
582 #[test]
583 fn test_batch_generation() {
584 let codes = make_tax_codes();
585 let config = TaxLineGeneratorConfig::default();
586 let mut gen = TaxLineGenerator::new(config, codes, 42);
587 let date = test_date();
588
589 let documents = vec![
590 (
591 "INV-B1".into(),
592 "DE".into(),
593 "DE".into(),
594 dec!(1000),
595 date,
596 None,
597 ),
598 (
599 "INV-B2".into(),
600 "FR".into(),
601 "FR".into(),
602 dec!(2000),
603 date,
604 None,
605 ),
606 (
607 "INV-B3".into(),
608 "GB".into(),
609 "GB".into(),
610 dec!(3000),
611 date,
612 None,
613 ),
614 ];
615
616 let lines = gen.generate_batch(TaxableDocumentType::VendorInvoice, &documents);
617
618 assert_eq!(lines.len(), 3, "Should produce one line per document");
619
620 let doc_ids: Vec<&str> = lines.iter().map(|l| l.document_id.as_str()).collect();
622 assert!(doc_ids.contains(&"INV-B1"));
623 assert!(doc_ids.contains(&"INV-B2"));
624 assert!(doc_ids.contains(&"INV-B3"));
625
626 let de_line = lines.iter().find(|l| l.document_id == "INV-B1").unwrap();
628 assert_eq!(de_line.tax_amount, dec!(190.00));
629
630 let fr_line = lines.iter().find(|l| l.document_id == "INV-B2").unwrap();
631 assert_eq!(fr_line.tax_amount, dec!(400.00));
632
633 let gb_line = lines.iter().find(|l| l.document_id == "INV-B3").unwrap();
634 assert_eq!(gb_line.tax_amount, dec!(600.00));
635 }
636
637 #[test]
638 fn test_deterministic() {
639 let codes1 = make_tax_codes();
640 let codes2 = make_tax_codes();
641 let config1 = TaxLineGeneratorConfig::default();
642 let config2 = TaxLineGeneratorConfig::default();
643 let date = test_date();
644
645 let mut gen1 = TaxLineGenerator::new(config1, codes1, 999);
646 let mut gen2 = TaxLineGenerator::new(config2, codes2, 999);
647
648 let lines1 = gen1.generate_for_document(
649 TaxableDocumentType::VendorInvoice,
650 "INV-DET",
651 "DE",
652 "DE",
653 dec!(5000),
654 date,
655 None,
656 );
657 let lines2 = gen2.generate_for_document(
658 TaxableDocumentType::VendorInvoice,
659 "INV-DET",
660 "DE",
661 "DE",
662 dec!(5000),
663 date,
664 None,
665 );
666
667 assert_eq!(lines1.len(), lines2.len());
668 for (l1, l2) in lines1.iter().zip(lines2.iter()) {
669 assert_eq!(l1.id, l2.id);
670 assert_eq!(l1.tax_code_id, l2.tax_code_id);
671 assert_eq!(l1.tax_amount, l2.tax_amount);
672 assert_eq!(l1.jurisdiction_id, l2.jurisdiction_id);
673 assert_eq!(l1.is_deductible, l2.is_deductible);
674 assert_eq!(l1.is_reverse_charge, l2.is_reverse_charge);
675 }
676 }
677
678 #[test]
679 fn test_line_counter_increments() {
680 let codes = make_tax_codes();
681 let config = TaxLineGeneratorConfig::default();
682 let mut gen = TaxLineGenerator::new(config, codes, 42);
683 let date = test_date();
684
685 let lines1 = gen.generate_for_document(
686 TaxableDocumentType::VendorInvoice,
687 "INV-C1",
688 "DE",
689 "DE",
690 dec!(1000),
691 date,
692 None,
693 );
694 let lines2 = gen.generate_for_document(
695 TaxableDocumentType::VendorInvoice,
696 "INV-C2",
697 "DE",
698 "DE",
699 dec!(2000),
700 date,
701 None,
702 );
703 let lines3 = gen.generate_for_document(
704 TaxableDocumentType::VendorInvoice,
705 "INV-C3",
706 "DE",
707 "DE",
708 dec!(3000),
709 date,
710 None,
711 );
712
713 assert_eq!(lines1[0].id, "TXLN-000001");
714 assert_eq!(lines2[0].id, "TXLN-000002");
715 assert_eq!(lines3[0].id, "TXLN-000003");
716 }
717}