Skip to main content

datasynth_core/distributions/
advanced_amount.rs

1//! Advanced amount sampler that dispatches between the legacy
2//! [`AmountSampler`](super::AmountSampler) and the richer
3//! [`LogNormalMixtureSampler`](super::LogNormalMixtureSampler) /
4//! [`GaussianMixtureSampler`](super::GaussianMixtureSampler).
5//!
6//! Exists so callers (notably the journal-entry generator) can swap in a
7//! mixture-model sampler when `distributions.amounts.enabled = true` without
8//! perturbing the legacy `transactions.amounts` code path byte-for-byte when
9//! the advanced config is absent.
10
11use rust_decimal::Decimal;
12
13use super::mixture::{
14    GaussianComponent, GaussianMixtureConfig, GaussianMixtureSampler, LogNormalComponent,
15    LogNormalMixtureConfig, LogNormalMixtureSampler,
16};
17use super::pareto::{ParetoConfig, ParetoSampler};
18
19/// Advanced amount sampler wrapping one of the supported distribution
20/// families. Callers keep their existing legacy [`AmountSampler`](super::
21/// AmountSampler) and only consult this wrapper when
22/// `distributions.amounts.enabled` (or another advanced sub-block like
23/// `distributions.pareto.enabled`) is true.
24///
25/// v3.4.4 added the `Pareto` variant for heavy-tailed monetary samples
26/// (capex, strategic contracts, fraud amounts).
27#[derive(Clone)]
28pub enum AdvancedAmountSampler {
29    /// Log-normal mixture (preferred for positive monetary amounts).
30    LogNormal(LogNormalMixtureSampler),
31    /// Gaussian mixture (useful for signed quantities like deltas).
32    Gaussian(GaussianMixtureSampler),
33    /// Pareto heavy-tailed distribution (v3.4.4+).
34    Pareto(ParetoSampler),
35}
36
37impl AdvancedAmountSampler {
38    /// Create a log-normal-mixture sampler.
39    pub fn new_log_normal(seed: u64, config: LogNormalMixtureConfig) -> Result<Self, String> {
40        Ok(Self::LogNormal(LogNormalMixtureSampler::new(seed, config)?))
41    }
42
43    /// Create a Gaussian-mixture sampler.
44    pub fn new_gaussian(seed: u64, config: GaussianMixtureConfig) -> Result<Self, String> {
45        Ok(Self::Gaussian(GaussianMixtureSampler::new(seed, config)?))
46    }
47
48    /// Create a Pareto sampler (v3.4.4+).
49    pub fn new_pareto(seed: u64, config: ParetoConfig) -> Result<Self, String> {
50        Ok(Self::Pareto(ParetoSampler::new(seed, config)?))
51    }
52
53    /// Sample one amount as `Decimal`.
54    pub fn sample_decimal(&mut self) -> Decimal {
55        match self {
56            Self::LogNormal(s) => s.sample_decimal(),
57            Self::Gaussian(s) => {
58                let value = s.sample().max(0.0);
59                Decimal::from_f64_retain(value).unwrap_or(Decimal::ZERO)
60            }
61            Self::Pareto(s) => s.sample_decimal(),
62        }
63    }
64
65    /// Reset the underlying RNG.
66    pub fn reset(&mut self, seed: u64) {
67        match self {
68            Self::LogNormal(s) => s.reset(seed),
69            Self::Gaussian(s) => s.reset(seed),
70            Self::Pareto(s) => s.reset(seed),
71        }
72    }
73}
74
75/// Build a [`LogNormalMixtureConfig`] from a list of `(weight, mu, sigma,
76/// label)` tuples.  Thin convenience used by the config-layer converter.
77pub fn log_normal_config_from_components(
78    components: Vec<(f64, f64, f64, Option<String>)>,
79    min_value: f64,
80    max_value: Option<f64>,
81    decimal_places: u8,
82) -> LogNormalMixtureConfig {
83    LogNormalMixtureConfig {
84        components: components
85            .into_iter()
86            .map(|(w, mu, sigma, label)| match label {
87                Some(l) => LogNormalComponent::with_label(w, mu, sigma, l),
88                None => LogNormalComponent::new(w, mu, sigma),
89            })
90            .collect(),
91        min_value,
92        max_value,
93        decimal_places,
94    }
95}
96
97/// Build a [`GaussianMixtureConfig`] from a list of `(weight, mu, sigma)`
98/// tuples. Labels are ignored (the Gaussian component has no label field).
99pub fn gaussian_config_from_components(
100    components: Vec<(f64, f64, f64)>,
101    min_value: Option<f64>,
102    max_value: Option<f64>,
103) -> GaussianMixtureConfig {
104    GaussianMixtureConfig {
105        components: components
106            .into_iter()
107            .map(|(w, mu, sigma)| GaussianComponent::new(w, mu, sigma))
108            .collect(),
109        allow_negative: true,
110        min_value,
111        max_value,
112    }
113}
114
115#[cfg(test)]
116#[allow(clippy::unwrap_used)]
117mod tests {
118    use super::*;
119
120    #[test]
121    fn log_normal_sampler_produces_positive_values() {
122        let cfg = log_normal_config_from_components(
123            vec![(1.0, 7.0, 1.0, Some("test".into()))],
124            0.01,
125            None,
126            2,
127        );
128        let mut sampler = AdvancedAmountSampler::new_log_normal(42, cfg).unwrap();
129        for _ in 0..1000 {
130            let v = sampler.sample_decimal();
131            assert!(v >= Decimal::ZERO);
132        }
133    }
134
135    #[test]
136    fn gaussian_sampler_clamps_to_non_negative() {
137        let cfg = gaussian_config_from_components(vec![(1.0, 0.0, 1.0)], None, None);
138        let mut sampler = AdvancedAmountSampler::new_gaussian(42, cfg).unwrap();
139        for _ in 0..1000 {
140            let v = sampler.sample_decimal();
141            assert!(v >= Decimal::ZERO);
142        }
143    }
144
145    #[test]
146    fn reset_restores_determinism() {
147        let cfg = log_normal_config_from_components(vec![(1.0, 5.0, 1.0, None)], 0.01, None, 2);
148        let mut a = AdvancedAmountSampler::new_log_normal(7, cfg.clone()).unwrap();
149        let first: Vec<Decimal> = (0..5).map(|_| a.sample_decimal()).collect();
150        a.reset(7);
151        let second: Vec<Decimal> = (0..5).map(|_| a.sample_decimal()).collect();
152        assert_eq!(first, second);
153    }
154
155    #[test]
156    fn pareto_sampler_produces_heavy_tail() {
157        let cfg = ParetoConfig {
158            alpha: 1.5,
159            x_min: 1000.0,
160            max_value: None,
161            decimal_places: 2,
162        };
163        let mut sampler = AdvancedAmountSampler::new_pareto(42, cfg).unwrap();
164        let samples: Vec<Decimal> = (0..10_000).map(|_| sampler.sample_decimal()).collect();
165        // All samples must be >= x_min (Pareto support).
166        let min_sample = samples.iter().min().unwrap();
167        assert!(
168            *min_sample >= Decimal::from(1000),
169            "Pareto sample {min_sample} < x_min"
170        );
171        // Heavy tail: at least a few samples > 10x x_min (very unlikely for
172        // log-normal with similar parameters, almost certain for Pareto).
173        let extreme_count = samples
174            .iter()
175            .filter(|s| **s > Decimal::from(10_000))
176            .count();
177        assert!(
178            extreme_count > 100,
179            "expected heavy tail with >100/10000 extreme samples, got {extreme_count}"
180        );
181    }
182
183    #[test]
184    fn pareto_reset_is_deterministic() {
185        let cfg = ParetoConfig {
186            alpha: 2.0,
187            x_min: 100.0,
188            max_value: Some(1_000_000.0),
189            decimal_places: 2,
190        };
191        let mut a = AdvancedAmountSampler::new_pareto(9, cfg).unwrap();
192        let first: Vec<Decimal> = (0..5).map(|_| a.sample_decimal()).collect();
193        a.reset(9);
194        let second: Vec<Decimal> = (0..5).map(|_| a.sample_decimal()).collect();
195        assert_eq!(first, second);
196    }
197}