1use rand::prelude::*;
9use rand_chacha::ChaCha8Rng;
10use rust_decimal::Decimal;
11use rust_decimal_macros::dec;
12use std::collections::HashMap;
13
14use datasynth_core::models::{WithholdingTaxRecord, WithholdingType};
15
16pub struct WithholdingGenerator {
37 rng: ChaCha8Rng,
38 treaty_rates: HashMap<(String, String), Decimal>,
40 default_rate: Decimal,
42 counter: u64,
43}
44
45impl WithholdingGenerator {
46 pub fn new(seed: u64, default_rate: Decimal) -> Self {
51 Self {
52 rng: ChaCha8Rng::seed_from_u64(seed),
53 treaty_rates: HashMap::new(),
54 default_rate,
55 counter: 0,
56 }
57 }
58
59 pub fn add_treaty_rate(&mut self, source_country: &str, vendor_country: &str, rate: Decimal) {
65 self.treaty_rates.insert(
66 (source_country.to_string(), vendor_country.to_string()),
67 rate,
68 );
69 }
70
71 pub fn with_standard_treaties(mut self) -> Self {
82 let treaties = [
83 ("US", "GB", dec!(0.00)),
84 ("US", "DE", dec!(0.00)),
85 ("US", "JP", dec!(0.00)),
86 ("US", "FR", dec!(0.00)),
87 ("US", "SG", dec!(0.00)),
88 ("US", "IN", dec!(0.15)),
89 ("US", "BR", dec!(0.15)),
90 ];
91
92 for (source, vendor, rate) in &treaties {
93 self.treaty_rates
94 .insert((source.to_string(), vendor.to_string()), *rate);
95 }
96
97 self
98 }
99
100 pub fn generate(
111 &mut self,
112 payments: &[(String, String, String, Decimal)],
113 source_country: &str,
114 ) -> Vec<WithholdingTaxRecord> {
115 let mut records = Vec::new();
116
117 for (payment_id, vendor_id, vendor_country, amount) in payments {
118 if vendor_country == source_country {
120 continue;
121 }
122
123 let key = (source_country.to_string(), vendor_country.clone());
124 let (applied_rate, treaty_rate) = match self.treaty_rates.get(&key) {
125 Some(&rate) => (rate, Some(rate)),
126 None => (self.default_rate, None),
127 };
128
129 self.counter += 1;
130 let record_id = format!("WHT-{:06}", self.counter);
131
132 let cert_suffix: u32 = self.rng.gen_range(100_000..999_999);
134 let cert_number = format!("CERT-{}-{cert_suffix}", &record_id);
135
136 let mut record = WithholdingTaxRecord::new(
137 record_id,
138 payment_id,
139 vendor_id,
140 WithholdingType::ServiceWithholding,
141 self.default_rate,
142 applied_rate,
143 *amount,
144 )
145 .with_certificate_number(cert_number);
146
147 if let Some(rate) = treaty_rate {
148 record = record.with_treaty_rate(rate);
149 }
150
151 records.push(record);
152 }
153
154 records
155 }
156}
157
158#[cfg(test)]
163#[allow(clippy::unwrap_used)]
164mod tests {
165 use super::*;
166
167 fn payment(
168 id: &str,
169 vendor_id: &str,
170 vendor_country: &str,
171 amount: Decimal,
172 ) -> (String, String, String, Decimal) {
173 (
174 id.to_string(),
175 vendor_id.to_string(),
176 vendor_country.to_string(),
177 amount,
178 )
179 }
180
181 #[test]
182 fn test_with_treaty_rate() {
183 let mut gen = WithholdingGenerator::new(42, dec!(0.30)).with_standard_treaties();
184
185 let payments = vec![payment("PAY-001", "V-GB-01", "GB", dec!(100000))];
186
187 let records = gen.generate(&payments, "US");
188
189 assert_eq!(records.len(), 1);
190 let rec = &records[0];
191 assert_eq!(rec.vendor_id, "V-GB-01");
192 assert_eq!(rec.applied_rate, dec!(0.00));
193 assert_eq!(rec.treaty_rate, Some(dec!(0.00)));
194 assert_eq!(rec.withheld_amount, dec!(0.00));
195 assert_eq!(rec.statutory_rate, dec!(0.30));
196 assert!(rec.has_treaty_benefit());
197 }
198
199 #[test]
200 fn test_without_treaty() {
201 let mut gen = WithholdingGenerator::new(42, dec!(0.30)).with_standard_treaties();
202
203 let payments = vec![payment("PAY-002", "V-ZZ-01", "ZZ", dec!(50000))];
205
206 let records = gen.generate(&payments, "US");
207
208 assert_eq!(records.len(), 1);
209 let rec = &records[0];
210 assert_eq!(rec.applied_rate, dec!(0.30));
211 assert_eq!(rec.treaty_rate, None);
212 assert_eq!(rec.withheld_amount, dec!(15000.00));
213 assert!(!rec.has_treaty_benefit());
214 }
215
216 #[test]
217 fn test_standard_treaties() {
218 let gen = WithholdingGenerator::new(42, dec!(0.30)).with_standard_treaties();
219
220 assert_eq!(gen.treaty_rates.len(), 7);
222
223 assert_eq!(
224 gen.treaty_rates.get(&("US".to_string(), "GB".to_string())),
225 Some(&dec!(0.00))
226 );
227 assert_eq!(
228 gen.treaty_rates.get(&("US".to_string(), "DE".to_string())),
229 Some(&dec!(0.00))
230 );
231 assert_eq!(
232 gen.treaty_rates.get(&("US".to_string(), "JP".to_string())),
233 Some(&dec!(0.00))
234 );
235 assert_eq!(
236 gen.treaty_rates.get(&("US".to_string(), "FR".to_string())),
237 Some(&dec!(0.00))
238 );
239 assert_eq!(
240 gen.treaty_rates.get(&("US".to_string(), "SG".to_string())),
241 Some(&dec!(0.00))
242 );
243 assert_eq!(
244 gen.treaty_rates.get(&("US".to_string(), "IN".to_string())),
245 Some(&dec!(0.15))
246 );
247 assert_eq!(
248 gen.treaty_rates.get(&("US".to_string(), "BR".to_string())),
249 Some(&dec!(0.15))
250 );
251 }
252
253 #[test]
254 fn test_domestic_excluded() {
255 let mut gen = WithholdingGenerator::new(42, dec!(0.30)).with_standard_treaties();
256
257 let payments = vec![
258 payment("PAY-DOM", "V-US-01", "US", dec!(100000)),
259 payment("PAY-XB", "V-GB-01", "GB", dec!(50000)),
260 ];
261
262 let records = gen.generate(&payments, "US");
263
264 assert_eq!(records.len(), 1);
266 assert_eq!(records[0].payment_id, "PAY-XB");
267 }
268
269 #[test]
270 fn test_deterministic() {
271 let payments = vec![
272 payment("PAY-001", "V-GB-01", "GB", dec!(100000)),
273 payment("PAY-002", "V-IN-01", "IN", dec!(50000)),
274 payment("PAY-003", "V-ZZ-01", "ZZ", dec!(25000)),
275 ];
276
277 let mut gen1 = WithholdingGenerator::new(12345, dec!(0.30)).with_standard_treaties();
278 let records1 = gen1.generate(&payments, "US");
279
280 let mut gen2 = WithholdingGenerator::new(12345, dec!(0.30)).with_standard_treaties();
281 let records2 = gen2.generate(&payments, "US");
282
283 assert_eq!(records1.len(), records2.len());
284 for (r1, r2) in records1.iter().zip(records2.iter()) {
285 assert_eq!(r1.id, r2.id);
286 assert_eq!(r1.payment_id, r2.payment_id);
287 assert_eq!(r1.vendor_id, r2.vendor_id);
288 assert_eq!(r1.applied_rate, r2.applied_rate);
289 assert_eq!(r1.treaty_rate, r2.treaty_rate);
290 assert_eq!(r1.withheld_amount, r2.withheld_amount);
291 assert_eq!(r1.certificate_number, r2.certificate_number);
292 }
293 }
294
295 #[test]
296 fn test_treaty_with_nonzero_rate() {
297 let mut gen = WithholdingGenerator::new(42, dec!(0.30)).with_standard_treaties();
298
299 let payments = vec![payment("PAY-IN", "V-IN-01", "IN", dec!(100000))];
301
302 let records = gen.generate(&payments, "US");
303
304 assert_eq!(records.len(), 1);
305 let rec = &records[0];
306 assert_eq!(rec.applied_rate, dec!(0.15));
307 assert_eq!(rec.treaty_rate, Some(dec!(0.15)));
308 assert_eq!(rec.withheld_amount, dec!(15000.00));
309 assert_eq!(rec.statutory_rate, dec!(0.30));
310 assert!(
311 rec.has_treaty_benefit(),
312 "15% treaty rate is less than 30% statutory"
313 );
314 assert_eq!(rec.treaty_savings(), dec!(15000.00));
315 }
316
317 #[test]
318 fn test_custom_treaty_rate() {
319 let mut gen = WithholdingGenerator::new(42, dec!(0.25));
320 gen.add_treaty_rate("DE", "US", dec!(0.05));
321
322 let payments = vec![payment("PAY-001", "V-US-01", "US", dec!(200000))];
323
324 let records = gen.generate(&payments, "DE");
325
326 assert_eq!(records.len(), 1);
327 let rec = &records[0];
328 assert_eq!(rec.applied_rate, dec!(0.05));
329 assert_eq!(rec.treaty_rate, Some(dec!(0.05)));
330 assert_eq!(rec.withheld_amount, dec!(10000.00));
331 assert_eq!(rec.statutory_rate, dec!(0.25));
332 }
333}