compute/distributions/
discreteuniform.rs1use crate::distributions::*;
2
3#[derive(Debug, Clone, Copy)]
5pub struct DiscreteUniform {
6 lower: i64,
8 upper: i64,
10}
11
12impl DiscreteUniform {
13 pub fn new(lower: i64, upper: i64) -> Self {
18 if lower > upper {
19 panic!("`Upper` must be larger than `lower`.");
20 }
21 DiscreteUniform { lower, upper }
22 }
23 pub fn set_lower(&mut self, lower: i64) -> &mut Self {
24 if lower > self.upper {
25 panic!("Upper must be larger than lower.")
26 }
27 self.lower = lower;
28 self
29 }
30 pub fn set_upper(&mut self, upper: i64) -> &mut Self {
31 if self.lower > upper {
32 panic!("Upper must be larger than lower.")
33 }
34 self.upper = upper;
35 self
36 }
37}
38
39impl Default for DiscreteUniform {
40 fn default() -> Self {
41 Self::new(0, 1)
42 }
43}
44
45impl Distribution for DiscreteUniform {
46 type Output = f64;
47 fn sample(&self) -> f64 {
49 alea::i64_in_range(self.lower, self.upper) as f64
50 }
51}
52
53impl Distribution1D for DiscreteUniform {
54 fn update(&mut self, params: &[f64]) {
55 self.set_lower(params[0] as i64).set_upper(params[1] as i64);
56 }
57}
58
59impl Discrete for DiscreteUniform {
60 fn pmf(&self, x: i64) -> f64 {
68 if x < self.lower || x > self.upper {
69 0.
70 } else {
71 1. / (self.upper - self.lower + 1) as f64
72 }
73 }
74}
75
76impl Mean for DiscreteUniform {
77 type MeanType = f64;
78 fn mean(&self) -> f64 {
80 ((self.lower + self.upper) / 2) as f64
81 }
82}
83
84impl Variance for DiscreteUniform {
85 type VarianceType = f64;
86 fn var(&self) -> f64 {
88 (((self.upper - self.lower + 1) as f64).powi(2) - 1.) / 12.
89 }
90}
91
92#[test]
93fn inrange() {
94 let u = self::DiscreteUniform::new(-2, 6);
95 let samples = u.sample_n(100);
96 samples.into_iter().for_each(|x| {
97 assert!(-2. <= x);
98 assert!(x <= 6.);
99 })
100}