use crate::data::DataOrSuffStat;
use rand::Rng;
pub trait Rv<X> {
fn f(&self, x: &X) -> f64 {
self.ln_f(x).exp()
}
fn ln_f(&self, x: &X) -> f64;
fn draw<R: Rng>(&self, rng: &mut R) -> X;
fn sample<R: Rng>(&self, n: usize, mut rng: &mut R) -> Vec<X> {
(0..n).map(|_| self.draw(&mut rng)).collect()
}
fn sample_stream<'r, R: Rng>(
&'r self,
mut rng: &'r mut R,
) -> Box<dyn Iterator<Item = X> + 'r> {
Box::new(std::iter::repeat_with(move || self.draw(&mut rng)))
}
}
impl<Fx, X> Rv<X> for Fx
where
Fx: std::ops::Deref,
Fx::Target: Rv<X>,
{
fn ln_f(&self, x: &X) -> f64 {
self.deref().ln_f(x)
}
fn f(&self, x: &X) -> f64 {
self.deref().f(x)
}
fn draw<R: Rng>(&self, mut rng: &mut R) -> X {
self.deref().draw(&mut rng)
}
fn sample<R: Rng>(&self, n: usize, mut rng: &mut R) -> Vec<X> {
self.deref().sample(n, &mut rng)
}
}
pub trait Support<X> {
fn supports(&self, x: &X) -> bool;
}
impl<Fx, X> Support<X> for Fx
where
Fx: std::ops::Deref,
Fx::Target: Support<X>,
{
fn supports(&self, x: &X) -> bool {
self.deref().supports(x)
}
}
pub trait ContinuousDistr<X>: Rv<X> + Support<X> {
fn pdf(&self, x: &X) -> f64 {
self.ln_pdf(x).exp()
}
fn ln_pdf(&self, x: &X) -> f64 {
if !self.supports(&x) {
panic!("x not in support");
}
self.ln_f(x)
}
}
impl<Fx, X> ContinuousDistr<X> for Fx
where
Fx: std::ops::Deref,
Fx::Target: ContinuousDistr<X>,
{
}
pub trait Cdf<X>: Rv<X> {
fn cdf(&self, x: &X) -> f64;
fn sf(&self, x: &X) -> f64 {
1.0 - self.cdf(x)
}
}
impl<Fx, X> Cdf<X> for Fx
where
Fx: std::ops::Deref,
Fx::Target: Cdf<X>,
{
fn cdf(&self, x: &X) -> f64 {
self.deref().cdf(x)
}
fn sf(&self, x: &X) -> f64 {
self.deref().sf(x)
}
}
pub trait InverseCdf<X>: Rv<X> + Support<X> {
fn invcdf(&self, p: f64) -> X;
fn quantile(&self, p: f64) -> X {
self.invcdf(p)
}
fn interval(&self, p: f64) -> (X, X) {
let pt = (1.0 - p) / 2.0;
(self.quantile(pt), self.quantile(p + pt))
}
}
impl<Fx, X> InverseCdf<X> for Fx
where
Fx: std::ops::Deref,
Fx::Target: InverseCdf<X>,
{
fn invcdf(&self, p: f64) -> X {
self.deref().invcdf(p)
}
fn quantile(&self, p: f64) -> X {
self.deref().quantile(p)
}
fn interval(&self, p: f64) -> (X, X) {
self.deref().interval(p)
}
}
pub trait DiscreteDistr<X>: Rv<X> + Support<X> {
fn pmf(&self, x: &X) -> f64 {
self.ln_pmf(x).exp()
}
fn ln_pmf(&self, x: &X) -> f64 {
if !self.supports(&x) {
panic!("x not in support");
}
self.ln_f(x)
}
}
impl<Fx, X> DiscreteDistr<X> for Fx
where
Fx: std::ops::Deref,
Fx::Target: DiscreteDistr<X>,
{
}
pub trait Mean<X> {
fn mean(&self) -> Option<X>;
}
impl<Fx, X> Mean<X> for Fx
where
Fx: std::ops::Deref,
Fx::Target: Mean<X>,
{
fn mean(&self) -> Option<X> {
self.deref().mean()
}
}
pub trait Median<X> {
fn median(&self) -> Option<X>;
}
impl<Fx, X> Median<X> for Fx
where
Fx: std::ops::Deref,
Fx::Target: Median<X>,
{
fn median(&self) -> Option<X> {
self.deref().median()
}
}
pub trait Mode<X> {
fn mode(&self) -> Option<X>;
}
impl<Fx, X> Mode<X> for Fx
where
Fx: std::ops::Deref,
Fx::Target: Mode<X>,
{
fn mode(&self) -> Option<X> {
self.deref().mode()
}
}
pub trait Variance<X> {
fn variance(&self) -> Option<X>;
}
impl<Fx, X> Variance<X> for Fx
where
Fx: std::ops::Deref,
Fx::Target: Variance<X>,
{
fn variance(&self) -> Option<X> {
self.deref().variance()
}
}
pub trait Entropy {
fn entropy(&self) -> f64;
}
impl<Fx> Entropy for Fx
where
Fx: std::ops::Deref,
Fx::Target: Entropy,
{
fn entropy(&self) -> f64 {
self.deref().entropy()
}
}
pub trait Skewness {
fn skewness(&self) -> Option<f64>;
}
impl<Fx> Skewness for Fx
where
Fx: std::ops::Deref,
Fx::Target: Skewness,
{
fn skewness(&self) -> Option<f64> {
self.deref().skewness()
}
}
pub trait Kurtosis {
fn kurtosis(&self) -> Option<f64>;
}
impl<Fx> Kurtosis for Fx
where
Fx: std::ops::Deref,
Fx::Target: Kurtosis,
{
fn kurtosis(&self) -> Option<f64> {
self.deref().kurtosis()
}
}
pub trait KlDivergence {
fn kl(&self, other: &Self) -> f64;
fn kl_sym(&self, other: &Self) -> f64 {
self.kl(&other) + other.kl(&self)
}
}
impl<Fx> KlDivergence for Fx
where
Fx: std::ops::Deref,
Fx::Target: KlDivergence,
{
fn kl(&self, other: &Self) -> f64 {
self.deref().kl(other)
}
fn kl_sym(&self, other: &Self) -> f64 {
self.deref().kl_sym(other)
}
}
pub trait HasSuffStat<X> {
type Stat: SuffStat<X>;
fn empty_suffstat(&self) -> Self::Stat;
}
impl<Fx, X> HasSuffStat<X> for Fx
where
Fx: std::ops::Deref,
Fx::Target: HasSuffStat<X>,
{
type Stat = <<Fx as std::ops::Deref>::Target as HasSuffStat<X>>::Stat;
fn empty_suffstat(&self) -> Self::Stat {
self.deref().empty_suffstat()
}
}
pub trait SuffStat<X> {
fn n(&self) -> usize;
fn observe(&mut self, x: &X);
fn forget(&mut self, x: &X);
fn observe_many(&mut self, xs: &[X]) {
xs.iter().for_each(|x| self.observe(x));
}
fn forget_many(&mut self, xs: &[X]) {
xs.iter().for_each(|x| self.forget(x));
}
}
impl<S, X> SuffStat<X> for S
where
S: std::ops::DerefMut,
S::Target: SuffStat<X>,
{
fn n(&self) -> usize {
self.deref().n()
}
fn observe(&mut self, x: &X) {
self.deref_mut().observe(x)
}
fn forget(&mut self, x: &X) {
self.deref_mut().forget(x)
}
fn observe_many(&mut self, xs: &[X]) {
self.deref_mut().observe_many(xs)
}
fn forget_many(&mut self, xs: &[X]) {
self.deref_mut().forget_many(xs)
}
}
pub trait ConjugatePrior<X, Fx>: Rv<Fx>
where
Fx: Rv<X> + HasSuffStat<X>,
{
type Posterior: Rv<Fx>;
fn posterior(&self, x: &DataOrSuffStat<X, Fx>) -> Self::Posterior;
fn ln_m(&self, x: &DataOrSuffStat<X, Fx>) -> f64;
fn ln_pp(&self, y: &X, x: &DataOrSuffStat<X, Fx>) -> f64;
fn m(&self, x: &DataOrSuffStat<X, Fx>) -> f64 {
self.ln_m(x).exp()
}
fn pp(&self, y: &X, x: &DataOrSuffStat<X, Fx>) -> f64 {
self.ln_pp(&y, x).exp()
}
}
pub trait QuadBounds {
fn quad_bounds(&self) -> (f64, f64);
}