datasynth_core/distributions/
advanced_amount.rs1use rust_decimal::Decimal;
12
13use super::mixture::{
14 GaussianComponent, GaussianMixtureConfig, GaussianMixtureSampler, LogNormalComponent,
15 LogNormalMixtureConfig, LogNormalMixtureSampler,
16};
17use super::pareto::{ParetoConfig, ParetoSampler};
18
19#[derive(Clone)]
28pub enum AdvancedAmountSampler {
29 LogNormal(LogNormalMixtureSampler),
31 Gaussian(GaussianMixtureSampler),
33 Pareto(ParetoSampler),
35}
36
37impl AdvancedAmountSampler {
38 pub fn new_log_normal(seed: u64, config: LogNormalMixtureConfig) -> Result<Self, String> {
40 Ok(Self::LogNormal(LogNormalMixtureSampler::new(seed, config)?))
41 }
42
43 pub fn new_gaussian(seed: u64, config: GaussianMixtureConfig) -> Result<Self, String> {
45 Ok(Self::Gaussian(GaussianMixtureSampler::new(seed, config)?))
46 }
47
48 pub fn new_pareto(seed: u64, config: ParetoConfig) -> Result<Self, String> {
50 Ok(Self::Pareto(ParetoSampler::new(seed, config)?))
51 }
52
53 pub fn sample_decimal(&mut self) -> Decimal {
55 match self {
56 Self::LogNormal(s) => s.sample_decimal(),
57 Self::Gaussian(s) => {
58 let value = s.sample().max(0.0);
59 Decimal::from_f64_retain(value).unwrap_or(Decimal::ZERO)
60 }
61 Self::Pareto(s) => s.sample_decimal(),
62 }
63 }
64
65 pub fn reset(&mut self, seed: u64) {
67 match self {
68 Self::LogNormal(s) => s.reset(seed),
69 Self::Gaussian(s) => s.reset(seed),
70 Self::Pareto(s) => s.reset(seed),
71 }
72 }
73}
74
75pub fn log_normal_config_from_components(
78 components: Vec<(f64, f64, f64, Option<String>)>,
79 min_value: f64,
80 max_value: Option<f64>,
81 decimal_places: u8,
82) -> LogNormalMixtureConfig {
83 LogNormalMixtureConfig {
84 components: components
85 .into_iter()
86 .map(|(w, mu, sigma, label)| match label {
87 Some(l) => LogNormalComponent::with_label(w, mu, sigma, l),
88 None => LogNormalComponent::new(w, mu, sigma),
89 })
90 .collect(),
91 min_value,
92 max_value,
93 decimal_places,
94 }
95}
96
97pub fn gaussian_config_from_components(
100 components: Vec<(f64, f64, f64)>,
101 min_value: Option<f64>,
102 max_value: Option<f64>,
103) -> GaussianMixtureConfig {
104 GaussianMixtureConfig {
105 components: components
106 .into_iter()
107 .map(|(w, mu, sigma)| GaussianComponent::new(w, mu, sigma))
108 .collect(),
109 allow_negative: true,
110 min_value,
111 max_value,
112 }
113}
114
115#[cfg(test)]
116#[allow(clippy::unwrap_used)]
117mod tests {
118 use super::*;
119
120 #[test]
121 fn log_normal_sampler_produces_positive_values() {
122 let cfg = log_normal_config_from_components(
123 vec![(1.0, 7.0, 1.0, Some("test".into()))],
124 0.01,
125 None,
126 2,
127 );
128 let mut sampler = AdvancedAmountSampler::new_log_normal(42, cfg).unwrap();
129 for _ in 0..1000 {
130 let v = sampler.sample_decimal();
131 assert!(v >= Decimal::ZERO);
132 }
133 }
134
135 #[test]
136 fn gaussian_sampler_clamps_to_non_negative() {
137 let cfg = gaussian_config_from_components(vec![(1.0, 0.0, 1.0)], None, None);
138 let mut sampler = AdvancedAmountSampler::new_gaussian(42, cfg).unwrap();
139 for _ in 0..1000 {
140 let v = sampler.sample_decimal();
141 assert!(v >= Decimal::ZERO);
142 }
143 }
144
145 #[test]
146 fn reset_restores_determinism() {
147 let cfg = log_normal_config_from_components(vec![(1.0, 5.0, 1.0, None)], 0.01, None, 2);
148 let mut a = AdvancedAmountSampler::new_log_normal(7, cfg.clone()).unwrap();
149 let first: Vec<Decimal> = (0..5).map(|_| a.sample_decimal()).collect();
150 a.reset(7);
151 let second: Vec<Decimal> = (0..5).map(|_| a.sample_decimal()).collect();
152 assert_eq!(first, second);
153 }
154
155 #[test]
156 fn pareto_sampler_produces_heavy_tail() {
157 let cfg = ParetoConfig {
158 alpha: 1.5,
159 x_min: 1000.0,
160 max_value: None,
161 decimal_places: 2,
162 };
163 let mut sampler = AdvancedAmountSampler::new_pareto(42, cfg).unwrap();
164 let samples: Vec<Decimal> = (0..10_000).map(|_| sampler.sample_decimal()).collect();
165 let min_sample = samples.iter().min().unwrap();
167 assert!(
168 *min_sample >= Decimal::from(1000),
169 "Pareto sample {min_sample} < x_min"
170 );
171 let extreme_count = samples
174 .iter()
175 .filter(|s| **s > Decimal::from(10_000))
176 .count();
177 assert!(
178 extreme_count > 100,
179 "expected heavy tail with >100/10000 extreme samples, got {extreme_count}"
180 );
181 }
182
183 #[test]
184 fn pareto_reset_is_deterministic() {
185 let cfg = ParetoConfig {
186 alpha: 2.0,
187 x_min: 100.0,
188 max_value: Some(1_000_000.0),
189 decimal_places: 2,
190 };
191 let mut a = AdvancedAmountSampler::new_pareto(9, cfg).unwrap();
192 let first: Vec<Decimal> = (0..5).map(|_| a.sample_decimal()).collect();
193 a.reset(9);
194 let second: Vec<Decimal> = (0..5).map(|_| a.sample_decimal()).collect();
195 assert_eq!(first, second);
196 }
197}