Skip to main content

datasynth_core/distributions/
beta.rs

1//! Beta distribution for modeling proportions and percentages.
2//!
3//! The Beta distribution is ideal for:
4//! - Discount percentages
5//! - Completion rates
6//! - Proportion of revenue recognized
7//! - Match rates and quality scores
8
9use rand::prelude::*;
10use rand_chacha::ChaCha8Rng;
11use rand_distr::{Beta, Distribution};
12use rust_decimal::Decimal;
13use serde::{Deserialize, Serialize};
14
15/// Configuration for Beta distribution.
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct BetaConfig {
18    /// Alpha parameter (shape1) - controls skewness towards 1.
19    /// Higher alpha = more mass towards 1.
20    pub alpha: f64,
21    /// Beta parameter (shape2) - controls skewness towards 0.
22    /// Higher beta = more mass towards 0.
23    pub beta: f64,
24    /// Lower bound of the output range (default: 0.0).
25    #[serde(default)]
26    pub lower_bound: f64,
27    /// Upper bound of the output range (default: 1.0).
28    #[serde(default = "default_upper_bound")]
29    pub upper_bound: f64,
30    /// Number of decimal places for rounding.
31    #[serde(default = "default_decimal_places")]
32    pub decimal_places: u8,
33}
34
35fn default_upper_bound() -> f64 {
36    1.0
37}
38
39fn default_decimal_places() -> u8 {
40    4
41}
42
43impl Default for BetaConfig {
44    fn default() -> Self {
45        Self {
46            alpha: 2.0,
47            beta: 5.0,
48            lower_bound: 0.0,
49            upper_bound: 1.0,
50            decimal_places: 4,
51        }
52    }
53}
54
55impl BetaConfig {
56    /// Create a new Beta configuration.
57    pub fn new(alpha: f64, beta: f64) -> Self {
58        Self {
59            alpha,
60            beta,
61            ..Default::default()
62        }
63    }
64
65    /// Create a configuration scaled to a percentage range.
66    pub fn percentage(alpha: f64, beta: f64) -> Self {
67        Self {
68            alpha,
69            beta,
70            lower_bound: 0.0,
71            upper_bound: 100.0,
72            decimal_places: 2,
73        }
74    }
75
76    /// Create a configuration for discount percentages (typically 2-15%).
77    pub fn discount_rate() -> Self {
78        Self {
79            alpha: 2.0, // Skewed towards lower discounts
80            beta: 8.0,
81            lower_bound: 0.02, // 2% minimum
82            upper_bound: 0.15, // 15% maximum
83            decimal_places: 4,
84        }
85    }
86
87    /// Create a configuration for cash discount rates (1-3%).
88    pub fn cash_discount() -> Self {
89        Self {
90            alpha: 3.0,
91            beta: 3.0,         // Symmetric around 2%
92            lower_bound: 0.01, // 1% minimum
93            upper_bound: 0.03, // 3% maximum
94            decimal_places: 4,
95        }
96    }
97
98    /// Create a configuration for completion rates (biased towards high).
99    pub fn completion_rate() -> Self {
100        Self {
101            alpha: 8.0, // Strongly biased towards 1
102            beta: 2.0,
103            lower_bound: 0.0,
104            upper_bound: 1.0,
105            decimal_places: 4,
106        }
107    }
108
109    /// Create a configuration for match rates (typically 85-99%).
110    pub fn match_rate() -> Self {
111        Self {
112            alpha: 10.0,
113            beta: 1.5,
114            lower_bound: 0.85,
115            upper_bound: 0.99,
116            decimal_places: 4,
117        }
118    }
119
120    /// Create a configuration for quality scores (0-100, slightly skewed high).
121    pub fn quality_score() -> Self {
122        Self {
123            alpha: 5.0,
124            beta: 2.0,
125            lower_bound: 0.0,
126            upper_bound: 100.0,
127            decimal_places: 1,
128        }
129    }
130
131    /// Create a uniform distribution on [0, 1].
132    pub fn uniform() -> Self {
133        Self {
134            alpha: 1.0,
135            beta: 1.0,
136            ..Default::default()
137        }
138    }
139
140    /// Validate the configuration.
141    pub fn validate(&self) -> Result<(), String> {
142        if self.alpha <= 0.0 {
143            return Err("alpha must be positive".to_string());
144        }
145        if self.beta <= 0.0 {
146            return Err("beta must be positive".to_string());
147        }
148        if self.upper_bound <= self.lower_bound {
149            return Err("upper_bound must be greater than lower_bound".to_string());
150        }
151        Ok(())
152    }
153
154    /// Get the expected value (mean) of the distribution.
155    pub fn expected_value(&self) -> f64 {
156        let raw_mean = self.alpha / (self.alpha + self.beta);
157        self.lower_bound + raw_mean * (self.upper_bound - self.lower_bound)
158    }
159
160    /// Get the mode of the distribution.
161    /// Only defined for alpha > 1 and beta > 1.
162    pub fn mode(&self) -> Option<f64> {
163        if self.alpha > 1.0 && self.beta > 1.0 {
164            let raw_mode = (self.alpha - 1.0) / (self.alpha + self.beta - 2.0);
165            Some(self.lower_bound + raw_mode * (self.upper_bound - self.lower_bound))
166        } else {
167            None
168        }
169    }
170
171    /// Get the variance of the distribution.
172    pub fn variance(&self) -> f64 {
173        let ab = self.alpha + self.beta;
174        let raw_variance = (self.alpha * self.beta) / (ab.powi(2) * (ab + 1.0));
175        raw_variance * (self.upper_bound - self.lower_bound).powi(2)
176    }
177}
178
179/// Beta distribution sampler.
180pub struct BetaSampler {
181    rng: ChaCha8Rng,
182    config: BetaConfig,
183    distribution: Beta<f64>,
184    decimal_multiplier: f64,
185    range: f64,
186}
187
188impl BetaSampler {
189    /// Create a new Beta sampler.
190    pub fn new(seed: u64, config: BetaConfig) -> Result<Self, String> {
191        config.validate()?;
192
193        let distribution = Beta::new(config.alpha, config.beta)
194            .map_err(|e| format!("Invalid Beta distribution: {}", e))?;
195
196        let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
197        let range = config.upper_bound - config.lower_bound;
198
199        Ok(Self {
200            rng: ChaCha8Rng::seed_from_u64(seed),
201            config,
202            distribution,
203            decimal_multiplier,
204            range,
205        })
206    }
207
208    /// Sample a value from the distribution.
209    pub fn sample(&mut self) -> f64 {
210        let raw_value = self.distribution.sample(&mut self.rng);
211        let scaled_value = self.config.lower_bound + raw_value * self.range;
212
213        // Round to decimal places
214        (scaled_value * self.decimal_multiplier).round() / self.decimal_multiplier
215    }
216
217    /// Sample a value as Decimal.
218    pub fn sample_decimal(&mut self) -> Decimal {
219        let value = self.sample();
220        Decimal::from_f64_retain(value).unwrap_or(Decimal::ZERO)
221    }
222
223    /// Sample a value as a percentage (multiplied by 100).
224    pub fn sample_percentage(&mut self) -> f64 {
225        let raw_value = self.distribution.sample(&mut self.rng);
226        let scaled_value = self.config.lower_bound + raw_value * self.range;
227        (scaled_value * 100.0 * self.decimal_multiplier).round() / self.decimal_multiplier
228    }
229
230    /// Sample multiple values.
231    pub fn sample_n(&mut self, n: usize) -> Vec<f64> {
232        (0..n).map(|_| self.sample()).collect()
233    }
234
235    /// Reset the sampler with a new seed.
236    pub fn reset(&mut self, seed: u64) {
237        self.rng = ChaCha8Rng::seed_from_u64(seed);
238    }
239
240    /// Get the configuration.
241    pub fn config(&self) -> &BetaConfig {
242        &self.config
243    }
244}
245
246/// Determine the shape of the distribution based on alpha and beta.
247#[derive(Debug, Clone, Copy, PartialEq, Eq)]
248pub enum BetaShape {
249    /// Uniform distribution (alpha = beta = 1)
250    Uniform,
251    /// U-shaped (alpha < 1 and beta < 1)
252    UShaped,
253    /// Unimodal symmetric (alpha = beta > 1)
254    Symmetric,
255    /// Unimodal skewed left (alpha > beta)
256    SkewedLeft,
257    /// Unimodal skewed right (alpha < beta)
258    SkewedRight,
259    /// J-shaped towards 1 (alpha >= 1, beta < 1)
260    JShapedHigh,
261    /// J-shaped towards 0 (alpha < 1, beta >= 1)
262    JShapedLow,
263}
264
265impl BetaConfig {
266    /// Determine the shape of this distribution.
267    pub fn shape(&self) -> BetaShape {
268        match (self.alpha, self.beta) {
269            (a, b) if (a - 1.0).abs() < 0.001 && (b - 1.0).abs() < 0.001 => BetaShape::Uniform,
270            (a, b) if a < 1.0 && b < 1.0 => BetaShape::UShaped,
271            (a, b) if (a - b).abs() < 0.001 && a > 1.0 => BetaShape::Symmetric,
272            (a, b) if a < 1.0 && b >= 1.0 => BetaShape::JShapedLow,
273            (a, b) if a >= 1.0 && b < 1.0 => BetaShape::JShapedHigh,
274            (a, b) if a > b => BetaShape::SkewedLeft,
275            _ => BetaShape::SkewedRight,
276        }
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283
284    #[test]
285    fn test_beta_validation() {
286        let config = BetaConfig::new(2.0, 5.0);
287        assert!(config.validate().is_ok());
288
289        let invalid_alpha = BetaConfig::new(-1.0, 5.0);
290        assert!(invalid_alpha.validate().is_err());
291
292        let invalid_beta = BetaConfig::new(2.0, 0.0);
293        assert!(invalid_beta.validate().is_err());
294    }
295
296    #[test]
297    fn test_beta_sampling() {
298        let config = BetaConfig::new(2.0, 5.0);
299        let mut sampler = BetaSampler::new(42, config).unwrap();
300
301        let samples = sampler.sample_n(1000);
302        assert_eq!(samples.len(), 1000);
303
304        // All samples should be in [0, 1]
305        assert!(samples.iter().all(|&x| (0.0..=1.0).contains(&x)));
306    }
307
308    #[test]
309    fn test_beta_determinism() {
310        let config = BetaConfig::new(2.0, 5.0);
311
312        let mut sampler1 = BetaSampler::new(42, config.clone()).unwrap();
313        let mut sampler2 = BetaSampler::new(42, config).unwrap();
314
315        for _ in 0..100 {
316            assert_eq!(sampler1.sample(), sampler2.sample());
317        }
318    }
319
320    #[test]
321    fn test_beta_scaled_range() {
322        let config = BetaConfig {
323            alpha: 2.0,
324            beta: 2.0,
325            lower_bound: 0.02,
326            upper_bound: 0.15,
327            decimal_places: 4,
328        };
329        let mut sampler = BetaSampler::new(42, config).unwrap();
330
331        let samples = sampler.sample_n(1000);
332        assert!(samples.iter().all(|&x| (0.02..=0.15).contains(&x)));
333    }
334
335    #[test]
336    fn test_beta_expected_value() {
337        let config = BetaConfig::new(2.0, 5.0);
338        // E[X] = alpha / (alpha + beta) = 2/7 ≈ 0.286
339        let expected = config.expected_value();
340        assert!((expected - 0.286).abs() < 0.01);
341    }
342
343    #[test]
344    fn test_beta_mode() {
345        let config = BetaConfig::new(2.0, 5.0);
346        // Mode = (alpha - 1) / (alpha + beta - 2) = 1/5 = 0.2
347        let mode = config.mode();
348        assert!(mode.is_some());
349        assert!((mode.unwrap() - 0.2).abs() < 0.001);
350
351        // No mode for alpha <= 1
352        let no_mode_config = BetaConfig::new(0.5, 5.0);
353        assert!(no_mode_config.mode().is_none());
354    }
355
356    #[test]
357    fn test_beta_presets() {
358        let discount = BetaConfig::discount_rate();
359        assert!(discount.validate().is_ok());
360
361        let cash = BetaConfig::cash_discount();
362        assert!(cash.validate().is_ok());
363
364        let completion = BetaConfig::completion_rate();
365        assert!(completion.validate().is_ok());
366
367        let match_rate = BetaConfig::match_rate();
368        assert!(match_rate.validate().is_ok());
369
370        let quality = BetaConfig::quality_score();
371        assert!(quality.validate().is_ok());
372    }
373
374    #[test]
375    fn test_beta_shape_detection() {
376        assert_eq!(BetaConfig::uniform().shape(), BetaShape::Uniform);
377        assert_eq!(BetaConfig::new(0.5, 0.5).shape(), BetaShape::UShaped);
378        assert_eq!(BetaConfig::new(5.0, 5.0).shape(), BetaShape::Symmetric);
379        assert_eq!(BetaConfig::new(8.0, 2.0).shape(), BetaShape::SkewedLeft);
380        assert_eq!(BetaConfig::new(2.0, 8.0).shape(), BetaShape::SkewedRight);
381    }
382
383    #[test]
384    fn test_discount_rate_distribution() {
385        let config = BetaConfig::discount_rate();
386        let mut sampler = BetaSampler::new(42, config.clone()).unwrap();
387
388        let samples = sampler.sample_n(1000);
389
390        // All samples should be in [2%, 15%]
391        assert!(samples.iter().all(|&x| (0.02..=0.15).contains(&x)));
392
393        // Mean should be around the expected value
394        let mean: f64 = samples.iter().sum::<f64>() / samples.len() as f64;
395        let expected = config.expected_value();
396        assert!((mean - expected).abs() < 0.01);
397    }
398
399    #[test]
400    fn test_beta_percentage_sampling() {
401        let config = BetaConfig::percentage(2.0, 5.0);
402        let mut sampler = BetaSampler::new(42, config).unwrap();
403
404        let samples = sampler.sample_n(1000);
405        assert!(samples.iter().all(|&x| (0.0..=100.0).contains(&x)));
406    }
407}