1use datasynth_core::utils::seeded_rng;
9use rand::prelude::*;
10use rand_chacha::ChaCha8Rng;
11use rust_decimal::Decimal;
12use rust_decimal_macros::dec;
13use std::collections::HashMap;
14
15use datasynth_core::models::{WithholdingTaxRecord, WithholdingType};
16
17pub struct WithholdingGenerator {
38 rng: ChaCha8Rng,
39 treaty_rates: HashMap<(String, String), Decimal>,
41 default_rate: Decimal,
43 counter: u64,
44}
45
46impl WithholdingGenerator {
47 pub fn new(seed: u64, default_rate: Decimal) -> Self {
52 Self {
53 rng: seeded_rng(seed, 0),
54 treaty_rates: HashMap::new(),
55 default_rate,
56 counter: 0,
57 }
58 }
59
60 pub fn add_treaty_rate(&mut self, source_country: &str, vendor_country: &str, rate: Decimal) {
66 self.treaty_rates.insert(
67 (source_country.to_string(), vendor_country.to_string()),
68 rate,
69 );
70 }
71
72 pub fn with_standard_treaties(mut self) -> Self {
83 let treaties = [
84 ("US", "GB", dec!(0.00)),
85 ("US", "DE", dec!(0.00)),
86 ("US", "JP", dec!(0.00)),
87 ("US", "FR", dec!(0.00)),
88 ("US", "SG", dec!(0.00)),
89 ("US", "IN", dec!(0.15)),
90 ("US", "BR", dec!(0.15)),
91 ];
92
93 for (source, vendor, rate) in &treaties {
94 self.treaty_rates
95 .insert((source.to_string(), vendor.to_string()), *rate);
96 }
97
98 self
99 }
100
101 pub fn generate(
112 &mut self,
113 payments: &[(String, String, String, Decimal)],
114 source_country: &str,
115 ) -> Vec<WithholdingTaxRecord> {
116 let mut records = Vec::new();
117
118 for (payment_id, vendor_id, vendor_country, amount) in payments {
119 if vendor_country == source_country {
121 continue;
122 }
123
124 let key = (source_country.to_string(), vendor_country.clone());
125 let (applied_rate, treaty_rate) = match self.treaty_rates.get(&key) {
126 Some(&rate) => (rate, Some(rate)),
127 None => (self.default_rate, None),
128 };
129
130 self.counter += 1;
131 let record_id = format!("WHT-{:06}", self.counter);
132
133 let cert_suffix: u32 = self.rng.gen_range(100_000..999_999);
135 let cert_number = format!("CERT-{}-{cert_suffix}", &record_id);
136
137 let mut record = WithholdingTaxRecord::new(
138 record_id,
139 payment_id,
140 vendor_id,
141 WithholdingType::ServiceWithholding,
142 self.default_rate,
143 applied_rate,
144 *amount,
145 )
146 .with_certificate_number(cert_number);
147
148 if let Some(rate) = treaty_rate {
149 record = record.with_treaty_rate(rate);
150 }
151
152 records.push(record);
153 }
154
155 records
156 }
157}
158
159#[cfg(test)]
164#[allow(clippy::unwrap_used)]
165mod tests {
166 use super::*;
167
168 fn payment(
169 id: &str,
170 vendor_id: &str,
171 vendor_country: &str,
172 amount: Decimal,
173 ) -> (String, String, String, Decimal) {
174 (
175 id.to_string(),
176 vendor_id.to_string(),
177 vendor_country.to_string(),
178 amount,
179 )
180 }
181
182 #[test]
183 fn test_with_treaty_rate() {
184 let mut gen = WithholdingGenerator::new(42, dec!(0.30)).with_standard_treaties();
185
186 let payments = vec![payment("PAY-001", "V-GB-01", "GB", dec!(100000))];
187
188 let records = gen.generate(&payments, "US");
189
190 assert_eq!(records.len(), 1);
191 let rec = &records[0];
192 assert_eq!(rec.vendor_id, "V-GB-01");
193 assert_eq!(rec.applied_rate, dec!(0.00));
194 assert_eq!(rec.treaty_rate, Some(dec!(0.00)));
195 assert_eq!(rec.withheld_amount, dec!(0.00));
196 assert_eq!(rec.statutory_rate, dec!(0.30));
197 assert!(rec.has_treaty_benefit());
198 }
199
200 #[test]
201 fn test_without_treaty() {
202 let mut gen = WithholdingGenerator::new(42, dec!(0.30)).with_standard_treaties();
203
204 let payments = vec![payment("PAY-002", "V-ZZ-01", "ZZ", dec!(50000))];
206
207 let records = gen.generate(&payments, "US");
208
209 assert_eq!(records.len(), 1);
210 let rec = &records[0];
211 assert_eq!(rec.applied_rate, dec!(0.30));
212 assert_eq!(rec.treaty_rate, None);
213 assert_eq!(rec.withheld_amount, dec!(15000.00));
214 assert!(!rec.has_treaty_benefit());
215 }
216
217 #[test]
218 fn test_standard_treaties() {
219 let gen = WithholdingGenerator::new(42, dec!(0.30)).with_standard_treaties();
220
221 assert_eq!(gen.treaty_rates.len(), 7);
223
224 assert_eq!(
225 gen.treaty_rates.get(&("US".to_string(), "GB".to_string())),
226 Some(&dec!(0.00))
227 );
228 assert_eq!(
229 gen.treaty_rates.get(&("US".to_string(), "DE".to_string())),
230 Some(&dec!(0.00))
231 );
232 assert_eq!(
233 gen.treaty_rates.get(&("US".to_string(), "JP".to_string())),
234 Some(&dec!(0.00))
235 );
236 assert_eq!(
237 gen.treaty_rates.get(&("US".to_string(), "FR".to_string())),
238 Some(&dec!(0.00))
239 );
240 assert_eq!(
241 gen.treaty_rates.get(&("US".to_string(), "SG".to_string())),
242 Some(&dec!(0.00))
243 );
244 assert_eq!(
245 gen.treaty_rates.get(&("US".to_string(), "IN".to_string())),
246 Some(&dec!(0.15))
247 );
248 assert_eq!(
249 gen.treaty_rates.get(&("US".to_string(), "BR".to_string())),
250 Some(&dec!(0.15))
251 );
252 }
253
254 #[test]
255 fn test_domestic_excluded() {
256 let mut gen = WithholdingGenerator::new(42, dec!(0.30)).with_standard_treaties();
257
258 let payments = vec![
259 payment("PAY-DOM", "V-US-01", "US", dec!(100000)),
260 payment("PAY-XB", "V-GB-01", "GB", dec!(50000)),
261 ];
262
263 let records = gen.generate(&payments, "US");
264
265 assert_eq!(records.len(), 1);
267 assert_eq!(records[0].payment_id, "PAY-XB");
268 }
269
270 #[test]
271 fn test_deterministic() {
272 let payments = vec![
273 payment("PAY-001", "V-GB-01", "GB", dec!(100000)),
274 payment("PAY-002", "V-IN-01", "IN", dec!(50000)),
275 payment("PAY-003", "V-ZZ-01", "ZZ", dec!(25000)),
276 ];
277
278 let mut gen1 = WithholdingGenerator::new(12345, dec!(0.30)).with_standard_treaties();
279 let records1 = gen1.generate(&payments, "US");
280
281 let mut gen2 = WithholdingGenerator::new(12345, dec!(0.30)).with_standard_treaties();
282 let records2 = gen2.generate(&payments, "US");
283
284 assert_eq!(records1.len(), records2.len());
285 for (r1, r2) in records1.iter().zip(records2.iter()) {
286 assert_eq!(r1.id, r2.id);
287 assert_eq!(r1.payment_id, r2.payment_id);
288 assert_eq!(r1.vendor_id, r2.vendor_id);
289 assert_eq!(r1.applied_rate, r2.applied_rate);
290 assert_eq!(r1.treaty_rate, r2.treaty_rate);
291 assert_eq!(r1.withheld_amount, r2.withheld_amount);
292 assert_eq!(r1.certificate_number, r2.certificate_number);
293 }
294 }
295
296 #[test]
297 fn test_treaty_with_nonzero_rate() {
298 let mut gen = WithholdingGenerator::new(42, dec!(0.30)).with_standard_treaties();
299
300 let payments = vec![payment("PAY-IN", "V-IN-01", "IN", dec!(100000))];
302
303 let records = gen.generate(&payments, "US");
304
305 assert_eq!(records.len(), 1);
306 let rec = &records[0];
307 assert_eq!(rec.applied_rate, dec!(0.15));
308 assert_eq!(rec.treaty_rate, Some(dec!(0.15)));
309 assert_eq!(rec.withheld_amount, dec!(15000.00));
310 assert_eq!(rec.statutory_rate, dec!(0.30));
311 assert!(
312 rec.has_treaty_benefit(),
313 "15% treaty rate is less than 30% statutory"
314 );
315 assert_eq!(rec.treaty_savings(), dec!(15000.00));
316 }
317
318 #[test]
319 fn test_custom_treaty_rate() {
320 let mut gen = WithholdingGenerator::new(42, dec!(0.25));
321 gen.add_treaty_rate("DE", "US", dec!(0.05));
322
323 let payments = vec![payment("PAY-001", "V-US-01", "US", dec!(200000))];
324
325 let records = gen.generate(&payments, "DE");
326
327 assert_eq!(records.len(), 1);
328 let rec = &records[0];
329 assert_eq!(rec.applied_rate, dec!(0.05));
330 assert_eq!(rec.treaty_rate, Some(dec!(0.05)));
331 assert_eq!(rec.withheld_amount, dec!(10000.00));
332 assert_eq!(rec.statutory_rate, dec!(0.25));
333 }
334}