Skip to main content

datasynth_eval/coherence/
tax.rs

1//! Tax coherence evaluator.
2//!
3//! Validates tax calculation accuracy, VAT/GST return coherence,
4//! and withholding tax compliance including treaty rate validation.
5
6use crate::error::EvalResult;
7use serde::{Deserialize, Serialize};
8
9/// Thresholds for tax evaluation.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct TaxThresholds {
12    /// Minimum accuracy for tax_amount = taxable_amount * rate.
13    pub min_tax_calculation_accuracy: f64,
14    /// Tolerance for tax amount comparisons.
15    pub rate_tolerance: f64,
16    /// Minimum accuracy for return net_payable = output - input.
17    pub min_return_accuracy: f64,
18    /// Minimum accuracy for withheld_amount = base * applied_rate.
19    pub min_withholding_accuracy: f64,
20    /// Minimum rate of treaty records where applied_rate <= statutory_rate.
21    pub min_treaty_compliance_rate: f64,
22}
23
24impl Default for TaxThresholds {
25    fn default() -> Self {
26        Self {
27            min_tax_calculation_accuracy: 0.999,
28            rate_tolerance: 0.001,
29            min_return_accuracy: 0.95,
30            min_withholding_accuracy: 0.999,
31            min_treaty_compliance_rate: 0.95,
32        }
33    }
34}
35
36/// Tax line data for calculation validation.
37#[derive(Debug, Clone)]
38pub struct TaxLineData {
39    /// Tax code identifier.
40    pub tax_code_id: String,
41    /// Taxable amount (base).
42    pub taxable_amount: f64,
43    /// Computed tax amount.
44    pub tax_amount: f64,
45    /// Tax rate applied.
46    pub rate: f64,
47}
48
49/// Tax return data for net payable validation.
50#[derive(Debug, Clone)]
51pub struct TaxReturnData {
52    /// Return identifier.
53    pub return_id: String,
54    /// Total output tax (collected).
55    pub total_output_tax: f64,
56    /// Total input tax (paid).
57    pub total_input_tax: f64,
58    /// Net payable (output - input).
59    pub net_payable: f64,
60}
61
62/// Withholding tax data for treaty validation.
63#[derive(Debug, Clone)]
64pub struct WithholdingData {
65    /// Record identifier.
66    pub record_id: String,
67    /// Base amount subject to withholding.
68    pub base_amount: f64,
69    /// Applied withholding rate.
70    pub applied_rate: f64,
71    /// Statutory withholding rate.
72    pub statutory_rate: f64,
73    /// Actual withheld amount.
74    pub withheld_amount: f64,
75    /// Whether a treaty rate was applied.
76    pub has_treaty: bool,
77}
78
79/// Results of tax coherence evaluation.
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct TaxEvaluation {
82    /// Fraction of tax lines where tax_amount ≈ taxable_amount * rate.
83    pub tax_calculation_accuracy: f64,
84    /// Fraction of returns where net_payable ≈ output - input.
85    pub return_net_accuracy: f64,
86    /// Fraction of withholding records where withheld ≈ base * rate.
87    pub withholding_accuracy: f64,
88    /// Fraction of treaty records where applied_rate <= statutory_rate.
89    pub treaty_compliance_rate: f64,
90    /// Total tax lines evaluated.
91    pub total_tax_lines: usize,
92    /// Total returns evaluated.
93    pub total_returns: usize,
94    /// Total withholding records evaluated.
95    pub total_withholding: usize,
96    /// Overall pass/fail.
97    pub passes: bool,
98    /// Issues found.
99    pub issues: Vec<String>,
100}
101
102/// Evaluator for tax calculation coherence.
103pub struct TaxEvaluator {
104    thresholds: TaxThresholds,
105}
106
107impl TaxEvaluator {
108    /// Create a new evaluator with default thresholds.
109    pub fn new() -> Self {
110        Self {
111            thresholds: TaxThresholds::default(),
112        }
113    }
114
115    /// Create with custom thresholds.
116    pub fn with_thresholds(thresholds: TaxThresholds) -> Self {
117        Self { thresholds }
118    }
119
120    /// Evaluate tax data coherence.
121    pub fn evaluate(
122        &self,
123        tax_lines: &[TaxLineData],
124        returns: &[TaxReturnData],
125        withholding: &[WithholdingData],
126    ) -> EvalResult<TaxEvaluation> {
127        let mut issues = Vec::new();
128        let tolerance = self.thresholds.rate_tolerance;
129
130        // 1. Tax calculation accuracy: tax_amount ≈ taxable_amount * rate
131        let tax_ok = tax_lines
132            .iter()
133            .filter(|t| {
134                let expected = t.taxable_amount * t.rate;
135                (t.tax_amount - expected).abs() <= tolerance * t.taxable_amount.abs().max(1.0)
136            })
137            .count();
138        let tax_calculation_accuracy = if tax_lines.is_empty() {
139            1.0
140        } else {
141            tax_ok as f64 / tax_lines.len() as f64
142        };
143
144        // 2. Return net payable: net_payable ≈ output - input
145        let return_ok = returns
146            .iter()
147            .filter(|r| {
148                let expected = r.total_output_tax - r.total_input_tax;
149                (r.net_payable - expected).abs() <= tolerance * r.total_output_tax.abs().max(1.0)
150            })
151            .count();
152        let return_net_accuracy = if returns.is_empty() {
153            1.0
154        } else {
155            return_ok as f64 / returns.len() as f64
156        };
157
158        // 3. Withholding accuracy: withheld ≈ base * applied_rate
159        let wh_ok = withholding
160            .iter()
161            .filter(|w| {
162                let expected = w.base_amount * w.applied_rate;
163                (w.withheld_amount - expected).abs() <= tolerance * w.base_amount.abs().max(1.0)
164            })
165            .count();
166        let withholding_accuracy = if withholding.is_empty() {
167            1.0
168        } else {
169            wh_ok as f64 / withholding.len() as f64
170        };
171
172        // 4. Treaty compliance: has_treaty implies applied_rate <= statutory_rate
173        let treaty_records: Vec<_> = withholding.iter().filter(|w| w.has_treaty).collect();
174        let treaty_ok = treaty_records
175            .iter()
176            .filter(|w| w.applied_rate <= w.statutory_rate + tolerance)
177            .count();
178        let treaty_compliance_rate = if treaty_records.is_empty() {
179            1.0
180        } else {
181            treaty_ok as f64 / treaty_records.len() as f64
182        };
183
184        // Check thresholds
185        if tax_calculation_accuracy < self.thresholds.min_tax_calculation_accuracy {
186            issues.push(format!(
187                "Tax calculation accuracy {:.4} < {:.4}",
188                tax_calculation_accuracy, self.thresholds.min_tax_calculation_accuracy
189            ));
190        }
191        if return_net_accuracy < self.thresholds.min_return_accuracy {
192            issues.push(format!(
193                "Return net payable accuracy {:.4} < {:.4}",
194                return_net_accuracy, self.thresholds.min_return_accuracy
195            ));
196        }
197        if withholding_accuracy < self.thresholds.min_withholding_accuracy {
198            issues.push(format!(
199                "Withholding accuracy {:.4} < {:.4}",
200                withholding_accuracy, self.thresholds.min_withholding_accuracy
201            ));
202        }
203        if treaty_compliance_rate < self.thresholds.min_treaty_compliance_rate {
204            issues.push(format!(
205                "Treaty compliance rate {:.4} < {:.4}",
206                treaty_compliance_rate, self.thresholds.min_treaty_compliance_rate
207            ));
208        }
209
210        let passes = issues.is_empty();
211
212        Ok(TaxEvaluation {
213            tax_calculation_accuracy,
214            return_net_accuracy,
215            withholding_accuracy,
216            treaty_compliance_rate,
217            total_tax_lines: tax_lines.len(),
218            total_returns: returns.len(),
219            total_withholding: withholding.len(),
220            passes,
221            issues,
222        })
223    }
224}
225
226impl Default for TaxEvaluator {
227    fn default() -> Self {
228        Self::new()
229    }
230}
231
232#[cfg(test)]
233#[allow(clippy::unwrap_used)]
234mod tests {
235    use super::*;
236
237    #[test]
238    fn test_valid_tax_calculations() {
239        let evaluator = TaxEvaluator::new();
240        let lines = vec![
241            TaxLineData {
242                tax_code_id: "VAT20".to_string(),
243                taxable_amount: 1000.0,
244                tax_amount: 200.0,
245                rate: 0.20,
246            },
247            TaxLineData {
248                tax_code_id: "VAT10".to_string(),
249                taxable_amount: 500.0,
250                tax_amount: 50.0,
251                rate: 0.10,
252            },
253        ];
254        let returns = vec![TaxReturnData {
255            return_id: "RET001".to_string(),
256            total_output_tax: 250.0,
257            total_input_tax: 100.0,
258            net_payable: 150.0,
259        }];
260        let withholding = vec![WithholdingData {
261            record_id: "WH001".to_string(),
262            base_amount: 10000.0,
263            applied_rate: 0.10,
264            statutory_rate: 0.15,
265            withheld_amount: 1000.0,
266            has_treaty: true,
267        }];
268
269        let result = evaluator.evaluate(&lines, &returns, &withholding).unwrap();
270        assert!(result.passes);
271        assert_eq!(result.total_tax_lines, 2);
272        assert_eq!(result.total_returns, 1);
273        assert_eq!(result.treaty_compliance_rate, 1.0);
274    }
275
276    #[test]
277    fn test_wrong_tax_amount() {
278        let evaluator = TaxEvaluator::new();
279        let lines = vec![TaxLineData {
280            tax_code_id: "VAT20".to_string(),
281            taxable_amount: 1000.0,
282            tax_amount: 300.0, // Wrong: should be 200.0
283            rate: 0.20,
284        }];
285
286        let result = evaluator.evaluate(&lines, &[], &[]).unwrap();
287        assert!(!result.passes);
288        assert!(result.issues[0].contains("Tax calculation accuracy"));
289    }
290
291    #[test]
292    fn test_wrong_net_payable() {
293        let evaluator = TaxEvaluator::new();
294        let returns = vec![TaxReturnData {
295            return_id: "RET001".to_string(),
296            total_output_tax: 250.0,
297            total_input_tax: 100.0,
298            net_payable: 200.0, // Wrong: should be 150.0
299        }];
300
301        let result = evaluator.evaluate(&[], &returns, &[]).unwrap();
302        assert!(!result.passes);
303        assert!(result.issues[0].contains("Return net payable"));
304    }
305
306    #[test]
307    fn test_treaty_violation() {
308        let evaluator = TaxEvaluator::new();
309        let withholding = vec![WithholdingData {
310            record_id: "WH001".to_string(),
311            base_amount: 10000.0,
312            applied_rate: 0.20, // Higher than statutory
313            statutory_rate: 0.15,
314            withheld_amount: 2000.0,
315            has_treaty: true,
316        }];
317
318        let result = evaluator.evaluate(&[], &[], &withholding).unwrap();
319        assert!(!result.passes);
320        assert!(result
321            .issues
322            .iter()
323            .any(|i| i.contains("Treaty compliance")));
324    }
325
326    #[test]
327    fn test_empty_data() {
328        let evaluator = TaxEvaluator::new();
329        let result = evaluator.evaluate(&[], &[], &[]).unwrap();
330        assert!(result.passes);
331    }
332}