use crate::{
ConditionDifferentiableDistribution, DependentJoint, Distribution, IndependentJoint,
NormalParams, RandomVariable, SampleableDistribution, ValueDifferentiableDistribution,
};
use crate::{DistributionError, NormalError};
use rand::prelude::*;
use rand_distr::Normal as RandNormal;
use std::{ops::BitAnd, ops::Mul};
#[derive(Clone, Debug)]
pub struct Normal;
impl Distribution for Normal {
type Value = f64;
type Condition = NormalParams;
fn p_kernel(&self, x: &Self::Value, theta: &Self::Condition) -> Result<f64, DistributionError> {
let mu = theta.mu();
let sigma = theta.sigma();
Ok((-(x - mu).powi(2) / (2.0 * sigma.powi(2))).exp())
}
}
impl<Rhs, TRhs> Mul<Rhs> for Normal
where
Rhs: Distribution<Value = TRhs, Condition = NormalParams>,
TRhs: RandomVariable,
{
type Output = IndependentJoint<Self, Rhs, f64, TRhs, NormalParams>;
fn mul(self, rhs: Rhs) -> Self::Output {
IndependentJoint::new(self, rhs)
}
}
impl<Rhs, URhs> BitAnd<Rhs> for Normal
where
Rhs: Distribution<Value = NormalParams, Condition = URhs>,
URhs: RandomVariable,
{
type Output = DependentJoint<Self, Rhs, f64, NormalParams, URhs>;
fn bitand(self, rhs: Rhs) -> Self::Output {
DependentJoint::new(self, rhs)
}
}
impl SampleableDistribution for Normal {
fn sample(
&self,
theta: &Self::Condition,
rng: &mut dyn RngCore,
) -> Result<Self::Value, DistributionError> {
let mu = theta.mu();
let sigma = theta.sigma();
let normal = match RandNormal::new(mu, sigma) {
Ok(n) => n,
Err(_) => {
return Err(DistributionError::InvalidParameters(
NormalError::SigmaMustBePositive.into(),
))
}
};
Ok(rng.sample(normal))
}
}
impl ValueDifferentiableDistribution for Normal {
fn ln_diff_value(
&self,
x: &Self::Value,
theta: &Self::Condition,
) -> Result<Vec<f64>, DistributionError> {
let sigma = theta.sigma();
let mu = theta.mu();
let f_x = -(x - mu) / sigma.powi(2);
Ok(vec![f_x])
}
}
impl ConditionDifferentiableDistribution for Normal {
fn ln_diff_condition(
&self,
x: &Self::Value,
theta: &Self::Condition,
) -> Result<Vec<f64>, DistributionError> {
let sigma = theta.sigma();
let mu = theta.mu();
let f_mu = (x - mu) / sigma.powi(2);
let f_sigma = (x - mu).powi(2) / sigma.powi(3);
Ok(vec![f_mu, f_sigma])
}
}
#[cfg(test)]
mod tests {
use crate::{
ConditionDifferentiableDistribution, Distribution, Normal, NormalParams,
SampleableDistribution, ValueDifferentiableDistribution,
};
use rand::prelude::*;
#[test]
fn it_works() {
let n = Normal;
let mut rng = StdRng::from_seed([1; 32]);
let mu = 2.0;
let sigma = 3.0;
let x = n
.sample(&NormalParams::new(mu, sigma).unwrap(), &mut rng)
.unwrap();
println!("{}", x);
}
#[test]
fn it_works2() {
let n = Normal;
let mu = 2.0;
let sigma = 3.0;
let x = 0.5;
let f = n.ln_diff_value(&x, &NormalParams::new(mu, sigma).unwrap());
println!("{:#?}", f);
}
#[test]
fn it_works_3() {
let n = Normal;
let mu = 2.0;
let sigma = 3.0;
let x = 0.5;
let f = n.ln_diff_condition(&x, &NormalParams::new(mu, sigma).unwrap());
println!("{:#?}", f);
}
}