1use rand::distributions::Distribution;
3use rand::Rng;
4
5const EXP_MINUS_HALF: f64 = 0.606530659712633;
6
7pub struct DiscreteNormal {
20 mu: f64,
21 sigma: f64,
22}
23
24impl DiscreteNormal {
25 pub fn new(mu: f64, sigma: f64) -> Self {
26 Self { mu, sigma }
27 }
28
29 #[inline]
30 fn bernoulli(&self) -> bool {
31 rand::thread_rng().gen::<f64>() < EXP_MINUS_HALF
32 }
33
34 #[inline]
35 fn g(&self) -> u32 {
36 let mut res: u32 = 0;
37 while self.bernoulli() {
38 res += 1;
39 }
40 res
41 }
42
43 #[inline]
44 fn s(&self) -> u32 {
45 loop {
46 let k = self.g() as i32;
47 if k < 2 {
48 return k as u32;
49 }
50 let mut z = k * (k - 1) - 1;
51 while z != 0 && self.bernoulli() {
52 z -= 1;
53 }
54 if z < 0 {
55 return k as u32;
56 };
57 }
58 }
59}
60
61impl Distribution<i32> for DiscreteNormal {
62 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> i32 {
63 loop {
64 let k = self.s();
65
66 let s = if rng.gen::<bool>() { 1 } else { -1 };
67
68 let mut xn0 = (k as f64) * self.sigma + (s as f64) * self.mu;
69 let i0 = xn0.ceil();
70 xn0 = (i0 - xn0) / self.sigma;
71 let j = rng.gen::<u32>() % (self.sigma.ceil() as u32);
72
73 let x = xn0 + (j as f64) / self.sigma;
74 if x < 1.0 && !(x == 0.0 && s < 0 && k == 0) {
75 xn0 = (-x * ((k << 1) as f64 + x) / 2.0).exp();
76 if x == 0.0 || rng.gen::<f64>() <= xn0 {
77 return s * (i0 as u32 + j) as i32;
78 }
79 }
80 }
81 }
82}
83
84#[cfg(test)]
85mod tests {
86 use super::*;
87 use libm::erf;
88 use std::collections::HashMap;
89
90 fn normcdf(x: f64, mu: f64, sigma: f64) -> f64 {
91 (0.5) * (1.0 + erf((x - mu) / (sigma * (2.0_f64).sqrt())))
92 }
93
94 #[test]
95 fn test_chisq() {
96 let n: usize = 1000;
97 let mu = 0.0;
98 let sigma = 3.0;
99
100 let dn = DiscreteNormal::new(mu, sigma);
101
102 let mut hist = HashMap::new();
103 for _ in 0..n {
104 let o = dn.sample(&mut rand::thread_rng());
105 match hist.get(&o) {
106 Some(count) => hist.insert(o, count + 1),
107 None => hist.insert(o, 1),
108 };
109 }
110
111 let mut data = hist.iter().map(|(k, _)| *k).collect::<Vec<i32>>();
112 data.sort();
113
114 let s = data.len();
115 let probs = data
116 .iter()
117 .map(|d| normcdf(*d as f64, mu, sigma))
118 .collect::<Vec<f64>>();
119
120 let expected = (0..s)
121 .map(|i| match i {
122 0 => probs[i],
123 _ => probs[i] - probs[i - 1],
124 })
125 .collect::<Vec<f64>>();
126
127 let mut chisq = 0.0;
128 for i in 0..s {
129 let o = hist[&data[i]];
130 let e = expected[i] * n as f64;
131 chisq += (o as f64 - e).powf(2.0) / e;
132 }
133
134 dbg!(chisq);
135 }
136}