use crate::traits::{
Cdf, DiscreteDistr, Entropy, HasDensity, InverseCdf, Kurtosis, Mean,
Median, Parameterized, Sampleable, Skewness, Support, Variance,
};
use num::{FromPrimitive, Integer, ToPrimitive};
use rand::Rng;
use rand_distr::uniform::SampleUniform;
use std::f64;
use std::fmt;
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
pub trait DuParam: Integer + Copy {}
impl<T> DuParam for T where T: Integer + Copy {}
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
pub struct DiscreteUniform<T: DuParam> {
a: T,
b: T,
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
pub enum DiscreteUniformError {
InvalidInterval,
}
pub struct DiscreteUniformParameters<T: DuParam> {
pub a: T,
pub b: T,
}
impl<T: DuParam> Parameterized for DiscreteUniform<T> {
type Parameters = DiscreteUniformParameters<T>;
fn emit_params(&self) -> Self::Parameters {
Self::Parameters {
a: self.a(),
b: self.b(),
}
}
fn from_params(params: Self::Parameters) -> Self {
Self::new_unchecked(params.a, params.b)
}
}
impl<T: DuParam> DiscreteUniform<T> {
#[inline]
pub fn new(a: T, b: T) -> Result<Self, DiscreteUniformError> {
if a < b {
Ok(Self { a, b })
} else {
Err(DiscreteUniformError::InvalidInterval)
}
}
#[inline]
pub fn new_unchecked(a: T, b: T) -> Self {
Self { a, b }
}
#[inline]
pub fn a(&self) -> T {
self.a
}
#[inline]
pub fn b(&self) -> T {
self.b
}
}
impl<T> From<&DiscreteUniform<T>> for String
where
T: DuParam + fmt::Display,
{
fn from(u: &DiscreteUniform<T>) -> String {
format!("DiscreteUniform({}, {})", u.a, u.b)
}
}
#[cfg_attr(coverage_nightly, coverage(off))]
impl<T> fmt::Display for DiscreteUniform<T>
where
T: DuParam + fmt::Display,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", String::from(self))
}
}
impl<X, T> HasDensity<X> for DiscreteUniform<T>
where
T: DuParam + SampleUniform + Copy,
X: Integer + From<T>,
{
fn ln_f(&self, x: &X) -> f64 {
if *x >= X::from(self.a) && *x <= X::from(self.b) {
0.0
} else {
f64::NEG_INFINITY
}
}
}
impl<X, T> Sampleable<X> for DiscreteUniform<T>
where
T: DuParam + SampleUniform + Copy,
X: Integer + From<T>,
{
fn draw<R: Rng>(&self, rng: &mut R) -> X {
let d = rand::distr::Uniform::new_inclusive(self.a, self.b)
.expect("By construction, this should be valid");
X::from(rng.sample(d))
}
fn sample<R: Rng>(&self, n: usize, rng: &mut R) -> Vec<X> {
let d = rand::distr::Uniform::new_inclusive(self.a, self.b)
.expect("By construction, this should be valid");
rng.sample_iter(&d).take(n).map(X::from).collect()
}
}
impl<X, T> Support<X> for DiscreteUniform<T>
where
X: Integer + From<T>,
T: DuParam,
{
fn supports(&self, x: &X) -> bool {
X::from(self.a) <= *x && X::from(self.b) >= *x
}
}
impl<X, T> DiscreteDistr<X> for DiscreteUniform<T>
where
X: Integer + From<T>,
T: DuParam + SampleUniform + Into<f64>,
{
}
impl<T> Entropy for DiscreteUniform<T>
where
T: DuParam + Into<f64>,
{
fn entropy(&self) -> f64 {
let diff: f64 = (self.b - self.a).into();
diff.ln()
}
}
impl<T> Mean<f64> for DiscreteUniform<T>
where
T: DuParam + SampleUniform + Into<f64>,
{
fn mean(&self) -> Option<f64> {
let m = ((self.b + self.a).into()) / 2.0;
Some(m)
}
}
impl<T> Median<f64> for DiscreteUniform<T>
where
T: DuParam + SampleUniform + Into<f64>,
{
fn median(&self) -> Option<f64> {
let m: f64 = (self.b + self.a).into() / 2.0;
Some(m)
}
}
impl<T> Variance<f64> for DiscreteUniform<T>
where
T: DuParam + SampleUniform + Into<f64>,
{
fn variance(&self) -> Option<f64> {
let v = (self.b - self.a + T::one()).into()
* (self.b - self.a + T::one()).into()
/ 12.0;
Some(v)
}
}
impl<X, T> Cdf<X> for DiscreteUniform<T>
where
X: Integer + From<T> + ToPrimitive + Copy,
T: DuParam + SampleUniform + ToPrimitive,
{
fn cdf(&self, x: &X) -> f64 {
if *x < X::from(self.a) {
0.0
} else if *x >= X::from(self.b) {
1.0
} else {
let xf: f64 = (*x).to_f64().unwrap();
let a: f64 = self.a.to_f64().unwrap();
let b: f64 = self.b.to_f64().unwrap();
(xf - a + 1.0) / (b - a + 1.0)
}
}
}
impl<X, T> InverseCdf<X> for DiscreteUniform<T>
where
X: Integer + From<T> + FromPrimitive,
T: DuParam + SampleUniform + ToPrimitive,
{
fn invcdf(&self, p: f64) -> X {
let diff: f64 = (self.b - self.a).to_f64().unwrap();
X::from_f64(p * diff).unwrap() + X::from(self.a)
}
}
impl<T: DuParam> Skewness for DiscreteUniform<T> {
fn skewness(&self) -> Option<f64> {
Some(0.0)
}
}
impl<T: DuParam> Kurtosis for DiscreteUniform<T> {
fn kurtosis(&self) -> Option<f64> {
Some(-1.2)
}
}
impl std::error::Error for DiscreteUniformError {}
#[cfg_attr(coverage_nightly, coverage(off))]
impl fmt::Display for DiscreteUniformError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidInterval => {
write!(f, "a (lower) is greater than or equal to b (upper)")
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::misc::ks_test;
use crate::test_basic_impls;
const TOL: f64 = 1E-12;
const KS_PVAL: f64 = 0.2;
const N_TRIES: usize = 5;
test_basic_impls!(
u32,
DiscreteUniform<u32>,
DiscreteUniform::new(0_u32, 10_u32).unwrap()
);
#[test]
fn new() {
let u = DiscreteUniform::new(0, 10).unwrap();
assert!(u.a == 0);
assert!(u.b == 10);
}
#[test]
fn new_rejects_a_equal_to_b() {
assert!(DiscreteUniform::new(5, 5).is_err());
}
#[test]
fn new_rejects_a_gt_b() {
assert!(DiscreteUniform::new(5, 1).is_err());
}
#[test]
fn mean() {
let m: f64 = DiscreteUniform::new(0, 10).unwrap().mean().unwrap();
assert::close(m, 5.0, TOL);
}
#[test]
fn median() {
let m: f64 = DiscreteUniform::new(0, 10).unwrap().median().unwrap();
assert::close(m, 5.0, TOL);
}
#[test]
fn variance() {
let v: f64 = DiscreteUniform::new(0, 10).unwrap().variance().unwrap();
assert::close(v, (11.0 * 11.0) / 12.0, TOL);
}
#[test]
fn entropy() {
let h: f64 = DiscreteUniform::new(2, 4).unwrap().entropy();
assert::close(h, f64::consts::LN_2, TOL);
}
#[test]
fn ln_pmf() {
let u = DiscreteUniform::new(0, 10).unwrap();
assert::close(u.ln_pmf(&2_u8), 0.0, TOL);
}
#[test]
fn cdf() {
let u = DiscreteUniform::new(0_u32, 10_u32).unwrap();
assert::close(u.cdf(&0_u32), 1.0 / 11.0, TOL);
assert::close(u.cdf(&5_u32), 6.0 / 11.0, TOL);
assert::close(u.cdf(&10_u32), 1.0, TOL);
}
#[test]
fn cdf_inv_cdf_ident() {
let mut rng = rand::rng();
let ru = rand::distr::Uniform::new_inclusive(0_u32, 100_u32).unwrap();
let u = DiscreteUniform::new(0_u32, 100_u32).unwrap();
for _ in 0..100 {
let x: u32 = rng.sample(ru);
let cdf = u.cdf(&x);
let y: u32 = u.invcdf(cdf);
assert!(x == y);
}
}
#[test]
fn draw_test() {
let mut rng = rand::rng();
let u = DiscreteUniform::new(0_u32, 100_u32).unwrap();
let cdf = |x: u64| u.cdf(&x);
let passes = (0..N_TRIES).fold(0, |acc, _| {
let xs: Vec<u64> = u.sample(1000, &mut rng);
let (_, p) = ks_test(&xs, cdf);
if p > KS_PVAL { acc + 1 } else { acc }
});
assert!(passes > 0);
}
#[test]
fn emit_and_from_params_are_identity() {
let dist_a = DiscreteUniform::new(0_u32, 10_u32).unwrap();
let dist_b = DiscreteUniform::from_params(dist_a.emit_params());
assert_eq!(dist_a, dist_b);
}
}