1use datasynth_core::utils::seeded_rng;
9use rand::prelude::*;
10use rand_chacha::ChaCha8Rng;
11use rust_decimal::Decimal;
12use rust_decimal_macros::dec;
13use std::collections::HashMap;
14
15use datasynth_core::models::{WithholdingTaxRecord, WithholdingType};
16
17pub struct WithholdingGenerator {
38 rng: ChaCha8Rng,
39 treaty_rates: HashMap<(String, String), Decimal>,
41 default_rate: Decimal,
43 counter: u64,
44}
45
46impl WithholdingGenerator {
47 pub fn new(seed: u64, default_rate: Decimal) -> Self {
52 Self {
53 rng: seeded_rng(seed, 0),
54 treaty_rates: HashMap::new(),
55 default_rate,
56 counter: 0,
57 }
58 }
59
60 pub fn add_treaty_rate(&mut self, source_country: &str, vendor_country: &str, rate: Decimal) {
66 self.treaty_rates.insert(
67 (source_country.to_string(), vendor_country.to_string()),
68 rate,
69 );
70 }
71
72 pub fn with_standard_treaties(mut self) -> Self {
83 let treaties = [
84 ("US", "GB", dec!(0.00)),
85 ("US", "DE", dec!(0.00)),
86 ("US", "JP", dec!(0.00)),
87 ("US", "FR", dec!(0.00)),
88 ("US", "SG", dec!(0.00)),
89 ("US", "IN", dec!(0.15)),
90 ("US", "BR", dec!(0.15)),
91 ];
92
93 for (source, vendor, rate) in &treaties {
94 self.treaty_rates
95 .insert((source.to_string(), vendor.to_string()), *rate);
96 }
97
98 self
99 }
100
101 pub fn generate(
112 &mut self,
113 payments: &[(String, String, String, Decimal)],
114 source_country: &str,
115 ) -> Vec<WithholdingTaxRecord> {
116 let mut records = Vec::new();
117
118 for (payment_id, vendor_id, vendor_country, amount) in payments {
119 if vendor_country == source_country {
121 continue;
122 }
123
124 let key = (source_country.to_string(), vendor_country.clone());
125 let (applied_rate, treaty_rate) = match self.treaty_rates.get(&key) {
126 Some(&rate) => (rate, Some(rate)),
127 None => (self.default_rate, None),
128 };
129
130 self.counter += 1;
131 let record_id = format!("WHT-{:06}", self.counter);
132
133 let cert_suffix: u32 = self.rng.random_range(100_000..999_999);
135 let cert_number = format!("CERT-{}-{cert_suffix}", &record_id);
136
137 let mut record = WithholdingTaxRecord::new(
138 record_id,
139 payment_id,
140 vendor_id,
141 WithholdingType::ServiceWithholding,
142 self.default_rate,
143 applied_rate,
144 *amount,
145 )
146 .with_certificate_number(cert_number);
147
148 if let Some(rate) = treaty_rate {
149 record = record.with_treaty_rate(rate);
150 }
151
152 records.push(record);
153 }
154
155 records
156 }
157}
158
159#[cfg(test)]
164mod tests {
165 use super::*;
166
167 fn payment(
168 id: &str,
169 vendor_id: &str,
170 vendor_country: &str,
171 amount: Decimal,
172 ) -> (String, String, String, Decimal) {
173 (
174 id.to_string(),
175 vendor_id.to_string(),
176 vendor_country.to_string(),
177 amount,
178 )
179 }
180
181 #[test]
182 fn test_with_treaty_rate() {
183 let mut gen = WithholdingGenerator::new(42, dec!(0.30)).with_standard_treaties();
184
185 let payments = vec![payment("PAY-001", "V-GB-01", "GB", dec!(100000))];
186
187 let records = gen.generate(&payments, "US");
188
189 assert_eq!(records.len(), 1);
190 let rec = &records[0];
191 assert_eq!(rec.vendor_id, "V-GB-01");
192 assert_eq!(rec.applied_rate, dec!(0.00));
193 assert_eq!(rec.treaty_rate, Some(dec!(0.00)));
194 assert_eq!(rec.withheld_amount, dec!(0.00));
195 assert_eq!(rec.statutory_rate, dec!(0.30));
196 assert!(rec.has_treaty_benefit());
197 }
198
199 #[test]
200 fn test_without_treaty() {
201 let mut gen = WithholdingGenerator::new(42, dec!(0.30)).with_standard_treaties();
202
203 let payments = vec![payment("PAY-002", "V-ZZ-01", "ZZ", dec!(50000))];
205
206 let records = gen.generate(&payments, "US");
207
208 assert_eq!(records.len(), 1);
209 let rec = &records[0];
210 assert_eq!(rec.applied_rate, dec!(0.30));
211 assert_eq!(rec.treaty_rate, None);
212 assert_eq!(rec.withheld_amount, dec!(15000.00));
213 assert!(!rec.has_treaty_benefit());
214 }
215
216 #[test]
217 fn test_standard_treaties() {
218 let gen = WithholdingGenerator::new(42, dec!(0.30)).with_standard_treaties();
219
220 assert_eq!(gen.treaty_rates.len(), 7);
222
223 assert_eq!(
224 gen.treaty_rates.get(&("US".to_string(), "GB".to_string())),
225 Some(&dec!(0.00))
226 );
227 assert_eq!(
228 gen.treaty_rates.get(&("US".to_string(), "DE".to_string())),
229 Some(&dec!(0.00))
230 );
231 assert_eq!(
232 gen.treaty_rates.get(&("US".to_string(), "JP".to_string())),
233 Some(&dec!(0.00))
234 );
235 assert_eq!(
236 gen.treaty_rates.get(&("US".to_string(), "FR".to_string())),
237 Some(&dec!(0.00))
238 );
239 assert_eq!(
240 gen.treaty_rates.get(&("US".to_string(), "SG".to_string())),
241 Some(&dec!(0.00))
242 );
243 assert_eq!(
244 gen.treaty_rates.get(&("US".to_string(), "IN".to_string())),
245 Some(&dec!(0.15))
246 );
247 assert_eq!(
248 gen.treaty_rates.get(&("US".to_string(), "BR".to_string())),
249 Some(&dec!(0.15))
250 );
251 }
252
253 #[test]
254 fn test_domestic_excluded() {
255 let mut gen = WithholdingGenerator::new(42, dec!(0.30)).with_standard_treaties();
256
257 let payments = vec![
258 payment("PAY-DOM", "V-US-01", "US", dec!(100000)),
259 payment("PAY-XB", "V-GB-01", "GB", dec!(50000)),
260 ];
261
262 let records = gen.generate(&payments, "US");
263
264 assert_eq!(records.len(), 1);
266 assert_eq!(records[0].payment_id, "PAY-XB");
267 }
268
269 #[test]
270 fn test_deterministic() {
271 let payments = vec![
272 payment("PAY-001", "V-GB-01", "GB", dec!(100000)),
273 payment("PAY-002", "V-IN-01", "IN", dec!(50000)),
274 payment("PAY-003", "V-ZZ-01", "ZZ", dec!(25000)),
275 ];
276
277 let mut gen1 = WithholdingGenerator::new(12345, dec!(0.30)).with_standard_treaties();
278 let records1 = gen1.generate(&payments, "US");
279
280 let mut gen2 = WithholdingGenerator::new(12345, dec!(0.30)).with_standard_treaties();
281 let records2 = gen2.generate(&payments, "US");
282
283 assert_eq!(records1.len(), records2.len());
284 for (r1, r2) in records1.iter().zip(records2.iter()) {
285 assert_eq!(r1.id, r2.id);
286 assert_eq!(r1.payment_id, r2.payment_id);
287 assert_eq!(r1.vendor_id, r2.vendor_id);
288 assert_eq!(r1.applied_rate, r2.applied_rate);
289 assert_eq!(r1.treaty_rate, r2.treaty_rate);
290 assert_eq!(r1.withheld_amount, r2.withheld_amount);
291 assert_eq!(r1.certificate_number, r2.certificate_number);
292 }
293 }
294
295 #[test]
296 fn test_treaty_with_nonzero_rate() {
297 let mut gen = WithholdingGenerator::new(42, dec!(0.30)).with_standard_treaties();
298
299 let payments = vec![payment("PAY-IN", "V-IN-01", "IN", dec!(100000))];
301
302 let records = gen.generate(&payments, "US");
303
304 assert_eq!(records.len(), 1);
305 let rec = &records[0];
306 assert_eq!(rec.applied_rate, dec!(0.15));
307 assert_eq!(rec.treaty_rate, Some(dec!(0.15)));
308 assert_eq!(rec.withheld_amount, dec!(15000.00));
309 assert_eq!(rec.statutory_rate, dec!(0.30));
310 assert!(
311 rec.has_treaty_benefit(),
312 "15% treaty rate is less than 30% statutory"
313 );
314 assert_eq!(rec.treaty_savings(), dec!(15000.00));
315 }
316
317 #[test]
318 fn test_custom_treaty_rate() {
319 let mut gen = WithholdingGenerator::new(42, dec!(0.25));
320 gen.add_treaty_rate("DE", "US", dec!(0.05));
321
322 let payments = vec![payment("PAY-001", "V-US-01", "US", dec!(200000))];
323
324 let records = gen.generate(&payments, "DE");
325
326 assert_eq!(records.len(), 1);
327 let rec = &records[0];
328 assert_eq!(rec.applied_rate, dec!(0.05));
329 assert_eq!(rec.treaty_rate, Some(dec!(0.05)));
330 assert_eq!(rec.withheld_amount, dec!(10000.00));
331 assert_eq!(rec.statutory_rate, dec!(0.25));
332 }
333}