Skip to main content

amari_flynn/
distributions.rs

1//! Common probability distributions
2
3use crate::prob::Prob;
4use rand::Rng;
5use rand_distr::{Bernoulli as RandBernoulli, Distribution, Exp, Normal as RandNormal};
6
7/// Uniform distribution over a range
8#[derive(Clone, Copy, Debug)]
9pub struct Uniform<T> {
10    min: T,
11    max: T,
12}
13
14impl Uniform<i32> {
15    /// Create uniform distribution over [min, max)
16    pub fn new(min: i32, max: i32) -> Self {
17        Self { min, max }
18    }
19
20    /// Sample from the distribution
21    pub fn sample(&self) -> Prob<i32> {
22        let mut rng = rand::thread_rng();
23        let value = rng.gen_range(self.min..self.max);
24        Prob::new(value)
25    }
26}
27
28/// Bernoulli distribution (binary outcome)
29#[derive(Clone, Copy, Debug)]
30pub struct Bernoulli {
31    p: f64,
32}
33
34impl Bernoulli {
35    /// Create Bernoulli distribution with probability p
36    pub fn new(p: f64) -> Self {
37        assert!((0.0..=1.0).contains(&p), "Probability must be in [0, 1]");
38        Self { p }
39    }
40
41    /// Sample from the distribution
42    pub fn sample(&self) -> Prob<bool> {
43        let mut rng = rand::thread_rng();
44        let dist = RandBernoulli::new(self.p).unwrap();
45        let value = dist.sample(&mut rng);
46        Prob::with_probability(self.p, value)
47    }
48}
49
50/// Normal (Gaussian) distribution
51#[derive(Clone, Copy, Debug)]
52pub struct Normal {
53    mean: f64,
54    std_dev: f64,
55}
56
57impl Normal {
58    /// Create normal distribution with given mean and standard deviation
59    pub fn new(mean: f64, std_dev: f64) -> Self {
60        assert!(std_dev > 0.0, "Standard deviation must be positive");
61        Self { mean, std_dev }
62    }
63
64    /// Sample from the distribution
65    pub fn sample(&self) -> Prob<f64> {
66        let mut rng = rand::thread_rng();
67        let dist = RandNormal::new(self.mean, self.std_dev).unwrap();
68        let value = dist.sample(&mut rng);
69        Prob::new(value)
70    }
71}
72
73/// Exponential distribution
74#[derive(Clone, Copy, Debug)]
75pub struct Exponential {
76    lambda: f64,
77}
78
79impl Exponential {
80    /// Create exponential distribution with rate parameter lambda
81    pub fn new(lambda: f64) -> Self {
82        assert!(lambda > 0.0, "Lambda must be positive");
83        Self { lambda }
84    }
85
86    /// Sample from the distribution
87    pub fn sample(&self) -> Prob<f64> {
88        let mut rng = rand::thread_rng();
89        let dist = Exp::new(self.lambda).unwrap();
90        let value = dist.sample(&mut rng);
91        Prob::new(value)
92    }
93}
94
95#[cfg(test)]
96mod tests {
97    use super::*;
98
99    #[test]
100    fn test_uniform_range() {
101        let dist = Uniform::new(1, 7);
102        for _ in 0..100 {
103            let sample = dist.sample().into_inner();
104            assert!((1..7).contains(&sample));
105        }
106    }
107
108    #[test]
109    fn test_bernoulli() {
110        let dist = Bernoulli::new(0.5);
111        let sample = dist.sample();
112        assert_eq!(sample.probability(), 0.5);
113    }
114
115    #[test]
116    fn test_normal() {
117        let dist = Normal::new(0.0, 1.0);
118        let sample = dist.sample();
119        assert_eq!(sample.probability(), 1.0); // Sampling is deterministic
120    }
121}