use crate::{ProbabilityDistribution, StrError};
use rand::Rng;
use rand_distr::{Distribution, Uniform};
pub struct DistributionUniform {
xmin: f64, xmax: f64,
sampler: Uniform<f64>, }
impl DistributionUniform {
pub fn new(xmin: f64, xmax: f64) -> Result<Self, StrError> {
if xmax < xmin {
return Err("invalid parameters");
}
Ok(DistributionUniform {
xmin,
xmax,
sampler: Uniform::new(xmin, xmax).map_err(|_| "invalid parameters")?,
})
}
}
impl ProbabilityDistribution for DistributionUniform {
fn pdf(&self, x: f64) -> f64 {
if x < self.xmin {
return 0.0;
}
if x > self.xmax {
return 0.0;
}
1.0 / (self.xmax - self.xmin)
}
fn cdf(&self, x: f64) -> f64 {
if x < self.xmin {
return 0.0;
}
if x > self.xmax {
return 1.0;
}
(x - self.xmin) / (self.xmax - self.xmin)
}
fn mean(&self) -> f64 {
(self.xmin + self.xmax) / 2.0
}
fn variance(&self) -> f64 {
(self.xmax - self.xmin) * (self.xmax - self.xmin) / 12.0
}
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
self.sampler.sample(rng)
}
}
#[cfg(test)]
mod tests {
use crate::{DistributionUniform, ProbabilityDistribution};
use rand::prelude::StdRng;
use rand::SeedableRng;
use russell_lab::approx_eq;
#[test]
fn uniform_handles_errors() {
assert_eq!(DistributionUniform::new(2.0, 1.0).err(), Some("invalid parameters"));
}
#[test]
fn uniform_works() {
#[rustfmt::skip]
let data = [
[0.5, 1.5, 2.5, 0.0, 0.0],
[1.0, 1.5, 2.5, 0.0, 0.0],
[1.5, 1.5, 2.5, 1.0, 0.0],
[2.0, 1.5, 2.5, 1.0, 0.5],
[2.5, 1.5, 2.5, 1.0, 1.0],
[3.0, 1.5, 2.5, 0.0, 1.0],
];
for row in data {
let [x, xmin, xmax, pdf, cdf] = row;
let d = DistributionUniform::new(xmin, xmax).unwrap();
approx_eq(d.pdf(x), pdf, 1e-14);
approx_eq(d.cdf(x), cdf, 1e-14);
}
}
#[test]
fn mean_and_variance_work() {
let d = DistributionUniform::new(1.0, 3.0).unwrap();
approx_eq(d.mean(), 2.0, 1e-14);
approx_eq(d.variance(), 1.0 / 3.0, 1e-14);
}
#[test]
fn sample_works() {
let mut rng = StdRng::seed_from_u64(1234);
let dist_x = DistributionUniform::new(0.0, 2.0).unwrap();
let dist_y = DistributionUniform::new(0.0, 1.0).unwrap();
let x = dist_x.sample(&mut rng);
let y = dist_y.sample(&mut rng);
approx_eq(x, 0.23691851694908816, 1e-15);
approx_eq(y, 0.16964948689475423, 1e-15);
}
}