Skip to main content

datasynth_core/distributions/
amount.rs

1//! Transaction amount distribution sampler.
2//!
3//! Generates realistic transaction amounts using log-normal distributions
4//! and round-number bias commonly observed in accounting data.
5
6use rand::prelude::*;
7use rand_chacha::ChaCha8Rng;
8use rand_distr::{Distribution, LogNormal};
9use rust_decimal::Decimal;
10use serde::{Deserialize, Serialize};
11
12use super::benford::{BenfordSampler, FraudAmountGenerator, FraudAmountPattern, ThresholdConfig};
13
14/// Configuration for amount distribution.
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct AmountDistributionConfig {
17    /// Minimum transaction amount
18    pub min_amount: f64,
19    /// Maximum transaction amount
20    pub max_amount: f64,
21    /// Log-normal mu parameter (location)
22    pub lognormal_mu: f64,
23    /// Log-normal sigma parameter (scale)
24    pub lognormal_sigma: f64,
25    /// Number of decimal places to round to
26    pub decimal_places: u8,
27    /// Probability of round number (ending in .00)
28    pub round_number_probability: f64,
29    /// Probability of nice number (ending in 0 or 5)
30    pub nice_number_probability: f64,
31}
32
33impl Default for AmountDistributionConfig {
34    fn default() -> Self {
35        Self {
36            min_amount: 0.01,
37            max_amount: 100_000_000.0, // 100 million
38            lognormal_mu: 7.0,         // Center around ~1000
39            lognormal_sigma: 2.5,      // Wide spread
40            decimal_places: 2,
41            round_number_probability: 0.25, // 25% chance of .00 ending
42            nice_number_probability: 0.15,  // 15% chance of nice numbers
43        }
44    }
45}
46
47impl AmountDistributionConfig {
48    /// Configuration for small transactions (e.g., retail).
49    pub fn small_transactions() -> Self {
50        Self {
51            min_amount: 0.01,
52            max_amount: 10_000.0,
53            lognormal_mu: 4.0, // Center around ~55
54            lognormal_sigma: 1.5,
55            decimal_places: 2,
56            round_number_probability: 0.30,
57            nice_number_probability: 0.20,
58        }
59    }
60
61    /// Configuration for medium transactions (e.g., B2B).
62    pub fn medium_transactions() -> Self {
63        Self {
64            min_amount: 100.0,
65            max_amount: 1_000_000.0,
66            lognormal_mu: 8.5, // Center around ~5000
67            lognormal_sigma: 2.0,
68            decimal_places: 2,
69            round_number_probability: 0.20,
70            nice_number_probability: 0.15,
71        }
72    }
73
74    /// Configuration for large transactions (e.g., enterprise).
75    pub fn large_transactions() -> Self {
76        Self {
77            min_amount: 1000.0,
78            max_amount: 100_000_000.0,
79            lognormal_mu: 10.0, // Center around ~22000
80            lognormal_sigma: 2.5,
81            decimal_places: 2,
82            round_number_probability: 0.15,
83            nice_number_probability: 0.10,
84        }
85    }
86}
87
88/// Sampler for realistic transaction amounts.
89pub struct AmountSampler {
90    /// RNG for sampling
91    rng: ChaCha8Rng,
92    /// Configuration
93    config: AmountDistributionConfig,
94    /// Log-normal distribution
95    lognormal: LogNormal<f64>,
96    /// Decimal multiplier for rounding
97    decimal_multiplier: f64,
98    /// Optional Benford sampler for compliant generation
99    benford_sampler: Option<BenfordSampler>,
100    /// Optional fraud amount generator
101    fraud_generator: Option<FraudAmountGenerator>,
102    /// Whether Benford's Law compliance is enabled
103    benford_enabled: bool,
104}
105
106impl AmountSampler {
107    /// Create a new sampler with default configuration.
108    pub fn new(seed: u64) -> Self {
109        Self::with_config(seed, AmountDistributionConfig::default())
110    }
111
112    /// Create a sampler with custom configuration.
113    pub fn with_config(seed: u64, config: AmountDistributionConfig) -> Self {
114        let lognormal = LogNormal::new(config.lognormal_mu, config.lognormal_sigma)
115            .expect("Invalid log-normal parameters");
116        let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
117
118        Self {
119            rng: ChaCha8Rng::seed_from_u64(seed),
120            config,
121            lognormal,
122            decimal_multiplier,
123            benford_sampler: None,
124            fraud_generator: None,
125            benford_enabled: false,
126        }
127    }
128
129    /// Create a sampler with Benford's Law compliance enabled.
130    pub fn with_benford(seed: u64, config: AmountDistributionConfig) -> Self {
131        let lognormal = LogNormal::new(config.lognormal_mu, config.lognormal_sigma)
132            .expect("Invalid log-normal parameters");
133        let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
134
135        Self {
136            rng: ChaCha8Rng::seed_from_u64(seed),
137            benford_sampler: Some(BenfordSampler::new(seed + 100, config.clone())),
138            fraud_generator: Some(FraudAmountGenerator::new(
139                seed + 200,
140                config.clone(),
141                ThresholdConfig::default(),
142            )),
143            config,
144            lognormal,
145            decimal_multiplier,
146            benford_enabled: true,
147        }
148    }
149
150    /// Create a sampler with full fraud configuration.
151    pub fn with_fraud_config(
152        seed: u64,
153        config: AmountDistributionConfig,
154        threshold_config: ThresholdConfig,
155        benford_enabled: bool,
156    ) -> Self {
157        let lognormal = LogNormal::new(config.lognormal_mu, config.lognormal_sigma)
158            .expect("Invalid log-normal parameters");
159        let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
160
161        Self {
162            rng: ChaCha8Rng::seed_from_u64(seed),
163            benford_sampler: if benford_enabled {
164                Some(BenfordSampler::new(seed + 100, config.clone()))
165            } else {
166                None
167            },
168            fraud_generator: Some(FraudAmountGenerator::new(
169                seed + 200,
170                config.clone(),
171                threshold_config,
172            )),
173            config,
174            lognormal,
175            decimal_multiplier,
176            benford_enabled,
177        }
178    }
179
180    /// Enable or disable Benford's Law compliance.
181    pub fn set_benford_enabled(&mut self, enabled: bool) {
182        self.benford_enabled = enabled;
183        if enabled && self.benford_sampler.is_none() {
184            // Initialize Benford sampler if not already present
185            let seed = self.rng.gen();
186            self.benford_sampler = Some(BenfordSampler::new(seed, self.config.clone()));
187        }
188    }
189
190    /// Check if Benford's Law compliance is enabled.
191    pub fn is_benford_enabled(&self) -> bool {
192        self.benford_enabled
193    }
194
195    /// Sample a single amount.
196    ///
197    /// If Benford's Law compliance is enabled, uses the Benford sampler.
198    /// Otherwise uses log-normal distribution with round-number bias.
199    pub fn sample(&mut self) -> Decimal {
200        // Use Benford sampler if enabled
201        if self.benford_enabled {
202            if let Some(ref mut benford) = self.benford_sampler {
203                return benford.sample();
204            }
205        }
206
207        // Fall back to log-normal sampling
208        self.sample_lognormal()
209    }
210
211    /// Sample using the log-normal distribution (original behavior).
212    pub fn sample_lognormal(&mut self) -> Decimal {
213        let mut amount = self.lognormal.sample(&mut self.rng);
214
215        // Clamp to configured range
216        amount = amount.clamp(self.config.min_amount, self.config.max_amount);
217
218        // Apply round number bias
219        let p: f64 = self.rng.gen();
220        if p < self.config.round_number_probability {
221            // Round to nearest whole number ending in 00
222            amount = (amount / 100.0).round() * 100.0;
223        } else if p < self.config.round_number_probability + self.config.nice_number_probability {
224            // Round to nearest 5 or 10
225            amount = (amount / 5.0).round() * 5.0;
226        }
227
228        // Round to configured decimal places
229        amount = (amount * self.decimal_multiplier).round() / self.decimal_multiplier;
230
231        // Ensure minimum after rounding
232        amount = amount.max(self.config.min_amount);
233
234        // Convert to Decimal with explicit 2 decimal place precision to avoid f64 noise
235        let amount_str = format!("{:.2}", amount);
236        amount_str.parse::<Decimal>().unwrap_or(Decimal::ONE)
237    }
238
239    /// Sample a fraud amount with the specified pattern.
240    ///
241    /// Returns a normal amount if fraud generator is not configured.
242    pub fn sample_fraud(&mut self, pattern: FraudAmountPattern) -> Decimal {
243        if let Some(ref mut fraud_gen) = self.fraud_generator {
244            fraud_gen.sample(pattern)
245        } else {
246            // Fallback to normal sampling
247            self.sample()
248        }
249    }
250
251    /// Sample multiple amounts that sum to a target total.
252    ///
253    /// Useful for generating line items that must balance.
254    pub fn sample_summing_to(&mut self, count: usize, total: Decimal) -> Vec<Decimal> {
255        use rust_decimal::prelude::ToPrimitive;
256
257        if count == 0 {
258            return Vec::new();
259        }
260        if count == 1 {
261            return vec![total];
262        }
263
264        let total_f64 = total.to_f64().unwrap_or(0.0);
265
266        // Generate random weights ensuring minimum weight
267        let mut weights: Vec<f64> = (0..count)
268            .map(|_| self.rng.gen::<f64>().max(0.01))
269            .collect();
270        let sum: f64 = weights.iter().sum();
271        weights.iter_mut().for_each(|w| *w /= sum);
272
273        // Calculate amounts based on weights, using string parsing for precision
274        let mut amounts: Vec<Decimal> = weights
275            .iter()
276            .map(|w| {
277                let amount = total_f64 * w;
278                let rounded = (amount * self.decimal_multiplier).round() / self.decimal_multiplier;
279                // Use string format for more reliable decimal conversion
280                let amount_str = format!("{:.2}", rounded);
281                amount_str.parse::<Decimal>().unwrap_or(Decimal::ZERO)
282            })
283            .collect();
284
285        // Adjust last amount to ensure exact sum
286        let current_sum: Decimal = amounts.iter().copied().sum();
287        let diff = total - current_sum;
288        let last_idx = amounts.len() - 1;
289        amounts[last_idx] += diff;
290
291        // If last amount became negative (rare edge case), redistribute
292        if amounts[last_idx] < Decimal::ZERO {
293            let mut remaining = amounts[last_idx].abs();
294            amounts[last_idx] = Decimal::ZERO;
295
296            // Distribute the negative amount across all earlier amounts
297            for amt in amounts.iter_mut().take(last_idx).rev() {
298                if remaining <= Decimal::ZERO {
299                    break;
300                }
301                let take = remaining.min(*amt);
302                *amt -= take;
303                remaining -= take;
304            }
305
306            // If still remaining (shouldn't happen with proper weights),
307            // absorb into the last amount as a negative value for safety
308            if remaining > Decimal::ZERO {
309                // Re-add to first non-zero amount - this ensures sum is correct
310                for amt in amounts.iter_mut() {
311                    if *amt > Decimal::ZERO {
312                        *amt -= remaining;
313                        break;
314                    }
315                }
316            }
317        }
318
319        amounts
320    }
321
322    /// Sample an amount within a specific range.
323    pub fn sample_in_range(&mut self, min: Decimal, max: Decimal) -> Decimal {
324        let min_f64 = min.to_string().parse::<f64>().unwrap_or(0.0);
325        let max_f64 = max.to_string().parse::<f64>().unwrap_or(1000000.0);
326
327        let range = max_f64 - min_f64;
328        let amount = min_f64 + self.rng.gen::<f64>() * range;
329
330        let rounded = (amount * self.decimal_multiplier).round() / self.decimal_multiplier;
331        Decimal::from_f64_retain(rounded).unwrap_or(min)
332    }
333
334    /// Reset the sampler with a new seed.
335    pub fn reset(&mut self, seed: u64) {
336        self.rng = ChaCha8Rng::seed_from_u64(seed);
337    }
338}
339
340/// Sampler for currency exchange rates.
341pub struct ExchangeRateSampler {
342    rng: ChaCha8Rng,
343    /// Base rates for common currency pairs (vs USD)
344    base_rates: std::collections::HashMap<String, f64>,
345    /// Daily volatility (standard deviation)
346    volatility: f64,
347}
348
349impl ExchangeRateSampler {
350    /// Create a new exchange rate sampler.
351    pub fn new(seed: u64) -> Self {
352        let mut base_rates = std::collections::HashMap::new();
353        // Approximate rates as of 2024
354        base_rates.insert("EUR".to_string(), 0.92);
355        base_rates.insert("GBP".to_string(), 0.79);
356        base_rates.insert("CHF".to_string(), 0.88);
357        base_rates.insert("JPY".to_string(), 149.0);
358        base_rates.insert("CNY".to_string(), 7.24);
359        base_rates.insert("CAD".to_string(), 1.36);
360        base_rates.insert("AUD".to_string(), 1.53);
361        base_rates.insert("INR".to_string(), 83.0);
362        base_rates.insert("USD".to_string(), 1.0);
363
364        Self {
365            rng: ChaCha8Rng::seed_from_u64(seed),
366            base_rates,
367            volatility: 0.005, // 0.5% daily volatility
368        }
369    }
370
371    /// Get exchange rate from one currency to another.
372    pub fn get_rate(&mut self, from: &str, to: &str) -> Decimal {
373        let from_usd = self.base_rates.get(from).copied().unwrap_or(1.0);
374        let to_usd = self.base_rates.get(to).copied().unwrap_or(1.0);
375
376        // Base rate
377        let base_rate = to_usd / from_usd;
378
379        // Add some random variation
380        let variation = 1.0 + (self.rng.gen::<f64>() - 0.5) * 2.0 * self.volatility;
381        let rate = base_rate * variation;
382
383        // Round to 6 decimal places
384        let rounded = (rate * 1_000_000.0).round() / 1_000_000.0;
385        Decimal::from_f64_retain(rounded).unwrap_or(Decimal::ONE)
386    }
387}
388
389#[cfg(test)]
390mod tests {
391    use super::*;
392
393    #[test]
394    fn test_amount_sampler_determinism() {
395        let mut sampler1 = AmountSampler::new(42);
396        let mut sampler2 = AmountSampler::new(42);
397
398        for _ in 0..100 {
399            assert_eq!(sampler1.sample(), sampler2.sample());
400        }
401    }
402
403    #[test]
404    fn test_amount_sampler_range() {
405        let config = AmountDistributionConfig {
406            min_amount: 100.0,
407            max_amount: 1000.0,
408            ..Default::default()
409        };
410        let mut sampler = AmountSampler::with_config(42, config);
411
412        for _ in 0..1000 {
413            let amount = sampler.sample();
414            let amount_f64: f64 = amount.to_string().parse().unwrap();
415            assert!(amount_f64 >= 100.0, "Amount {} below minimum", amount);
416            assert!(amount_f64 <= 1000.0, "Amount {} above maximum", amount);
417        }
418    }
419
420    #[test]
421    fn test_summing_amounts() {
422        let mut sampler = AmountSampler::new(42);
423        let total = Decimal::from(10000);
424        let amounts = sampler.sample_summing_to(5, total);
425
426        assert_eq!(amounts.len(), 5);
427
428        let sum: Decimal = amounts.iter().sum();
429        assert_eq!(sum, total, "Sum {} doesn't match total {}", sum, total);
430    }
431
432    #[test]
433    fn test_exchange_rate() {
434        let mut sampler = ExchangeRateSampler::new(42);
435
436        let eur_usd = sampler.get_rate("EUR", "USD");
437        let eur_f64: f64 = eur_usd.to_string().parse().unwrap();
438        assert!(
439            eur_f64 > 0.8 && eur_f64 < 1.2,
440            "EUR/USD rate {} out of range",
441            eur_f64
442        );
443
444        let usd_usd = sampler.get_rate("USD", "USD");
445        let usd_f64: f64 = usd_usd.to_string().parse().unwrap();
446        assert!(
447            (usd_f64 - 1.0).abs() < 0.01,
448            "USD/USD rate {} should be ~1.0",
449            usd_f64
450        );
451    }
452}