use rand::{Rng, RngCore, distr::StandardUniform};
use super::element::{Element, ElementConversion};
#[derive(Debug, Default, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
pub enum Distribution {
#[default]
Default,
Bernoulli(f64),
Uniform(f64, f64),
Normal(f64, f64),
}
#[derive(new)]
pub struct DistributionSampler<'a, E, R>
where
StandardUniform: rand::distr::Distribution<E>,
E: rand::distr::uniform::SampleUniform,
R: RngCore,
{
kind: DistributionSamplerKind<E>,
rng: &'a mut R,
}
pub enum DistributionSamplerKind<E>
where
StandardUniform: rand::distr::Distribution<E>,
E: rand::distr::uniform::SampleUniform,
{
Standard(rand::distr::StandardUniform),
Uniform(rand::distr::Uniform<E>),
Bernoulli(rand::distr::Bernoulli),
Normal(rand_distr::Normal<f64>),
}
impl<E, R> DistributionSampler<'_, E, R>
where
StandardUniform: rand::distr::Distribution<E>,
E: rand::distr::uniform::SampleUniform,
E: Element,
R: RngCore,
{
pub fn sample(&mut self) -> E {
match &self.kind {
DistributionSamplerKind::Standard(distribution) => self.rng.sample(distribution),
DistributionSamplerKind::Uniform(distribution) => self.rng.sample(distribution),
DistributionSamplerKind::Bernoulli(distribution) => {
if self.rng.sample(distribution) {
1.elem()
} else {
0.elem()
}
}
DistributionSamplerKind::Normal(distribution) => self.rng.sample(distribution).elem(),
}
}
}
impl Distribution {
pub fn sampler<R, E>(self, rng: &'_ mut R) -> DistributionSampler<'_, E, R>
where
R: RngCore,
E: Element + rand::distr::uniform::SampleUniform,
StandardUniform: rand::distr::Distribution<E>,
{
let kind = match self {
Distribution::Default => {
DistributionSamplerKind::Standard(rand::distr::StandardUniform {})
}
Distribution::Uniform(low, high) => DistributionSamplerKind::Uniform(
rand::distr::Uniform::new(low.elem::<E>(), high.elem::<E>()).unwrap(),
),
Distribution::Bernoulli(prob) => {
DistributionSamplerKind::Bernoulli(rand::distr::Bernoulli::new(prob).unwrap())
}
Distribution::Normal(mean, std) => {
DistributionSamplerKind::Normal(rand_distr::Normal::new(mean, std).unwrap())
}
};
DistributionSampler::new(kind, rng)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_distribution_default() {
let dist: Distribution = Default::default();
assert_eq!(dist, Distribution::Default);
assert_eq!(Distribution::default(), Distribution::Default);
}
}