datasynth_eval/coherence/
tax.rs1use crate::error::EvalResult;
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct TaxThresholds {
12 pub min_tax_calculation_accuracy: f64,
14 pub rate_tolerance: f64,
16 pub min_return_accuracy: f64,
18 pub min_withholding_accuracy: f64,
20 pub min_treaty_compliance_rate: f64,
22}
23
24impl Default for TaxThresholds {
25 fn default() -> Self {
26 Self {
27 min_tax_calculation_accuracy: 0.999,
28 rate_tolerance: 0.001,
29 min_return_accuracy: 0.95,
30 min_withholding_accuracy: 0.999,
31 min_treaty_compliance_rate: 0.95,
32 }
33 }
34}
35
36#[derive(Debug, Clone)]
38pub struct TaxLineData {
39 pub tax_code_id: String,
41 pub taxable_amount: f64,
43 pub tax_amount: f64,
45 pub rate: f64,
47}
48
49#[derive(Debug, Clone)]
51pub struct TaxReturnData {
52 pub return_id: String,
54 pub total_output_tax: f64,
56 pub total_input_tax: f64,
58 pub net_payable: f64,
60}
61
62#[derive(Debug, Clone)]
64pub struct WithholdingData {
65 pub record_id: String,
67 pub base_amount: f64,
69 pub applied_rate: f64,
71 pub statutory_rate: f64,
73 pub withheld_amount: f64,
75 pub has_treaty: bool,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct TaxEvaluation {
82 pub tax_calculation_accuracy: f64,
84 pub return_net_accuracy: f64,
86 pub withholding_accuracy: f64,
88 pub treaty_compliance_rate: f64,
90 pub total_tax_lines: usize,
92 pub total_returns: usize,
94 pub total_withholding: usize,
96 pub passes: bool,
98 pub issues: Vec<String>,
100}
101
102pub struct TaxEvaluator {
104 thresholds: TaxThresholds,
105}
106
107impl TaxEvaluator {
108 pub fn new() -> Self {
110 Self {
111 thresholds: TaxThresholds::default(),
112 }
113 }
114
115 pub fn with_thresholds(thresholds: TaxThresholds) -> Self {
117 Self { thresholds }
118 }
119
120 pub fn evaluate(
122 &self,
123 tax_lines: &[TaxLineData],
124 returns: &[TaxReturnData],
125 withholding: &[WithholdingData],
126 ) -> EvalResult<TaxEvaluation> {
127 let mut issues = Vec::new();
128 let tolerance = self.thresholds.rate_tolerance;
129
130 let tax_ok = tax_lines
132 .iter()
133 .filter(|t| {
134 let expected = t.taxable_amount * t.rate;
135 (t.tax_amount - expected).abs() <= tolerance * t.taxable_amount.abs().max(1.0)
136 })
137 .count();
138 let tax_calculation_accuracy = if tax_lines.is_empty() {
139 1.0
140 } else {
141 tax_ok as f64 / tax_lines.len() as f64
142 };
143
144 let return_ok = returns
146 .iter()
147 .filter(|r| {
148 let expected = r.total_output_tax - r.total_input_tax;
149 (r.net_payable - expected).abs() <= tolerance * r.total_output_tax.abs().max(1.0)
150 })
151 .count();
152 let return_net_accuracy = if returns.is_empty() {
153 1.0
154 } else {
155 return_ok as f64 / returns.len() as f64
156 };
157
158 let wh_ok = withholding
160 .iter()
161 .filter(|w| {
162 let expected = w.base_amount * w.applied_rate;
163 (w.withheld_amount - expected).abs() <= tolerance * w.base_amount.abs().max(1.0)
164 })
165 .count();
166 let withholding_accuracy = if withholding.is_empty() {
167 1.0
168 } else {
169 wh_ok as f64 / withholding.len() as f64
170 };
171
172 let treaty_records: Vec<_> = withholding.iter().filter(|w| w.has_treaty).collect();
174 let treaty_ok = treaty_records
175 .iter()
176 .filter(|w| w.applied_rate <= w.statutory_rate + tolerance)
177 .count();
178 let treaty_compliance_rate = if treaty_records.is_empty() {
179 1.0
180 } else {
181 treaty_ok as f64 / treaty_records.len() as f64
182 };
183
184 if tax_calculation_accuracy < self.thresholds.min_tax_calculation_accuracy {
186 issues.push(format!(
187 "Tax calculation accuracy {:.4} < {:.4}",
188 tax_calculation_accuracy, self.thresholds.min_tax_calculation_accuracy
189 ));
190 }
191 if return_net_accuracy < self.thresholds.min_return_accuracy {
192 issues.push(format!(
193 "Return net payable accuracy {:.4} < {:.4}",
194 return_net_accuracy, self.thresholds.min_return_accuracy
195 ));
196 }
197 if withholding_accuracy < self.thresholds.min_withholding_accuracy {
198 issues.push(format!(
199 "Withholding accuracy {:.4} < {:.4}",
200 withholding_accuracy, self.thresholds.min_withholding_accuracy
201 ));
202 }
203 if treaty_compliance_rate < self.thresholds.min_treaty_compliance_rate {
204 issues.push(format!(
205 "Treaty compliance rate {:.4} < {:.4}",
206 treaty_compliance_rate, self.thresholds.min_treaty_compliance_rate
207 ));
208 }
209
210 let passes = issues.is_empty();
211
212 Ok(TaxEvaluation {
213 tax_calculation_accuracy,
214 return_net_accuracy,
215 withholding_accuracy,
216 treaty_compliance_rate,
217 total_tax_lines: tax_lines.len(),
218 total_returns: returns.len(),
219 total_withholding: withholding.len(),
220 passes,
221 issues,
222 })
223 }
224}
225
226impl Default for TaxEvaluator {
227 fn default() -> Self {
228 Self::new()
229 }
230}
231
232#[cfg(test)]
233#[allow(clippy::unwrap_used)]
234mod tests {
235 use super::*;
236
237 #[test]
238 fn test_valid_tax_calculations() {
239 let evaluator = TaxEvaluator::new();
240 let lines = vec![
241 TaxLineData {
242 tax_code_id: "VAT20".to_string(),
243 taxable_amount: 1000.0,
244 tax_amount: 200.0,
245 rate: 0.20,
246 },
247 TaxLineData {
248 tax_code_id: "VAT10".to_string(),
249 taxable_amount: 500.0,
250 tax_amount: 50.0,
251 rate: 0.10,
252 },
253 ];
254 let returns = vec![TaxReturnData {
255 return_id: "RET001".to_string(),
256 total_output_tax: 250.0,
257 total_input_tax: 100.0,
258 net_payable: 150.0,
259 }];
260 let withholding = vec![WithholdingData {
261 record_id: "WH001".to_string(),
262 base_amount: 10000.0,
263 applied_rate: 0.10,
264 statutory_rate: 0.15,
265 withheld_amount: 1000.0,
266 has_treaty: true,
267 }];
268
269 let result = evaluator.evaluate(&lines, &returns, &withholding).unwrap();
270 assert!(result.passes);
271 assert_eq!(result.total_tax_lines, 2);
272 assert_eq!(result.total_returns, 1);
273 assert_eq!(result.treaty_compliance_rate, 1.0);
274 }
275
276 #[test]
277 fn test_wrong_tax_amount() {
278 let evaluator = TaxEvaluator::new();
279 let lines = vec![TaxLineData {
280 tax_code_id: "VAT20".to_string(),
281 taxable_amount: 1000.0,
282 tax_amount: 300.0, rate: 0.20,
284 }];
285
286 let result = evaluator.evaluate(&lines, &[], &[]).unwrap();
287 assert!(!result.passes);
288 assert!(result.issues[0].contains("Tax calculation accuracy"));
289 }
290
291 #[test]
292 fn test_wrong_net_payable() {
293 let evaluator = TaxEvaluator::new();
294 let returns = vec![TaxReturnData {
295 return_id: "RET001".to_string(),
296 total_output_tax: 250.0,
297 total_input_tax: 100.0,
298 net_payable: 200.0, }];
300
301 let result = evaluator.evaluate(&[], &returns, &[]).unwrap();
302 assert!(!result.passes);
303 assert!(result.issues[0].contains("Return net payable"));
304 }
305
306 #[test]
307 fn test_treaty_violation() {
308 let evaluator = TaxEvaluator::new();
309 let withholding = vec![WithholdingData {
310 record_id: "WH001".to_string(),
311 base_amount: 10000.0,
312 applied_rate: 0.20, statutory_rate: 0.15,
314 withheld_amount: 2000.0,
315 has_treaty: true,
316 }];
317
318 let result = evaluator.evaluate(&[], &[], &withholding).unwrap();
319 assert!(!result.passes);
320 assert!(result
321 .issues
322 .iter()
323 .any(|i| i.contains("Treaty compliance")));
324 }
325
326 #[test]
327 fn test_empty_data() {
328 let evaluator = TaxEvaluator::new();
329 let result = evaluator.evaluate(&[], &[], &[]).unwrap();
330 assert!(result.passes);
331 }
332}