use crate::distributions::*;
#[derive(Debug, Clone, Copy)]
pub struct DiscreteUniform {
lower: i64,
upper: i64,
}
impl DiscreteUniform {
pub fn new(lower: i64, upper: i64) -> Self {
if lower > upper {
panic!("`Upper` must be larger than `lower`.");
}
DiscreteUniform { lower, upper }
}
pub fn set_lower(&mut self, lower: i64) -> &mut Self {
if lower > self.upper {
panic!("Upper must be larger than lower.")
}
self.lower = lower;
self
}
pub fn set_upper(&mut self, upper: i64) -> &mut Self {
if self.lower > upper {
panic!("Upper must be larger than lower.")
}
self.upper = upper;
self
}
}
impl Default for DiscreteUniform {
fn default() -> Self {
Self::new(0, 1)
}
}
impl Distribution for DiscreteUniform {
type Output = f64;
fn sample(&self) -> f64 {
alea::i64_in_range(self.lower, self.upper) as f64
}
}
impl Distribution1D for DiscreteUniform {
fn update(&mut self, params: &[f64]) {
self.set_lower(params[0] as i64).set_upper(params[1] as i64);
}
}
impl Discrete for DiscreteUniform {
fn pmf(&self, x: i64) -> f64 {
if x < self.lower || x > self.upper {
0.
} else {
1. / (self.upper - self.lower + 1) as f64
}
}
}
impl Mean for DiscreteUniform {
type MeanType = f64;
fn mean(&self) -> f64 {
((self.lower + self.upper) / 2) as f64
}
}
impl Variance for DiscreteUniform {
type VarianceType = f64;
fn var(&self) -> f64 {
(((self.upper - self.lower + 1) as f64).powi(2) - 1.) / 12.
}
}
#[test]
fn inrange() {
let u = self::DiscreteUniform::new(-2, 6);
let samples = u.sample_n(100);
samples.into_iter().for_each(|x| {
assert!(-2. <= x);
assert!(x <= 6.);
})
}