compute/distributions/
discreteuniform.rs

1use crate::distributions::*;
2
3/// Implements the [discrete uniform distribution](https://en.wikipedia.org/wiki/Discrete_uniform_distribution).
4#[derive(Debug, Clone, Copy)]
5pub struct DiscreteUniform {
6    /// Lower bound for the discrete uniform distribution.
7    lower: i64,
8    /// Upper bound for the discrete uniform distribution.
9    upper: i64,
10}
11
12impl DiscreteUniform {
13    /// Create a new discrete uniform distribution with lower bound `lower` and upper bound `upper` (inclusive on both ends).
14    ///
15    /// # Errors
16    /// Panics if `lower > upper`.
17    pub fn new(lower: i64, upper: i64) -> Self {
18        if lower > upper {
19            panic!("`Upper` must be larger than `lower`.");
20        }
21        DiscreteUniform { lower, upper }
22    }
23    pub fn set_lower(&mut self, lower: i64) -> &mut Self {
24        if lower > self.upper {
25            panic!("Upper must be larger than lower.")
26        }
27        self.lower = lower;
28        self
29    }
30    pub fn set_upper(&mut self, upper: i64) -> &mut Self {
31        if self.lower > upper {
32            panic!("Upper must be larger than lower.")
33        }
34        self.upper = upper;
35        self
36    }
37}
38
39impl Default for DiscreteUniform {
40    fn default() -> Self {
41        Self::new(0, 1)
42    }
43}
44
45impl Distribution for DiscreteUniform {
46    type Output = f64;
47    /// Samples from the given discrete uniform distribution.
48    fn sample(&self) -> f64 {
49        alea::i64_in_range(self.lower, self.upper) as f64
50    }
51}
52
53impl Distribution1D for DiscreteUniform {
54    fn update(&mut self, params: &[f64]) {
55        self.set_lower(params[0] as i64).set_upper(params[1] as i64);
56    }
57}
58
59impl Discrete for DiscreteUniform {
60    /// Calculates the [probability mass
61    /// function](https://en.wikipedia.org/wiki/Probability_mass_function) for the given discrete uniform
62    /// distribution at `x`.
63    ///
64    /// # Remarks
65    ///
66    /// Returns `0.` if `x` is not in `[lower, upper]`
67    fn pmf(&self, x: i64) -> f64 {
68        if x < self.lower || x > self.upper {
69            0.
70        } else {
71            1. / (self.upper - self.lower + 1) as f64
72        }
73    }
74}
75
76impl Mean for DiscreteUniform {
77    type MeanType = f64;
78    /// Calculates the mean, which for a Uniform(a, b) distribution is given by `(a + b) / 2`.
79    fn mean(&self) -> f64 {
80        ((self.lower + self.upper) / 2) as f64
81    }
82}
83
84impl Variance for DiscreteUniform {
85    type VarianceType = f64;
86    /// Calculates the variance of the given Uniform distribution.
87    fn var(&self) -> f64 {
88        (((self.upper - self.lower + 1) as f64).powi(2) - 1.) / 12.
89    }
90}
91
92#[test]
93fn inrange() {
94    let u = self::DiscreteUniform::new(-2, 6);
95    let samples = u.sample_n(100);
96    samples.into_iter().for_each(|x| {
97        assert!(-2. <= x);
98        assert!(x <= 6.);
99    })
100}