datasynth_core/distributions/
beta.rs1use rand::prelude::*;
10use rand_chacha::ChaCha8Rng;
11use rand_distr::{Beta, Distribution};
12use rust_decimal::Decimal;
13use serde::{Deserialize, Serialize};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct BetaConfig {
18 pub alpha: f64,
21 pub beta: f64,
24 #[serde(default)]
26 pub lower_bound: f64,
27 #[serde(default = "default_upper_bound")]
29 pub upper_bound: f64,
30 #[serde(default = "default_decimal_places")]
32 pub decimal_places: u8,
33}
34
35fn default_upper_bound() -> f64 {
36 1.0
37}
38
39fn default_decimal_places() -> u8 {
40 4
41}
42
43impl Default for BetaConfig {
44 fn default() -> Self {
45 Self {
46 alpha: 2.0,
47 beta: 5.0,
48 lower_bound: 0.0,
49 upper_bound: 1.0,
50 decimal_places: 4,
51 }
52 }
53}
54
55impl BetaConfig {
56 pub fn new(alpha: f64, beta: f64) -> Self {
58 Self {
59 alpha,
60 beta,
61 ..Default::default()
62 }
63 }
64
65 pub fn percentage(alpha: f64, beta: f64) -> Self {
67 Self {
68 alpha,
69 beta,
70 lower_bound: 0.0,
71 upper_bound: 100.0,
72 decimal_places: 2,
73 }
74 }
75
76 pub fn discount_rate() -> Self {
78 Self {
79 alpha: 2.0, beta: 8.0,
81 lower_bound: 0.02, upper_bound: 0.15, decimal_places: 4,
84 }
85 }
86
87 pub fn cash_discount() -> Self {
89 Self {
90 alpha: 3.0,
91 beta: 3.0, lower_bound: 0.01, upper_bound: 0.03, decimal_places: 4,
95 }
96 }
97
98 pub fn completion_rate() -> Self {
100 Self {
101 alpha: 8.0, beta: 2.0,
103 lower_bound: 0.0,
104 upper_bound: 1.0,
105 decimal_places: 4,
106 }
107 }
108
109 pub fn match_rate() -> Self {
111 Self {
112 alpha: 10.0,
113 beta: 1.5,
114 lower_bound: 0.85,
115 upper_bound: 0.99,
116 decimal_places: 4,
117 }
118 }
119
120 pub fn quality_score() -> Self {
122 Self {
123 alpha: 5.0,
124 beta: 2.0,
125 lower_bound: 0.0,
126 upper_bound: 100.0,
127 decimal_places: 1,
128 }
129 }
130
131 pub fn uniform() -> Self {
133 Self {
134 alpha: 1.0,
135 beta: 1.0,
136 ..Default::default()
137 }
138 }
139
140 pub fn validate(&self) -> Result<(), String> {
142 if self.alpha <= 0.0 {
143 return Err("alpha must be positive".to_string());
144 }
145 if self.beta <= 0.0 {
146 return Err("beta must be positive".to_string());
147 }
148 if self.upper_bound <= self.lower_bound {
149 return Err("upper_bound must be greater than lower_bound".to_string());
150 }
151 Ok(())
152 }
153
154 pub fn expected_value(&self) -> f64 {
156 let raw_mean = self.alpha / (self.alpha + self.beta);
157 self.lower_bound + raw_mean * (self.upper_bound - self.lower_bound)
158 }
159
160 pub fn mode(&self) -> Option<f64> {
163 if self.alpha > 1.0 && self.beta > 1.0 {
164 let raw_mode = (self.alpha - 1.0) / (self.alpha + self.beta - 2.0);
165 Some(self.lower_bound + raw_mode * (self.upper_bound - self.lower_bound))
166 } else {
167 None
168 }
169 }
170
171 pub fn variance(&self) -> f64 {
173 let ab = self.alpha + self.beta;
174 let raw_variance = (self.alpha * self.beta) / (ab.powi(2) * (ab + 1.0));
175 raw_variance * (self.upper_bound - self.lower_bound).powi(2)
176 }
177}
178
179pub struct BetaSampler {
181 rng: ChaCha8Rng,
182 config: BetaConfig,
183 distribution: Beta<f64>,
184 decimal_multiplier: f64,
185 range: f64,
186}
187
188impl BetaSampler {
189 pub fn new(seed: u64, config: BetaConfig) -> Result<Self, String> {
191 config.validate()?;
192
193 let distribution = Beta::new(config.alpha, config.beta)
194 .map_err(|e| format!("Invalid Beta distribution: {}", e))?;
195
196 let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
197 let range = config.upper_bound - config.lower_bound;
198
199 Ok(Self {
200 rng: ChaCha8Rng::seed_from_u64(seed),
201 config,
202 distribution,
203 decimal_multiplier,
204 range,
205 })
206 }
207
208 pub fn sample(&mut self) -> f64 {
210 let raw_value = self.distribution.sample(&mut self.rng);
211 let scaled_value = self.config.lower_bound + raw_value * self.range;
212
213 (scaled_value * self.decimal_multiplier).round() / self.decimal_multiplier
215 }
216
217 pub fn sample_decimal(&mut self) -> Decimal {
219 let value = self.sample();
220 Decimal::from_f64_retain(value).unwrap_or(Decimal::ZERO)
221 }
222
223 pub fn sample_percentage(&mut self) -> f64 {
225 let raw_value = self.distribution.sample(&mut self.rng);
226 let scaled_value = self.config.lower_bound + raw_value * self.range;
227 (scaled_value * 100.0 * self.decimal_multiplier).round() / self.decimal_multiplier
228 }
229
230 pub fn sample_n(&mut self, n: usize) -> Vec<f64> {
232 (0..n).map(|_| self.sample()).collect()
233 }
234
235 pub fn reset(&mut self, seed: u64) {
237 self.rng = ChaCha8Rng::seed_from_u64(seed);
238 }
239
240 pub fn config(&self) -> &BetaConfig {
242 &self.config
243 }
244}
245
246#[derive(Debug, Clone, Copy, PartialEq, Eq)]
248pub enum BetaShape {
249 Uniform,
251 UShaped,
253 Symmetric,
255 SkewedLeft,
257 SkewedRight,
259 JShapedHigh,
261 JShapedLow,
263}
264
265impl BetaConfig {
266 pub fn shape(&self) -> BetaShape {
268 match (self.alpha, self.beta) {
269 (a, b) if (a - 1.0).abs() < 0.001 && (b - 1.0).abs() < 0.001 => BetaShape::Uniform,
270 (a, b) if a < 1.0 && b < 1.0 => BetaShape::UShaped,
271 (a, b) if (a - b).abs() < 0.001 && a > 1.0 => BetaShape::Symmetric,
272 (a, b) if a < 1.0 && b >= 1.0 => BetaShape::JShapedLow,
273 (a, b) if a >= 1.0 && b < 1.0 => BetaShape::JShapedHigh,
274 (a, b) if a > b => BetaShape::SkewedLeft,
275 _ => BetaShape::SkewedRight,
276 }
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283
284 #[test]
285 fn test_beta_validation() {
286 let config = BetaConfig::new(2.0, 5.0);
287 assert!(config.validate().is_ok());
288
289 let invalid_alpha = BetaConfig::new(-1.0, 5.0);
290 assert!(invalid_alpha.validate().is_err());
291
292 let invalid_beta = BetaConfig::new(2.0, 0.0);
293 assert!(invalid_beta.validate().is_err());
294 }
295
296 #[test]
297 fn test_beta_sampling() {
298 let config = BetaConfig::new(2.0, 5.0);
299 let mut sampler = BetaSampler::new(42, config).unwrap();
300
301 let samples = sampler.sample_n(1000);
302 assert_eq!(samples.len(), 1000);
303
304 assert!(samples.iter().all(|&x| (0.0..=1.0).contains(&x)));
306 }
307
308 #[test]
309 fn test_beta_determinism() {
310 let config = BetaConfig::new(2.0, 5.0);
311
312 let mut sampler1 = BetaSampler::new(42, config.clone()).unwrap();
313 let mut sampler2 = BetaSampler::new(42, config).unwrap();
314
315 for _ in 0..100 {
316 assert_eq!(sampler1.sample(), sampler2.sample());
317 }
318 }
319
320 #[test]
321 fn test_beta_scaled_range() {
322 let config = BetaConfig {
323 alpha: 2.0,
324 beta: 2.0,
325 lower_bound: 0.02,
326 upper_bound: 0.15,
327 decimal_places: 4,
328 };
329 let mut sampler = BetaSampler::new(42, config).unwrap();
330
331 let samples = sampler.sample_n(1000);
332 assert!(samples.iter().all(|&x| (0.02..=0.15).contains(&x)));
333 }
334
335 #[test]
336 fn test_beta_expected_value() {
337 let config = BetaConfig::new(2.0, 5.0);
338 let expected = config.expected_value();
340 assert!((expected - 0.286).abs() < 0.01);
341 }
342
343 #[test]
344 fn test_beta_mode() {
345 let config = BetaConfig::new(2.0, 5.0);
346 let mode = config.mode();
348 assert!(mode.is_some());
349 assert!((mode.unwrap() - 0.2).abs() < 0.001);
350
351 let no_mode_config = BetaConfig::new(0.5, 5.0);
353 assert!(no_mode_config.mode().is_none());
354 }
355
356 #[test]
357 fn test_beta_presets() {
358 let discount = BetaConfig::discount_rate();
359 assert!(discount.validate().is_ok());
360
361 let cash = BetaConfig::cash_discount();
362 assert!(cash.validate().is_ok());
363
364 let completion = BetaConfig::completion_rate();
365 assert!(completion.validate().is_ok());
366
367 let match_rate = BetaConfig::match_rate();
368 assert!(match_rate.validate().is_ok());
369
370 let quality = BetaConfig::quality_score();
371 assert!(quality.validate().is_ok());
372 }
373
374 #[test]
375 fn test_beta_shape_detection() {
376 assert_eq!(BetaConfig::uniform().shape(), BetaShape::Uniform);
377 assert_eq!(BetaConfig::new(0.5, 0.5).shape(), BetaShape::UShaped);
378 assert_eq!(BetaConfig::new(5.0, 5.0).shape(), BetaShape::Symmetric);
379 assert_eq!(BetaConfig::new(8.0, 2.0).shape(), BetaShape::SkewedLeft);
380 assert_eq!(BetaConfig::new(2.0, 8.0).shape(), BetaShape::SkewedRight);
381 }
382
383 #[test]
384 fn test_discount_rate_distribution() {
385 let config = BetaConfig::discount_rate();
386 let mut sampler = BetaSampler::new(42, config.clone()).unwrap();
387
388 let samples = sampler.sample_n(1000);
389
390 assert!(samples.iter().all(|&x| (0.02..=0.15).contains(&x)));
392
393 let mean: f64 = samples.iter().sum::<f64>() / samples.len() as f64;
395 let expected = config.expected_value();
396 assert!((mean - expected).abs() < 0.01);
397 }
398
399 #[test]
400 fn test_beta_percentage_sampling() {
401 let config = BetaConfig::percentage(2.0, 5.0);
402 let mut sampler = BetaSampler::new(42, config).unwrap();
403
404 let samples = sampler.sample_n(1000);
405 assert!(samples.iter().all(|&x| (0.0..=100.0).contains(&x)));
406 }
407}