datasynth_core/distributions/
advanced_amount.rs1use rust_decimal::Decimal;
12
13use super::mixture::{
14 GaussianComponent, GaussianMixtureConfig, GaussianMixtureSampler, LogNormalComponent,
15 LogNormalMixtureConfig, LogNormalMixtureSampler,
16};
17use super::pareto::{ParetoConfig, ParetoSampler};
18
19#[derive(Clone)]
28pub enum AdvancedAmountSampler {
29 LogNormal(LogNormalMixtureSampler),
31 Gaussian(GaussianMixtureSampler),
33 Pareto(ParetoSampler),
35}
36
37impl AdvancedAmountSampler {
38 pub fn new_log_normal(seed: u64, config: LogNormalMixtureConfig) -> Result<Self, String> {
40 Ok(Self::LogNormal(LogNormalMixtureSampler::new(seed, config)?))
41 }
42
43 pub fn new_gaussian(seed: u64, config: GaussianMixtureConfig) -> Result<Self, String> {
45 Ok(Self::Gaussian(GaussianMixtureSampler::new(seed, config)?))
46 }
47
48 pub fn new_pareto(seed: u64, config: ParetoConfig) -> Result<Self, String> {
50 Ok(Self::Pareto(ParetoSampler::new(seed, config)?))
51 }
52
53 pub fn sample_decimal(&mut self) -> Decimal {
55 match self {
56 Self::LogNormal(s) => s.sample_decimal(),
57 Self::Gaussian(s) => {
58 let value = s.sample().max(0.0);
59 Decimal::from_f64_retain(value).unwrap_or(Decimal::ZERO)
60 }
61 Self::Pareto(s) => s.sample_decimal(),
62 }
63 }
64
65 pub fn reset(&mut self, seed: u64) {
67 match self {
68 Self::LogNormal(s) => s.reset(seed),
69 Self::Gaussian(s) => s.reset(seed),
70 Self::Pareto(s) => s.reset(seed),
71 }
72 }
73
74 pub fn ppf_decimal(&self, u: f64) -> Decimal {
79 match self {
80 Self::LogNormal(s) => s.ppf_decimal(u),
81 Self::Pareto(s) => s.ppf_decimal(u),
82 Self::Gaussian(s) => {
83 let v = s.ppf(u).max(0.0);
84 Decimal::from_f64_retain(v).unwrap_or(Decimal::ZERO)
85 }
86 }
87 }
88}
89
90pub fn log_normal_config_from_components(
93 components: Vec<(f64, f64, f64, Option<String>)>,
94 min_value: f64,
95 max_value: Option<f64>,
96 decimal_places: u8,
97) -> LogNormalMixtureConfig {
98 LogNormalMixtureConfig {
99 components: components
100 .into_iter()
101 .map(|(w, mu, sigma, label)| match label {
102 Some(l) => LogNormalComponent::with_label(w, mu, sigma, l),
103 None => LogNormalComponent::new(w, mu, sigma),
104 })
105 .collect(),
106 min_value,
107 max_value,
108 decimal_places,
109 }
110}
111
112pub fn gaussian_config_from_components(
115 components: Vec<(f64, f64, f64)>,
116 min_value: Option<f64>,
117 max_value: Option<f64>,
118) -> GaussianMixtureConfig {
119 GaussianMixtureConfig {
120 components: components
121 .into_iter()
122 .map(|(w, mu, sigma)| GaussianComponent::new(w, mu, sigma))
123 .collect(),
124 allow_negative: true,
125 min_value,
126 max_value,
127 }
128}
129
130#[cfg(test)]
131#[allow(clippy::unwrap_used)]
132mod tests {
133 use super::*;
134
135 #[test]
136 fn log_normal_sampler_produces_positive_values() {
137 let cfg = log_normal_config_from_components(
138 vec![(1.0, 7.0, 1.0, Some("test".into()))],
139 0.01,
140 None,
141 2,
142 );
143 let mut sampler = AdvancedAmountSampler::new_log_normal(42, cfg).unwrap();
144 for _ in 0..1000 {
145 let v = sampler.sample_decimal();
146 assert!(v >= Decimal::ZERO);
147 }
148 }
149
150 #[test]
151 fn gaussian_sampler_clamps_to_non_negative() {
152 let cfg = gaussian_config_from_components(vec![(1.0, 0.0, 1.0)], None, None);
153 let mut sampler = AdvancedAmountSampler::new_gaussian(42, cfg).unwrap();
154 for _ in 0..1000 {
155 let v = sampler.sample_decimal();
156 assert!(v >= Decimal::ZERO);
157 }
158 }
159
160 #[test]
161 fn reset_restores_determinism() {
162 let cfg = log_normal_config_from_components(vec![(1.0, 5.0, 1.0, None)], 0.01, None, 2);
163 let mut a = AdvancedAmountSampler::new_log_normal(7, cfg.clone()).unwrap();
164 let first: Vec<Decimal> = (0..5).map(|_| a.sample_decimal()).collect();
165 a.reset(7);
166 let second: Vec<Decimal> = (0..5).map(|_| a.sample_decimal()).collect();
167 assert_eq!(first, second);
168 }
169
170 #[test]
171 fn pareto_sampler_produces_heavy_tail() {
172 let cfg = ParetoConfig {
173 alpha: 1.5,
174 x_min: 1000.0,
175 max_value: None,
176 decimal_places: 2,
177 };
178 let mut sampler = AdvancedAmountSampler::new_pareto(42, cfg).unwrap();
179 let samples: Vec<Decimal> = (0..10_000).map(|_| sampler.sample_decimal()).collect();
180 let min_sample = samples.iter().min().unwrap();
182 assert!(
183 *min_sample >= Decimal::from(1000),
184 "Pareto sample {min_sample} < x_min"
185 );
186 let extreme_count = samples
189 .iter()
190 .filter(|s| **s > Decimal::from(10_000))
191 .count();
192 assert!(
193 extreme_count > 100,
194 "expected heavy tail with >100/10000 extreme samples, got {extreme_count}"
195 );
196 }
197
198 #[test]
199 fn pareto_reset_is_deterministic() {
200 let cfg = ParetoConfig {
201 alpha: 2.0,
202 x_min: 100.0,
203 max_value: Some(1_000_000.0),
204 decimal_places: 2,
205 };
206 let mut a = AdvancedAmountSampler::new_pareto(9, cfg).unwrap();
207 let first: Vec<Decimal> = (0..5).map(|_| a.sample_decimal()).collect();
208 a.reset(9);
209 let second: Vec<Decimal> = (0..5).map(|_| a.sample_decimal()).collect();
210 assert_eq!(first, second);
211 }
212}