use self::GammaRepr::*;
use self::ChiSquaredRepr::*;
use rand::Rng;
use crate::normal::StandardNormal;
use crate::{Distribution, Exp1, Exp, Open01};
use crate::utils::Float;
#[derive(Clone, Copy, Debug)]
pub struct Gamma<N> {
repr: GammaRepr<N>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Error {
ShapeTooSmall,
ScaleTooSmall,
ScaleTooLarge,
}
#[derive(Clone, Copy, Debug)]
enum GammaRepr<N> {
Large(GammaLargeShape<N>),
One(Exp<N>),
Small(GammaSmallShape<N>)
}
#[derive(Clone, Copy, Debug)]
struct GammaSmallShape<N> {
inv_shape: N,
large_shape: GammaLargeShape<N>
}
#[derive(Clone, Copy, Debug)]
struct GammaLargeShape<N> {
scale: N,
c: N,
d: N
}
impl<N: Float> Gamma<N>
where StandardNormal: Distribution<N>, Exp1: Distribution<N>, Open01: Distribution<N>
{
#[inline]
pub fn new(shape: N, scale: N) -> Result<Gamma<N>, Error> {
if !(shape > N::from(0.0)) {
return Err(Error::ShapeTooSmall);
}
if !(scale > N::from(0.0)) {
return Err(Error::ScaleTooSmall);
}
let repr = if shape == N::from(1.0) {
One(Exp::new(N::from(1.0) / scale).map_err(|_| Error::ScaleTooLarge)?)
} else if shape < N::from(1.0) {
Small(GammaSmallShape::new_raw(shape, scale))
} else {
Large(GammaLargeShape::new_raw(shape, scale))
};
Ok(Gamma { repr })
}
}
impl<N: Float> GammaSmallShape<N>
where StandardNormal: Distribution<N>, Open01: Distribution<N>
{
fn new_raw(shape: N, scale: N) -> GammaSmallShape<N> {
GammaSmallShape {
inv_shape: N::from(1.0) / shape,
large_shape: GammaLargeShape::new_raw(shape + N::from(1.0), scale)
}
}
}
impl<N: Float> GammaLargeShape<N>
where StandardNormal: Distribution<N>, Open01: Distribution<N>
{
fn new_raw(shape: N, scale: N) -> GammaLargeShape<N> {
let d = shape - N::from(1. / 3.);
GammaLargeShape {
scale,
c: N::from(1.0) / (N::from(9.) * d).sqrt(),
d
}
}
}
impl<N: Float> Distribution<N> for Gamma<N>
where StandardNormal: Distribution<N>, Exp1: Distribution<N>, Open01: Distribution<N>
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> N {
match self.repr {
Small(ref g) => g.sample(rng),
One(ref g) => g.sample(rng),
Large(ref g) => g.sample(rng),
}
}
}
impl<N: Float> Distribution<N> for GammaSmallShape<N>
where StandardNormal: Distribution<N>, Open01: Distribution<N>
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> N {
let u: N = rng.sample(Open01);
self.large_shape.sample(rng) * u.powf(self.inv_shape)
}
}
impl<N: Float> Distribution<N> for GammaLargeShape<N>
where StandardNormal: Distribution<N>, Open01: Distribution<N>
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> N {
loop {
let x: N = rng.sample(StandardNormal);
let v_cbrt = N::from(1.0) + self.c * x;
if v_cbrt <= N::from(0.0) {
continue
}
let v = v_cbrt * v_cbrt * v_cbrt;
let u: N = rng.sample(Open01);
let x_sqr = x * x;
if u < N::from(1.0) - N::from(0.0331) * x_sqr * x_sqr ||
u.ln() < N::from(0.5) * x_sqr + self.d * (N::from(1.0) - v + v.ln())
{
return self.d * v * self.scale
}
}
}
}
#[derive(Clone, Copy, Debug)]
pub struct ChiSquared<N> {
repr: ChiSquaredRepr<N>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum ChiSquaredError {
DoFTooSmall,
}
#[derive(Clone, Copy, Debug)]
enum ChiSquaredRepr<N> {
DoFExactlyOne,
DoFAnythingElse(Gamma<N>),
}
impl<N: Float> ChiSquared<N>
where StandardNormal: Distribution<N>, Exp1: Distribution<N>, Open01: Distribution<N>
{
pub fn new(k: N) -> Result<ChiSquared<N>, ChiSquaredError> {
let repr = if k == N::from(1.0) {
DoFExactlyOne
} else {
if !(N::from(0.5) * k > N::from(0.0)) {
return Err(ChiSquaredError::DoFTooSmall);
}
DoFAnythingElse(Gamma::new(N::from(0.5) * k, N::from(2.0)).unwrap())
};
Ok(ChiSquared { repr })
}
}
impl<N: Float> Distribution<N> for ChiSquared<N>
where StandardNormal: Distribution<N>, Exp1: Distribution<N>, Open01: Distribution<N>
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> N {
match self.repr {
DoFExactlyOne => {
let norm: N = rng.sample(StandardNormal);
norm * norm
}
DoFAnythingElse(ref g) => g.sample(rng)
}
}
}
#[derive(Clone, Copy, Debug)]
pub struct FisherF<N> {
numer: ChiSquared<N>,
denom: ChiSquared<N>,
dof_ratio: N,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum FisherFError {
MTooSmall,
NTooSmall,
}
impl<N: Float> FisherF<N>
where StandardNormal: Distribution<N>, Exp1: Distribution<N>, Open01: Distribution<N>
{
pub fn new(m: N, n: N) -> Result<FisherF<N>, FisherFError> {
if !(m > N::from(0.0)) {
return Err(FisherFError::MTooSmall);
}
if !(n > N::from(0.0)) {
return Err(FisherFError::NTooSmall);
}
Ok(FisherF {
numer: ChiSquared::new(m).unwrap(),
denom: ChiSquared::new(n).unwrap(),
dof_ratio: n / m
})
}
}
impl<N: Float> Distribution<N> for FisherF<N>
where StandardNormal: Distribution<N>, Exp1: Distribution<N>, Open01: Distribution<N>
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> N {
self.numer.sample(rng) / self.denom.sample(rng) * self.dof_ratio
}
}
#[derive(Clone, Copy, Debug)]
pub struct StudentT<N> {
chi: ChiSquared<N>,
dof: N
}
impl<N: Float> StudentT<N>
where StandardNormal: Distribution<N>, Exp1: Distribution<N>, Open01: Distribution<N>
{
pub fn new(n: N) -> Result<StudentT<N>, ChiSquaredError> {
Ok(StudentT {
chi: ChiSquared::new(n)?,
dof: n
})
}
}
impl<N: Float> Distribution<N> for StudentT<N>
where StandardNormal: Distribution<N>, Exp1: Distribution<N>, Open01: Distribution<N>
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> N {
let norm: N = rng.sample(StandardNormal);
norm * (self.dof / self.chi.sample(rng)).sqrt()
}
}
#[derive(Clone, Copy, Debug)]
pub struct Beta<N> {
gamma_a: Gamma<N>,
gamma_b: Gamma<N>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum BetaError {
AlphaTooSmall,
BetaTooSmall,
}
impl<N: Float> Beta<N>
where StandardNormal: Distribution<N>, Exp1: Distribution<N>, Open01: Distribution<N>
{
pub fn new(alpha: N, beta: N) -> Result<Beta<N>, BetaError> {
Ok(Beta {
gamma_a: Gamma::new(alpha, N::from(1.))
.map_err(|_| BetaError::AlphaTooSmall)?,
gamma_b: Gamma::new(beta, N::from(1.))
.map_err(|_| BetaError::BetaTooSmall)?,
})
}
}
impl<N: Float> Distribution<N> for Beta<N>
where StandardNormal: Distribution<N>, Exp1: Distribution<N>, Open01: Distribution<N>
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> N {
let x = self.gamma_a.sample(rng);
let y = self.gamma_b.sample(rng);
x / (x + y)
}
}
#[cfg(test)]
mod test {
use crate::Distribution;
use super::{Beta, ChiSquared, StudentT, FisherF};
#[test]
fn test_chi_squared_one() {
let chi = ChiSquared::new(1.0).unwrap();
let mut rng = crate::test::rng(201);
for _ in 0..1000 {
chi.sample(&mut rng);
}
}
#[test]
fn test_chi_squared_small() {
let chi = ChiSquared::new(0.5).unwrap();
let mut rng = crate::test::rng(202);
for _ in 0..1000 {
chi.sample(&mut rng);
}
}
#[test]
fn test_chi_squared_large() {
let chi = ChiSquared::new(30.0).unwrap();
let mut rng = crate::test::rng(203);
for _ in 0..1000 {
chi.sample(&mut rng);
}
}
#[test]
#[should_panic]
fn test_chi_squared_invalid_dof() {
ChiSquared::new(-1.0).unwrap();
}
#[test]
fn test_f() {
let f = FisherF::new(2.0, 32.0).unwrap();
let mut rng = crate::test::rng(204);
for _ in 0..1000 {
f.sample(&mut rng);
}
}
#[test]
fn test_t() {
let t = StudentT::new(11.0).unwrap();
let mut rng = crate::test::rng(205);
for _ in 0..1000 {
t.sample(&mut rng);
}
}
#[test]
fn test_beta() {
let beta = Beta::new(1.0, 2.0).unwrap();
let mut rng = crate::test::rng(201);
for _ in 0..1000 {
beta.sample(&mut rng);
}
}
#[test]
#[should_panic]
fn test_beta_invalid_dof() {
Beta::new(0., 0.).unwrap();
}
}