Skip to main content

datasynth_generators/tax/
withholding_generator.rs

1//! Withholding Tax Generator.
2//!
3//! Generates [`WithholdingTaxRecord`]s for cross-border payments, applying
4//! treaty rates when a bilateral tax treaty exists between the source country
5//! and vendor country, or falling back to a configurable default withholding
6//! rate.
7
8use datasynth_core::utils::seeded_rng;
9use rand::prelude::*;
10use rand_chacha::ChaCha8Rng;
11use rust_decimal::Decimal;
12use rust_decimal_macros::dec;
13use std::collections::HashMap;
14
15use datasynth_core::models::{WithholdingTaxRecord, WithholdingType};
16
17// ---------------------------------------------------------------------------
18// Generator
19// ---------------------------------------------------------------------------
20
21/// Generates withholding tax records for cross-border vendor payments.
22///
23/// For each payment where `vendor_country != source_country`:
24/// - Looks up the treaty rate for the `(source_country, vendor_country)` pair.
25/// - If a treaty exists, applies the treaty rate.
26/// - If no treaty exists, applies the default withholding rate.
27/// - Computes `withheld_amount = base_amount * applied_rate`.
28///
29/// Domestic payments (where `vendor_country == source_country`) are excluded
30/// from withholding and produce no records.
31///
32/// # Standard Treaty Rates
33///
34/// The [`with_standard_treaties`](WithholdingGenerator::with_standard_treaties)
35/// method loads a US-centric treaty network with service withholding rates for
36/// major trading partners (GB, DE, JP, FR, SG, IN, BR).
37pub struct WithholdingGenerator {
38    rng: ChaCha8Rng,
39    /// Treaty rates indexed by `(source_country, vendor_country)`.
40    treaty_rates: HashMap<(String, String), Decimal>,
41    /// Default withholding rate when no treaty applies.
42    default_rate: Decimal,
43    counter: u64,
44}
45
46impl WithholdingGenerator {
47    /// Creates a new withholding generator with the given seed and default rate.
48    ///
49    /// The default rate is applied when no treaty rate exists for a given
50    /// country pair. A common value is `0.30` (30%).
51    pub fn new(seed: u64, default_rate: Decimal) -> Self {
52        Self {
53            rng: seeded_rng(seed, 0),
54            treaty_rates: HashMap::new(),
55            default_rate,
56            counter: 0,
57        }
58    }
59
60    /// Adds a treaty rate for a specific country pair.
61    ///
62    /// The rate is stored for the `(source_country, vendor_country)` direction.
63    /// Treaty rates are directional: a US-DE treaty rate applies when the US is
64    /// the source and DE is the vendor, but not vice versa.
65    pub fn add_treaty_rate(&mut self, source_country: &str, vendor_country: &str, rate: Decimal) {
66        self.treaty_rates.insert(
67            (source_country.to_string(), vendor_country.to_string()),
68            rate,
69        );
70    }
71
72    /// Loads the standard US treaty network for service withholding rates.
73    ///
74    /// Treaty rates (service withholding, US perspective):
75    /// - US-GB: 0% (services)
76    /// - US-DE: 0% (services)
77    /// - US-JP: 0% (services)
78    /// - US-FR: 0% (services)
79    /// - US-SG: 0% (services)
80    /// - US-IN: 15% (services)
81    /// - US-BR: 15% (services)
82    pub fn with_standard_treaties(mut self) -> Self {
83        let treaties = [
84            ("US", "GB", dec!(0.00)),
85            ("US", "DE", dec!(0.00)),
86            ("US", "JP", dec!(0.00)),
87            ("US", "FR", dec!(0.00)),
88            ("US", "SG", dec!(0.00)),
89            ("US", "IN", dec!(0.15)),
90            ("US", "BR", dec!(0.15)),
91        ];
92
93        for (source, vendor, rate) in &treaties {
94            self.treaty_rates
95                .insert((source.to_string(), vendor.to_string()), *rate);
96        }
97
98        self
99    }
100
101    /// Generate withholding records for cross-border payments.
102    ///
103    /// Each payment is a tuple of `(payment_id, vendor_id, vendor_country, amount)`.
104    /// Domestic payments (where `vendor_country == source_country`) are excluded.
105    ///
106    /// For each cross-border payment:
107    /// - If a treaty rate exists for `(source_country, vendor_country)`, the
108    ///   treaty rate is applied.
109    /// - Otherwise the `default_rate` is applied.
110    /// - `withheld_amount = base_amount * applied_rate`.
111    pub fn generate(
112        &mut self,
113        payments: &[(String, String, String, Decimal)],
114        source_country: &str,
115    ) -> Vec<WithholdingTaxRecord> {
116        let mut records = Vec::new();
117
118        for (payment_id, vendor_id, vendor_country, amount) in payments {
119            // Skip domestic payments
120            if vendor_country == source_country {
121                continue;
122            }
123
124            let key = (source_country.to_string(), vendor_country.clone());
125            let (applied_rate, treaty_rate) = match self.treaty_rates.get(&key) {
126                Some(&rate) => (rate, Some(rate)),
127                None => (self.default_rate, None),
128            };
129
130            self.counter += 1;
131            let record_id = format!("WHT-{:06}", self.counter);
132
133            // Generate a certificate number with some randomness
134            let cert_suffix: u32 = self.rng.gen_range(100_000..999_999);
135            let cert_number = format!("CERT-{}-{cert_suffix}", &record_id);
136
137            let mut record = WithholdingTaxRecord::new(
138                record_id,
139                payment_id,
140                vendor_id,
141                WithholdingType::ServiceWithholding,
142                self.default_rate,
143                applied_rate,
144                *amount,
145            )
146            .with_certificate_number(cert_number);
147
148            if let Some(rate) = treaty_rate {
149                record = record.with_treaty_rate(rate);
150            }
151
152            records.push(record);
153        }
154
155        records
156    }
157}
158
159// ---------------------------------------------------------------------------
160// Tests
161// ---------------------------------------------------------------------------
162
163#[cfg(test)]
164#[allow(clippy::unwrap_used)]
165mod tests {
166    use super::*;
167
168    fn payment(
169        id: &str,
170        vendor_id: &str,
171        vendor_country: &str,
172        amount: Decimal,
173    ) -> (String, String, String, Decimal) {
174        (
175            id.to_string(),
176            vendor_id.to_string(),
177            vendor_country.to_string(),
178            amount,
179        )
180    }
181
182    #[test]
183    fn test_with_treaty_rate() {
184        let mut gen = WithholdingGenerator::new(42, dec!(0.30)).with_standard_treaties();
185
186        let payments = vec![payment("PAY-001", "V-GB-01", "GB", dec!(100000))];
187
188        let records = gen.generate(&payments, "US");
189
190        assert_eq!(records.len(), 1);
191        let rec = &records[0];
192        assert_eq!(rec.vendor_id, "V-GB-01");
193        assert_eq!(rec.applied_rate, dec!(0.00));
194        assert_eq!(rec.treaty_rate, Some(dec!(0.00)));
195        assert_eq!(rec.withheld_amount, dec!(0.00));
196        assert_eq!(rec.statutory_rate, dec!(0.30));
197        assert!(rec.has_treaty_benefit());
198    }
199
200    #[test]
201    fn test_without_treaty() {
202        let mut gen = WithholdingGenerator::new(42, dec!(0.30)).with_standard_treaties();
203
204        // ZZ is not in the treaty network
205        let payments = vec![payment("PAY-002", "V-ZZ-01", "ZZ", dec!(50000))];
206
207        let records = gen.generate(&payments, "US");
208
209        assert_eq!(records.len(), 1);
210        let rec = &records[0];
211        assert_eq!(rec.applied_rate, dec!(0.30));
212        assert_eq!(rec.treaty_rate, None);
213        assert_eq!(rec.withheld_amount, dec!(15000.00));
214        assert!(!rec.has_treaty_benefit());
215    }
216
217    #[test]
218    fn test_standard_treaties() {
219        let gen = WithholdingGenerator::new(42, dec!(0.30)).with_standard_treaties();
220
221        // Verify all standard treaty rates are loaded
222        assert_eq!(gen.treaty_rates.len(), 7);
223
224        assert_eq!(
225            gen.treaty_rates.get(&("US".to_string(), "GB".to_string())),
226            Some(&dec!(0.00))
227        );
228        assert_eq!(
229            gen.treaty_rates.get(&("US".to_string(), "DE".to_string())),
230            Some(&dec!(0.00))
231        );
232        assert_eq!(
233            gen.treaty_rates.get(&("US".to_string(), "JP".to_string())),
234            Some(&dec!(0.00))
235        );
236        assert_eq!(
237            gen.treaty_rates.get(&("US".to_string(), "FR".to_string())),
238            Some(&dec!(0.00))
239        );
240        assert_eq!(
241            gen.treaty_rates.get(&("US".to_string(), "SG".to_string())),
242            Some(&dec!(0.00))
243        );
244        assert_eq!(
245            gen.treaty_rates.get(&("US".to_string(), "IN".to_string())),
246            Some(&dec!(0.15))
247        );
248        assert_eq!(
249            gen.treaty_rates.get(&("US".to_string(), "BR".to_string())),
250            Some(&dec!(0.15))
251        );
252    }
253
254    #[test]
255    fn test_domestic_excluded() {
256        let mut gen = WithholdingGenerator::new(42, dec!(0.30)).with_standard_treaties();
257
258        let payments = vec![
259            payment("PAY-DOM", "V-US-01", "US", dec!(100000)),
260            payment("PAY-XB", "V-GB-01", "GB", dec!(50000)),
261        ];
262
263        let records = gen.generate(&payments, "US");
264
265        // Only the cross-border payment should produce a record
266        assert_eq!(records.len(), 1);
267        assert_eq!(records[0].payment_id, "PAY-XB");
268    }
269
270    #[test]
271    fn test_deterministic() {
272        let payments = vec![
273            payment("PAY-001", "V-GB-01", "GB", dec!(100000)),
274            payment("PAY-002", "V-IN-01", "IN", dec!(50000)),
275            payment("PAY-003", "V-ZZ-01", "ZZ", dec!(25000)),
276        ];
277
278        let mut gen1 = WithholdingGenerator::new(12345, dec!(0.30)).with_standard_treaties();
279        let records1 = gen1.generate(&payments, "US");
280
281        let mut gen2 = WithholdingGenerator::new(12345, dec!(0.30)).with_standard_treaties();
282        let records2 = gen2.generate(&payments, "US");
283
284        assert_eq!(records1.len(), records2.len());
285        for (r1, r2) in records1.iter().zip(records2.iter()) {
286            assert_eq!(r1.id, r2.id);
287            assert_eq!(r1.payment_id, r2.payment_id);
288            assert_eq!(r1.vendor_id, r2.vendor_id);
289            assert_eq!(r1.applied_rate, r2.applied_rate);
290            assert_eq!(r1.treaty_rate, r2.treaty_rate);
291            assert_eq!(r1.withheld_amount, r2.withheld_amount);
292            assert_eq!(r1.certificate_number, r2.certificate_number);
293        }
294    }
295
296    #[test]
297    fn test_treaty_with_nonzero_rate() {
298        let mut gen = WithholdingGenerator::new(42, dec!(0.30)).with_standard_treaties();
299
300        // India has a 15% treaty rate for services
301        let payments = vec![payment("PAY-IN", "V-IN-01", "IN", dec!(100000))];
302
303        let records = gen.generate(&payments, "US");
304
305        assert_eq!(records.len(), 1);
306        let rec = &records[0];
307        assert_eq!(rec.applied_rate, dec!(0.15));
308        assert_eq!(rec.treaty_rate, Some(dec!(0.15)));
309        assert_eq!(rec.withheld_amount, dec!(15000.00));
310        assert_eq!(rec.statutory_rate, dec!(0.30));
311        assert!(
312            rec.has_treaty_benefit(),
313            "15% treaty rate is less than 30% statutory"
314        );
315        assert_eq!(rec.treaty_savings(), dec!(15000.00));
316    }
317
318    #[test]
319    fn test_custom_treaty_rate() {
320        let mut gen = WithholdingGenerator::new(42, dec!(0.25));
321        gen.add_treaty_rate("DE", "US", dec!(0.05));
322
323        let payments = vec![payment("PAY-001", "V-US-01", "US", dec!(200000))];
324
325        let records = gen.generate(&payments, "DE");
326
327        assert_eq!(records.len(), 1);
328        let rec = &records[0];
329        assert_eq!(rec.applied_rate, dec!(0.05));
330        assert_eq!(rec.treaty_rate, Some(dec!(0.05)));
331        assert_eq!(rec.withheld_amount, dec!(10000.00));
332        assert_eq!(rec.statutory_rate, dec!(0.25));
333    }
334}