amari_flynn/
distributions.rs1use crate::prob::Prob;
4use rand::Rng;
5use rand_distr::{Bernoulli as RandBernoulli, Distribution, Exp, Normal as RandNormal};
6
7#[derive(Clone, Copy, Debug)]
9pub struct Uniform<T> {
10 min: T,
11 max: T,
12}
13
14impl Uniform<i32> {
15 pub fn new(min: i32, max: i32) -> Self {
17 Self { min, max }
18 }
19
20 pub fn sample(&self) -> Prob<i32> {
22 let mut rng = rand::thread_rng();
23 let value = rng.gen_range(self.min..self.max);
24 Prob::new(value)
25 }
26}
27
28#[derive(Clone, Copy, Debug)]
30pub struct Bernoulli {
31 p: f64,
32}
33
34impl Bernoulli {
35 pub fn new(p: f64) -> Self {
37 assert!((0.0..=1.0).contains(&p), "Probability must be in [0, 1]");
38 Self { p }
39 }
40
41 pub fn sample(&self) -> Prob<bool> {
43 let mut rng = rand::thread_rng();
44 let dist = RandBernoulli::new(self.p).unwrap();
45 let value = dist.sample(&mut rng);
46 Prob::with_probability(self.p, value)
47 }
48}
49
50#[derive(Clone, Copy, Debug)]
52pub struct Normal {
53 mean: f64,
54 std_dev: f64,
55}
56
57impl Normal {
58 pub fn new(mean: f64, std_dev: f64) -> Self {
60 assert!(std_dev > 0.0, "Standard deviation must be positive");
61 Self { mean, std_dev }
62 }
63
64 pub fn sample(&self) -> Prob<f64> {
66 let mut rng = rand::thread_rng();
67 let dist = RandNormal::new(self.mean, self.std_dev).unwrap();
68 let value = dist.sample(&mut rng);
69 Prob::new(value)
70 }
71}
72
73#[derive(Clone, Copy, Debug)]
75pub struct Exponential {
76 lambda: f64,
77}
78
79impl Exponential {
80 pub fn new(lambda: f64) -> Self {
82 assert!(lambda > 0.0, "Lambda must be positive");
83 Self { lambda }
84 }
85
86 pub fn sample(&self) -> Prob<f64> {
88 let mut rng = rand::thread_rng();
89 let dist = Exp::new(self.lambda).unwrap();
90 let value = dist.sample(&mut rng);
91 Prob::new(value)
92 }
93}
94
95#[cfg(test)]
96mod tests {
97 use super::*;
98
99 #[test]
100 fn test_uniform_range() {
101 let dist = Uniform::new(1, 7);
102 for _ in 0..100 {
103 let sample = dist.sample().into_inner();
104 assert!((1..7).contains(&sample));
105 }
106 }
107
108 #[test]
109 fn test_bernoulli() {
110 let dist = Bernoulli::new(0.5);
111 let sample = dist.sample();
112 assert_eq!(sample.probability(), 0.5);
113 }
114
115 #[test]
116 fn test_normal() {
117 let dist = Normal::new(0.0, 1.0);
118 let sample = dist.sample();
119 assert_eq!(sample.probability(), 1.0); }
121}