Skip to main content

datasynth_generators/tax/
withholding_generator.rs

1//! Withholding Tax Generator.
2//!
3//! Generates [`WithholdingTaxRecord`]s for cross-border payments, applying
4//! treaty rates when a bilateral tax treaty exists between the source country
5//! and vendor country, or falling back to a configurable default withholding
6//! rate.
7
8use datasynth_core::utils::seeded_rng;
9use rand::prelude::*;
10use rand_chacha::ChaCha8Rng;
11use rust_decimal::Decimal;
12use rust_decimal_macros::dec;
13use std::collections::HashMap;
14
15use datasynth_core::models::{WithholdingTaxRecord, WithholdingType};
16
17// ---------------------------------------------------------------------------
18// Generator
19// ---------------------------------------------------------------------------
20
21/// Generates withholding tax records for cross-border vendor payments.
22///
23/// For each payment where `vendor_country != source_country`:
24/// - Looks up the treaty rate for the `(source_country, vendor_country)` pair.
25/// - If a treaty exists, applies the treaty rate.
26/// - If no treaty exists, applies the default withholding rate.
27/// - Computes `withheld_amount = base_amount * applied_rate`.
28///
29/// Domestic payments (where `vendor_country == source_country`) are excluded
30/// from withholding and produce no records.
31///
32/// # Standard Treaty Rates
33///
34/// The [`with_standard_treaties`](WithholdingGenerator::with_standard_treaties)
35/// method loads a US-centric treaty network with service withholding rates for
36/// major trading partners (GB, DE, JP, FR, SG, IN, BR).
37pub struct WithholdingGenerator {
38    rng: ChaCha8Rng,
39    /// Treaty rates indexed by `(source_country, vendor_country)`.
40    treaty_rates: HashMap<(String, String), Decimal>,
41    /// Default withholding rate when no treaty applies.
42    default_rate: Decimal,
43    counter: u64,
44}
45
46impl WithholdingGenerator {
47    /// Creates a new withholding generator with the given seed and default rate.
48    ///
49    /// The default rate is applied when no treaty rate exists for a given
50    /// country pair. A common value is `0.30` (30%).
51    pub fn new(seed: u64, default_rate: Decimal) -> Self {
52        Self {
53            rng: seeded_rng(seed, 0),
54            treaty_rates: HashMap::new(),
55            default_rate,
56            counter: 0,
57        }
58    }
59
60    /// Adds a treaty rate for a specific country pair.
61    ///
62    /// The rate is stored for the `(source_country, vendor_country)` direction.
63    /// Treaty rates are directional: a US-DE treaty rate applies when the US is
64    /// the source and DE is the vendor, but not vice versa.
65    pub fn add_treaty_rate(&mut self, source_country: &str, vendor_country: &str, rate: Decimal) {
66        self.treaty_rates.insert(
67            (source_country.to_string(), vendor_country.to_string()),
68            rate,
69        );
70    }
71
72    /// Loads the standard US treaty network for service withholding rates.
73    ///
74    /// Treaty rates (service withholding, US perspective):
75    /// - US-GB: 0% (services)
76    /// - US-DE: 0% (services)
77    /// - US-JP: 0% (services)
78    /// - US-FR: 0% (services)
79    /// - US-SG: 0% (services)
80    /// - US-IN: 15% (services)
81    /// - US-BR: 15% (services)
82    pub fn with_standard_treaties(mut self) -> Self {
83        let treaties = [
84            ("US", "GB", dec!(0.00)),
85            ("US", "DE", dec!(0.00)),
86            ("US", "JP", dec!(0.00)),
87            ("US", "FR", dec!(0.00)),
88            ("US", "SG", dec!(0.00)),
89            ("US", "IN", dec!(0.15)),
90            ("US", "BR", dec!(0.15)),
91        ];
92
93        for (source, vendor, rate) in &treaties {
94            self.treaty_rates
95                .insert((source.to_string(), vendor.to_string()), *rate);
96        }
97
98        self
99    }
100
101    /// Generate withholding records for cross-border payments.
102    ///
103    /// Each payment is a tuple of `(payment_id, vendor_id, vendor_country, amount)`.
104    /// Domestic payments (where `vendor_country == source_country`) are excluded.
105    ///
106    /// For each cross-border payment:
107    /// - If a treaty rate exists for `(source_country, vendor_country)`, the
108    ///   treaty rate is applied.
109    /// - Otherwise the `default_rate` is applied.
110    /// - `withheld_amount = base_amount * applied_rate`.
111    pub fn generate(
112        &mut self,
113        payments: &[(String, String, String, Decimal)],
114        source_country: &str,
115    ) -> Vec<WithholdingTaxRecord> {
116        let mut records = Vec::new();
117
118        for (payment_id, vendor_id, vendor_country, amount) in payments {
119            // Skip domestic payments
120            if vendor_country == source_country {
121                continue;
122            }
123
124            let key = (source_country.to_string(), vendor_country.clone());
125            let (applied_rate, treaty_rate) = match self.treaty_rates.get(&key) {
126                Some(&rate) => (rate, Some(rate)),
127                None => (self.default_rate, None),
128            };
129
130            self.counter += 1;
131            let record_id = format!("WHT-{:06}", self.counter);
132
133            // Generate a certificate number with some randomness
134            let cert_suffix: u32 = self.rng.random_range(100_000..999_999);
135            let cert_number = format!("CERT-{}-{cert_suffix}", &record_id);
136
137            let mut record = WithholdingTaxRecord::new(
138                record_id,
139                payment_id,
140                vendor_id,
141                WithholdingType::ServiceWithholding,
142                self.default_rate,
143                applied_rate,
144                *amount,
145            )
146            .with_certificate_number(cert_number);
147
148            if let Some(rate) = treaty_rate {
149                record = record.with_treaty_rate(rate);
150            }
151
152            records.push(record);
153        }
154
155        records
156    }
157}
158
159// ---------------------------------------------------------------------------
160// Tests
161// ---------------------------------------------------------------------------
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166
167    fn payment(
168        id: &str,
169        vendor_id: &str,
170        vendor_country: &str,
171        amount: Decimal,
172    ) -> (String, String, String, Decimal) {
173        (
174            id.to_string(),
175            vendor_id.to_string(),
176            vendor_country.to_string(),
177            amount,
178        )
179    }
180
181    #[test]
182    fn test_with_treaty_rate() {
183        let mut gen = WithholdingGenerator::new(42, dec!(0.30)).with_standard_treaties();
184
185        let payments = vec![payment("PAY-001", "V-GB-01", "GB", dec!(100000))];
186
187        let records = gen.generate(&payments, "US");
188
189        assert_eq!(records.len(), 1);
190        let rec = &records[0];
191        assert_eq!(rec.vendor_id, "V-GB-01");
192        assert_eq!(rec.applied_rate, dec!(0.00));
193        assert_eq!(rec.treaty_rate, Some(dec!(0.00)));
194        assert_eq!(rec.withheld_amount, dec!(0.00));
195        assert_eq!(rec.statutory_rate, dec!(0.30));
196        assert!(rec.has_treaty_benefit());
197    }
198
199    #[test]
200    fn test_without_treaty() {
201        let mut gen = WithholdingGenerator::new(42, dec!(0.30)).with_standard_treaties();
202
203        // ZZ is not in the treaty network
204        let payments = vec![payment("PAY-002", "V-ZZ-01", "ZZ", dec!(50000))];
205
206        let records = gen.generate(&payments, "US");
207
208        assert_eq!(records.len(), 1);
209        let rec = &records[0];
210        assert_eq!(rec.applied_rate, dec!(0.30));
211        assert_eq!(rec.treaty_rate, None);
212        assert_eq!(rec.withheld_amount, dec!(15000.00));
213        assert!(!rec.has_treaty_benefit());
214    }
215
216    #[test]
217    fn test_standard_treaties() {
218        let gen = WithholdingGenerator::new(42, dec!(0.30)).with_standard_treaties();
219
220        // Verify all standard treaty rates are loaded
221        assert_eq!(gen.treaty_rates.len(), 7);
222
223        assert_eq!(
224            gen.treaty_rates.get(&("US".to_string(), "GB".to_string())),
225            Some(&dec!(0.00))
226        );
227        assert_eq!(
228            gen.treaty_rates.get(&("US".to_string(), "DE".to_string())),
229            Some(&dec!(0.00))
230        );
231        assert_eq!(
232            gen.treaty_rates.get(&("US".to_string(), "JP".to_string())),
233            Some(&dec!(0.00))
234        );
235        assert_eq!(
236            gen.treaty_rates.get(&("US".to_string(), "FR".to_string())),
237            Some(&dec!(0.00))
238        );
239        assert_eq!(
240            gen.treaty_rates.get(&("US".to_string(), "SG".to_string())),
241            Some(&dec!(0.00))
242        );
243        assert_eq!(
244            gen.treaty_rates.get(&("US".to_string(), "IN".to_string())),
245            Some(&dec!(0.15))
246        );
247        assert_eq!(
248            gen.treaty_rates.get(&("US".to_string(), "BR".to_string())),
249            Some(&dec!(0.15))
250        );
251    }
252
253    #[test]
254    fn test_domestic_excluded() {
255        let mut gen = WithholdingGenerator::new(42, dec!(0.30)).with_standard_treaties();
256
257        let payments = vec![
258            payment("PAY-DOM", "V-US-01", "US", dec!(100000)),
259            payment("PAY-XB", "V-GB-01", "GB", dec!(50000)),
260        ];
261
262        let records = gen.generate(&payments, "US");
263
264        // Only the cross-border payment should produce a record
265        assert_eq!(records.len(), 1);
266        assert_eq!(records[0].payment_id, "PAY-XB");
267    }
268
269    #[test]
270    fn test_deterministic() {
271        let payments = vec![
272            payment("PAY-001", "V-GB-01", "GB", dec!(100000)),
273            payment("PAY-002", "V-IN-01", "IN", dec!(50000)),
274            payment("PAY-003", "V-ZZ-01", "ZZ", dec!(25000)),
275        ];
276
277        let mut gen1 = WithholdingGenerator::new(12345, dec!(0.30)).with_standard_treaties();
278        let records1 = gen1.generate(&payments, "US");
279
280        let mut gen2 = WithholdingGenerator::new(12345, dec!(0.30)).with_standard_treaties();
281        let records2 = gen2.generate(&payments, "US");
282
283        assert_eq!(records1.len(), records2.len());
284        for (r1, r2) in records1.iter().zip(records2.iter()) {
285            assert_eq!(r1.id, r2.id);
286            assert_eq!(r1.payment_id, r2.payment_id);
287            assert_eq!(r1.vendor_id, r2.vendor_id);
288            assert_eq!(r1.applied_rate, r2.applied_rate);
289            assert_eq!(r1.treaty_rate, r2.treaty_rate);
290            assert_eq!(r1.withheld_amount, r2.withheld_amount);
291            assert_eq!(r1.certificate_number, r2.certificate_number);
292        }
293    }
294
295    #[test]
296    fn test_treaty_with_nonzero_rate() {
297        let mut gen = WithholdingGenerator::new(42, dec!(0.30)).with_standard_treaties();
298
299        // India has a 15% treaty rate for services
300        let payments = vec![payment("PAY-IN", "V-IN-01", "IN", dec!(100000))];
301
302        let records = gen.generate(&payments, "US");
303
304        assert_eq!(records.len(), 1);
305        let rec = &records[0];
306        assert_eq!(rec.applied_rate, dec!(0.15));
307        assert_eq!(rec.treaty_rate, Some(dec!(0.15)));
308        assert_eq!(rec.withheld_amount, dec!(15000.00));
309        assert_eq!(rec.statutory_rate, dec!(0.30));
310        assert!(
311            rec.has_treaty_benefit(),
312            "15% treaty rate is less than 30% statutory"
313        );
314        assert_eq!(rec.treaty_savings(), dec!(15000.00));
315    }
316
317    #[test]
318    fn test_custom_treaty_rate() {
319        let mut gen = WithholdingGenerator::new(42, dec!(0.25));
320        gen.add_treaty_rate("DE", "US", dec!(0.05));
321
322        let payments = vec![payment("PAY-001", "V-US-01", "US", dec!(200000))];
323
324        let records = gen.generate(&payments, "DE");
325
326        assert_eq!(records.len(), 1);
327        let rec = &records[0];
328        assert_eq!(rec.applied_rate, dec!(0.05));
329        assert_eq!(rec.treaty_rate, Some(dec!(0.05)));
330        assert_eq!(rec.withheld_amount, dec!(10000.00));
331        assert_eq!(rec.statutory_rate, dec!(0.25));
332    }
333}