1use chrono::Duration;
9use datasynth_core::utils::seeded_rng;
10use rand::prelude::*;
11use rand_chacha::ChaCha8Rng;
12use rust_decimal::Decimal;
13use rust_decimal_macros::dec;
14use serde::{Deserialize, Serialize};
15
16use datasynth_core::models::{TaxLine, TaxReturn, WithholdingTaxRecord};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
24#[serde(rename_all = "snake_case")]
25pub enum TaxAnomalyType {
26 IncorrectTaxCode,
28 MissingTaxLine,
30 RateArbitrage,
32 LateFilingRisk,
34 TransferPricingDeviation,
36 WithholdingUnderstatement,
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
42#[serde(rename_all = "snake_case")]
43pub enum TaxAnomalySeverity {
44 Low,
45 Medium,
46 High,
47 Critical,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct TaxAnomalyLabel {
57 pub id: String,
59 pub anomaly_type: TaxAnomalyType,
61 pub severity: TaxAnomalySeverity,
63 pub document_type: String,
65 pub document_id: String,
67 pub description: String,
69 pub original_value: Option<String>,
71 pub anomalous_value: Option<String>,
73}
74
75const LOW_TAX_JURISDICTIONS: &[&str] = &[
81 "JUR-IE", "JUR-SG", "JUR-HK", "JUR-BM", "JUR-KY", "JUR-LU", ];
88
89pub struct TaxAnomalyInjector {
91 rng: ChaCha8Rng,
92 anomaly_rate: f64,
93 counter: u64,
94}
95
96impl TaxAnomalyInjector {
97 pub fn new(seed: u64, anomaly_rate: f64) -> Self {
104 Self {
105 rng: seeded_rng(seed, 0),
106 anomaly_rate: anomaly_rate.clamp(0.0, 1.0),
107 counter: 0,
108 }
109 }
110
111 pub fn inject_into_tax_lines(&mut self, lines: &mut Vec<TaxLine>) -> Vec<TaxAnomalyLabel> {
119 let mut labels = Vec::new();
120 let mut indices_to_remove: Vec<usize> = Vec::new();
121
122 for (i, line) in lines.iter_mut().enumerate() {
123 if !self.should_inject() {
124 continue;
125 }
126
127 let roll: f64 = self.rng.random();
128 if roll < 0.40 {
129 labels.push(self.inject_incorrect_tax_code(line));
131 } else if roll < 0.70 {
132 labels.push(self.create_missing_tax_line_label(line));
134 indices_to_remove.push(i);
135 } else if roll < 0.85 {
136 labels.push(self.inject_rate_arbitrage(line));
138 } else {
139 labels.push(self.inject_tax_line_understatement(line));
141 }
142 }
143
144 for &i in indices_to_remove.iter().rev() {
146 lines.remove(i);
147 }
148
149 labels
150 }
151
152 pub fn inject_into_returns(&mut self, returns: &mut [TaxReturn]) -> Vec<TaxAnomalyLabel> {
156 let mut labels = Vec::new();
157
158 for ret in returns.iter_mut() {
159 if !self.should_inject() {
160 continue;
161 }
162 labels.push(self.inject_late_filing(ret));
163 }
164
165 labels
166 }
167
168 pub fn inject_into_withholding(
172 &mut self,
173 records: &mut [WithholdingTaxRecord],
174 ) -> Vec<TaxAnomalyLabel> {
175 let mut labels = Vec::new();
176
177 for record in records.iter_mut() {
178 if !self.should_inject() {
179 continue;
180 }
181 labels.push(self.inject_withholding_understatement(record));
182 }
183
184 labels
185 }
186
187 fn should_inject(&mut self) -> bool {
193 self.rng.random::<f64>() < self.anomaly_rate
194 }
195
196 fn next_id(&mut self) -> String {
198 self.counter += 1;
199 format!("TXANO-{:06}", self.counter)
200 }
201
202 fn severity_from_impact(impact_ratio: Decimal) -> TaxAnomalySeverity {
204 if impact_ratio >= dec!(0.50) {
205 TaxAnomalySeverity::Critical
206 } else if impact_ratio >= dec!(0.25) {
207 TaxAnomalySeverity::High
208 } else if impact_ratio >= dec!(0.10) {
209 TaxAnomalySeverity::Medium
210 } else {
211 TaxAnomalySeverity::Low
212 }
213 }
214
215 fn inject_incorrect_tax_code(&mut self, line: &mut TaxLine) -> TaxAnomalyLabel {
217 let original_amount = line.tax_amount;
218 let original_rate = line.effective_rate();
219
220 let wrong_rates = [
224 dec!(0.05),
225 dec!(0.07),
226 dec!(0.10),
227 dec!(0.13),
228 dec!(0.15),
229 dec!(0.21),
230 dec!(0.23),
231 dec!(0.25),
232 ];
233
234 let idx = self.rng.random_range(0..wrong_rates.len());
235 let mut wrong_rate = wrong_rates[idx];
236 if wrong_rate == original_rate.round_dp(2) {
238 wrong_rate = wrong_rates[(idx + 1) % wrong_rates.len()];
239 }
240
241 let new_amount = (line.taxable_amount * wrong_rate).round_dp(2);
242 line.tax_amount = new_amount;
243
244 let impact = if original_amount.is_zero() {
245 dec!(1.0)
246 } else {
247 ((new_amount - original_amount).abs() / original_amount.abs()).round_dp(4)
248 };
249
250 TaxAnomalyLabel {
251 id: self.next_id(),
252 anomaly_type: TaxAnomalyType::IncorrectTaxCode,
253 severity: Self::severity_from_impact(impact),
254 document_type: "tax_line".to_string(),
255 document_id: line.id.clone(),
256 description: format!(
257 "Incorrect tax code applied: effective rate changed from {} to {} on tax line {}",
258 original_rate, wrong_rate, line.id
259 ),
260 original_value: Some(original_amount.to_string()),
261 anomalous_value: Some(new_amount.to_string()),
262 }
263 }
264
265 fn create_missing_tax_line_label(&mut self, line: &TaxLine) -> TaxAnomalyLabel {
268 TaxAnomalyLabel {
269 id: self.next_id(),
270 anomaly_type: TaxAnomalyType::MissingTaxLine,
271 severity: TaxAnomalySeverity::High,
272 document_type: "tax_line".to_string(),
273 document_id: line.id.clone(),
274 description: format!(
275 "Tax line {} removed from document {}: taxable amount {} has no tax applied",
276 line.id, line.document_id, line.taxable_amount
277 ),
278 original_value: Some(line.tax_amount.to_string()),
279 anomalous_value: None,
280 }
281 }
282
283 fn inject_rate_arbitrage(&mut self, line: &mut TaxLine) -> TaxAnomalyLabel {
285 let original_jurisdiction = line.jurisdiction_id.clone();
286
287 let idx = self.rng.random_range(0..LOW_TAX_JURISDICTIONS.len());
288 let new_jurisdiction = LOW_TAX_JURISDICTIONS[idx].to_string();
289
290 line.jurisdiction_id = new_jurisdiction.clone();
291
292 let reduction_factor =
294 dec!(0.25) + dec!(0.25) * Decimal::from(self.rng.random_range(0u32..4));
295 let original_amount = line.tax_amount;
296 line.tax_amount = (line.tax_amount * reduction_factor).round_dp(2);
297
298 TaxAnomalyLabel {
299 id: self.next_id(),
300 anomaly_type: TaxAnomalyType::RateArbitrage,
301 severity: TaxAnomalySeverity::Critical,
302 document_type: "tax_line".to_string(),
303 document_id: line.id.clone(),
304 description: format!(
305 "Rate arbitrage: jurisdiction changed from {} to {} on tax line {}",
306 original_jurisdiction, new_jurisdiction, line.id
307 ),
308 original_value: Some(format!(
309 "jurisdiction={original_jurisdiction}, tax_amount={original_amount}"
310 )),
311 anomalous_value: Some(format!(
312 "jurisdiction={}, tax_amount={}",
313 new_jurisdiction, line.tax_amount
314 )),
315 }
316 }
317
318 fn inject_tax_line_understatement(&mut self, line: &mut TaxLine) -> TaxAnomalyLabel {
320 let original_amount = line.tax_amount;
321
322 let reduction: f64 = 0.30 + self.rng.random::<f64>() * 0.40;
324 let reduction_dec = Decimal::from_f64_retain(reduction).unwrap_or(dec!(0.50));
325 let new_amount = (line.tax_amount * (Decimal::ONE - reduction_dec)).round_dp(2);
326 line.tax_amount = new_amount;
327
328 let impact = if original_amount.is_zero() {
329 dec!(0.50)
330 } else {
331 ((original_amount - new_amount) / original_amount).round_dp(4)
332 };
333
334 TaxAnomalyLabel {
335 id: self.next_id(),
336 anomaly_type: TaxAnomalyType::WithholdingUnderstatement,
337 severity: Self::severity_from_impact(impact),
338 document_type: "tax_line".to_string(),
339 document_id: line.id.clone(),
340 description: format!(
341 "Tax understatement on line {}: tax reduced from {} to {} ({:.0}% reduction)",
342 line.id,
343 original_amount,
344 new_amount,
345 reduction * 100.0
346 ),
347 original_value: Some(original_amount.to_string()),
348 anomalous_value: Some(new_amount.to_string()),
349 }
350 }
351
352 fn inject_late_filing(&mut self, ret: &mut TaxReturn) -> TaxAnomalyLabel {
354 let deadline = ret.filing_deadline;
355
356 let days_offset: i64 = self.rng.random_range(-2..=30);
358 let filing_date = deadline + Duration::days(days_offset);
359
360 ret.actual_filing_date = Some(filing_date);
361 ret.is_late = filing_date > deadline;
362
363 let severity = if days_offset > 14 {
364 TaxAnomalySeverity::Critical
365 } else if days_offset > 5 {
366 TaxAnomalySeverity::High
367 } else if days_offset > 0 {
368 TaxAnomalySeverity::Medium
369 } else {
370 TaxAnomalySeverity::Low
371 };
372
373 TaxAnomalyLabel {
374 id: self.next_id(),
375 anomaly_type: TaxAnomalyType::LateFilingRisk,
376 severity,
377 document_type: "tax_return".to_string(),
378 document_id: ret.id.clone(),
379 description: format!(
380 "Late filing risk for return {}: deadline={}, actual_filing_date={}, {} days {}",
381 ret.id,
382 deadline,
383 filing_date,
384 days_offset.unsigned_abs(),
385 if days_offset > 0 {
386 "past deadline"
387 } else {
388 "before deadline"
389 }
390 ),
391 original_value: Some(deadline.to_string()),
392 anomalous_value: Some(filing_date.to_string()),
393 }
394 }
395
396 fn inject_withholding_understatement(
399 &mut self,
400 record: &mut WithholdingTaxRecord,
401 ) -> TaxAnomalyLabel {
402 let original_rate = record.applied_rate;
403 let statutory = record.statutory_rate;
404
405 let fraction: f64 = 0.30 + self.rng.random::<f64>() * 0.40;
407 let fraction_dec = Decimal::from_f64_retain(fraction).unwrap_or(dec!(0.50));
408 let new_rate = (statutory * fraction_dec).round_dp(4);
409
410 record.applied_rate = new_rate;
411 record.treaty_rate = None; record.withheld_amount = (record.base_amount * new_rate).round_dp(2);
413 record.certificate_number = None; let impact = if statutory.is_zero() {
416 dec!(0.50)
417 } else {
418 ((statutory - new_rate) / statutory).round_dp(4)
419 };
420
421 TaxAnomalyLabel {
422 id: self.next_id(),
423 anomaly_type: TaxAnomalyType::WithholdingUnderstatement,
424 severity: Self::severity_from_impact(impact),
425 document_type: "withholding_record".to_string(),
426 document_id: record.id.clone(),
427 description: format!(
428 "Withholding understatement on {}: applied_rate reduced from {} to {} \
429 (statutory_rate={}) without treaty justification",
430 record.id, original_rate, new_rate, statutory
431 ),
432 original_value: Some(original_rate.to_string()),
433 anomalous_value: Some(new_rate.to_string()),
434 }
435 }
436}
437
438#[cfg(test)]
443#[allow(clippy::unwrap_used)]
444mod tests {
445 use super::*;
446 use chrono::NaiveDate;
447 use datasynth_core::models::{TaxReturnType, TaxableDocumentType, WithholdingType};
448
449 fn make_tax_line(id: &str, taxable: Decimal, tax: Decimal) -> TaxLine {
451 TaxLine::new(
452 id,
453 TaxableDocumentType::VendorInvoice,
454 "DOC-001",
455 1,
456 "TC-VAT-20",
457 "JUR-DE",
458 taxable,
459 tax,
460 )
461 }
462
463 fn make_tax_return(id: &str) -> TaxReturn {
465 TaxReturn::new(
466 id,
467 "ENT-001",
468 "JUR-DE",
469 NaiveDate::from_ymd_opt(2024, 1, 1).unwrap(),
470 NaiveDate::from_ymd_opt(2024, 3, 31).unwrap(),
471 TaxReturnType::VatReturn,
472 dec!(50000),
473 dec!(30000),
474 NaiveDate::from_ymd_opt(2024, 4, 30).unwrap(),
475 )
476 }
477
478 fn make_withholding(id: &str) -> WithholdingTaxRecord {
480 WithholdingTaxRecord::new(
481 id,
482 "PAY-001",
483 "V-100",
484 WithholdingType::ServiceWithholding,
485 dec!(0.30),
486 dec!(0.30),
487 dec!(100000),
488 )
489 }
490
491 #[test]
496 fn test_inject_tax_line_anomalies() {
497 let mut injector = TaxAnomalyInjector::new(42, 1.0);
499 let mut lines: Vec<TaxLine> = (0..10)
500 .map(|i| make_tax_line(&format!("TL-{:03}", i), dec!(10000), dec!(2000)))
501 .collect();
502
503 let labels = injector.inject_into_tax_lines(&mut lines);
504
505 assert_eq!(labels.len(), 10, "Expected 10 labels at 100% rate");
507
508 let missing_count = labels
510 .iter()
511 .filter(|l| l.anomaly_type == TaxAnomalyType::MissingTaxLine)
512 .count();
513 assert_eq!(
514 lines.len(),
515 10 - missing_count,
516 "Remaining lines should be 10 minus missing count"
517 );
518 }
519
520 #[test]
521 fn test_anomaly_rate_respected() {
522 let mut injector = TaxAnomalyInjector::new(123, 0.10);
524 let mut lines: Vec<TaxLine> = (0..1000)
525 .map(|i| make_tax_line(&format!("TL-{:04}", i), dec!(5000), dec!(1000)))
526 .collect();
527
528 let labels = injector.inject_into_tax_lines(&mut lines);
529
530 assert!(
532 labels.len() >= 50 && labels.len() <= 200,
533 "Expected ~100 anomalies at 10% rate, got {}",
534 labels.len()
535 );
536 }
537
538 #[test]
539 fn test_incorrect_tax_code_anomaly() {
540 let mut injector = TaxAnomalyInjector::new(42, 1.0);
544 let mut lines: Vec<TaxLine> = (0..20)
545 .map(|i| make_tax_line(&format!("TL-{:03}", i), dec!(10000), dec!(2000)))
546 .collect();
547
548 let original_amounts: Vec<Decimal> = lines.iter().map(|l| l.tax_amount).collect();
549 let labels = injector.inject_into_tax_lines(&mut lines);
550
551 let incorrect_labels: Vec<_> = labels
552 .iter()
553 .filter(|l| l.anomaly_type == TaxAnomalyType::IncorrectTaxCode)
554 .collect();
555
556 assert!(
557 !incorrect_labels.is_empty(),
558 "Expected at least one IncorrectTaxCode anomaly"
559 );
560
561 for label in &incorrect_labels {
563 assert_ne!(
564 label.original_value, label.anomalous_value,
565 "Incorrect tax code should change the tax amount"
566 );
567 }
568
569 let remaining_ids: Vec<&str> = lines.iter().map(|l| l.id.as_str()).collect();
572 let mut found_changed = false;
573 for (i, orig_amount) in original_amounts.iter().enumerate() {
574 let id = format!("TL-{:03}", i);
575 if let Some(pos) = remaining_ids.iter().position(|&lid| lid == id) {
576 if lines[pos].tax_amount != *orig_amount {
577 found_changed = true;
578 break;
579 }
580 }
581 }
582 assert!(found_changed, "At least one tax_amount should be changed");
583 }
584
585 #[test]
586 fn test_late_filing_anomaly() {
587 let mut injector = TaxAnomalyInjector::new(42, 1.0);
588 let mut returns: Vec<TaxReturn> = (0..10)
589 .map(|i| make_tax_return(&format!("TR-{:03}", i)))
590 .collect();
591
592 let labels = injector.inject_into_returns(&mut returns);
593
594 assert_eq!(labels.len(), 10, "All returns should get anomalies at 100%");
595
596 for (label, ret) in labels.iter().zip(returns.iter()) {
597 assert_eq!(label.anomaly_type, TaxAnomalyType::LateFilingRisk);
598 assert!(
599 ret.actual_filing_date.is_some(),
600 "Filing date should be set"
601 );
602
603 let filing_date = ret.actual_filing_date.unwrap();
604 let deadline = ret.filing_deadline;
605
606 let diff = (filing_date - deadline).num_days();
608 assert!(
609 (-2..=30).contains(&diff),
610 "Filing date offset should be -2 to +30 days, got {}",
611 diff
612 );
613
614 assert_eq!(
616 ret.is_late,
617 filing_date > deadline,
618 "is_late flag should match actual vs deadline comparison"
619 );
620 }
621 }
622
623 #[test]
624 fn test_withholding_understatement() {
625 let mut injector = TaxAnomalyInjector::new(42, 1.0);
626 let mut records: Vec<WithholdingTaxRecord> = (0..10)
627 .map(|i| make_withholding(&format!("WHT-{:03}", i)))
628 .collect();
629
630 let labels = injector.inject_into_withholding(&mut records);
631
632 assert_eq!(labels.len(), 10, "All records should get anomalies at 100%");
633
634 for (label, record) in labels.iter().zip(records.iter()) {
635 assert_eq!(
636 label.anomaly_type,
637 TaxAnomalyType::WithholdingUnderstatement
638 );
639
640 assert!(
642 record.applied_rate < record.statutory_rate,
643 "applied_rate ({}) should be less than statutory_rate ({})",
644 record.applied_rate,
645 record.statutory_rate
646 );
647
648 assert!(
650 record.treaty_rate.is_none(),
651 "Treaty rate should be removed for unjustified understatement"
652 );
653
654 let expected_withheld = (record.base_amount * record.applied_rate).round_dp(2);
656 assert_eq!(
657 record.withheld_amount, expected_withheld,
658 "withheld_amount should match base_amount * applied_rate"
659 );
660 }
661 }
662
663 #[test]
664 fn test_labels_have_descriptions() {
665 let mut injector = TaxAnomalyInjector::new(42, 1.0);
666
667 let mut lines: Vec<TaxLine> = (0..5)
669 .map(|i| make_tax_line(&format!("TL-{:03}", i), dec!(10000), dec!(2000)))
670 .collect();
671 let line_labels = injector.inject_into_tax_lines(&mut lines);
672
673 let mut returns = vec![make_tax_return("TR-001")];
675 let return_labels = injector.inject_into_returns(&mut returns);
676
677 let mut records = vec![make_withholding("WHT-001")];
679 let wht_labels = injector.inject_into_withholding(&mut records);
680
681 let all_labels: Vec<&TaxAnomalyLabel> = line_labels
682 .iter()
683 .chain(return_labels.iter())
684 .chain(wht_labels.iter())
685 .collect();
686
687 assert!(
688 !all_labels.is_empty(),
689 "Should have at least some labels to test"
690 );
691
692 for label in &all_labels {
693 assert!(
694 !label.description.is_empty(),
695 "Label {} should have a non-empty description",
696 label.id
697 );
698 assert!(!label.id.is_empty(), "Label should have a non-empty ID");
699 assert!(
700 !label.document_type.is_empty(),
701 "Label {} should have a non-empty document_type",
702 label.id
703 );
704 assert!(
705 !label.document_id.is_empty(),
706 "Label {} should have a non-empty document_id",
707 label.id
708 );
709 }
710 }
711
712 #[test]
713 fn test_deterministic() {
714 let mut injector1 = TaxAnomalyInjector::new(999, 0.5);
716 let mut injector2 = TaxAnomalyInjector::new(999, 0.5);
717
718 let mut lines1: Vec<TaxLine> = (0..50)
719 .map(|i| make_tax_line(&format!("TL-{:03}", i), dec!(10000), dec!(2000)))
720 .collect();
721 let mut lines2: Vec<TaxLine> = (0..50)
722 .map(|i| make_tax_line(&format!("TL-{:03}", i), dec!(10000), dec!(2000)))
723 .collect();
724
725 let labels1 = injector1.inject_into_tax_lines(&mut lines1);
726 let labels2 = injector2.inject_into_tax_lines(&mut lines2);
727
728 assert_eq!(labels1.len(), labels2.len(), "Label counts should match");
729 assert_eq!(
730 lines1.len(),
731 lines2.len(),
732 "Remaining line counts should match"
733 );
734
735 for (l1, l2) in labels1.iter().zip(labels2.iter()) {
736 assert_eq!(l1.id, l2.id, "Label IDs should match");
737 assert_eq!(
738 l1.anomaly_type, l2.anomaly_type,
739 "Anomaly types should match"
740 );
741 assert_eq!(l1.severity, l2.severity, "Severities should match");
742 assert_eq!(l1.document_id, l2.document_id, "Document IDs should match");
743 assert_eq!(
744 l1.original_value, l2.original_value,
745 "Original values should match"
746 );
747 assert_eq!(
748 l1.anomalous_value, l2.anomalous_value,
749 "Anomalous values should match"
750 );
751 }
752
753 for (ln1, ln2) in lines1.iter().zip(lines2.iter()) {
754 assert_eq!(ln1.id, ln2.id);
755 assert_eq!(ln1.tax_amount, ln2.tax_amount);
756 assert_eq!(ln1.jurisdiction_id, ln2.jurisdiction_id);
757 }
758 }
759
760 #[test]
761 fn test_zero_rate_no_anomalies() {
762 let mut injector = TaxAnomalyInjector::new(42, 0.0);
763 let mut lines: Vec<TaxLine> = (0..100)
764 .map(|i| make_tax_line(&format!("TL-{:03}", i), dec!(10000), dec!(2000)))
765 .collect();
766
767 let labels = injector.inject_into_tax_lines(&mut lines);
768
769 assert!(labels.is_empty(), "Zero rate should produce no anomalies");
770 assert_eq!(lines.len(), 100, "No lines should be removed");
771 }
772
773 #[test]
774 fn test_label_ids_are_sequential() {
775 let mut injector = TaxAnomalyInjector::new(42, 1.0);
776
777 let mut lines: Vec<TaxLine> = (0..5)
778 .map(|i| make_tax_line(&format!("TL-{:03}", i), dec!(10000), dec!(2000)))
779 .collect();
780 let labels = injector.inject_into_tax_lines(&mut lines);
781
782 for (i, label) in labels.iter().enumerate() {
783 let expected_id = format!("TXANO-{:06}", i + 1);
784 assert_eq!(label.id, expected_id, "Labels should have sequential IDs");
785 }
786 }
787
788 #[test]
789 fn test_serde_roundtrip() {
790 let label = TaxAnomalyLabel {
791 id: "TXANO-000001".to_string(),
792 anomaly_type: TaxAnomalyType::IncorrectTaxCode,
793 severity: TaxAnomalySeverity::High,
794 document_type: "tax_line".to_string(),
795 document_id: "TL-001".to_string(),
796 description: "Test anomaly".to_string(),
797 original_value: Some("2000".to_string()),
798 anomalous_value: Some("1500".to_string()),
799 };
800
801 let json = serde_json::to_string_pretty(&label).unwrap();
802 let deserialized: TaxAnomalyLabel = serde_json::from_str(&json).unwrap();
803
804 assert_eq!(deserialized.id, label.id);
805 assert_eq!(deserialized.anomaly_type, label.anomaly_type);
806 assert_eq!(deserialized.severity, label.severity);
807 assert_eq!(deserialized.document_type, label.document_type);
808 assert_eq!(deserialized.document_id, label.document_id);
809 assert_eq!(deserialized.original_value, label.original_value);
810 assert_eq!(deserialized.anomalous_value, label.anomalous_value);
811 }
812}