Skip to main content

datasynth_core/distributions/
advanced_amount.rs

1//! Advanced amount sampler that dispatches between the legacy
2//! [`AmountSampler`](super::AmountSampler) and the richer
3//! [`LogNormalMixtureSampler`](super::LogNormalMixtureSampler) /
4//! [`GaussianMixtureSampler`](super::GaussianMixtureSampler).
5//!
6//! Exists so callers (notably the journal-entry generator) can swap in a
7//! mixture-model sampler when `distributions.amounts.enabled = true` without
8//! perturbing the legacy `transactions.amounts` code path byte-for-byte when
9//! the advanced config is absent.
10
11use rust_decimal::Decimal;
12
13use super::mixture::{
14    GaussianComponent, GaussianMixtureConfig, GaussianMixtureSampler, LogNormalComponent,
15    LogNormalMixtureConfig, LogNormalMixtureSampler,
16};
17use super::pareto::{ParetoConfig, ParetoSampler};
18
19/// Advanced amount sampler wrapping one of the supported distribution
20/// families. Callers keep their existing legacy [`AmountSampler`](super::
21/// AmountSampler) and only consult this wrapper when
22/// `distributions.amounts.enabled` (or another advanced sub-block like
23/// `distributions.pareto.enabled`) is true.
24///
25/// v3.4.4 added the `Pareto` variant for heavy-tailed monetary samples
26/// (capex, strategic contracts, fraud amounts).
27#[derive(Clone)]
28pub enum AdvancedAmountSampler {
29    /// Log-normal mixture (preferred for positive monetary amounts).
30    LogNormal(LogNormalMixtureSampler),
31    /// Gaussian mixture (useful for signed quantities like deltas).
32    Gaussian(GaussianMixtureSampler),
33    /// Pareto heavy-tailed distribution (v3.4.4+).
34    Pareto(ParetoSampler),
35}
36
37impl AdvancedAmountSampler {
38    /// Create a log-normal-mixture sampler.
39    pub fn new_log_normal(seed: u64, config: LogNormalMixtureConfig) -> Result<Self, String> {
40        Ok(Self::LogNormal(LogNormalMixtureSampler::new(seed, config)?))
41    }
42
43    /// Create a Gaussian-mixture sampler.
44    pub fn new_gaussian(seed: u64, config: GaussianMixtureConfig) -> Result<Self, String> {
45        Ok(Self::Gaussian(GaussianMixtureSampler::new(seed, config)?))
46    }
47
48    /// Create a Pareto sampler (v3.4.4+).
49    pub fn new_pareto(seed: u64, config: ParetoConfig) -> Result<Self, String> {
50        Ok(Self::Pareto(ParetoSampler::new(seed, config)?))
51    }
52
53    /// Sample one amount as `Decimal`.
54    pub fn sample_decimal(&mut self) -> Decimal {
55        match self {
56            Self::LogNormal(s) => s.sample_decimal(),
57            Self::Gaussian(s) => {
58                let value = s.sample().max(0.0);
59                Decimal::from_f64_retain(value).unwrap_or(Decimal::ZERO)
60            }
61            Self::Pareto(s) => s.sample_decimal(),
62        }
63    }
64
65    /// Reset the underlying RNG.
66    pub fn reset(&mut self, seed: u64) {
67        match self {
68            Self::LogNormal(s) => s.reset(seed),
69            Self::Gaussian(s) => s.reset(seed),
70            Self::Pareto(s) => s.reset(seed),
71        }
72    }
73
74    /// v4.1.6+ inverse CDF (quantile function) — returns the `Decimal`
75    /// quantile at uniform `u ∈ (0, 1)` for whichever underlying
76    /// sampler is active. Gaussian variant clamps negatives to zero
77    /// (monetary-amount semantics).
78    pub fn ppf_decimal(&self, u: f64) -> Decimal {
79        match self {
80            Self::LogNormal(s) => s.ppf_decimal(u),
81            Self::Pareto(s) => s.ppf_decimal(u),
82            Self::Gaussian(s) => {
83                let v = s.ppf(u).max(0.0);
84                Decimal::from_f64_retain(v).unwrap_or(Decimal::ZERO)
85            }
86        }
87    }
88}
89
90/// Build a [`LogNormalMixtureConfig`] from a list of `(weight, mu, sigma,
91/// label)` tuples.  Thin convenience used by the config-layer converter.
92pub fn log_normal_config_from_components(
93    components: Vec<(f64, f64, f64, Option<String>)>,
94    min_value: f64,
95    max_value: Option<f64>,
96    decimal_places: u8,
97) -> LogNormalMixtureConfig {
98    LogNormalMixtureConfig {
99        components: components
100            .into_iter()
101            .map(|(w, mu, sigma, label)| match label {
102                Some(l) => LogNormalComponent::with_label(w, mu, sigma, l),
103                None => LogNormalComponent::new(w, mu, sigma),
104            })
105            .collect(),
106        min_value,
107        max_value,
108        decimal_places,
109    }
110}
111
112/// Build a [`GaussianMixtureConfig`] from a list of `(weight, mu, sigma)`
113/// tuples. Labels are ignored (the Gaussian component has no label field).
114pub fn gaussian_config_from_components(
115    components: Vec<(f64, f64, f64)>,
116    min_value: Option<f64>,
117    max_value: Option<f64>,
118) -> GaussianMixtureConfig {
119    GaussianMixtureConfig {
120        components: components
121            .into_iter()
122            .map(|(w, mu, sigma)| GaussianComponent::new(w, mu, sigma))
123            .collect(),
124        allow_negative: true,
125        min_value,
126        max_value,
127    }
128}
129
130#[cfg(test)]
131#[allow(clippy::unwrap_used)]
132mod tests {
133    use super::*;
134
135    #[test]
136    fn log_normal_sampler_produces_positive_values() {
137        let cfg = log_normal_config_from_components(
138            vec![(1.0, 7.0, 1.0, Some("test".into()))],
139            0.01,
140            None,
141            2,
142        );
143        let mut sampler = AdvancedAmountSampler::new_log_normal(42, cfg).unwrap();
144        for _ in 0..1000 {
145            let v = sampler.sample_decimal();
146            assert!(v >= Decimal::ZERO);
147        }
148    }
149
150    #[test]
151    fn gaussian_sampler_clamps_to_non_negative() {
152        let cfg = gaussian_config_from_components(vec![(1.0, 0.0, 1.0)], None, None);
153        let mut sampler = AdvancedAmountSampler::new_gaussian(42, cfg).unwrap();
154        for _ in 0..1000 {
155            let v = sampler.sample_decimal();
156            assert!(v >= Decimal::ZERO);
157        }
158    }
159
160    #[test]
161    fn reset_restores_determinism() {
162        let cfg = log_normal_config_from_components(vec![(1.0, 5.0, 1.0, None)], 0.01, None, 2);
163        let mut a = AdvancedAmountSampler::new_log_normal(7, cfg.clone()).unwrap();
164        let first: Vec<Decimal> = (0..5).map(|_| a.sample_decimal()).collect();
165        a.reset(7);
166        let second: Vec<Decimal> = (0..5).map(|_| a.sample_decimal()).collect();
167        assert_eq!(first, second);
168    }
169
170    #[test]
171    fn pareto_sampler_produces_heavy_tail() {
172        let cfg = ParetoConfig {
173            alpha: 1.5,
174            x_min: 1000.0,
175            max_value: None,
176            decimal_places: 2,
177        };
178        let mut sampler = AdvancedAmountSampler::new_pareto(42, cfg).unwrap();
179        let samples: Vec<Decimal> = (0..10_000).map(|_| sampler.sample_decimal()).collect();
180        // All samples must be >= x_min (Pareto support).
181        let min_sample = samples.iter().min().unwrap();
182        assert!(
183            *min_sample >= Decimal::from(1000),
184            "Pareto sample {min_sample} < x_min"
185        );
186        // Heavy tail: at least a few samples > 10x x_min (very unlikely for
187        // log-normal with similar parameters, almost certain for Pareto).
188        let extreme_count = samples
189            .iter()
190            .filter(|s| **s > Decimal::from(10_000))
191            .count();
192        assert!(
193            extreme_count > 100,
194            "expected heavy tail with >100/10000 extreme samples, got {extreme_count}"
195        );
196    }
197
198    #[test]
199    fn pareto_reset_is_deterministic() {
200        let cfg = ParetoConfig {
201            alpha: 2.0,
202            x_min: 100.0,
203            max_value: Some(1_000_000.0),
204            decimal_places: 2,
205        };
206        let mut a = AdvancedAmountSampler::new_pareto(9, cfg).unwrap();
207        let first: Vec<Decimal> = (0..5).map(|_| a.sample_decimal()).collect();
208        a.reset(9);
209        let second: Vec<Decimal> = (0..5).map(|_| a.sample_decimal()).collect();
210        assert_eq!(first, second);
211    }
212}