datasynth_eval/coherence/
tax.rs1use crate::error::EvalResult;
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct TaxThresholds {
12 pub min_tax_calculation_accuracy: f64,
14 pub rate_tolerance: f64,
16 pub min_return_accuracy: f64,
18 pub min_withholding_accuracy: f64,
20 pub min_treaty_compliance_rate: f64,
22}
23
24impl Default for TaxThresholds {
25 fn default() -> Self {
26 Self {
27 min_tax_calculation_accuracy: 0.999,
28 rate_tolerance: 0.001,
29 min_return_accuracy: 0.95,
30 min_withholding_accuracy: 0.999,
31 min_treaty_compliance_rate: 0.95,
32 }
33 }
34}
35
36#[derive(Debug, Clone)]
38pub struct TaxLineData {
39 pub tax_code_id: String,
41 pub taxable_amount: f64,
43 pub tax_amount: f64,
45 pub rate: f64,
47}
48
49#[derive(Debug, Clone)]
51pub struct TaxReturnData {
52 pub return_id: String,
54 pub total_output_tax: f64,
56 pub total_input_tax: f64,
58 pub net_payable: f64,
60}
61
62#[derive(Debug, Clone)]
64pub struct WithholdingData {
65 pub record_id: String,
67 pub base_amount: f64,
69 pub applied_rate: f64,
71 pub statutory_rate: f64,
73 pub withheld_amount: f64,
75 pub has_treaty: bool,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct TaxEvaluation {
82 pub tax_calculation_accuracy: f64,
84 pub return_net_accuracy: f64,
86 pub withholding_accuracy: f64,
88 pub treaty_compliance_rate: f64,
90 pub total_tax_lines: usize,
92 pub total_returns: usize,
94 pub total_withholding: usize,
96 pub passes: bool,
98 pub issues: Vec<String>,
100}
101
102pub struct TaxEvaluator {
104 thresholds: TaxThresholds,
105}
106
107impl TaxEvaluator {
108 pub fn new() -> Self {
110 Self {
111 thresholds: TaxThresholds::default(),
112 }
113 }
114
115 pub fn with_thresholds(thresholds: TaxThresholds) -> Self {
117 Self { thresholds }
118 }
119
120 pub fn evaluate(
122 &self,
123 tax_lines: &[TaxLineData],
124 returns: &[TaxReturnData],
125 withholding: &[WithholdingData],
126 ) -> EvalResult<TaxEvaluation> {
127 let mut issues = Vec::new();
128 let tolerance = self.thresholds.rate_tolerance;
129
130 let tax_ok = tax_lines
132 .iter()
133 .filter(|t| {
134 let expected = t.taxable_amount * t.rate;
135 (t.tax_amount - expected).abs() <= tolerance * t.taxable_amount.abs().max(1.0)
136 })
137 .count();
138 let tax_calculation_accuracy = if tax_lines.is_empty() {
139 1.0
140 } else {
141 tax_ok as f64 / tax_lines.len() as f64
142 };
143
144 let return_ok = returns
146 .iter()
147 .filter(|r| {
148 let expected = r.total_output_tax - r.total_input_tax;
149 (r.net_payable - expected).abs() <= tolerance * r.total_output_tax.abs().max(1.0)
150 })
151 .count();
152 let return_net_accuracy = if returns.is_empty() {
153 1.0
154 } else {
155 return_ok as f64 / returns.len() as f64
156 };
157
158 let wh_ok = withholding
160 .iter()
161 .filter(|w| {
162 let expected = w.base_amount * w.applied_rate;
163 (w.withheld_amount - expected).abs() <= tolerance * w.base_amount.abs().max(1.0)
164 })
165 .count();
166 let withholding_accuracy = if withholding.is_empty() {
167 1.0
168 } else {
169 wh_ok as f64 / withholding.len() as f64
170 };
171
172 let treaty_records: Vec<_> = withholding.iter().filter(|w| w.has_treaty).collect();
174 let treaty_ok = treaty_records
175 .iter()
176 .filter(|w| w.applied_rate <= w.statutory_rate + tolerance)
177 .count();
178 let treaty_compliance_rate = if treaty_records.is_empty() {
179 1.0
180 } else {
181 treaty_ok as f64 / treaty_records.len() as f64
182 };
183
184 if tax_calculation_accuracy < self.thresholds.min_tax_calculation_accuracy {
186 issues.push(format!(
187 "Tax calculation accuracy {:.4} < {:.4}",
188 tax_calculation_accuracy, self.thresholds.min_tax_calculation_accuracy
189 ));
190 }
191 if return_net_accuracy < self.thresholds.min_return_accuracy {
192 issues.push(format!(
193 "Return net payable accuracy {:.4} < {:.4}",
194 return_net_accuracy, self.thresholds.min_return_accuracy
195 ));
196 }
197 if withholding_accuracy < self.thresholds.min_withholding_accuracy {
198 issues.push(format!(
199 "Withholding accuracy {:.4} < {:.4}",
200 withholding_accuracy, self.thresholds.min_withholding_accuracy
201 ));
202 }
203 if treaty_compliance_rate < self.thresholds.min_treaty_compliance_rate {
204 issues.push(format!(
205 "Treaty compliance rate {:.4} < {:.4}",
206 treaty_compliance_rate, self.thresholds.min_treaty_compliance_rate
207 ));
208 }
209
210 let passes = issues.is_empty();
211
212 Ok(TaxEvaluation {
213 tax_calculation_accuracy,
214 return_net_accuracy,
215 withholding_accuracy,
216 treaty_compliance_rate,
217 total_tax_lines: tax_lines.len(),
218 total_returns: returns.len(),
219 total_withholding: withholding.len(),
220 passes,
221 issues,
222 })
223 }
224}
225
226impl Default for TaxEvaluator {
227 fn default() -> Self {
228 Self::new()
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235
236 #[test]
237 fn test_valid_tax_calculations() {
238 let evaluator = TaxEvaluator::new();
239 let lines = vec![
240 TaxLineData {
241 tax_code_id: "VAT20".to_string(),
242 taxable_amount: 1000.0,
243 tax_amount: 200.0,
244 rate: 0.20,
245 },
246 TaxLineData {
247 tax_code_id: "VAT10".to_string(),
248 taxable_amount: 500.0,
249 tax_amount: 50.0,
250 rate: 0.10,
251 },
252 ];
253 let returns = vec![TaxReturnData {
254 return_id: "RET001".to_string(),
255 total_output_tax: 250.0,
256 total_input_tax: 100.0,
257 net_payable: 150.0,
258 }];
259 let withholding = vec![WithholdingData {
260 record_id: "WH001".to_string(),
261 base_amount: 10000.0,
262 applied_rate: 0.10,
263 statutory_rate: 0.15,
264 withheld_amount: 1000.0,
265 has_treaty: true,
266 }];
267
268 let result = evaluator.evaluate(&lines, &returns, &withholding).unwrap();
269 assert!(result.passes);
270 assert_eq!(result.total_tax_lines, 2);
271 assert_eq!(result.total_returns, 1);
272 assert_eq!(result.treaty_compliance_rate, 1.0);
273 }
274
275 #[test]
276 fn test_wrong_tax_amount() {
277 let evaluator = TaxEvaluator::new();
278 let lines = vec![TaxLineData {
279 tax_code_id: "VAT20".to_string(),
280 taxable_amount: 1000.0,
281 tax_amount: 300.0, rate: 0.20,
283 }];
284
285 let result = evaluator.evaluate(&lines, &[], &[]).unwrap();
286 assert!(!result.passes);
287 assert!(result.issues[0].contains("Tax calculation accuracy"));
288 }
289
290 #[test]
291 fn test_wrong_net_payable() {
292 let evaluator = TaxEvaluator::new();
293 let returns = vec![TaxReturnData {
294 return_id: "RET001".to_string(),
295 total_output_tax: 250.0,
296 total_input_tax: 100.0,
297 net_payable: 200.0, }];
299
300 let result = evaluator.evaluate(&[], &returns, &[]).unwrap();
301 assert!(!result.passes);
302 assert!(result.issues[0].contains("Return net payable"));
303 }
304
305 #[test]
306 fn test_treaty_violation() {
307 let evaluator = TaxEvaluator::new();
308 let withholding = vec![WithholdingData {
309 record_id: "WH001".to_string(),
310 base_amount: 10000.0,
311 applied_rate: 0.20, statutory_rate: 0.15,
313 withheld_amount: 2000.0,
314 has_treaty: true,
315 }];
316
317 let result = evaluator.evaluate(&[], &[], &withholding).unwrap();
318 assert!(!result.passes);
319 assert!(result
320 .issues
321 .iter()
322 .any(|i| i.contains("Treaty compliance")));
323 }
324
325 #[test]
326 fn test_empty_data() {
327 let evaluator = TaxEvaluator::new();
328 let result = evaluator.evaluate(&[], &[], &[]).unwrap();
329 assert!(result.passes);
330 }
331}