use rand::distributions::{Bernoulli as BoolBernoulli, Distribution};
use rand::prelude::*;
use std::cmp::PartialOrd;
use std::convert::Infallible;
pub trait ArrayDistribution<E, T> {
fn batch_shape(&self) -> Vec<usize>;
fn element_shape(&self) -> Vec<usize>;
fn sample(&self) -> E;
fn log_probs(&self, elements: &E) -> T;
fn entropy(&self) -> T;
fn kl_divergence_from(&self, other: &Self) -> T;
}
pub trait FromMean<T>
where
Self: Sized,
{
type Error;
#[allow(clippy::missing_errors_doc)]
fn from_mean(mean: T) -> Result<Self, Self::Error>;
}
pub trait Bounded<T: PartialOrd> {
fn bounds(&self) -> (T, T);
}
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Deterministic<T>(T);
impl<T> Deterministic<T> {
pub const fn new(value: T) -> Self {
Self(value)
}
}
impl<T: Copy> Distribution<T> for Deterministic<T> {
fn sample<R: Rng + ?Sized>(&self, _rng: &mut R) -> T {
self.0
}
}
impl Bounded<f64> for Deterministic<f64> {
fn bounds(&self) -> (f64, f64) {
(self.0, self.0)
}
}
impl<T> FromMean<T> for Deterministic<T> {
type Error = Infallible;
fn from_mean(mean: T) -> Result<Self, Self::Error> {
Ok(Self::new(mean))
}
}
#[derive(Debug, Clone, Copy)]
pub struct Bernoulli(BoolBernoulli);
impl Bernoulli {
pub fn new(mean: f64) -> Result<Self, rand::distributions::BernoulliError> {
Ok(Self(BoolBernoulli::new(mean)?))
}
}
impl Distribution<f64> for Bernoulli {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
if self.0.sample(rng) {
1.0
} else {
0.0
}
}
}
impl Bounded<f64> for Bernoulli {
fn bounds(&self) -> (f64, f64) {
(0.0, 1.0)
}
}
impl FromMean<f64> for Bernoulli {
type Error = rand::distributions::BernoulliError;
fn from_mean(mean: f64) -> Result<Self, Self::Error> {
Self::new(mean)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn check_bounded<T: PartialOrd, D: Distribution<T> + Bounded<T>>(
d: &D,
num_samples: usize,
seed: u64,
) {
let (lower_bound, upper_bound) = d.bounds();
let mut rng = StdRng::seed_from_u64(seed);
for _ in 0..num_samples {
let x: T = d.sample(&mut rng);
assert!(x >= lower_bound);
assert!(x <= upper_bound);
}
}
fn check_mean<D: Distribution<f64> + FromMean<f64>>(
d: &D,
expected_mean: f64,
num_samples: usize,
stddev_upper_bound: f64,
seed: u64,
) {
let rng = StdRng::seed_from_u64(seed);
let empirical_mean = Distribution::<f64>::sample_iter(&d, rng)
.take(num_samples)
.sum::<f64>()
/ (num_samples as f64);
let false_positive_prob: f64 = 1e-5;
let error_bound =
stddev_upper_bound * (-2.0 / (num_samples as f64) * false_positive_prob.ln()).sqrt();
assert!(
error_bound < 0.1 || error_bound < 0.1 * expected_mean.abs(),
"Use more samples"
);
assert!((empirical_mean - expected_mean) < error_bound);
}
#[test]
fn deterministic_bounded() {
check_bounded(&Deterministic::new(0.7), 1000, 1);
}
#[test]
fn deterministic_mean() {
check_mean(&Deterministic::new(0.7), 0.7, 100, 1e-6, 2);
}
#[test]
fn bernoulli_bounded() {
check_bounded(&Bernoulli::new(0.7).unwrap(), 1000, 1);
}
#[test]
fn bernoulli_mean() {
let p: f64 = 0.7;
let stddev = (p * (1.0 - p)).sqrt();
check_mean(&Bernoulli::new(p).unwrap(), p, 1000, stddev, 2);
}
}