Skip to main content

datasynth_core/distributions/
amount.rs

1//! Transaction amount distribution sampler.
2//!
3//! Generates realistic transaction amounts using log-normal distributions
4//! and round-number bias commonly observed in accounting data.
5
6use rand::prelude::*;
7use rand_chacha::ChaCha8Rng;
8use rand_distr::{Distribution, LogNormal};
9use rust_decimal::Decimal;
10use serde::{Deserialize, Serialize};
11
12use super::benford::{BenfordSampler, FraudAmountGenerator, FraudAmountPattern, ThresholdConfig};
13
14/// Configuration for amount distribution.
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct AmountDistributionConfig {
17    /// Minimum transaction amount
18    pub min_amount: f64,
19    /// Maximum transaction amount
20    pub max_amount: f64,
21    /// Log-normal mu parameter (location)
22    pub lognormal_mu: f64,
23    /// Log-normal sigma parameter (scale)
24    pub lognormal_sigma: f64,
25    /// Number of decimal places to round to
26    pub decimal_places: u8,
27    /// Probability of round number (ending in .00)
28    pub round_number_probability: f64,
29    /// Probability of nice number (ending in 0 or 5)
30    pub nice_number_probability: f64,
31}
32
33impl Default for AmountDistributionConfig {
34    fn default() -> Self {
35        Self {
36            min_amount: 0.01,
37            max_amount: 100_000_000.0, // 100 million
38            // Corpus-aligned central tendency. The reference GL corpus has a
39            // ~$8-10K per-line median (un-split 2-line JEs ~$10K); the old
40            // mu=7.0 centered at ~$1,100 (~10x low) and, once a JE total was
41            // split across many lines, produced a flood of sub-$10 lines
42            // (median ~$9) that collapsed the overall median and degraded
43            // Benford. P0b set mu=9.25 (median e^9.25 ~ $10.4K JE total ->
44            // ~$4.8K per-line after splitting), but the E2E line median came
45            // in ~1.8x under the corpus ~$8.5K. P0c lifts mu by ln(1.8)~0.58
46            // to 9.85 (median e^9.85 ~ $18.9K JE total -> ~$8.5K per-line).
47            // See docs/analysis/gl-corpus-realism-roadmap.md (P0b/P0c).
48            lognormal_mu: 9.85,
49            // Tightened so the per-line p99/p50 lands near the corpus ~200x
50            // rather than ~340x. exp(2.326*2.3) ~ 210x.
51            lognormal_sigma: 2.3,
52            decimal_places: 2,
53            round_number_probability: 0.25, // 25% chance of .00 ending
54            nice_number_probability: 0.15,  // 15% chance of nice numbers
55        }
56    }
57}
58
59impl AmountDistributionConfig {
60    /// Configuration for small transactions (e.g., retail).
61    pub fn small_transactions() -> Self {
62        Self {
63            min_amount: 0.01,
64            max_amount: 10_000.0,
65            lognormal_mu: 4.0, // Center around ~55
66            lognormal_sigma: 1.5,
67            decimal_places: 2,
68            round_number_probability: 0.30,
69            nice_number_probability: 0.20,
70        }
71    }
72
73    /// Configuration for medium transactions (e.g., B2B).
74    pub fn medium_transactions() -> Self {
75        Self {
76            min_amount: 100.0,
77            max_amount: 1_000_000.0,
78            lognormal_mu: 8.5, // Center around ~5000
79            lognormal_sigma: 2.0,
80            decimal_places: 2,
81            round_number_probability: 0.20,
82            nice_number_probability: 0.15,
83        }
84    }
85
86    /// Configuration for large transactions (e.g., enterprise).
87    pub fn large_transactions() -> Self {
88        Self {
89            min_amount: 1000.0,
90            max_amount: 100_000_000.0,
91            lognormal_mu: 10.0, // Center around ~22000
92            lognormal_sigma: 2.5,
93            decimal_places: 2,
94            round_number_probability: 0.15,
95            nice_number_probability: 0.10,
96        }
97    }
98}
99
100/// Sampler for realistic transaction amounts.
101pub struct AmountSampler {
102    /// RNG for sampling
103    rng: ChaCha8Rng,
104    /// Configuration
105    config: AmountDistributionConfig,
106    /// Log-normal distribution
107    lognormal: LogNormal<f64>,
108    /// Decimal multiplier for rounding
109    decimal_multiplier: f64,
110    /// Optional Benford sampler for compliant generation
111    benford_sampler: Option<BenfordSampler>,
112    /// Optional fraud amount generator
113    fraud_generator: Option<FraudAmountGenerator>,
114    /// Whether Benford's Law compliance is enabled
115    benford_enabled: bool,
116}
117
118impl AmountSampler {
119    /// Create a new sampler with default configuration.
120    pub fn new(seed: u64) -> Self {
121        Self::with_config(seed, AmountDistributionConfig::default())
122    }
123
124    /// Create a sampler with custom configuration.
125    pub fn with_config(seed: u64, config: AmountDistributionConfig) -> Self {
126        let lognormal = LogNormal::new(config.lognormal_mu, config.lognormal_sigma)
127            .expect("Invalid log-normal parameters");
128        let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
129
130        Self {
131            rng: ChaCha8Rng::seed_from_u64(seed),
132            config,
133            lognormal,
134            decimal_multiplier,
135            benford_sampler: None,
136            fraud_generator: None,
137            benford_enabled: false,
138        }
139    }
140
141    /// Create a sampler with Benford's Law compliance enabled.
142    pub fn with_benford(seed: u64, config: AmountDistributionConfig) -> Self {
143        let lognormal = LogNormal::new(config.lognormal_mu, config.lognormal_sigma)
144            .expect("Invalid log-normal parameters");
145        let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
146
147        Self {
148            rng: ChaCha8Rng::seed_from_u64(seed),
149            benford_sampler: Some(BenfordSampler::new(seed + 100, config.clone())),
150            fraud_generator: Some(FraudAmountGenerator::new(
151                seed + 200,
152                config.clone(),
153                ThresholdConfig::default(),
154            )),
155            config,
156            lognormal,
157            decimal_multiplier,
158            benford_enabled: true,
159        }
160    }
161
162    /// Create a sampler with full fraud configuration.
163    pub fn with_fraud_config(
164        seed: u64,
165        config: AmountDistributionConfig,
166        threshold_config: ThresholdConfig,
167        benford_enabled: bool,
168    ) -> Self {
169        let lognormal = LogNormal::new(config.lognormal_mu, config.lognormal_sigma)
170            .expect("Invalid log-normal parameters");
171        let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
172
173        Self {
174            rng: ChaCha8Rng::seed_from_u64(seed),
175            benford_sampler: if benford_enabled {
176                Some(BenfordSampler::new(seed + 100, config.clone()))
177            } else {
178                None
179            },
180            fraud_generator: Some(FraudAmountGenerator::new(
181                seed + 200,
182                config.clone(),
183                threshold_config,
184            )),
185            config,
186            lognormal,
187            decimal_multiplier,
188            benford_enabled,
189        }
190    }
191
192    /// Enable or disable Benford's Law compliance.
193    pub fn set_benford_enabled(&mut self, enabled: bool) {
194        self.benford_enabled = enabled;
195        if enabled && self.benford_sampler.is_none() {
196            // Initialize Benford sampler if not already present
197            let seed = self.rng.random();
198            self.benford_sampler = Some(BenfordSampler::new(seed, self.config.clone()));
199        }
200    }
201
202    /// Check if Benford's Law compliance is enabled.
203    pub fn is_benford_enabled(&self) -> bool {
204        self.benford_enabled
205    }
206
207    /// Sample a single amount.
208    ///
209    /// If Benford's Law compliance is enabled, uses the Benford sampler.
210    /// Otherwise uses log-normal distribution with round-number bias.
211    #[inline]
212    pub fn sample(&mut self) -> Decimal {
213        // Use Benford sampler if enabled
214        if self.benford_enabled {
215            if let Some(ref mut benford) = self.benford_sampler {
216                return benford.sample();
217            }
218        }
219
220        // Fall back to log-normal sampling
221        self.sample_lognormal()
222    }
223
224    /// Sample using the log-normal distribution (original behavior).
225    #[inline]
226    pub fn sample_lognormal(&mut self) -> Decimal {
227        let mut amount = self.lognormal.sample(&mut self.rng);
228
229        // Clamp to configured range
230        amount = amount.clamp(self.config.min_amount, self.config.max_amount);
231
232        // Apply round number bias
233        let p: f64 = self.rng.random();
234        if p < self.config.round_number_probability {
235            // Round to nearest whole number ending in 00
236            amount = (amount / 100.0).round() * 100.0;
237        } else if p < self.config.round_number_probability + self.config.nice_number_probability {
238            // Round to nearest 5 or 10
239            amount = (amount / 5.0).round() * 5.0;
240        }
241
242        // Round to configured decimal places
243        amount = (amount * self.decimal_multiplier).round() / self.decimal_multiplier;
244
245        // Ensure minimum after rounding
246        amount = amount.max(self.config.min_amount);
247
248        // Convert to Decimal using fast integer math instead of string formatting.
249        // Multiply by 100, truncate to integer, then construct Decimal with scale 2.
250        // This avoids the overhead of format!() + parse() (~15x faster).
251        let cents = (amount * 100.0).round() as i64;
252        Decimal::new(cents, 2)
253    }
254
255    /// Sample a fraud amount with the specified pattern.
256    ///
257    /// Returns a normal amount if fraud generator is not configured.
258    pub fn sample_fraud(&mut self, pattern: FraudAmountPattern) -> Decimal {
259        if let Some(ref mut fraud_gen) = self.fraud_generator {
260            fraud_gen.sample(pattern)
261        } else {
262            // Fallback to normal sampling
263            self.sample()
264        }
265    }
266
267    /// Sample multiple amounts that sum to a target total.
268    ///
269    /// Useful for generating line items that must balance.
270    /// The sum of returned amounts is guaranteed to equal `total` exactly.
271    /// Every returned amount is guaranteed to be > 0 when `total > 0` and
272    /// `count * 0.01 <= total`.
273    pub fn sample_summing_to(&mut self, count: usize, total: Decimal) -> Vec<Decimal> {
274        use rust_decimal::prelude::ToPrimitive;
275
276        let min_amount = Decimal::new(1, 2); // 0.01
277
278        if count == 0 {
279            return Vec::new();
280        }
281        if count == 1 {
282            return vec![total];
283        }
284
285        let total_f64 = total.to_f64().unwrap_or(0.0);
286
287        // Generate random weights ensuring minimum weight
288        let mut weights: Vec<f64> = (0..count)
289            .map(|_| self.rng.random::<f64>().max(0.01))
290            .collect();
291        let sum: f64 = weights.iter().sum();
292        weights.iter_mut().for_each(|w| *w /= sum);
293
294        // Calculate amounts based on weights, using fast integer math for precision
295        let mut amounts: Vec<Decimal> = weights
296            .iter()
297            .map(|w| {
298                let amount = total_f64 * w;
299                let rounded = (amount * self.decimal_multiplier).round() / self.decimal_multiplier;
300                // Convert via integer cents — avoids format!()/parse() overhead
301                let cents = (rounded * 100.0).round() as i64;
302                Decimal::new(cents, 2)
303            })
304            .collect();
305
306        // Adjust last amount to ensure exact sum
307        let current_sum: Decimal = amounts.iter().copied().sum();
308        let diff = total - current_sum;
309        let last_idx = amounts.len() - 1;
310        amounts[last_idx] += diff;
311
312        // If last amount became negative (rare edge case), redistribute
313        if amounts[last_idx] < Decimal::ZERO {
314            let mut remaining = amounts[last_idx].abs();
315            amounts[last_idx] = Decimal::ZERO;
316
317            // Distribute the negative amount across all earlier amounts
318            for amt in amounts.iter_mut().take(last_idx).rev() {
319                if remaining <= Decimal::ZERO {
320                    break;
321                }
322                let take = remaining.min(*amt);
323                *amt -= take;
324                remaining -= take;
325            }
326
327            // If still remaining, absorb into the first non-zero amount
328            if remaining > Decimal::ZERO {
329                for amt in amounts.iter_mut() {
330                    if *amt > Decimal::ZERO {
331                        *amt -= remaining;
332                        break;
333                    }
334                }
335            }
336        }
337
338        // Post-process: fix zero-amount lines by transferring min_amount from the
339        // largest line. This preserves the exact sum while eliminating zeros.
340        // Only attempt when total is large enough to support min_amount per line.
341        if total >= min_amount * Decimal::from(count as u32) {
342            loop {
343                // Find a zero-amount line
344                let zero_idx = amounts.iter().position(|a| *a == Decimal::ZERO);
345                let Some(zi) = zero_idx else { break };
346
347                // Find the largest amount (must be > min_amount to donate)
348                let donor = amounts
349                    .iter()
350                    .enumerate()
351                    .filter(|&(j, _)| j != zi)
352                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
353                    .map(|(j, _)| j);
354
355                if let Some(di) = donor {
356                    if amounts[di] > min_amount {
357                        amounts[zi] = min_amount;
358                        amounts[di] -= min_amount;
359                    } else {
360                        break; // No donor has enough headroom
361                    }
362                } else {
363                    break;
364                }
365            }
366        }
367
368        amounts
369    }
370
371    /// Sample an amount within a specific range.
372    pub fn sample_in_range(&mut self, min: Decimal, max: Decimal) -> Decimal {
373        let min_f64 = min.to_string().parse::<f64>().unwrap_or(0.0);
374        let max_f64 = max.to_string().parse::<f64>().unwrap_or(1000000.0);
375
376        let range = max_f64 - min_f64;
377        let amount = min_f64 + self.rng.random::<f64>() * range;
378
379        let rounded = (amount * self.decimal_multiplier).round() / self.decimal_multiplier;
380        Decimal::from_f64_retain(rounded).unwrap_or(min)
381    }
382
383    /// Reset the sampler with a new seed.
384    pub fn reset(&mut self, seed: u64) {
385        self.rng = ChaCha8Rng::seed_from_u64(seed);
386    }
387
388    /// SP3.5b — Update the log-normal sigma parameter.
389    ///
390    /// Also rebuilds the internal `LogNormal` distribution so that subsequent
391    /// `sample_lognormal()` calls reflect the new sigma. No-op if `sigma` is
392    /// not positive (guard against calibrator edge-cases).
393    pub fn set_lognormal_sigma(&mut self, sigma: f64) {
394        if sigma > 0.0 {
395            self.config.lognormal_sigma = sigma;
396            if let Ok(dist) = LogNormal::new(self.config.lognormal_mu, sigma) {
397                self.lognormal = dist;
398            }
399        }
400    }
401
402    /// SP3.5b — Update the round-number probability parameter.
403    ///
404    /// Clamped to [0, 1] for safety.
405    pub fn set_round_number_probability(&mut self, p: f64) {
406        self.config.round_number_probability = p.clamp(0.0, 1.0);
407    }
408
409    /// Return the current log-normal sigma (for testing).
410    pub fn lognormal_sigma(&self) -> f64 {
411        self.config.lognormal_sigma
412    }
413
414    /// Return the current round-number probability (for testing).
415    pub fn round_number_probability(&self) -> f64 {
416        self.config.round_number_probability
417    }
418}
419
420/// Sampler for currency exchange rates.
421pub struct ExchangeRateSampler {
422    rng: ChaCha8Rng,
423    /// Base rates for common currency pairs (vs USD)
424    base_rates: std::collections::HashMap<String, f64>,
425    /// Daily volatility (standard deviation)
426    volatility: f64,
427}
428
429impl ExchangeRateSampler {
430    /// Create a new exchange rate sampler.
431    pub fn new(seed: u64) -> Self {
432        let mut base_rates = std::collections::HashMap::new();
433        // Approximate rates as of 2024
434        base_rates.insert("EUR".to_string(), 0.92);
435        base_rates.insert("GBP".to_string(), 0.79);
436        base_rates.insert("CHF".to_string(), 0.88);
437        base_rates.insert("JPY".to_string(), 149.0);
438        base_rates.insert("CNY".to_string(), 7.24);
439        base_rates.insert("CAD".to_string(), 1.36);
440        base_rates.insert("AUD".to_string(), 1.53);
441        base_rates.insert("INR".to_string(), 83.0);
442        base_rates.insert("USD".to_string(), 1.0);
443
444        Self {
445            rng: ChaCha8Rng::seed_from_u64(seed),
446            base_rates,
447            volatility: 0.005, // 0.5% daily volatility
448        }
449    }
450
451    /// Get exchange rate from one currency to another.
452    pub fn get_rate(&mut self, from: &str, to: &str) -> Decimal {
453        let from_usd = self.base_rates.get(from).copied().unwrap_or(1.0);
454        let to_usd = self.base_rates.get(to).copied().unwrap_or(1.0);
455
456        // Base rate
457        let base_rate = to_usd / from_usd;
458
459        // Add some random variation
460        let variation = 1.0 + (self.rng.random::<f64>() - 0.5) * 2.0 * self.volatility;
461        let rate = base_rate * variation;
462
463        // Round to 6 decimal places
464        let rounded = (rate * 1_000_000.0).round() / 1_000_000.0;
465        Decimal::from_f64_retain(rounded).unwrap_or(Decimal::ONE)
466    }
467}
468
469#[cfg(test)]
470mod tests {
471    use super::*;
472
473    #[test]
474    fn test_amount_sampler_determinism() {
475        let mut sampler1 = AmountSampler::new(42);
476        let mut sampler2 = AmountSampler::new(42);
477
478        for _ in 0..100 {
479            assert_eq!(sampler1.sample(), sampler2.sample());
480        }
481    }
482
483    #[test]
484    fn test_amount_sampler_range() {
485        let config = AmountDistributionConfig {
486            min_amount: 100.0,
487            max_amount: 1000.0,
488            ..Default::default()
489        };
490        let mut sampler = AmountSampler::with_config(42, config);
491
492        for _ in 0..1000 {
493            let amount = sampler.sample();
494            let amount_f64: f64 = amount.to_string().parse().unwrap();
495            assert!(amount_f64 >= 100.0, "Amount {} below minimum", amount);
496            assert!(amount_f64 <= 1000.0, "Amount {} above maximum", amount);
497        }
498    }
499
500    #[test]
501    fn test_summing_amounts() {
502        let mut sampler = AmountSampler::new(42);
503        let total = Decimal::from(10000);
504        let amounts = sampler.sample_summing_to(5, total);
505
506        assert_eq!(amounts.len(), 5);
507
508        let sum: Decimal = amounts.iter().sum();
509        assert_eq!(sum, total, "Sum {} doesn't match total {}", sum, total);
510    }
511
512    #[test]
513    fn test_exchange_rate() {
514        let mut sampler = ExchangeRateSampler::new(42);
515
516        let eur_usd = sampler.get_rate("EUR", "USD");
517        let eur_f64: f64 = eur_usd.to_string().parse().unwrap();
518        assert!(
519            eur_f64 > 0.8 && eur_f64 < 1.2,
520            "EUR/USD rate {} out of range",
521            eur_f64
522        );
523
524        let usd_usd = sampler.get_rate("USD", "USD");
525        let usd_f64: f64 = usd_usd.to_string().parse().unwrap();
526        assert!(
527            (usd_f64 - 1.0).abs() < 0.01,
528            "USD/USD rate {} should be ~1.0",
529            usd_f64
530        );
531    }
532}