Skip to main content

datasynth_eval/coherence/
tax.rs

1//! Tax coherence evaluator.
2//!
3//! Validates tax calculation accuracy, VAT/GST return coherence,
4//! and withholding tax compliance including treaty rate validation.
5
6use crate::error::EvalResult;
7use serde::{Deserialize, Serialize};
8
9/// Thresholds for tax evaluation.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct TaxThresholds {
12    /// Minimum accuracy for tax_amount = taxable_amount * rate.
13    pub min_tax_calculation_accuracy: f64,
14    /// Tolerance for tax amount comparisons.
15    pub rate_tolerance: f64,
16    /// Minimum accuracy for return net_payable = output - input.
17    pub min_return_accuracy: f64,
18    /// Minimum accuracy for withheld_amount = base * applied_rate.
19    pub min_withholding_accuracy: f64,
20    /// Minimum rate of treaty records where applied_rate <= statutory_rate.
21    pub min_treaty_compliance_rate: f64,
22}
23
24impl Default for TaxThresholds {
25    fn default() -> Self {
26        Self {
27            min_tax_calculation_accuracy: 0.999,
28            rate_tolerance: 0.001,
29            min_return_accuracy: 0.95,
30            min_withholding_accuracy: 0.999,
31            min_treaty_compliance_rate: 0.95,
32        }
33    }
34}
35
36/// Tax line data for calculation validation.
37#[derive(Debug, Clone)]
38pub struct TaxLineData {
39    /// Tax code identifier.
40    pub tax_code_id: String,
41    /// Taxable amount (base).
42    pub taxable_amount: f64,
43    /// Computed tax amount.
44    pub tax_amount: f64,
45    /// Tax rate applied.
46    pub rate: f64,
47}
48
49/// Tax return data for net payable validation.
50#[derive(Debug, Clone)]
51pub struct TaxReturnData {
52    /// Return identifier.
53    pub return_id: String,
54    /// Total output tax (collected).
55    pub total_output_tax: f64,
56    /// Total input tax (paid).
57    pub total_input_tax: f64,
58    /// Net payable (output - input).
59    pub net_payable: f64,
60}
61
62/// Withholding tax data for treaty validation.
63#[derive(Debug, Clone)]
64pub struct WithholdingData {
65    /// Record identifier.
66    pub record_id: String,
67    /// Base amount subject to withholding.
68    pub base_amount: f64,
69    /// Applied withholding rate.
70    pub applied_rate: f64,
71    /// Statutory withholding rate.
72    pub statutory_rate: f64,
73    /// Actual withheld amount.
74    pub withheld_amount: f64,
75    /// Whether a treaty rate was applied.
76    pub has_treaty: bool,
77}
78
79/// Results of tax coherence evaluation.
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct TaxEvaluation {
82    /// Fraction of tax lines where tax_amount ≈ taxable_amount * rate.
83    pub tax_calculation_accuracy: f64,
84    /// Fraction of returns where net_payable ≈ output - input.
85    pub return_net_accuracy: f64,
86    /// Fraction of withholding records where withheld ≈ base * rate.
87    pub withholding_accuracy: f64,
88    /// Fraction of treaty records where applied_rate <= statutory_rate.
89    pub treaty_compliance_rate: f64,
90    /// Total tax lines evaluated.
91    pub total_tax_lines: usize,
92    /// Total returns evaluated.
93    pub total_returns: usize,
94    /// Total withholding records evaluated.
95    pub total_withholding: usize,
96    /// Overall pass/fail.
97    pub passes: bool,
98    /// Issues found.
99    pub issues: Vec<String>,
100}
101
102/// Evaluator for tax calculation coherence.
103pub struct TaxEvaluator {
104    thresholds: TaxThresholds,
105}
106
107impl TaxEvaluator {
108    /// Create a new evaluator with default thresholds.
109    pub fn new() -> Self {
110        Self {
111            thresholds: TaxThresholds::default(),
112        }
113    }
114
115    /// Create with custom thresholds.
116    pub fn with_thresholds(thresholds: TaxThresholds) -> Self {
117        Self { thresholds }
118    }
119
120    /// Evaluate tax data coherence.
121    pub fn evaluate(
122        &self,
123        tax_lines: &[TaxLineData],
124        returns: &[TaxReturnData],
125        withholding: &[WithholdingData],
126    ) -> EvalResult<TaxEvaluation> {
127        let mut issues = Vec::new();
128        let tolerance = self.thresholds.rate_tolerance;
129
130        // 1. Tax calculation accuracy: tax_amount ≈ taxable_amount * rate
131        let tax_ok = tax_lines
132            .iter()
133            .filter(|t| {
134                let expected = t.taxable_amount * t.rate;
135                (t.tax_amount - expected).abs() <= tolerance * t.taxable_amount.abs().max(1.0)
136            })
137            .count();
138        let tax_calculation_accuracy = if tax_lines.is_empty() {
139            1.0
140        } else {
141            tax_ok as f64 / tax_lines.len() as f64
142        };
143
144        // 2. Return net payable: net_payable ≈ output - input
145        let return_ok = returns
146            .iter()
147            .filter(|r| {
148                let expected = r.total_output_tax - r.total_input_tax;
149                (r.net_payable - expected).abs() <= tolerance * r.total_output_tax.abs().max(1.0)
150            })
151            .count();
152        let return_net_accuracy = if returns.is_empty() {
153            1.0
154        } else {
155            return_ok as f64 / returns.len() as f64
156        };
157
158        // 3. Withholding accuracy: withheld ≈ base * applied_rate
159        let wh_ok = withholding
160            .iter()
161            .filter(|w| {
162                let expected = w.base_amount * w.applied_rate;
163                (w.withheld_amount - expected).abs() <= tolerance * w.base_amount.abs().max(1.0)
164            })
165            .count();
166        let withholding_accuracy = if withholding.is_empty() {
167            1.0
168        } else {
169            wh_ok as f64 / withholding.len() as f64
170        };
171
172        // 4. Treaty compliance: has_treaty implies applied_rate <= statutory_rate
173        let treaty_records: Vec<_> = withholding.iter().filter(|w| w.has_treaty).collect();
174        let treaty_ok = treaty_records
175            .iter()
176            .filter(|w| w.applied_rate <= w.statutory_rate + tolerance)
177            .count();
178        let treaty_compliance_rate = if treaty_records.is_empty() {
179            1.0
180        } else {
181            treaty_ok as f64 / treaty_records.len() as f64
182        };
183
184        // Check thresholds
185        if tax_calculation_accuracy < self.thresholds.min_tax_calculation_accuracy {
186            issues.push(format!(
187                "Tax calculation accuracy {:.4} < {:.4}",
188                tax_calculation_accuracy, self.thresholds.min_tax_calculation_accuracy
189            ));
190        }
191        if return_net_accuracy < self.thresholds.min_return_accuracy {
192            issues.push(format!(
193                "Return net payable accuracy {:.4} < {:.4}",
194                return_net_accuracy, self.thresholds.min_return_accuracy
195            ));
196        }
197        if withholding_accuracy < self.thresholds.min_withholding_accuracy {
198            issues.push(format!(
199                "Withholding accuracy {:.4} < {:.4}",
200                withholding_accuracy, self.thresholds.min_withholding_accuracy
201            ));
202        }
203        if treaty_compliance_rate < self.thresholds.min_treaty_compliance_rate {
204            issues.push(format!(
205                "Treaty compliance rate {:.4} < {:.4}",
206                treaty_compliance_rate, self.thresholds.min_treaty_compliance_rate
207            ));
208        }
209
210        let passes = issues.is_empty();
211
212        Ok(TaxEvaluation {
213            tax_calculation_accuracy,
214            return_net_accuracy,
215            withholding_accuracy,
216            treaty_compliance_rate,
217            total_tax_lines: tax_lines.len(),
218            total_returns: returns.len(),
219            total_withholding: withholding.len(),
220            passes,
221            issues,
222        })
223    }
224}
225
226impl Default for TaxEvaluator {
227    fn default() -> Self {
228        Self::new()
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235
236    #[test]
237    fn test_valid_tax_calculations() {
238        let evaluator = TaxEvaluator::new();
239        let lines = vec![
240            TaxLineData {
241                tax_code_id: "VAT20".to_string(),
242                taxable_amount: 1000.0,
243                tax_amount: 200.0,
244                rate: 0.20,
245            },
246            TaxLineData {
247                tax_code_id: "VAT10".to_string(),
248                taxable_amount: 500.0,
249                tax_amount: 50.0,
250                rate: 0.10,
251            },
252        ];
253        let returns = vec![TaxReturnData {
254            return_id: "RET001".to_string(),
255            total_output_tax: 250.0,
256            total_input_tax: 100.0,
257            net_payable: 150.0,
258        }];
259        let withholding = vec![WithholdingData {
260            record_id: "WH001".to_string(),
261            base_amount: 10000.0,
262            applied_rate: 0.10,
263            statutory_rate: 0.15,
264            withheld_amount: 1000.0,
265            has_treaty: true,
266        }];
267
268        let result = evaluator.evaluate(&lines, &returns, &withholding).unwrap();
269        assert!(result.passes);
270        assert_eq!(result.total_tax_lines, 2);
271        assert_eq!(result.total_returns, 1);
272        assert_eq!(result.treaty_compliance_rate, 1.0);
273    }
274
275    #[test]
276    fn test_wrong_tax_amount() {
277        let evaluator = TaxEvaluator::new();
278        let lines = vec![TaxLineData {
279            tax_code_id: "VAT20".to_string(),
280            taxable_amount: 1000.0,
281            tax_amount: 300.0, // Wrong: should be 200.0
282            rate: 0.20,
283        }];
284
285        let result = evaluator.evaluate(&lines, &[], &[]).unwrap();
286        assert!(!result.passes);
287        assert!(result.issues[0].contains("Tax calculation accuracy"));
288    }
289
290    #[test]
291    fn test_wrong_net_payable() {
292        let evaluator = TaxEvaluator::new();
293        let returns = vec![TaxReturnData {
294            return_id: "RET001".to_string(),
295            total_output_tax: 250.0,
296            total_input_tax: 100.0,
297            net_payable: 200.0, // Wrong: should be 150.0
298        }];
299
300        let result = evaluator.evaluate(&[], &returns, &[]).unwrap();
301        assert!(!result.passes);
302        assert!(result.issues[0].contains("Return net payable"));
303    }
304
305    #[test]
306    fn test_treaty_violation() {
307        let evaluator = TaxEvaluator::new();
308        let withholding = vec![WithholdingData {
309            record_id: "WH001".to_string(),
310            base_amount: 10000.0,
311            applied_rate: 0.20, // Higher than statutory
312            statutory_rate: 0.15,
313            withheld_amount: 2000.0,
314            has_treaty: true,
315        }];
316
317        let result = evaluator.evaluate(&[], &[], &withholding).unwrap();
318        assert!(!result.passes);
319        assert!(result
320            .issues
321            .iter()
322            .any(|i| i.contains("Treaty compliance")));
323    }
324
325    #[test]
326    fn test_empty_data() {
327        let evaluator = TaxEvaluator::new();
328        let result = evaluator.evaluate(&[], &[], &[]).unwrap();
329        assert!(result.passes);
330    }
331}