Skip to main content

datasynth_core/distributions/
amount.rs

1//! Transaction amount distribution sampler.
2//!
3//! Generates realistic transaction amounts using log-normal distributions
4//! and round-number bias commonly observed in accounting data.
5
6use rand::prelude::*;
7use rand_chacha::ChaCha8Rng;
8use rand_distr::{Distribution, LogNormal};
9use rust_decimal::Decimal;
10use serde::{Deserialize, Serialize};
11
12use super::benford::{BenfordSampler, FraudAmountGenerator, FraudAmountPattern, ThresholdConfig};
13
14/// Configuration for amount distribution.
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct AmountDistributionConfig {
17    /// Minimum transaction amount
18    pub min_amount: f64,
19    /// Maximum transaction amount
20    pub max_amount: f64,
21    /// Log-normal mu parameter (location)
22    pub lognormal_mu: f64,
23    /// Log-normal sigma parameter (scale)
24    pub lognormal_sigma: f64,
25    /// Number of decimal places to round to
26    pub decimal_places: u8,
27    /// Probability of round number (ending in .00)
28    pub round_number_probability: f64,
29    /// Probability of nice number (ending in 0 or 5)
30    pub nice_number_probability: f64,
31}
32
33impl Default for AmountDistributionConfig {
34    fn default() -> Self {
35        Self {
36            min_amount: 0.01,
37            max_amount: 100_000_000.0, // 100 million
38            lognormal_mu: 7.0,         // Center around ~1000
39            lognormal_sigma: 2.5,      // Wide spread
40            decimal_places: 2,
41            round_number_probability: 0.25, // 25% chance of .00 ending
42            nice_number_probability: 0.15,  // 15% chance of nice numbers
43        }
44    }
45}
46
47impl AmountDistributionConfig {
48    /// Configuration for small transactions (e.g., retail).
49    pub fn small_transactions() -> Self {
50        Self {
51            min_amount: 0.01,
52            max_amount: 10_000.0,
53            lognormal_mu: 4.0, // Center around ~55
54            lognormal_sigma: 1.5,
55            decimal_places: 2,
56            round_number_probability: 0.30,
57            nice_number_probability: 0.20,
58        }
59    }
60
61    /// Configuration for medium transactions (e.g., B2B).
62    pub fn medium_transactions() -> Self {
63        Self {
64            min_amount: 100.0,
65            max_amount: 1_000_000.0,
66            lognormal_mu: 8.5, // Center around ~5000
67            lognormal_sigma: 2.0,
68            decimal_places: 2,
69            round_number_probability: 0.20,
70            nice_number_probability: 0.15,
71        }
72    }
73
74    /// Configuration for large transactions (e.g., enterprise).
75    pub fn large_transactions() -> Self {
76        Self {
77            min_amount: 1000.0,
78            max_amount: 100_000_000.0,
79            lognormal_mu: 10.0, // Center around ~22000
80            lognormal_sigma: 2.5,
81            decimal_places: 2,
82            round_number_probability: 0.15,
83            nice_number_probability: 0.10,
84        }
85    }
86}
87
88/// Sampler for realistic transaction amounts.
89pub struct AmountSampler {
90    /// RNG for sampling
91    rng: ChaCha8Rng,
92    /// Configuration
93    config: AmountDistributionConfig,
94    /// Log-normal distribution
95    lognormal: LogNormal<f64>,
96    /// Decimal multiplier for rounding
97    decimal_multiplier: f64,
98    /// Optional Benford sampler for compliant generation
99    benford_sampler: Option<BenfordSampler>,
100    /// Optional fraud amount generator
101    fraud_generator: Option<FraudAmountGenerator>,
102    /// Whether Benford's Law compliance is enabled
103    benford_enabled: bool,
104}
105
106impl AmountSampler {
107    /// Create a new sampler with default configuration.
108    pub fn new(seed: u64) -> Self {
109        Self::with_config(seed, AmountDistributionConfig::default())
110    }
111
112    /// Create a sampler with custom configuration.
113    pub fn with_config(seed: u64, config: AmountDistributionConfig) -> Self {
114        let lognormal = LogNormal::new(config.lognormal_mu, config.lognormal_sigma)
115            .expect("Invalid log-normal parameters");
116        let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
117
118        Self {
119            rng: ChaCha8Rng::seed_from_u64(seed),
120            config,
121            lognormal,
122            decimal_multiplier,
123            benford_sampler: None,
124            fraud_generator: None,
125            benford_enabled: false,
126        }
127    }
128
129    /// Create a sampler with Benford's Law compliance enabled.
130    pub fn with_benford(seed: u64, config: AmountDistributionConfig) -> Self {
131        let lognormal = LogNormal::new(config.lognormal_mu, config.lognormal_sigma)
132            .expect("Invalid log-normal parameters");
133        let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
134
135        Self {
136            rng: ChaCha8Rng::seed_from_u64(seed),
137            benford_sampler: Some(BenfordSampler::new(seed + 100, config.clone())),
138            fraud_generator: Some(FraudAmountGenerator::new(
139                seed + 200,
140                config.clone(),
141                ThresholdConfig::default(),
142            )),
143            config,
144            lognormal,
145            decimal_multiplier,
146            benford_enabled: true,
147        }
148    }
149
150    /// Create a sampler with full fraud configuration.
151    pub fn with_fraud_config(
152        seed: u64,
153        config: AmountDistributionConfig,
154        threshold_config: ThresholdConfig,
155        benford_enabled: bool,
156    ) -> Self {
157        let lognormal = LogNormal::new(config.lognormal_mu, config.lognormal_sigma)
158            .expect("Invalid log-normal parameters");
159        let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
160
161        Self {
162            rng: ChaCha8Rng::seed_from_u64(seed),
163            benford_sampler: if benford_enabled {
164                Some(BenfordSampler::new(seed + 100, config.clone()))
165            } else {
166                None
167            },
168            fraud_generator: Some(FraudAmountGenerator::new(
169                seed + 200,
170                config.clone(),
171                threshold_config,
172            )),
173            config,
174            lognormal,
175            decimal_multiplier,
176            benford_enabled,
177        }
178    }
179
180    /// Enable or disable Benford's Law compliance.
181    pub fn set_benford_enabled(&mut self, enabled: bool) {
182        self.benford_enabled = enabled;
183        if enabled && self.benford_sampler.is_none() {
184            // Initialize Benford sampler if not already present
185            let seed = self.rng.random();
186            self.benford_sampler = Some(BenfordSampler::new(seed, self.config.clone()));
187        }
188    }
189
190    /// Check if Benford's Law compliance is enabled.
191    pub fn is_benford_enabled(&self) -> bool {
192        self.benford_enabled
193    }
194
195    /// Sample a single amount.
196    ///
197    /// If Benford's Law compliance is enabled, uses the Benford sampler.
198    /// Otherwise uses log-normal distribution with round-number bias.
199    #[inline]
200    pub fn sample(&mut self) -> Decimal {
201        // Use Benford sampler if enabled
202        if self.benford_enabled {
203            if let Some(ref mut benford) = self.benford_sampler {
204                return benford.sample();
205            }
206        }
207
208        // Fall back to log-normal sampling
209        self.sample_lognormal()
210    }
211
212    /// Sample using the log-normal distribution (original behavior).
213    #[inline]
214    pub fn sample_lognormal(&mut self) -> Decimal {
215        let mut amount = self.lognormal.sample(&mut self.rng);
216
217        // Clamp to configured range
218        amount = amount.clamp(self.config.min_amount, self.config.max_amount);
219
220        // Apply round number bias
221        let p: f64 = self.rng.random();
222        if p < self.config.round_number_probability {
223            // Round to nearest whole number ending in 00
224            amount = (amount / 100.0).round() * 100.0;
225        } else if p < self.config.round_number_probability + self.config.nice_number_probability {
226            // Round to nearest 5 or 10
227            amount = (amount / 5.0).round() * 5.0;
228        }
229
230        // Round to configured decimal places
231        amount = (amount * self.decimal_multiplier).round() / self.decimal_multiplier;
232
233        // Ensure minimum after rounding
234        amount = amount.max(self.config.min_amount);
235
236        // Convert to Decimal using fast integer math instead of string formatting.
237        // Multiply by 100, truncate to integer, then construct Decimal with scale 2.
238        // This avoids the overhead of format!() + parse() (~15x faster).
239        let cents = (amount * 100.0).round() as i64;
240        Decimal::new(cents, 2)
241    }
242
243    /// Sample a fraud amount with the specified pattern.
244    ///
245    /// Returns a normal amount if fraud generator is not configured.
246    pub fn sample_fraud(&mut self, pattern: FraudAmountPattern) -> Decimal {
247        if let Some(ref mut fraud_gen) = self.fraud_generator {
248            fraud_gen.sample(pattern)
249        } else {
250            // Fallback to normal sampling
251            self.sample()
252        }
253    }
254
255    /// Sample multiple amounts that sum to a target total.
256    ///
257    /// Useful for generating line items that must balance.
258    /// The sum of returned amounts is guaranteed to equal `total` exactly.
259    /// Every returned amount is guaranteed to be > 0 when `total > 0` and
260    /// `count * 0.01 <= total`.
261    pub fn sample_summing_to(&mut self, count: usize, total: Decimal) -> Vec<Decimal> {
262        use rust_decimal::prelude::ToPrimitive;
263
264        let min_amount = Decimal::new(1, 2); // 0.01
265
266        if count == 0 {
267            return Vec::new();
268        }
269        if count == 1 {
270            return vec![total];
271        }
272
273        let total_f64 = total.to_f64().unwrap_or(0.0);
274
275        // Generate random weights ensuring minimum weight
276        let mut weights: Vec<f64> = (0..count)
277            .map(|_| self.rng.random::<f64>().max(0.01))
278            .collect();
279        let sum: f64 = weights.iter().sum();
280        weights.iter_mut().for_each(|w| *w /= sum);
281
282        // Calculate amounts based on weights, using fast integer math for precision
283        let mut amounts: Vec<Decimal> = weights
284            .iter()
285            .map(|w| {
286                let amount = total_f64 * w;
287                let rounded = (amount * self.decimal_multiplier).round() / self.decimal_multiplier;
288                // Convert via integer cents — avoids format!()/parse() overhead
289                let cents = (rounded * 100.0).round() as i64;
290                Decimal::new(cents, 2)
291            })
292            .collect();
293
294        // Adjust last amount to ensure exact sum
295        let current_sum: Decimal = amounts.iter().copied().sum();
296        let diff = total - current_sum;
297        let last_idx = amounts.len() - 1;
298        amounts[last_idx] += diff;
299
300        // If last amount became negative (rare edge case), redistribute
301        if amounts[last_idx] < Decimal::ZERO {
302            let mut remaining = amounts[last_idx].abs();
303            amounts[last_idx] = Decimal::ZERO;
304
305            // Distribute the negative amount across all earlier amounts
306            for amt in amounts.iter_mut().take(last_idx).rev() {
307                if remaining <= Decimal::ZERO {
308                    break;
309                }
310                let take = remaining.min(*amt);
311                *amt -= take;
312                remaining -= take;
313            }
314
315            // If still remaining, absorb into the first non-zero amount
316            if remaining > Decimal::ZERO {
317                for amt in amounts.iter_mut() {
318                    if *amt > Decimal::ZERO {
319                        *amt -= remaining;
320                        break;
321                    }
322                }
323            }
324        }
325
326        // Post-process: fix zero-amount lines by transferring min_amount from the
327        // largest line. This preserves the exact sum while eliminating zeros.
328        // Only attempt when total is large enough to support min_amount per line.
329        if total >= min_amount * Decimal::from(count as u32) {
330            loop {
331                // Find a zero-amount line
332                let zero_idx = amounts.iter().position(|a| *a == Decimal::ZERO);
333                let Some(zi) = zero_idx else { break };
334
335                // Find the largest amount (must be > min_amount to donate)
336                let donor = amounts
337                    .iter()
338                    .enumerate()
339                    .filter(|&(j, _)| j != zi)
340                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
341                    .map(|(j, _)| j);
342
343                if let Some(di) = donor {
344                    if amounts[di] > min_amount {
345                        amounts[zi] = min_amount;
346                        amounts[di] -= min_amount;
347                    } else {
348                        break; // No donor has enough headroom
349                    }
350                } else {
351                    break;
352                }
353            }
354        }
355
356        amounts
357    }
358
359    /// Sample an amount within a specific range.
360    pub fn sample_in_range(&mut self, min: Decimal, max: Decimal) -> Decimal {
361        let min_f64 = min.to_string().parse::<f64>().unwrap_or(0.0);
362        let max_f64 = max.to_string().parse::<f64>().unwrap_or(1000000.0);
363
364        let range = max_f64 - min_f64;
365        let amount = min_f64 + self.rng.random::<f64>() * range;
366
367        let rounded = (amount * self.decimal_multiplier).round() / self.decimal_multiplier;
368        Decimal::from_f64_retain(rounded).unwrap_or(min)
369    }
370
371    /// Reset the sampler with a new seed.
372    pub fn reset(&mut self, seed: u64) {
373        self.rng = ChaCha8Rng::seed_from_u64(seed);
374    }
375
376    /// SP3.5b — Update the log-normal sigma parameter.
377    ///
378    /// Also rebuilds the internal `LogNormal` distribution so that subsequent
379    /// `sample_lognormal()` calls reflect the new sigma. No-op if `sigma` is
380    /// not positive (guard against calibrator edge-cases).
381    pub fn set_lognormal_sigma(&mut self, sigma: f64) {
382        if sigma > 0.0 {
383            self.config.lognormal_sigma = sigma;
384            if let Ok(dist) = LogNormal::new(self.config.lognormal_mu, sigma) {
385                self.lognormal = dist;
386            }
387        }
388    }
389
390    /// SP3.5b — Update the round-number probability parameter.
391    ///
392    /// Clamped to [0, 1] for safety.
393    pub fn set_round_number_probability(&mut self, p: f64) {
394        self.config.round_number_probability = p.clamp(0.0, 1.0);
395    }
396
397    /// Return the current log-normal sigma (for testing).
398    pub fn lognormal_sigma(&self) -> f64 {
399        self.config.lognormal_sigma
400    }
401
402    /// Return the current round-number probability (for testing).
403    pub fn round_number_probability(&self) -> f64 {
404        self.config.round_number_probability
405    }
406}
407
408/// Sampler for currency exchange rates.
409pub struct ExchangeRateSampler {
410    rng: ChaCha8Rng,
411    /// Base rates for common currency pairs (vs USD)
412    base_rates: std::collections::HashMap<String, f64>,
413    /// Daily volatility (standard deviation)
414    volatility: f64,
415}
416
417impl ExchangeRateSampler {
418    /// Create a new exchange rate sampler.
419    pub fn new(seed: u64) -> Self {
420        let mut base_rates = std::collections::HashMap::new();
421        // Approximate rates as of 2024
422        base_rates.insert("EUR".to_string(), 0.92);
423        base_rates.insert("GBP".to_string(), 0.79);
424        base_rates.insert("CHF".to_string(), 0.88);
425        base_rates.insert("JPY".to_string(), 149.0);
426        base_rates.insert("CNY".to_string(), 7.24);
427        base_rates.insert("CAD".to_string(), 1.36);
428        base_rates.insert("AUD".to_string(), 1.53);
429        base_rates.insert("INR".to_string(), 83.0);
430        base_rates.insert("USD".to_string(), 1.0);
431
432        Self {
433            rng: ChaCha8Rng::seed_from_u64(seed),
434            base_rates,
435            volatility: 0.005, // 0.5% daily volatility
436        }
437    }
438
439    /// Get exchange rate from one currency to another.
440    pub fn get_rate(&mut self, from: &str, to: &str) -> Decimal {
441        let from_usd = self.base_rates.get(from).copied().unwrap_or(1.0);
442        let to_usd = self.base_rates.get(to).copied().unwrap_or(1.0);
443
444        // Base rate
445        let base_rate = to_usd / from_usd;
446
447        // Add some random variation
448        let variation = 1.0 + (self.rng.random::<f64>() - 0.5) * 2.0 * self.volatility;
449        let rate = base_rate * variation;
450
451        // Round to 6 decimal places
452        let rounded = (rate * 1_000_000.0).round() / 1_000_000.0;
453        Decimal::from_f64_retain(rounded).unwrap_or(Decimal::ONE)
454    }
455}
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460
461    #[test]
462    fn test_amount_sampler_determinism() {
463        let mut sampler1 = AmountSampler::new(42);
464        let mut sampler2 = AmountSampler::new(42);
465
466        for _ in 0..100 {
467            assert_eq!(sampler1.sample(), sampler2.sample());
468        }
469    }
470
471    #[test]
472    fn test_amount_sampler_range() {
473        let config = AmountDistributionConfig {
474            min_amount: 100.0,
475            max_amount: 1000.0,
476            ..Default::default()
477        };
478        let mut sampler = AmountSampler::with_config(42, config);
479
480        for _ in 0..1000 {
481            let amount = sampler.sample();
482            let amount_f64: f64 = amount.to_string().parse().unwrap();
483            assert!(amount_f64 >= 100.0, "Amount {} below minimum", amount);
484            assert!(amount_f64 <= 1000.0, "Amount {} above maximum", amount);
485        }
486    }
487
488    #[test]
489    fn test_summing_amounts() {
490        let mut sampler = AmountSampler::new(42);
491        let total = Decimal::from(10000);
492        let amounts = sampler.sample_summing_to(5, total);
493
494        assert_eq!(amounts.len(), 5);
495
496        let sum: Decimal = amounts.iter().sum();
497        assert_eq!(sum, total, "Sum {} doesn't match total {}", sum, total);
498    }
499
500    #[test]
501    fn test_exchange_rate() {
502        let mut sampler = ExchangeRateSampler::new(42);
503
504        let eur_usd = sampler.get_rate("EUR", "USD");
505        let eur_f64: f64 = eur_usd.to_string().parse().unwrap();
506        assert!(
507            eur_f64 > 0.8 && eur_f64 < 1.2,
508            "EUR/USD rate {} out of range",
509            eur_f64
510        );
511
512        let usd_usd = sampler.get_rate("USD", "USD");
513        let usd_f64: f64 = usd_usd.to_string().parse().unwrap();
514        assert!(
515            (usd_f64 - 1.0).abs() < 0.01,
516            "USD/USD rate {} should be ~1.0",
517            usd_f64
518        );
519    }
520}