datasynth_core/distributions/
advanced_amount.rs1use rust_decimal::Decimal;
12
13use super::mixture::{
14 GaussianComponent, GaussianMixtureConfig, GaussianMixtureSampler, LogNormalComponent,
15 LogNormalMixtureConfig, LogNormalMixtureSampler,
16};
17use super::pareto::{ParetoConfig, ParetoSampler};
18
19#[derive(Clone)]
28pub enum AdvancedAmountSampler {
29 LogNormal(LogNormalMixtureSampler),
31 Gaussian(GaussianMixtureSampler),
33 Pareto(ParetoSampler),
35}
36
37impl AdvancedAmountSampler {
38 pub fn new_log_normal(seed: u64, config: LogNormalMixtureConfig) -> Result<Self, String> {
40 Ok(Self::LogNormal(LogNormalMixtureSampler::new(seed, config)?))
41 }
42
43 pub fn new_gaussian(seed: u64, config: GaussianMixtureConfig) -> Result<Self, String> {
45 Ok(Self::Gaussian(GaussianMixtureSampler::new(seed, config)?))
46 }
47
48 pub fn new_pareto(seed: u64, config: ParetoConfig) -> Result<Self, String> {
50 Ok(Self::Pareto(ParetoSampler::new(seed, config)?))
51 }
52
53 pub fn sample_decimal(&mut self) -> Decimal {
55 match self {
56 Self::LogNormal(s) => s.sample_decimal(),
57 Self::Gaussian(s) => {
58 let value = s.sample().max(0.0);
59 Decimal::from_f64_retain(value).unwrap_or(Decimal::ZERO)
60 }
61 Self::Pareto(s) => s.sample_decimal(),
62 }
63 }
64
65 pub fn reset(&mut self, seed: u64) {
67 match self {
68 Self::LogNormal(s) => s.reset(seed),
69 Self::Gaussian(s) => s.reset(seed),
70 Self::Pareto(s) => s.reset(seed),
71 }
72 }
73
74 pub fn ppf_decimal(&self, u: f64) -> Decimal {
79 match self {
80 Self::LogNormal(s) => s.ppf_decimal(u),
81 Self::Pareto(s) => s.ppf_decimal(u),
82 Self::Gaussian(s) => {
83 let v = s.ppf(u).max(0.0);
84 Decimal::from_f64_retain(v).unwrap_or(Decimal::ZERO)
85 }
86 }
87 }
88}
89
90pub fn log_normal_config_from_components(
93 components: Vec<(f64, f64, f64, Option<String>)>,
94 min_value: f64,
95 max_value: Option<f64>,
96 decimal_places: u8,
97) -> LogNormalMixtureConfig {
98 LogNormalMixtureConfig {
99 components: components
100 .into_iter()
101 .map(|(w, mu, sigma, label)| match label {
102 Some(l) => LogNormalComponent::with_label(w, mu, sigma, l),
103 None => LogNormalComponent::new(w, mu, sigma),
104 })
105 .collect(),
106 min_value,
107 max_value,
108 decimal_places,
109 }
110}
111
112pub fn gaussian_config_from_components(
115 components: Vec<(f64, f64, f64)>,
116 min_value: Option<f64>,
117 max_value: Option<f64>,
118) -> GaussianMixtureConfig {
119 GaussianMixtureConfig {
120 components: components
121 .into_iter()
122 .map(|(w, mu, sigma)| GaussianComponent::new(w, mu, sigma))
123 .collect(),
124 allow_negative: true,
125 min_value,
126 max_value,
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133
134 #[test]
135 fn log_normal_sampler_produces_positive_values() {
136 let cfg = log_normal_config_from_components(
137 vec![(1.0, 7.0, 1.0, Some("test".into()))],
138 0.01,
139 None,
140 2,
141 );
142 let mut sampler = AdvancedAmountSampler::new_log_normal(42, cfg).unwrap();
143 for _ in 0..1000 {
144 let v = sampler.sample_decimal();
145 assert!(v >= Decimal::ZERO);
146 }
147 }
148
149 #[test]
150 fn gaussian_sampler_clamps_to_non_negative() {
151 let cfg = gaussian_config_from_components(vec![(1.0, 0.0, 1.0)], None, None);
152 let mut sampler = AdvancedAmountSampler::new_gaussian(42, cfg).unwrap();
153 for _ in 0..1000 {
154 let v = sampler.sample_decimal();
155 assert!(v >= Decimal::ZERO);
156 }
157 }
158
159 #[test]
160 fn reset_restores_determinism() {
161 let cfg = log_normal_config_from_components(vec![(1.0, 5.0, 1.0, None)], 0.01, None, 2);
162 let mut a = AdvancedAmountSampler::new_log_normal(7, cfg.clone()).unwrap();
163 let first: Vec<Decimal> = (0..5).map(|_| a.sample_decimal()).collect();
164 a.reset(7);
165 let second: Vec<Decimal> = (0..5).map(|_| a.sample_decimal()).collect();
166 assert_eq!(first, second);
167 }
168
169 #[test]
170 fn pareto_sampler_produces_heavy_tail() {
171 let cfg = ParetoConfig {
172 alpha: 1.5,
173 x_min: 1000.0,
174 max_value: None,
175 decimal_places: 2,
176 };
177 let mut sampler = AdvancedAmountSampler::new_pareto(42, cfg).unwrap();
178 let samples: Vec<Decimal> = (0..10_000).map(|_| sampler.sample_decimal()).collect();
179 let min_sample = samples.iter().min().unwrap();
181 assert!(
182 *min_sample >= Decimal::from(1000),
183 "Pareto sample {min_sample} < x_min"
184 );
185 let extreme_count = samples
188 .iter()
189 .filter(|s| **s > Decimal::from(10_000))
190 .count();
191 assert!(
192 extreme_count > 100,
193 "expected heavy tail with >100/10000 extreme samples, got {extreme_count}"
194 );
195 }
196
197 #[test]
198 fn pareto_reset_is_deterministic() {
199 let cfg = ParetoConfig {
200 alpha: 2.0,
201 x_min: 100.0,
202 max_value: Some(1_000_000.0),
203 decimal_places: 2,
204 };
205 let mut a = AdvancedAmountSampler::new_pareto(9, cfg).unwrap();
206 let first: Vec<Decimal> = (0..5).map(|_| a.sample_decimal()).collect();
207 a.reset(9);
208 let second: Vec<Decimal> = (0..5).map(|_| a.sample_decimal()).collect();
209 assert_eq!(first, second);
210 }
211}