probability/distribution/
categorical.rs

1use alloc::{vec, vec::Vec};
2#[allow(unused_imports)]
3use special::Primitive;
4
5use distribution;
6use source::Source;
7
8/// A categorical distribution.
9#[derive(Clone, Debug)]
10pub struct Categorical {
11    k: usize,
12    p: Vec<f64>,
13    cumsum: Vec<f64>,
14}
15
16impl Categorical {
17    /// Create a categorical distribution with success probability `p`.
18    ///
19    /// It should hold that `p[i] >= 0`, `p[i] <= 1`, and `sum(p) == 1`.
20    pub fn new(p: &[f64]) -> Self {
21        should!(is_probability_vector(p), {
22            const EPSILON: f64 = 1e-12;
23            p.iter().all(|&p| (0.0..=1.0).contains(&p))
24                && (p.iter().fold(0.0, |sum, &p| sum + p) - 1.0).abs() < EPSILON
25        });
26
27        let k = p.len();
28        let mut cumsum = p.to_vec();
29        for i in 1..(k - 1) {
30            cumsum[i] += cumsum[i - 1];
31        }
32        cumsum[k - 1] = 1.0;
33        Categorical {
34            k,
35            p: p.to_vec(),
36            cumsum,
37        }
38    }
39
40    /// Return the number of categories.
41    #[inline(always)]
42    pub fn k(&self) -> usize {
43        self.k
44    }
45
46    /// Return the event probabilities.
47    #[inline(always)]
48    pub fn p(&self) -> &[f64] {
49        &self.p
50    }
51}
52
53impl distribution::Discrete for Categorical {
54    #[inline]
55    fn mass(&self, x: usize) -> f64 {
56        should!(x < self.k);
57        self.p[x]
58    }
59}
60
61impl distribution::Distribution for Categorical {
62    type Value = usize;
63
64    fn distribution(&self, x: f64) -> f64 {
65        if x < 0.0 {
66            return 0.0;
67        }
68        let x = x as usize;
69        if x >= self.k {
70            return 1.0;
71        }
72        self.cumsum[x]
73    }
74}
75
76impl distribution::Entropy for Categorical {
77    fn entropy(&self) -> f64 {
78        -self.p.iter().fold(0.0, |sum, p| sum + p * p.ln())
79    }
80}
81
82impl distribution::Inverse for Categorical {
83    fn inverse(&self, p: f64) -> usize {
84        should!((0.0..=1.0).contains(&p));
85        self.cumsum
86            .iter()
87            .position(|&sum| sum > 0.0 && sum >= p)
88            .unwrap_or_else(|| self.p.iter().rposition(|&p| p > 0.0).unwrap())
89    }
90}
91
92impl distribution::Kurtosis for Categorical {
93    fn kurtosis(&self) -> f64 {
94        use distribution::{Mean, Variance};
95        let (mean, variance) = (self.mean(), self.variance());
96        let kurt = self
97            .p
98            .iter()
99            .enumerate()
100            .fold(0.0, |sum, (i, p)| sum + (i as f64 - mean).powi(4) * p);
101        kurt / variance.powi(2) - 3.0
102    }
103}
104
105impl distribution::Mean for Categorical {
106    fn mean(&self) -> f64 {
107        self.p
108            .iter()
109            .enumerate()
110            .fold(0.0, |sum, (i, p)| sum + i as f64 * p)
111    }
112}
113
114impl distribution::Median for Categorical {
115    fn median(&self) -> f64 {
116        if self.p[0] > 0.5 {
117            return 0.0;
118        }
119        if self.p[0] == 0.5 {
120            return 0.5;
121        }
122        for (i, &sum) in self.cumsum.iter().enumerate() {
123            if sum == 0.5 {
124                return (2 * i - 1) as f64 / 2.0;
125            } else if sum > 0.5 {
126                return i as f64;
127            }
128        }
129        unreachable!()
130    }
131}
132
133impl distribution::Modes for Categorical {
134    fn modes(&self) -> Vec<usize> {
135        let mut modes = Vec::new();
136        let mut max = 0.0;
137        for (i, &p) in self.p.iter().enumerate() {
138            if p == max {
139                modes.push(i);
140            }
141            if p > max {
142                max = p;
143                modes = vec![i];
144            }
145        }
146        modes
147    }
148}
149
150impl distribution::Sample for Categorical {
151    #[inline]
152    fn sample<S>(&self, source: &mut S) -> usize
153    where
154        S: Source,
155    {
156        use distribution::Inverse;
157        self.inverse(source.read::<f64>())
158    }
159}
160
161impl distribution::Skewness for Categorical {
162    fn skewness(&self) -> f64 {
163        use distribution::{Mean, Variance};
164        let (mean, variance) = (self.mean(), self.variance());
165        let skew = self
166            .p
167            .iter()
168            .enumerate()
169            .fold(0.0, |sum, (i, p)| sum + (i as f64 - mean).powi(3) * p);
170        skew / (variance * variance.sqrt())
171    }
172}
173
174impl distribution::Variance for Categorical {
175    fn variance(&self) -> f64 {
176        use distribution::Mean;
177        let mean = self.mean();
178        self.p
179            .iter()
180            .enumerate()
181            .fold(0.0, |sum, (i, p)| sum + (i as f64 - mean).powi(2) * p)
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use alloc::{vec, vec::Vec};
188    use prelude::*;
189
190    macro_rules! new(
191        (equal $k:expr) => { Categorical::new(&[1.0 / $k as f64; $k]) };
192        ($p:expr) => { Categorical::new(&$p) };
193    );
194
195    #[test]
196    fn distribution() {
197        let d = new!([0.0, 0.75, 0.25, 0.0]);
198        let p = vec![0.0, 0.0, 0.75, 1.0, 1.0];
199
200        let x = (-1..4)
201            .map(|x| d.distribution(x as f64))
202            .collect::<Vec<_>>();
203        assert_eq!(&x, &p);
204
205        let x = (-1..4)
206            .map(|x| d.distribution(x as f64 + 0.5))
207            .collect::<Vec<_>>();
208        assert_eq!(&x, &p);
209
210        let d = new!(equal 3);
211        let p = vec![0.0, 1.0 / 3.0, 2.0 / 3.0, 1.0];
212
213        let x = (-1..3)
214            .map(|x| d.distribution(x as f64))
215            .collect::<Vec<_>>();
216        assert_eq!(&x, &p);
217
218        let x = (-1..3)
219            .map(|x| d.distribution(x as f64 + 0.5))
220            .collect::<Vec<_>>();
221        assert_eq!(&x, &p);
222    }
223
224    #[test]
225    fn entropy() {
226        use core::f64::consts::LN_2;
227        assert_eq!(new!(equal 2).entropy(), LN_2);
228        assert_eq!(new!([0.1, 0.2, 0.3, 0.4]).entropy(), 1.2798542258336676);
229    }
230
231    #[test]
232    fn inverse() {
233        let d = new!([0.0, 0.75, 0.25, 0.0]);
234        let p = vec![0.0, 0.75, 0.7500001, 1.0];
235        assert_eq!(
236            &p.iter().map(|&p| d.inverse(p)).collect::<Vec<_>>(),
237            &vec![1, 1, 2, 2]
238        );
239
240        let d = new!(equal 3);
241        let p = vec![0.0, 0.5, 0.75, 1.0];
242        assert_eq!(
243            &p.iter().map(|&p| d.inverse(p)).collect::<Vec<_>>(),
244            &vec![0, 1, 2, 2]
245        );
246    }
247
248    #[test]
249    fn kurtosis() {
250        assert_eq!(new!(equal 2).kurtosis(), -2.0);
251        assert_eq!(new!([0.1, 0.2, 0.3, 0.4]).kurtosis(), -0.7999999999999998);
252    }
253
254    #[test]
255    fn mass() {
256        let p = [0.0, 0.75, 0.25, 0.0];
257        let d = new!(p);
258        assert_eq!(&(0..4).map(|x| d.mass(x)).collect::<Vec<_>>(), &p.to_vec());
259
260        let d = new!(equal 3);
261        assert_eq!(
262            &(0..3).map(|x| d.mass(x)).collect::<Vec<_>>(),
263            &vec![1.0 / 3.0; 3]
264        )
265    }
266
267    #[test]
268    fn mean() {
269        assert_eq!(new!(equal 3).mean(), 1.0);
270        assert_eq!(new!([0.3, 0.3, 0.4]).mean(), 1.1);
271        assert_eq!(
272            new!([1.0 / 6.0, 1.0 / 3.0, 1.0 / 3.0, 1.0 / 6.0]).mean(),
273            1.5
274        );
275    }
276
277    #[test]
278    fn median() {
279        assert_eq!(new!([0.6, 0.2, 0.2]).median(), 0.0);
280        assert_eq!(new!(equal 2).median(), 0.5);
281        assert_eq!(new!([0.1, 0.2, 0.3, 0.4]).median(), 2.0);
282        assert_eq!(
283            new!([1.0 / 6.0, 1.0 / 3.0, 1.0 / 3.0, 1.0 / 6.0]).median(),
284            0.5
285        );
286    }
287
288    #[test]
289    fn modes() {
290        assert_eq!(new!([0.6, 0.2, 0.2]).modes(), vec![0]);
291        assert_eq!(new!(equal 2).modes(), vec![0, 1]);
292        assert_eq!(new!(equal 3).modes(), vec![0, 1, 2]);
293        assert_eq!(new!([0.4, 0.2, 0.4]).modes(), vec![0, 2]);
294        assert_eq!(
295            new!([1.0 / 6.0, 1.0 / 3.0, 1.0 / 3.0, 1.0 / 6.0]).modes(),
296            vec![1, 2]
297        );
298    }
299
300    #[test]
301    fn sample() {
302        let mut source = source::default(42);
303
304        let sum = Independent(&new!([0.0, 0.5, 0.5]), &mut source)
305            .take(100)
306            .fold(0, |a, b| a + b);
307        assert!(100 <= sum && sum <= 200);
308
309        let p = (0..11)
310            .map(|i| if i % 2 != 0 { 0.2 } else { 0.0 })
311            .collect::<Vec<_>>();
312        assert!(Independent(&new!(p), &mut source)
313            .take(1000)
314            .all(|x| x % 2 != 0));
315    }
316
317    #[test]
318    fn skewness() {
319        assert_eq!(new!(equal 6).skewness(), 0.0);
320        assert_eq!(
321            new!([1.0 / 6.0, 1.0 / 3.0, 1.0 / 3.0, 1.0 / 6.0]).skewness(),
322            0.0
323        );
324        assert_eq!(new!([0.1, 0.2, 0.3, 0.4]).skewness(), -0.6);
325    }
326
327    #[test]
328    fn variance() {
329        assert_eq!(new!(equal 3).variance(), 2.0 / 3.0);
330        assert_eq!(
331            new!([1.0 / 6.0, 1.0 / 3.0, 1.0 / 3.0, 1.0 / 6.0]).variance(),
332            11.0 / 12.0
333        );
334    }
335}