probability/distribution/
categorical.rs1use alloc::{vec, vec::Vec};
2#[allow(unused_imports)]
3use special::Primitive;
4
5use distribution;
6use source::Source;
7
8#[derive(Clone, Debug)]
10pub struct Categorical {
11 k: usize,
12 p: Vec<f64>,
13 cumsum: Vec<f64>,
14}
15
16impl Categorical {
17 pub fn new(p: &[f64]) -> Self {
21 should!(is_probability_vector(p), {
22 const EPSILON: f64 = 1e-12;
23 p.iter().all(|&p| (0.0..=1.0).contains(&p))
24 && (p.iter().fold(0.0, |sum, &p| sum + p) - 1.0).abs() < EPSILON
25 });
26
27 let k = p.len();
28 let mut cumsum = p.to_vec();
29 for i in 1..(k - 1) {
30 cumsum[i] += cumsum[i - 1];
31 }
32 cumsum[k - 1] = 1.0;
33 Categorical {
34 k,
35 p: p.to_vec(),
36 cumsum,
37 }
38 }
39
40 #[inline(always)]
42 pub fn k(&self) -> usize {
43 self.k
44 }
45
46 #[inline(always)]
48 pub fn p(&self) -> &[f64] {
49 &self.p
50 }
51}
52
53impl distribution::Discrete for Categorical {
54 #[inline]
55 fn mass(&self, x: usize) -> f64 {
56 should!(x < self.k);
57 self.p[x]
58 }
59}
60
61impl distribution::Distribution for Categorical {
62 type Value = usize;
63
64 fn distribution(&self, x: f64) -> f64 {
65 if x < 0.0 {
66 return 0.0;
67 }
68 let x = x as usize;
69 if x >= self.k {
70 return 1.0;
71 }
72 self.cumsum[x]
73 }
74}
75
76impl distribution::Entropy for Categorical {
77 fn entropy(&self) -> f64 {
78 -self.p.iter().fold(0.0, |sum, p| sum + p * p.ln())
79 }
80}
81
82impl distribution::Inverse for Categorical {
83 fn inverse(&self, p: f64) -> usize {
84 should!((0.0..=1.0).contains(&p));
85 self.cumsum
86 .iter()
87 .position(|&sum| sum > 0.0 && sum >= p)
88 .unwrap_or_else(|| self.p.iter().rposition(|&p| p > 0.0).unwrap())
89 }
90}
91
92impl distribution::Kurtosis for Categorical {
93 fn kurtosis(&self) -> f64 {
94 use distribution::{Mean, Variance};
95 let (mean, variance) = (self.mean(), self.variance());
96 let kurt = self
97 .p
98 .iter()
99 .enumerate()
100 .fold(0.0, |sum, (i, p)| sum + (i as f64 - mean).powi(4) * p);
101 kurt / variance.powi(2) - 3.0
102 }
103}
104
105impl distribution::Mean for Categorical {
106 fn mean(&self) -> f64 {
107 self.p
108 .iter()
109 .enumerate()
110 .fold(0.0, |sum, (i, p)| sum + i as f64 * p)
111 }
112}
113
114impl distribution::Median for Categorical {
115 fn median(&self) -> f64 {
116 if self.p[0] > 0.5 {
117 return 0.0;
118 }
119 if self.p[0] == 0.5 {
120 return 0.5;
121 }
122 for (i, &sum) in self.cumsum.iter().enumerate() {
123 if sum == 0.5 {
124 return (2 * i - 1) as f64 / 2.0;
125 } else if sum > 0.5 {
126 return i as f64;
127 }
128 }
129 unreachable!()
130 }
131}
132
133impl distribution::Modes for Categorical {
134 fn modes(&self) -> Vec<usize> {
135 let mut modes = Vec::new();
136 let mut max = 0.0;
137 for (i, &p) in self.p.iter().enumerate() {
138 if p == max {
139 modes.push(i);
140 }
141 if p > max {
142 max = p;
143 modes = vec![i];
144 }
145 }
146 modes
147 }
148}
149
150impl distribution::Sample for Categorical {
151 #[inline]
152 fn sample<S>(&self, source: &mut S) -> usize
153 where
154 S: Source,
155 {
156 use distribution::Inverse;
157 self.inverse(source.read::<f64>())
158 }
159}
160
161impl distribution::Skewness for Categorical {
162 fn skewness(&self) -> f64 {
163 use distribution::{Mean, Variance};
164 let (mean, variance) = (self.mean(), self.variance());
165 let skew = self
166 .p
167 .iter()
168 .enumerate()
169 .fold(0.0, |sum, (i, p)| sum + (i as f64 - mean).powi(3) * p);
170 skew / (variance * variance.sqrt())
171 }
172}
173
174impl distribution::Variance for Categorical {
175 fn variance(&self) -> f64 {
176 use distribution::Mean;
177 let mean = self.mean();
178 self.p
179 .iter()
180 .enumerate()
181 .fold(0.0, |sum, (i, p)| sum + (i as f64 - mean).powi(2) * p)
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use alloc::{vec, vec::Vec};
188 use prelude::*;
189
190 macro_rules! new(
191 (equal $k:expr) => { Categorical::new(&[1.0 / $k as f64; $k]) };
192 ($p:expr) => { Categorical::new(&$p) };
193 );
194
195 #[test]
196 fn distribution() {
197 let d = new!([0.0, 0.75, 0.25, 0.0]);
198 let p = vec![0.0, 0.0, 0.75, 1.0, 1.0];
199
200 let x = (-1..4)
201 .map(|x| d.distribution(x as f64))
202 .collect::<Vec<_>>();
203 assert_eq!(&x, &p);
204
205 let x = (-1..4)
206 .map(|x| d.distribution(x as f64 + 0.5))
207 .collect::<Vec<_>>();
208 assert_eq!(&x, &p);
209
210 let d = new!(equal 3);
211 let p = vec![0.0, 1.0 / 3.0, 2.0 / 3.0, 1.0];
212
213 let x = (-1..3)
214 .map(|x| d.distribution(x as f64))
215 .collect::<Vec<_>>();
216 assert_eq!(&x, &p);
217
218 let x = (-1..3)
219 .map(|x| d.distribution(x as f64 + 0.5))
220 .collect::<Vec<_>>();
221 assert_eq!(&x, &p);
222 }
223
224 #[test]
225 fn entropy() {
226 use core::f64::consts::LN_2;
227 assert_eq!(new!(equal 2).entropy(), LN_2);
228 assert_eq!(new!([0.1, 0.2, 0.3, 0.4]).entropy(), 1.2798542258336676);
229 }
230
231 #[test]
232 fn inverse() {
233 let d = new!([0.0, 0.75, 0.25, 0.0]);
234 let p = vec![0.0, 0.75, 0.7500001, 1.0];
235 assert_eq!(
236 &p.iter().map(|&p| d.inverse(p)).collect::<Vec<_>>(),
237 &vec![1, 1, 2, 2]
238 );
239
240 let d = new!(equal 3);
241 let p = vec![0.0, 0.5, 0.75, 1.0];
242 assert_eq!(
243 &p.iter().map(|&p| d.inverse(p)).collect::<Vec<_>>(),
244 &vec![0, 1, 2, 2]
245 );
246 }
247
248 #[test]
249 fn kurtosis() {
250 assert_eq!(new!(equal 2).kurtosis(), -2.0);
251 assert_eq!(new!([0.1, 0.2, 0.3, 0.4]).kurtosis(), -0.7999999999999998);
252 }
253
254 #[test]
255 fn mass() {
256 let p = [0.0, 0.75, 0.25, 0.0];
257 let d = new!(p);
258 assert_eq!(&(0..4).map(|x| d.mass(x)).collect::<Vec<_>>(), &p.to_vec());
259
260 let d = new!(equal 3);
261 assert_eq!(
262 &(0..3).map(|x| d.mass(x)).collect::<Vec<_>>(),
263 &vec![1.0 / 3.0; 3]
264 )
265 }
266
267 #[test]
268 fn mean() {
269 assert_eq!(new!(equal 3).mean(), 1.0);
270 assert_eq!(new!([0.3, 0.3, 0.4]).mean(), 1.1);
271 assert_eq!(
272 new!([1.0 / 6.0, 1.0 / 3.0, 1.0 / 3.0, 1.0 / 6.0]).mean(),
273 1.5
274 );
275 }
276
277 #[test]
278 fn median() {
279 assert_eq!(new!([0.6, 0.2, 0.2]).median(), 0.0);
280 assert_eq!(new!(equal 2).median(), 0.5);
281 assert_eq!(new!([0.1, 0.2, 0.3, 0.4]).median(), 2.0);
282 assert_eq!(
283 new!([1.0 / 6.0, 1.0 / 3.0, 1.0 / 3.0, 1.0 / 6.0]).median(),
284 0.5
285 );
286 }
287
288 #[test]
289 fn modes() {
290 assert_eq!(new!([0.6, 0.2, 0.2]).modes(), vec![0]);
291 assert_eq!(new!(equal 2).modes(), vec![0, 1]);
292 assert_eq!(new!(equal 3).modes(), vec![0, 1, 2]);
293 assert_eq!(new!([0.4, 0.2, 0.4]).modes(), vec![0, 2]);
294 assert_eq!(
295 new!([1.0 / 6.0, 1.0 / 3.0, 1.0 / 3.0, 1.0 / 6.0]).modes(),
296 vec![1, 2]
297 );
298 }
299
300 #[test]
301 fn sample() {
302 let mut source = source::default(42);
303
304 let sum = Independent(&new!([0.0, 0.5, 0.5]), &mut source)
305 .take(100)
306 .fold(0, |a, b| a + b);
307 assert!(100 <= sum && sum <= 200);
308
309 let p = (0..11)
310 .map(|i| if i % 2 != 0 { 0.2 } else { 0.0 })
311 .collect::<Vec<_>>();
312 assert!(Independent(&new!(p), &mut source)
313 .take(1000)
314 .all(|x| x % 2 != 0));
315 }
316
317 #[test]
318 fn skewness() {
319 assert_eq!(new!(equal 6).skewness(), 0.0);
320 assert_eq!(
321 new!([1.0 / 6.0, 1.0 / 3.0, 1.0 / 3.0, 1.0 / 6.0]).skewness(),
322 0.0
323 );
324 assert_eq!(new!([0.1, 0.2, 0.3, 0.4]).skewness(), -0.6);
325 }
326
327 #[test]
328 fn variance() {
329 assert_eq!(new!(equal 3).variance(), 2.0 / 3.0);
330 assert_eq!(
331 new!([1.0 / 6.0, 1.0 / 3.0, 1.0 / 3.0, 1.0 / 6.0]).variance(),
332 11.0 / 12.0
333 );
334 }
335}