1use chrono::Duration;
9use datasynth_core::utils::seeded_rng;
10use rand::prelude::*;
11use rand_chacha::ChaCha8Rng;
12use rust_decimal::Decimal;
13use rust_decimal_macros::dec;
14use serde::{Deserialize, Serialize};
15
16use datasynth_core::models::{TaxLine, TaxReturn, WithholdingTaxRecord};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
24#[serde(rename_all = "snake_case")]
25pub enum TaxAnomalyType {
26 IncorrectTaxCode,
28 MissingTaxLine,
30 RateArbitrage,
32 LateFilingRisk,
34 TransferPricingDeviation,
36 WithholdingUnderstatement,
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
42#[serde(rename_all = "snake_case")]
43pub enum TaxAnomalySeverity {
44 Low,
45 Medium,
46 High,
47 Critical,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct TaxAnomalyLabel {
57 pub id: String,
59 pub anomaly_type: TaxAnomalyType,
61 pub severity: TaxAnomalySeverity,
63 pub document_type: String,
65 pub document_id: String,
67 pub description: String,
69 pub original_value: Option<String>,
71 pub anomalous_value: Option<String>,
73}
74
75const LOW_TAX_JURISDICTIONS: &[&str] = &[
81 "JUR-IE", "JUR-SG", "JUR-HK", "JUR-BM", "JUR-KY", "JUR-LU", ];
88
89pub struct TaxAnomalyInjector {
91 rng: ChaCha8Rng,
92 anomaly_rate: f64,
93 counter: u64,
94}
95
96impl TaxAnomalyInjector {
97 pub fn new(seed: u64, anomaly_rate: f64) -> Self {
104 Self {
105 rng: seeded_rng(seed, 0),
106 anomaly_rate: anomaly_rate.clamp(0.0, 1.0),
107 counter: 0,
108 }
109 }
110
111 pub fn inject_into_tax_lines(&mut self, lines: &mut Vec<TaxLine>) -> Vec<TaxAnomalyLabel> {
119 let mut labels = Vec::new();
120 let mut indices_to_remove: Vec<usize> = Vec::new();
121
122 for (i, line) in lines.iter_mut().enumerate() {
123 if !self.should_inject() {
124 continue;
125 }
126
127 let roll: f64 = self.rng.random();
128 if roll < 0.40 {
129 labels.push(self.inject_incorrect_tax_code(line));
131 } else if roll < 0.70 {
132 labels.push(self.create_missing_tax_line_label(line));
134 indices_to_remove.push(i);
135 } else if roll < 0.85 {
136 labels.push(self.inject_rate_arbitrage(line));
138 } else {
139 labels.push(self.inject_tax_line_understatement(line));
141 }
142 }
143
144 for &i in indices_to_remove.iter().rev() {
146 lines.remove(i);
147 }
148
149 labels
150 }
151
152 pub fn inject_into_returns(&mut self, returns: &mut [TaxReturn]) -> Vec<TaxAnomalyLabel> {
156 let mut labels = Vec::new();
157
158 for ret in returns.iter_mut() {
159 if !self.should_inject() {
160 continue;
161 }
162 labels.push(self.inject_late_filing(ret));
163 }
164
165 labels
166 }
167
168 pub fn inject_into_withholding(
172 &mut self,
173 records: &mut [WithholdingTaxRecord],
174 ) -> Vec<TaxAnomalyLabel> {
175 let mut labels = Vec::new();
176
177 for record in records.iter_mut() {
178 if !self.should_inject() {
179 continue;
180 }
181 labels.push(self.inject_withholding_understatement(record));
182 }
183
184 labels
185 }
186
187 fn should_inject(&mut self) -> bool {
193 self.rng.random::<f64>() < self.anomaly_rate
194 }
195
196 fn next_id(&mut self) -> String {
198 self.counter += 1;
199 format!("TXANO-{:06}", self.counter)
200 }
201
202 fn severity_from_impact(impact_ratio: Decimal) -> TaxAnomalySeverity {
204 if impact_ratio >= dec!(0.50) {
205 TaxAnomalySeverity::Critical
206 } else if impact_ratio >= dec!(0.25) {
207 TaxAnomalySeverity::High
208 } else if impact_ratio >= dec!(0.10) {
209 TaxAnomalySeverity::Medium
210 } else {
211 TaxAnomalySeverity::Low
212 }
213 }
214
215 fn inject_incorrect_tax_code(&mut self, line: &mut TaxLine) -> TaxAnomalyLabel {
217 let original_amount = line.tax_amount;
218 let original_rate = line.effective_rate();
219
220 let wrong_rates = [
224 dec!(0.05),
225 dec!(0.07),
226 dec!(0.10),
227 dec!(0.13),
228 dec!(0.15),
229 dec!(0.21),
230 dec!(0.23),
231 dec!(0.25),
232 ];
233
234 let idx = self.rng.random_range(0..wrong_rates.len());
235 let mut wrong_rate = wrong_rates[idx];
236 if wrong_rate == original_rate.round_dp(2) {
238 wrong_rate = wrong_rates[(idx + 1) % wrong_rates.len()];
239 }
240
241 let new_amount = (line.taxable_amount * wrong_rate).round_dp(2);
242 line.tax_amount = new_amount;
243
244 let impact = if original_amount.is_zero() {
245 dec!(1.0)
246 } else {
247 ((new_amount - original_amount).abs() / original_amount.abs()).round_dp(4)
248 };
249
250 TaxAnomalyLabel {
251 id: self.next_id(),
252 anomaly_type: TaxAnomalyType::IncorrectTaxCode,
253 severity: Self::severity_from_impact(impact),
254 document_type: "tax_line".to_string(),
255 document_id: line.id.clone(),
256 description: format!(
257 "Incorrect tax code applied: effective rate changed from {} to {} on tax line {}",
258 original_rate, wrong_rate, line.id
259 ),
260 original_value: Some(original_amount.to_string()),
261 anomalous_value: Some(new_amount.to_string()),
262 }
263 }
264
265 fn create_missing_tax_line_label(&mut self, line: &TaxLine) -> TaxAnomalyLabel {
268 TaxAnomalyLabel {
269 id: self.next_id(),
270 anomaly_type: TaxAnomalyType::MissingTaxLine,
271 severity: TaxAnomalySeverity::High,
272 document_type: "tax_line".to_string(),
273 document_id: line.id.clone(),
274 description: format!(
275 "Tax line {} removed from document {}: taxable amount {} has no tax applied",
276 line.id, line.document_id, line.taxable_amount
277 ),
278 original_value: Some(line.tax_amount.to_string()),
279 anomalous_value: None,
280 }
281 }
282
283 fn inject_rate_arbitrage(&mut self, line: &mut TaxLine) -> TaxAnomalyLabel {
285 let original_jurisdiction = line.jurisdiction_id.clone();
286
287 let idx = self.rng.random_range(0..LOW_TAX_JURISDICTIONS.len());
288 let new_jurisdiction = LOW_TAX_JURISDICTIONS[idx].to_string();
289
290 line.jurisdiction_id = new_jurisdiction.clone();
291
292 let reduction_factor =
294 dec!(0.25) + dec!(0.25) * Decimal::from(self.rng.random_range(0u32..4));
295 let original_amount = line.tax_amount;
296 line.tax_amount = (line.tax_amount * reduction_factor).round_dp(2);
297
298 TaxAnomalyLabel {
299 id: self.next_id(),
300 anomaly_type: TaxAnomalyType::RateArbitrage,
301 severity: TaxAnomalySeverity::Critical,
302 document_type: "tax_line".to_string(),
303 document_id: line.id.clone(),
304 description: format!(
305 "Rate arbitrage: jurisdiction changed from {} to {} on tax line {}",
306 original_jurisdiction, new_jurisdiction, line.id
307 ),
308 original_value: Some(format!(
309 "jurisdiction={original_jurisdiction}, tax_amount={original_amount}"
310 )),
311 anomalous_value: Some(format!(
312 "jurisdiction={}, tax_amount={}",
313 new_jurisdiction, line.tax_amount
314 )),
315 }
316 }
317
318 fn inject_tax_line_understatement(&mut self, line: &mut TaxLine) -> TaxAnomalyLabel {
320 let original_amount = line.tax_amount;
321
322 let reduction: f64 = 0.30 + self.rng.random::<f64>() * 0.40;
324 let reduction_dec = Decimal::from_f64_retain(reduction).unwrap_or(dec!(0.50));
325 let new_amount = (line.tax_amount * (Decimal::ONE - reduction_dec)).round_dp(2);
326 line.tax_amount = new_amount;
327
328 let impact = if original_amount.is_zero() {
329 dec!(0.50)
330 } else {
331 ((original_amount - new_amount) / original_amount).round_dp(4)
332 };
333
334 TaxAnomalyLabel {
335 id: self.next_id(),
336 anomaly_type: TaxAnomalyType::WithholdingUnderstatement,
337 severity: Self::severity_from_impact(impact),
338 document_type: "tax_line".to_string(),
339 document_id: line.id.clone(),
340 description: format!(
341 "Tax understatement on line {}: tax reduced from {} to {} ({:.0}% reduction)",
342 line.id,
343 original_amount,
344 new_amount,
345 reduction * 100.0
346 ),
347 original_value: Some(original_amount.to_string()),
348 anomalous_value: Some(new_amount.to_string()),
349 }
350 }
351
352 fn inject_late_filing(&mut self, ret: &mut TaxReturn) -> TaxAnomalyLabel {
354 let deadline = ret.filing_deadline;
355
356 let days_offset: i64 = self.rng.random_range(-2..=30);
358 let filing_date = deadline + Duration::days(days_offset);
359
360 ret.actual_filing_date = Some(filing_date);
361 ret.is_late = filing_date > deadline;
362
363 let severity = if days_offset > 14 {
364 TaxAnomalySeverity::Critical
365 } else if days_offset > 5 {
366 TaxAnomalySeverity::High
367 } else if days_offset > 0 {
368 TaxAnomalySeverity::Medium
369 } else {
370 TaxAnomalySeverity::Low
371 };
372
373 TaxAnomalyLabel {
374 id: self.next_id(),
375 anomaly_type: TaxAnomalyType::LateFilingRisk,
376 severity,
377 document_type: "tax_return".to_string(),
378 document_id: ret.id.clone(),
379 description: format!(
380 "Late filing risk for return {}: deadline={}, actual_filing_date={}, {} days {}",
381 ret.id,
382 deadline,
383 filing_date,
384 days_offset.unsigned_abs(),
385 if days_offset > 0 {
386 "past deadline"
387 } else {
388 "before deadline"
389 }
390 ),
391 original_value: Some(deadline.to_string()),
392 anomalous_value: Some(filing_date.to_string()),
393 }
394 }
395
396 fn inject_withholding_understatement(
399 &mut self,
400 record: &mut WithholdingTaxRecord,
401 ) -> TaxAnomalyLabel {
402 let original_rate = record.applied_rate;
403 let statutory = record.statutory_rate;
404
405 let fraction: f64 = 0.30 + self.rng.random::<f64>() * 0.40;
407 let fraction_dec = Decimal::from_f64_retain(fraction).unwrap_or(dec!(0.50));
408 let new_rate = (statutory * fraction_dec).round_dp(4);
409
410 record.applied_rate = new_rate;
411 record.treaty_rate = None; record.withheld_amount = (record.base_amount * new_rate).round_dp(2);
413 record.certificate_number = None; let impact = if statutory.is_zero() {
416 dec!(0.50)
417 } else {
418 ((statutory - new_rate) / statutory).round_dp(4)
419 };
420
421 TaxAnomalyLabel {
422 id: self.next_id(),
423 anomaly_type: TaxAnomalyType::WithholdingUnderstatement,
424 severity: Self::severity_from_impact(impact),
425 document_type: "withholding_record".to_string(),
426 document_id: record.id.clone(),
427 description: format!(
428 "Withholding understatement on {}: applied_rate reduced from {} to {} \
429 (statutory_rate={}) without treaty justification",
430 record.id, original_rate, new_rate, statutory
431 ),
432 original_value: Some(original_rate.to_string()),
433 anomalous_value: Some(new_rate.to_string()),
434 }
435 }
436}
437
438#[cfg(test)]
443mod tests {
444 use super::*;
445 use chrono::NaiveDate;
446 use datasynth_core::models::{TaxReturnType, TaxableDocumentType, WithholdingType};
447
448 fn make_tax_line(id: &str, taxable: Decimal, tax: Decimal) -> TaxLine {
450 TaxLine::new(
451 id,
452 TaxableDocumentType::VendorInvoice,
453 "DOC-001",
454 1,
455 "TC-VAT-20",
456 "JUR-DE",
457 taxable,
458 tax,
459 )
460 }
461
462 fn make_tax_return(id: &str) -> TaxReturn {
464 TaxReturn::new(
465 id,
466 "ENT-001",
467 "JUR-DE",
468 NaiveDate::from_ymd_opt(2024, 1, 1).unwrap(),
469 NaiveDate::from_ymd_opt(2024, 3, 31).unwrap(),
470 TaxReturnType::VatReturn,
471 dec!(50000),
472 dec!(30000),
473 NaiveDate::from_ymd_opt(2024, 4, 30).unwrap(),
474 )
475 }
476
477 fn make_withholding(id: &str) -> WithholdingTaxRecord {
479 WithholdingTaxRecord::new(
480 id,
481 "PAY-001",
482 "V-100",
483 WithholdingType::ServiceWithholding,
484 dec!(0.30),
485 dec!(0.30),
486 dec!(100000),
487 )
488 }
489
490 #[test]
495 fn test_inject_tax_line_anomalies() {
496 let mut injector = TaxAnomalyInjector::new(42, 1.0);
498 let mut lines: Vec<TaxLine> = (0..10)
499 .map(|i| make_tax_line(&format!("TL-{:03}", i), dec!(10000), dec!(2000)))
500 .collect();
501
502 let labels = injector.inject_into_tax_lines(&mut lines);
503
504 assert_eq!(labels.len(), 10, "Expected 10 labels at 100% rate");
506
507 let missing_count = labels
509 .iter()
510 .filter(|l| l.anomaly_type == TaxAnomalyType::MissingTaxLine)
511 .count();
512 assert_eq!(
513 lines.len(),
514 10 - missing_count,
515 "Remaining lines should be 10 minus missing count"
516 );
517 }
518
519 #[test]
520 fn test_anomaly_rate_respected() {
521 let mut injector = TaxAnomalyInjector::new(123, 0.10);
523 let mut lines: Vec<TaxLine> = (0..1000)
524 .map(|i| make_tax_line(&format!("TL-{:04}", i), dec!(5000), dec!(1000)))
525 .collect();
526
527 let labels = injector.inject_into_tax_lines(&mut lines);
528
529 assert!(
531 labels.len() >= 50 && labels.len() <= 200,
532 "Expected ~100 anomalies at 10% rate, got {}",
533 labels.len()
534 );
535 }
536
537 #[test]
538 fn test_incorrect_tax_code_anomaly() {
539 let mut injector = TaxAnomalyInjector::new(42, 1.0);
543 let mut lines: Vec<TaxLine> = (0..20)
544 .map(|i| make_tax_line(&format!("TL-{:03}", i), dec!(10000), dec!(2000)))
545 .collect();
546
547 let original_amounts: Vec<Decimal> = lines.iter().map(|l| l.tax_amount).collect();
548 let labels = injector.inject_into_tax_lines(&mut lines);
549
550 let incorrect_labels: Vec<_> = labels
551 .iter()
552 .filter(|l| l.anomaly_type == TaxAnomalyType::IncorrectTaxCode)
553 .collect();
554
555 assert!(
556 !incorrect_labels.is_empty(),
557 "Expected at least one IncorrectTaxCode anomaly"
558 );
559
560 for label in &incorrect_labels {
562 assert_ne!(
563 label.original_value, label.anomalous_value,
564 "Incorrect tax code should change the tax amount"
565 );
566 }
567
568 let remaining_ids: Vec<&str> = lines.iter().map(|l| l.id.as_str()).collect();
571 let mut found_changed = false;
572 for (i, orig_amount) in original_amounts.iter().enumerate() {
573 let id = format!("TL-{:03}", i);
574 if let Some(pos) = remaining_ids.iter().position(|&lid| lid == id) {
575 if lines[pos].tax_amount != *orig_amount {
576 found_changed = true;
577 break;
578 }
579 }
580 }
581 assert!(found_changed, "At least one tax_amount should be changed");
582 }
583
584 #[test]
585 fn test_late_filing_anomaly() {
586 let mut injector = TaxAnomalyInjector::new(42, 1.0);
587 let mut returns: Vec<TaxReturn> = (0..10)
588 .map(|i| make_tax_return(&format!("TR-{:03}", i)))
589 .collect();
590
591 let labels = injector.inject_into_returns(&mut returns);
592
593 assert_eq!(labels.len(), 10, "All returns should get anomalies at 100%");
594
595 for (label, ret) in labels.iter().zip(returns.iter()) {
596 assert_eq!(label.anomaly_type, TaxAnomalyType::LateFilingRisk);
597 assert!(
598 ret.actual_filing_date.is_some(),
599 "Filing date should be set"
600 );
601
602 let filing_date = ret.actual_filing_date.unwrap();
603 let deadline = ret.filing_deadline;
604
605 let diff = (filing_date - deadline).num_days();
607 assert!(
608 (-2..=30).contains(&diff),
609 "Filing date offset should be -2 to +30 days, got {}",
610 diff
611 );
612
613 assert_eq!(
615 ret.is_late,
616 filing_date > deadline,
617 "is_late flag should match actual vs deadline comparison"
618 );
619 }
620 }
621
622 #[test]
623 fn test_withholding_understatement() {
624 let mut injector = TaxAnomalyInjector::new(42, 1.0);
625 let mut records: Vec<WithholdingTaxRecord> = (0..10)
626 .map(|i| make_withholding(&format!("WHT-{:03}", i)))
627 .collect();
628
629 let labels = injector.inject_into_withholding(&mut records);
630
631 assert_eq!(labels.len(), 10, "All records should get anomalies at 100%");
632
633 for (label, record) in labels.iter().zip(records.iter()) {
634 assert_eq!(
635 label.anomaly_type,
636 TaxAnomalyType::WithholdingUnderstatement
637 );
638
639 assert!(
641 record.applied_rate < record.statutory_rate,
642 "applied_rate ({}) should be less than statutory_rate ({})",
643 record.applied_rate,
644 record.statutory_rate
645 );
646
647 assert!(
649 record.treaty_rate.is_none(),
650 "Treaty rate should be removed for unjustified understatement"
651 );
652
653 let expected_withheld = (record.base_amount * record.applied_rate).round_dp(2);
655 assert_eq!(
656 record.withheld_amount, expected_withheld,
657 "withheld_amount should match base_amount * applied_rate"
658 );
659 }
660 }
661
662 #[test]
663 fn test_labels_have_descriptions() {
664 let mut injector = TaxAnomalyInjector::new(42, 1.0);
665
666 let mut lines: Vec<TaxLine> = (0..5)
668 .map(|i| make_tax_line(&format!("TL-{:03}", i), dec!(10000), dec!(2000)))
669 .collect();
670 let line_labels = injector.inject_into_tax_lines(&mut lines);
671
672 let mut returns = vec![make_tax_return("TR-001")];
674 let return_labels = injector.inject_into_returns(&mut returns);
675
676 let mut records = vec![make_withholding("WHT-001")];
678 let wht_labels = injector.inject_into_withholding(&mut records);
679
680 let all_labels: Vec<&TaxAnomalyLabel> = line_labels
681 .iter()
682 .chain(return_labels.iter())
683 .chain(wht_labels.iter())
684 .collect();
685
686 assert!(
687 !all_labels.is_empty(),
688 "Should have at least some labels to test"
689 );
690
691 for label in &all_labels {
692 assert!(
693 !label.description.is_empty(),
694 "Label {} should have a non-empty description",
695 label.id
696 );
697 assert!(!label.id.is_empty(), "Label should have a non-empty ID");
698 assert!(
699 !label.document_type.is_empty(),
700 "Label {} should have a non-empty document_type",
701 label.id
702 );
703 assert!(
704 !label.document_id.is_empty(),
705 "Label {} should have a non-empty document_id",
706 label.id
707 );
708 }
709 }
710
711 #[test]
712 fn test_deterministic() {
713 let mut injector1 = TaxAnomalyInjector::new(999, 0.5);
715 let mut injector2 = TaxAnomalyInjector::new(999, 0.5);
716
717 let mut lines1: Vec<TaxLine> = (0..50)
718 .map(|i| make_tax_line(&format!("TL-{:03}", i), dec!(10000), dec!(2000)))
719 .collect();
720 let mut lines2: Vec<TaxLine> = (0..50)
721 .map(|i| make_tax_line(&format!("TL-{:03}", i), dec!(10000), dec!(2000)))
722 .collect();
723
724 let labels1 = injector1.inject_into_tax_lines(&mut lines1);
725 let labels2 = injector2.inject_into_tax_lines(&mut lines2);
726
727 assert_eq!(labels1.len(), labels2.len(), "Label counts should match");
728 assert_eq!(
729 lines1.len(),
730 lines2.len(),
731 "Remaining line counts should match"
732 );
733
734 for (l1, l2) in labels1.iter().zip(labels2.iter()) {
735 assert_eq!(l1.id, l2.id, "Label IDs should match");
736 assert_eq!(
737 l1.anomaly_type, l2.anomaly_type,
738 "Anomaly types should match"
739 );
740 assert_eq!(l1.severity, l2.severity, "Severities should match");
741 assert_eq!(l1.document_id, l2.document_id, "Document IDs should match");
742 assert_eq!(
743 l1.original_value, l2.original_value,
744 "Original values should match"
745 );
746 assert_eq!(
747 l1.anomalous_value, l2.anomalous_value,
748 "Anomalous values should match"
749 );
750 }
751
752 for (ln1, ln2) in lines1.iter().zip(lines2.iter()) {
753 assert_eq!(ln1.id, ln2.id);
754 assert_eq!(ln1.tax_amount, ln2.tax_amount);
755 assert_eq!(ln1.jurisdiction_id, ln2.jurisdiction_id);
756 }
757 }
758
759 #[test]
760 fn test_zero_rate_no_anomalies() {
761 let mut injector = TaxAnomalyInjector::new(42, 0.0);
762 let mut lines: Vec<TaxLine> = (0..100)
763 .map(|i| make_tax_line(&format!("TL-{:03}", i), dec!(10000), dec!(2000)))
764 .collect();
765
766 let labels = injector.inject_into_tax_lines(&mut lines);
767
768 assert!(labels.is_empty(), "Zero rate should produce no anomalies");
769 assert_eq!(lines.len(), 100, "No lines should be removed");
770 }
771
772 #[test]
773 fn test_label_ids_are_sequential() {
774 let mut injector = TaxAnomalyInjector::new(42, 1.0);
775
776 let mut lines: Vec<TaxLine> = (0..5)
777 .map(|i| make_tax_line(&format!("TL-{:03}", i), dec!(10000), dec!(2000)))
778 .collect();
779 let labels = injector.inject_into_tax_lines(&mut lines);
780
781 for (i, label) in labels.iter().enumerate() {
782 let expected_id = format!("TXANO-{:06}", i + 1);
783 assert_eq!(label.id, expected_id, "Labels should have sequential IDs");
784 }
785 }
786
787 #[test]
788 fn test_serde_roundtrip() {
789 let label = TaxAnomalyLabel {
790 id: "TXANO-000001".to_string(),
791 anomaly_type: TaxAnomalyType::IncorrectTaxCode,
792 severity: TaxAnomalySeverity::High,
793 document_type: "tax_line".to_string(),
794 document_id: "TL-001".to_string(),
795 description: "Test anomaly".to_string(),
796 original_value: Some("2000".to_string()),
797 anomalous_value: Some("1500".to_string()),
798 };
799
800 let json = serde_json::to_string_pretty(&label).unwrap();
801 let deserialized: TaxAnomalyLabel = serde_json::from_str(&json).unwrap();
802
803 assert_eq!(deserialized.id, label.id);
804 assert_eq!(deserialized.anomaly_type, label.anomaly_type);
805 assert_eq!(deserialized.severity, label.severity);
806 assert_eq!(deserialized.document_type, label.document_type);
807 assert_eq!(deserialized.document_id, label.document_id);
808 assert_eq!(deserialized.original_value, label.original_value);
809 assert_eq!(deserialized.anomalous_value, label.anomalous_value);
810 }
811}