Skip to main content

datasynth_generators/tax/
withholding_generator.rs

1//! Withholding Tax Generator.
2//!
3//! Generates [`WithholdingTaxRecord`]s for cross-border payments, applying
4//! treaty rates when a bilateral tax treaty exists between the source country
5//! and vendor country, or falling back to a configurable default withholding
6//! rate.
7
8use rand::prelude::*;
9use rand_chacha::ChaCha8Rng;
10use rust_decimal::Decimal;
11use rust_decimal_macros::dec;
12use std::collections::HashMap;
13
14use datasynth_core::models::{WithholdingTaxRecord, WithholdingType};
15
16// ---------------------------------------------------------------------------
17// Generator
18// ---------------------------------------------------------------------------
19
20/// Generates withholding tax records for cross-border vendor payments.
21///
22/// For each payment where `vendor_country != source_country`:
23/// - Looks up the treaty rate for the `(source_country, vendor_country)` pair.
24/// - If a treaty exists, applies the treaty rate.
25/// - If no treaty exists, applies the default withholding rate.
26/// - Computes `withheld_amount = base_amount * applied_rate`.
27///
28/// Domestic payments (where `vendor_country == source_country`) are excluded
29/// from withholding and produce no records.
30///
31/// # Standard Treaty Rates
32///
33/// The [`with_standard_treaties`](WithholdingGenerator::with_standard_treaties)
34/// method loads a US-centric treaty network with service withholding rates for
35/// major trading partners (GB, DE, JP, FR, SG, IN, BR).
36pub struct WithholdingGenerator {
37    rng: ChaCha8Rng,
38    /// Treaty rates indexed by `(source_country, vendor_country)`.
39    treaty_rates: HashMap<(String, String), Decimal>,
40    /// Default withholding rate when no treaty applies.
41    default_rate: Decimal,
42    counter: u64,
43}
44
45impl WithholdingGenerator {
46    /// Creates a new withholding generator with the given seed and default rate.
47    ///
48    /// The default rate is applied when no treaty rate exists for a given
49    /// country pair. A common value is `0.30` (30%).
50    pub fn new(seed: u64, default_rate: Decimal) -> Self {
51        Self {
52            rng: ChaCha8Rng::seed_from_u64(seed),
53            treaty_rates: HashMap::new(),
54            default_rate,
55            counter: 0,
56        }
57    }
58
59    /// Adds a treaty rate for a specific country pair.
60    ///
61    /// The rate is stored for the `(source_country, vendor_country)` direction.
62    /// Treaty rates are directional: a US-DE treaty rate applies when the US is
63    /// the source and DE is the vendor, but not vice versa.
64    pub fn add_treaty_rate(&mut self, source_country: &str, vendor_country: &str, rate: Decimal) {
65        self.treaty_rates.insert(
66            (source_country.to_string(), vendor_country.to_string()),
67            rate,
68        );
69    }
70
71    /// Loads the standard US treaty network for service withholding rates.
72    ///
73    /// Treaty rates (service withholding, US perspective):
74    /// - US-GB: 0% (services)
75    /// - US-DE: 0% (services)
76    /// - US-JP: 0% (services)
77    /// - US-FR: 0% (services)
78    /// - US-SG: 0% (services)
79    /// - US-IN: 15% (services)
80    /// - US-BR: 15% (services)
81    pub fn with_standard_treaties(mut self) -> Self {
82        let treaties = [
83            ("US", "GB", dec!(0.00)),
84            ("US", "DE", dec!(0.00)),
85            ("US", "JP", dec!(0.00)),
86            ("US", "FR", dec!(0.00)),
87            ("US", "SG", dec!(0.00)),
88            ("US", "IN", dec!(0.15)),
89            ("US", "BR", dec!(0.15)),
90        ];
91
92        for (source, vendor, rate) in &treaties {
93            self.treaty_rates
94                .insert((source.to_string(), vendor.to_string()), *rate);
95        }
96
97        self
98    }
99
100    /// Generate withholding records for cross-border payments.
101    ///
102    /// Each payment is a tuple of `(payment_id, vendor_id, vendor_country, amount)`.
103    /// Domestic payments (where `vendor_country == source_country`) are excluded.
104    ///
105    /// For each cross-border payment:
106    /// - If a treaty rate exists for `(source_country, vendor_country)`, the
107    ///   treaty rate is applied.
108    /// - Otherwise the `default_rate` is applied.
109    /// - `withheld_amount = base_amount * applied_rate`.
110    pub fn generate(
111        &mut self,
112        payments: &[(String, String, String, Decimal)],
113        source_country: &str,
114    ) -> Vec<WithholdingTaxRecord> {
115        let mut records = Vec::new();
116
117        for (payment_id, vendor_id, vendor_country, amount) in payments {
118            // Skip domestic payments
119            if vendor_country == source_country {
120                continue;
121            }
122
123            let key = (source_country.to_string(), vendor_country.clone());
124            let (applied_rate, treaty_rate) = match self.treaty_rates.get(&key) {
125                Some(&rate) => (rate, Some(rate)),
126                None => (self.default_rate, None),
127            };
128
129            self.counter += 1;
130            let record_id = format!("WHT-{:06}", self.counter);
131
132            // Generate a certificate number with some randomness
133            let cert_suffix: u32 = self.rng.gen_range(100_000..999_999);
134            let cert_number = format!("CERT-{}-{cert_suffix}", &record_id);
135
136            let mut record = WithholdingTaxRecord::new(
137                record_id,
138                payment_id,
139                vendor_id,
140                WithholdingType::ServiceWithholding,
141                self.default_rate,
142                applied_rate,
143                *amount,
144            )
145            .with_certificate_number(cert_number);
146
147            if let Some(rate) = treaty_rate {
148                record = record.with_treaty_rate(rate);
149            }
150
151            records.push(record);
152        }
153
154        records
155    }
156}
157
158// ---------------------------------------------------------------------------
159// Tests
160// ---------------------------------------------------------------------------
161
162#[cfg(test)]
163#[allow(clippy::unwrap_used)]
164mod tests {
165    use super::*;
166
167    fn payment(
168        id: &str,
169        vendor_id: &str,
170        vendor_country: &str,
171        amount: Decimal,
172    ) -> (String, String, String, Decimal) {
173        (
174            id.to_string(),
175            vendor_id.to_string(),
176            vendor_country.to_string(),
177            amount,
178        )
179    }
180
181    #[test]
182    fn test_with_treaty_rate() {
183        let mut gen = WithholdingGenerator::new(42, dec!(0.30)).with_standard_treaties();
184
185        let payments = vec![payment("PAY-001", "V-GB-01", "GB", dec!(100000))];
186
187        let records = gen.generate(&payments, "US");
188
189        assert_eq!(records.len(), 1);
190        let rec = &records[0];
191        assert_eq!(rec.vendor_id, "V-GB-01");
192        assert_eq!(rec.applied_rate, dec!(0.00));
193        assert_eq!(rec.treaty_rate, Some(dec!(0.00)));
194        assert_eq!(rec.withheld_amount, dec!(0.00));
195        assert_eq!(rec.statutory_rate, dec!(0.30));
196        assert!(rec.has_treaty_benefit());
197    }
198
199    #[test]
200    fn test_without_treaty() {
201        let mut gen = WithholdingGenerator::new(42, dec!(0.30)).with_standard_treaties();
202
203        // ZZ is not in the treaty network
204        let payments = vec![payment("PAY-002", "V-ZZ-01", "ZZ", dec!(50000))];
205
206        let records = gen.generate(&payments, "US");
207
208        assert_eq!(records.len(), 1);
209        let rec = &records[0];
210        assert_eq!(rec.applied_rate, dec!(0.30));
211        assert_eq!(rec.treaty_rate, None);
212        assert_eq!(rec.withheld_amount, dec!(15000.00));
213        assert!(!rec.has_treaty_benefit());
214    }
215
216    #[test]
217    fn test_standard_treaties() {
218        let gen = WithholdingGenerator::new(42, dec!(0.30)).with_standard_treaties();
219
220        // Verify all standard treaty rates are loaded
221        assert_eq!(gen.treaty_rates.len(), 7);
222
223        assert_eq!(
224            gen.treaty_rates.get(&("US".to_string(), "GB".to_string())),
225            Some(&dec!(0.00))
226        );
227        assert_eq!(
228            gen.treaty_rates.get(&("US".to_string(), "DE".to_string())),
229            Some(&dec!(0.00))
230        );
231        assert_eq!(
232            gen.treaty_rates.get(&("US".to_string(), "JP".to_string())),
233            Some(&dec!(0.00))
234        );
235        assert_eq!(
236            gen.treaty_rates.get(&("US".to_string(), "FR".to_string())),
237            Some(&dec!(0.00))
238        );
239        assert_eq!(
240            gen.treaty_rates.get(&("US".to_string(), "SG".to_string())),
241            Some(&dec!(0.00))
242        );
243        assert_eq!(
244            gen.treaty_rates.get(&("US".to_string(), "IN".to_string())),
245            Some(&dec!(0.15))
246        );
247        assert_eq!(
248            gen.treaty_rates.get(&("US".to_string(), "BR".to_string())),
249            Some(&dec!(0.15))
250        );
251    }
252
253    #[test]
254    fn test_domestic_excluded() {
255        let mut gen = WithholdingGenerator::new(42, dec!(0.30)).with_standard_treaties();
256
257        let payments = vec![
258            payment("PAY-DOM", "V-US-01", "US", dec!(100000)),
259            payment("PAY-XB", "V-GB-01", "GB", dec!(50000)),
260        ];
261
262        let records = gen.generate(&payments, "US");
263
264        // Only the cross-border payment should produce a record
265        assert_eq!(records.len(), 1);
266        assert_eq!(records[0].payment_id, "PAY-XB");
267    }
268
269    #[test]
270    fn test_deterministic() {
271        let payments = vec![
272            payment("PAY-001", "V-GB-01", "GB", dec!(100000)),
273            payment("PAY-002", "V-IN-01", "IN", dec!(50000)),
274            payment("PAY-003", "V-ZZ-01", "ZZ", dec!(25000)),
275        ];
276
277        let mut gen1 = WithholdingGenerator::new(12345, dec!(0.30)).with_standard_treaties();
278        let records1 = gen1.generate(&payments, "US");
279
280        let mut gen2 = WithholdingGenerator::new(12345, dec!(0.30)).with_standard_treaties();
281        let records2 = gen2.generate(&payments, "US");
282
283        assert_eq!(records1.len(), records2.len());
284        for (r1, r2) in records1.iter().zip(records2.iter()) {
285            assert_eq!(r1.id, r2.id);
286            assert_eq!(r1.payment_id, r2.payment_id);
287            assert_eq!(r1.vendor_id, r2.vendor_id);
288            assert_eq!(r1.applied_rate, r2.applied_rate);
289            assert_eq!(r1.treaty_rate, r2.treaty_rate);
290            assert_eq!(r1.withheld_amount, r2.withheld_amount);
291            assert_eq!(r1.certificate_number, r2.certificate_number);
292        }
293    }
294
295    #[test]
296    fn test_treaty_with_nonzero_rate() {
297        let mut gen = WithholdingGenerator::new(42, dec!(0.30)).with_standard_treaties();
298
299        // India has a 15% treaty rate for services
300        let payments = vec![payment("PAY-IN", "V-IN-01", "IN", dec!(100000))];
301
302        let records = gen.generate(&payments, "US");
303
304        assert_eq!(records.len(), 1);
305        let rec = &records[0];
306        assert_eq!(rec.applied_rate, dec!(0.15));
307        assert_eq!(rec.treaty_rate, Some(dec!(0.15)));
308        assert_eq!(rec.withheld_amount, dec!(15000.00));
309        assert_eq!(rec.statutory_rate, dec!(0.30));
310        assert!(
311            rec.has_treaty_benefit(),
312            "15% treaty rate is less than 30% statutory"
313        );
314        assert_eq!(rec.treaty_savings(), dec!(15000.00));
315    }
316
317    #[test]
318    fn test_custom_treaty_rate() {
319        let mut gen = WithholdingGenerator::new(42, dec!(0.25));
320        gen.add_treaty_rate("DE", "US", dec!(0.05));
321
322        let payments = vec![payment("PAY-001", "V-US-01", "US", dec!(200000))];
323
324        let records = gen.generate(&payments, "DE");
325
326        assert_eq!(records.len(), 1);
327        let rec = &records[0];
328        assert_eq!(rec.applied_rate, dec!(0.05));
329        assert_eq!(rec.treaty_rate, Some(dec!(0.05)));
330        assert_eq!(rec.withheld_amount, dec!(10000.00));
331        assert_eq!(rec.statutory_rate, dec!(0.25));
332    }
333}