Skip to main content

datasynth_core/distributions/
amount.rs

1//! Transaction amount distribution sampler.
2//!
3//! Generates realistic transaction amounts using log-normal distributions
4//! and round-number bias commonly observed in accounting data.
5
6use rand::prelude::*;
7use rand_chacha::ChaCha8Rng;
8use rand_distr::{Distribution, LogNormal};
9use rust_decimal::Decimal;
10use serde::{Deserialize, Serialize};
11
12use super::benford::{BenfordSampler, FraudAmountGenerator, FraudAmountPattern, ThresholdConfig};
13
14/// Configuration for amount distribution.
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct AmountDistributionConfig {
17    /// Minimum transaction amount
18    pub min_amount: f64,
19    /// Maximum transaction amount
20    pub max_amount: f64,
21    /// Log-normal mu parameter (location)
22    pub lognormal_mu: f64,
23    /// Log-normal sigma parameter (scale)
24    pub lognormal_sigma: f64,
25    /// Number of decimal places to round to
26    pub decimal_places: u8,
27    /// Probability of round number (ending in .00)
28    pub round_number_probability: f64,
29    /// Probability of nice number (ending in 0 or 5)
30    pub nice_number_probability: f64,
31}
32
33impl Default for AmountDistributionConfig {
34    fn default() -> Self {
35        Self {
36            min_amount: 0.01,
37            max_amount: 100_000_000.0, // 100 million
38            lognormal_mu: 7.0,         // Center around ~1000
39            lognormal_sigma: 2.5,      // Wide spread
40            decimal_places: 2,
41            round_number_probability: 0.25, // 25% chance of .00 ending
42            nice_number_probability: 0.15,  // 15% chance of nice numbers
43        }
44    }
45}
46
47impl AmountDistributionConfig {
48    /// Configuration for small transactions (e.g., retail).
49    pub fn small_transactions() -> Self {
50        Self {
51            min_amount: 0.01,
52            max_amount: 10_000.0,
53            lognormal_mu: 4.0, // Center around ~55
54            lognormal_sigma: 1.5,
55            decimal_places: 2,
56            round_number_probability: 0.30,
57            nice_number_probability: 0.20,
58        }
59    }
60
61    /// Configuration for medium transactions (e.g., B2B).
62    pub fn medium_transactions() -> Self {
63        Self {
64            min_amount: 100.0,
65            max_amount: 1_000_000.0,
66            lognormal_mu: 8.5, // Center around ~5000
67            lognormal_sigma: 2.0,
68            decimal_places: 2,
69            round_number_probability: 0.20,
70            nice_number_probability: 0.15,
71        }
72    }
73
74    /// Configuration for large transactions (e.g., enterprise).
75    pub fn large_transactions() -> Self {
76        Self {
77            min_amount: 1000.0,
78            max_amount: 100_000_000.0,
79            lognormal_mu: 10.0, // Center around ~22000
80            lognormal_sigma: 2.5,
81            decimal_places: 2,
82            round_number_probability: 0.15,
83            nice_number_probability: 0.10,
84        }
85    }
86}
87
88/// Sampler for realistic transaction amounts.
89pub struct AmountSampler {
90    /// RNG for sampling
91    rng: ChaCha8Rng,
92    /// Configuration
93    config: AmountDistributionConfig,
94    /// Log-normal distribution
95    lognormal: LogNormal<f64>,
96    /// Decimal multiplier for rounding
97    decimal_multiplier: f64,
98    /// Optional Benford sampler for compliant generation
99    benford_sampler: Option<BenfordSampler>,
100    /// Optional fraud amount generator
101    fraud_generator: Option<FraudAmountGenerator>,
102    /// Whether Benford's Law compliance is enabled
103    benford_enabled: bool,
104}
105
106impl AmountSampler {
107    /// Create a new sampler with default configuration.
108    pub fn new(seed: u64) -> Self {
109        Self::with_config(seed, AmountDistributionConfig::default())
110    }
111
112    /// Create a sampler with custom configuration.
113    pub fn with_config(seed: u64, config: AmountDistributionConfig) -> Self {
114        let lognormal = LogNormal::new(config.lognormal_mu, config.lognormal_sigma)
115            .expect("Invalid log-normal parameters");
116        let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
117
118        Self {
119            rng: ChaCha8Rng::seed_from_u64(seed),
120            config,
121            lognormal,
122            decimal_multiplier,
123            benford_sampler: None,
124            fraud_generator: None,
125            benford_enabled: false,
126        }
127    }
128
129    /// Create a sampler with Benford's Law compliance enabled.
130    pub fn with_benford(seed: u64, config: AmountDistributionConfig) -> Self {
131        let lognormal = LogNormal::new(config.lognormal_mu, config.lognormal_sigma)
132            .expect("Invalid log-normal parameters");
133        let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
134
135        Self {
136            rng: ChaCha8Rng::seed_from_u64(seed),
137            benford_sampler: Some(BenfordSampler::new(seed + 100, config.clone())),
138            fraud_generator: Some(FraudAmountGenerator::new(
139                seed + 200,
140                config.clone(),
141                ThresholdConfig::default(),
142            )),
143            config,
144            lognormal,
145            decimal_multiplier,
146            benford_enabled: true,
147        }
148    }
149
150    /// Create a sampler with full fraud configuration.
151    pub fn with_fraud_config(
152        seed: u64,
153        config: AmountDistributionConfig,
154        threshold_config: ThresholdConfig,
155        benford_enabled: bool,
156    ) -> Self {
157        let lognormal = LogNormal::new(config.lognormal_mu, config.lognormal_sigma)
158            .expect("Invalid log-normal parameters");
159        let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
160
161        Self {
162            rng: ChaCha8Rng::seed_from_u64(seed),
163            benford_sampler: if benford_enabled {
164                Some(BenfordSampler::new(seed + 100, config.clone()))
165            } else {
166                None
167            },
168            fraud_generator: Some(FraudAmountGenerator::new(
169                seed + 200,
170                config.clone(),
171                threshold_config,
172            )),
173            config,
174            lognormal,
175            decimal_multiplier,
176            benford_enabled,
177        }
178    }
179
180    /// Enable or disable Benford's Law compliance.
181    pub fn set_benford_enabled(&mut self, enabled: bool) {
182        self.benford_enabled = enabled;
183        if enabled && self.benford_sampler.is_none() {
184            // Initialize Benford sampler if not already present
185            let seed = self.rng.random();
186            self.benford_sampler = Some(BenfordSampler::new(seed, self.config.clone()));
187        }
188    }
189
190    /// Check if Benford's Law compliance is enabled.
191    pub fn is_benford_enabled(&self) -> bool {
192        self.benford_enabled
193    }
194
195    /// Sample a single amount.
196    ///
197    /// If Benford's Law compliance is enabled, uses the Benford sampler.
198    /// Otherwise uses log-normal distribution with round-number bias.
199    #[inline]
200    pub fn sample(&mut self) -> Decimal {
201        // Use Benford sampler if enabled
202        if self.benford_enabled {
203            if let Some(ref mut benford) = self.benford_sampler {
204                return benford.sample();
205            }
206        }
207
208        // Fall back to log-normal sampling
209        self.sample_lognormal()
210    }
211
212    /// Sample using the log-normal distribution (original behavior).
213    #[inline]
214    pub fn sample_lognormal(&mut self) -> Decimal {
215        let mut amount = self.lognormal.sample(&mut self.rng);
216
217        // Clamp to configured range
218        amount = amount.clamp(self.config.min_amount, self.config.max_amount);
219
220        // Apply round number bias
221        let p: f64 = self.rng.random();
222        if p < self.config.round_number_probability {
223            // Round to nearest whole number ending in 00
224            amount = (amount / 100.0).round() * 100.0;
225        } else if p < self.config.round_number_probability + self.config.nice_number_probability {
226            // Round to nearest 5 or 10
227            amount = (amount / 5.0).round() * 5.0;
228        }
229
230        // Round to configured decimal places
231        amount = (amount * self.decimal_multiplier).round() / self.decimal_multiplier;
232
233        // Ensure minimum after rounding
234        amount = amount.max(self.config.min_amount);
235
236        // Convert to Decimal using fast integer math instead of string formatting.
237        // Multiply by 100, truncate to integer, then construct Decimal with scale 2.
238        // This avoids the overhead of format!() + parse() (~15x faster).
239        let cents = (amount * 100.0).round() as i64;
240        Decimal::new(cents, 2)
241    }
242
243    /// Sample a fraud amount with the specified pattern.
244    ///
245    /// Returns a normal amount if fraud generator is not configured.
246    pub fn sample_fraud(&mut self, pattern: FraudAmountPattern) -> Decimal {
247        if let Some(ref mut fraud_gen) = self.fraud_generator {
248            fraud_gen.sample(pattern)
249        } else {
250            // Fallback to normal sampling
251            self.sample()
252        }
253    }
254
255    /// Sample multiple amounts that sum to a target total.
256    ///
257    /// Useful for generating line items that must balance.
258    /// The sum of returned amounts is guaranteed to equal `total` exactly.
259    /// Every returned amount is guaranteed to be > 0 when `total > 0` and
260    /// `count * 0.01 <= total`.
261    pub fn sample_summing_to(&mut self, count: usize, total: Decimal) -> Vec<Decimal> {
262        use rust_decimal::prelude::ToPrimitive;
263
264        let min_amount = Decimal::new(1, 2); // 0.01
265
266        if count == 0 {
267            return Vec::new();
268        }
269        if count == 1 {
270            return vec![total];
271        }
272
273        let total_f64 = total.to_f64().unwrap_or(0.0);
274
275        // Generate random weights ensuring minimum weight
276        let mut weights: Vec<f64> = (0..count)
277            .map(|_| self.rng.random::<f64>().max(0.01))
278            .collect();
279        let sum: f64 = weights.iter().sum();
280        weights.iter_mut().for_each(|w| *w /= sum);
281
282        // Calculate amounts based on weights, using fast integer math for precision
283        let mut amounts: Vec<Decimal> = weights
284            .iter()
285            .map(|w| {
286                let amount = total_f64 * w;
287                let rounded = (amount * self.decimal_multiplier).round() / self.decimal_multiplier;
288                // Convert via integer cents — avoids format!()/parse() overhead
289                let cents = (rounded * 100.0).round() as i64;
290                Decimal::new(cents, 2)
291            })
292            .collect();
293
294        // Adjust last amount to ensure exact sum
295        let current_sum: Decimal = amounts.iter().copied().sum();
296        let diff = total - current_sum;
297        let last_idx = amounts.len() - 1;
298        amounts[last_idx] += diff;
299
300        // If last amount became negative (rare edge case), redistribute
301        if amounts[last_idx] < Decimal::ZERO {
302            let mut remaining = amounts[last_idx].abs();
303            amounts[last_idx] = Decimal::ZERO;
304
305            // Distribute the negative amount across all earlier amounts
306            for amt in amounts.iter_mut().take(last_idx).rev() {
307                if remaining <= Decimal::ZERO {
308                    break;
309                }
310                let take = remaining.min(*amt);
311                *amt -= take;
312                remaining -= take;
313            }
314
315            // If still remaining, absorb into the first non-zero amount
316            if remaining > Decimal::ZERO {
317                for amt in amounts.iter_mut() {
318                    if *amt > Decimal::ZERO {
319                        *amt -= remaining;
320                        break;
321                    }
322                }
323            }
324        }
325
326        // Post-process: fix zero-amount lines by transferring min_amount from the
327        // largest line. This preserves the exact sum while eliminating zeros.
328        // Only attempt when total is large enough to support min_amount per line.
329        if total >= min_amount * Decimal::from(count as u32) {
330            loop {
331                // Find a zero-amount line
332                let zero_idx = amounts.iter().position(|a| *a == Decimal::ZERO);
333                let Some(zi) = zero_idx else { break };
334
335                // Find the largest amount (must be > min_amount to donate)
336                let donor = amounts
337                    .iter()
338                    .enumerate()
339                    .filter(|&(j, _)| j != zi)
340                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
341                    .map(|(j, _)| j);
342
343                if let Some(di) = donor {
344                    if amounts[di] > min_amount {
345                        amounts[zi] = min_amount;
346                        amounts[di] -= min_amount;
347                    } else {
348                        break; // No donor has enough headroom
349                    }
350                } else {
351                    break;
352                }
353            }
354        }
355
356        amounts
357    }
358
359    /// Sample an amount within a specific range.
360    pub fn sample_in_range(&mut self, min: Decimal, max: Decimal) -> Decimal {
361        let min_f64 = min.to_string().parse::<f64>().unwrap_or(0.0);
362        let max_f64 = max.to_string().parse::<f64>().unwrap_or(1000000.0);
363
364        let range = max_f64 - min_f64;
365        let amount = min_f64 + self.rng.random::<f64>() * range;
366
367        let rounded = (amount * self.decimal_multiplier).round() / self.decimal_multiplier;
368        Decimal::from_f64_retain(rounded).unwrap_or(min)
369    }
370
371    /// Reset the sampler with a new seed.
372    pub fn reset(&mut self, seed: u64) {
373        self.rng = ChaCha8Rng::seed_from_u64(seed);
374    }
375}
376
377/// Sampler for currency exchange rates.
378pub struct ExchangeRateSampler {
379    rng: ChaCha8Rng,
380    /// Base rates for common currency pairs (vs USD)
381    base_rates: std::collections::HashMap<String, f64>,
382    /// Daily volatility (standard deviation)
383    volatility: f64,
384}
385
386impl ExchangeRateSampler {
387    /// Create a new exchange rate sampler.
388    pub fn new(seed: u64) -> Self {
389        let mut base_rates = std::collections::HashMap::new();
390        // Approximate rates as of 2024
391        base_rates.insert("EUR".to_string(), 0.92);
392        base_rates.insert("GBP".to_string(), 0.79);
393        base_rates.insert("CHF".to_string(), 0.88);
394        base_rates.insert("JPY".to_string(), 149.0);
395        base_rates.insert("CNY".to_string(), 7.24);
396        base_rates.insert("CAD".to_string(), 1.36);
397        base_rates.insert("AUD".to_string(), 1.53);
398        base_rates.insert("INR".to_string(), 83.0);
399        base_rates.insert("USD".to_string(), 1.0);
400
401        Self {
402            rng: ChaCha8Rng::seed_from_u64(seed),
403            base_rates,
404            volatility: 0.005, // 0.5% daily volatility
405        }
406    }
407
408    /// Get exchange rate from one currency to another.
409    pub fn get_rate(&mut self, from: &str, to: &str) -> Decimal {
410        let from_usd = self.base_rates.get(from).copied().unwrap_or(1.0);
411        let to_usd = self.base_rates.get(to).copied().unwrap_or(1.0);
412
413        // Base rate
414        let base_rate = to_usd / from_usd;
415
416        // Add some random variation
417        let variation = 1.0 + (self.rng.random::<f64>() - 0.5) * 2.0 * self.volatility;
418        let rate = base_rate * variation;
419
420        // Round to 6 decimal places
421        let rounded = (rate * 1_000_000.0).round() / 1_000_000.0;
422        Decimal::from_f64_retain(rounded).unwrap_or(Decimal::ONE)
423    }
424}
425
426#[cfg(test)]
427#[allow(clippy::unwrap_used)]
428mod tests {
429    use super::*;
430
431    #[test]
432    fn test_amount_sampler_determinism() {
433        let mut sampler1 = AmountSampler::new(42);
434        let mut sampler2 = AmountSampler::new(42);
435
436        for _ in 0..100 {
437            assert_eq!(sampler1.sample(), sampler2.sample());
438        }
439    }
440
441    #[test]
442    fn test_amount_sampler_range() {
443        let config = AmountDistributionConfig {
444            min_amount: 100.0,
445            max_amount: 1000.0,
446            ..Default::default()
447        };
448        let mut sampler = AmountSampler::with_config(42, config);
449
450        for _ in 0..1000 {
451            let amount = sampler.sample();
452            let amount_f64: f64 = amount.to_string().parse().unwrap();
453            assert!(amount_f64 >= 100.0, "Amount {} below minimum", amount);
454            assert!(amount_f64 <= 1000.0, "Amount {} above maximum", amount);
455        }
456    }
457
458    #[test]
459    fn test_summing_amounts() {
460        let mut sampler = AmountSampler::new(42);
461        let total = Decimal::from(10000);
462        let amounts = sampler.sample_summing_to(5, total);
463
464        assert_eq!(amounts.len(), 5);
465
466        let sum: Decimal = amounts.iter().sum();
467        assert_eq!(sum, total, "Sum {} doesn't match total {}", sum, total);
468    }
469
470    #[test]
471    fn test_exchange_rate() {
472        let mut sampler = ExchangeRateSampler::new(42);
473
474        let eur_usd = sampler.get_rate("EUR", "USD");
475        let eur_f64: f64 = eur_usd.to_string().parse().unwrap();
476        assert!(
477            eur_f64 > 0.8 && eur_f64 < 1.2,
478            "EUR/USD rate {} out of range",
479            eur_f64
480        );
481
482        let usd_usd = sampler.get_rate("USD", "USD");
483        let usd_f64: f64 = usd_usd.to_string().parse().unwrap();
484        assert!(
485            (usd_f64 - 1.0).abs() < 0.01,
486            "USD/USD rate {} should be ~1.0",
487            usd_f64
488        );
489    }
490}