pub use crate::data::DataOrSuffStat;
use crate::data::extract_stat_then;
use rand::Rng;
pub trait Parameterized: Sized {
type Parameters;
fn emit_params(&self) -> Self::Parameters;
fn from_params(params: Self::Parameters) -> Self;
fn map_params(
&self,
f: impl Fn(Self::Parameters) -> Self::Parameters,
) -> Self {
let params = self.emit_params();
let new_params = f(params);
Self::from_params(new_params)
}
}
pub trait Sampleable<X> {
fn draw<R: Rng>(&self, rng: &mut R) -> X;
fn sample<R: Rng>(&self, n: usize, mut rng: &mut R) -> Vec<X> {
(0..n).map(|_| self.draw(&mut rng)).collect()
}
fn sample_stream<'r, R: Rng>(
&'r self,
mut rng: &'r mut R,
) -> Box<dyn Iterator<Item = X> + 'r> {
Box::new(std::iter::repeat_with(move || self.draw(&mut rng)))
}
}
pub trait HasDensity<X> {
fn f(&self, x: &X) -> f64 {
self.ln_f(x).exp()
}
fn ln_f(&self, x: &X) -> f64;
}
pub trait Rv<X>: Sampleable<X> + HasDensity<X> {}
impl<X, T> Rv<X> for T where T: Sampleable<X> + HasDensity<X> {}
pub trait Process<S, O>: Sampleable<S> + HasDensity<O> {}
impl<S, O, T> Process<S, O> for T where T: Sampleable<S> + HasDensity<O> {}
pub trait Support<X> {
fn supports(&self, x: &X) -> bool;
}
pub trait ContinuousDistr<X>: HasDensity<X> + Support<X> {
fn pdf(&self, x: &X) -> f64 {
self.f(x)
}
fn ln_pdf(&self, x: &X) -> f64 {
if self.supports(x) {
self.ln_f(x)
} else {
f64::NEG_INFINITY
}
}
}
pub trait Cdf<X>: HasDensity<X> {
fn cdf(&self, x: &X) -> f64;
fn sf(&self, x: &X) -> f64 {
1.0 - self.cdf(x)
}
}
pub trait InverseCdf<X>: HasDensity<X> + Support<X> {
fn invcdf(&self, p: f64) -> X;
fn quantile(&self, p: f64) -> X {
self.invcdf(p)
}
fn interval(&self, p: f64) -> (X, X) {
let pt = (1.0 - p) / 2.0;
(self.quantile(pt), self.quantile(p + pt))
}
}
pub trait DiscreteDistr<X>: Rv<X> + Support<X> {
fn pmf(&self, x: &X) -> f64 {
self.ln_pmf(x).exp()
}
fn ln_pmf(&self, x: &X) -> f64 {
if self.supports(x) {
self.ln_f(x)
} else {
f64::NEG_INFINITY
}
}
}
pub trait Mean<X> {
fn mean(&self) -> Option<X>;
}
pub trait Median<X> {
fn median(&self) -> Option<X>;
}
pub trait Mode<X> {
fn mode(&self) -> Option<X>;
}
pub trait Variance<X> {
fn variance(&self) -> Option<X>;
}
pub trait Entropy {
fn entropy(&self) -> f64;
}
pub trait Skewness {
fn skewness(&self) -> Option<f64>;
}
pub trait Kurtosis {
fn kurtosis(&self) -> Option<f64>;
}
pub trait KlDivergence {
fn kl(&self, other: &Self) -> f64;
fn kl_sym(&self, other: &Self) -> f64 {
self.kl(other) + other.kl(self)
}
}
pub trait ConjugatePrior<X, Fx>: Sampleable<Fx>
where
Fx: HasDensity<X> + HasSuffStat<X>,
{
type Posterior: Sampleable<Fx>;
type MCache;
type PpCache;
fn empty_stat(&self) -> Fx::Stat;
fn posterior_from_suffstat(&self, stat: &Fx::Stat) -> Self::Posterior {
self.posterior(&DataOrSuffStat::SuffStat(stat))
}
fn posterior(&self, x: &DataOrSuffStat<X, Fx>) -> Self::Posterior {
extract_stat_then(self, x, |stat| self.posterior_from_suffstat(stat))
}
fn ln_m_cache(&self) -> Self::MCache;
fn ln_m_with_cache(
&self,
cache: &Self::MCache,
x: &DataOrSuffStat<X, Fx>,
) -> f64;
fn ln_m(&self, x: &DataOrSuffStat<X, Fx>) -> f64 {
let cache = self.ln_m_cache();
self.ln_m_with_cache(&cache, x)
}
fn ln_pp_cache(&self, x: &DataOrSuffStat<X, Fx>) -> Self::PpCache;
fn ln_pp_with_cache(&self, cache: &Self::PpCache, y: &X) -> f64;
fn ln_pp(&self, y: &X, x: &DataOrSuffStat<X, Fx>) -> f64 {
let cache = self.ln_pp_cache(x);
self.ln_pp_with_cache(&cache, y)
}
fn m(&self, x: &DataOrSuffStat<X, Fx>) -> f64 {
self.ln_m(x).exp()
}
fn pp_with_cache(&self, cache: &Self::PpCache, y: &X) -> f64 {
self.ln_pp_with_cache(cache, y).exp()
}
fn pp(&self, y: &X, x: &DataOrSuffStat<X, Fx>) -> f64 {
self.ln_pp(y, x).exp()
}
}
pub trait QuadBounds {
fn quad_bounds(&self) -> (f64, f64);
}
pub trait HasSuffStat<X> {
type Stat: SuffStat<X>;
fn empty_suffstat(&self) -> Self::Stat;
fn ln_f_stat(&self, stat: &Self::Stat) -> f64;
}
pub trait SuffStat<X> {
fn n(&self) -> usize;
fn observe(&mut self, x: &X);
fn forget(&mut self, x: &X);
fn observe_many(&mut self, xs: &[X]) {
xs.iter().for_each(|x| self.observe(x));
}
fn forget_many(&mut self, xs: &[X]) {
xs.iter().for_each(|x| self.forget(x));
}
fn merge(&mut self, other: Self);
}
pub trait Shiftable {
type Output;
type Error;
fn shifted(self, shift: f64) -> Result<Self::Output, Self::Error>
where
Self: Sized;
fn shifted_unchecked(self, shift: f64) -> Self::Output
where
Self: Sized;
}
#[macro_export]
macro_rules! impl_shiftable {
($type:ty) => {
use $crate::prelude::Shifted;
use $crate::prelude::ShiftedError;
impl Shiftable for $type {
type Output = Shifted<Self>;
type Error = ShiftedError;
fn shifted(self, shift: f64) -> Result<Self::Output, Self::Error>
where
Self: Sized,
{
Shifted::new(self, shift)
}
fn shifted_unchecked(self, shift: f64) -> Self::Output
where
Self: Sized,
{
Shifted::new_unchecked(self, shift)
}
}
};
}
pub trait Scalable {
type Output;
type Error;
fn scaled(self, scale: f64) -> Result<Self::Output, Self::Error>
where
Self: Sized;
fn scaled_unchecked(self, scale: f64) -> Self::Output
where
Self: Sized;
}
#[macro_export]
macro_rules! impl_scalable {
($type:ty) => {
use $crate::prelude::Scaled;
use $crate::prelude::ScaledError;
impl Scalable for $type {
type Output = Scaled<Self>;
type Error = ScaledError;
fn scaled(self, scale: f64) -> Result<Self::Output, Self::Error>
where
Self: Sized,
{
Scaled::new(self, scale)
}
fn scaled_unchecked(self, scale: f64) -> Self::Output
where
Self: Sized,
{
Scaled::new_unchecked(self, scale)
}
}
};
}
#[cfg(test)]
mod test {
#[macro_export]
macro_rules! test_shiftable_mean {
($expr:expr) => {
use proptest::prelude::*;
proptest! {
#[test]
fn shiftable_mean(shift in -100.0..100.0) {
let dist = $expr;
let shifted = dist.clone().shifted_unchecked(shift);
let manual = Shifted::new_unchecked(dist, shift);
let mean_shifted = shifted.mean();
let mean_manual = manual.mean();
match (mean_shifted, mean_manual) {
(Some(mean_shifted), Some(mean_manual)) => {
let mean_shifted: f64 = mean_shifted;
let mean_manual: f64 = mean_manual;
prop_assert!($crate::misc::eq_or_close(mean_shifted, mean_manual, 1e-10), "means differ: {} vs {}", mean_shifted, mean_manual);
}
(None, None) => {},
_ => {
prop_assert!(false, "Shifting should not affect existence of mean");
}
}
}
}
};
}
#[macro_export]
macro_rules! test_shiftable_method {
($expr:expr, $ident:ident) => {
test_shiftable_method!($expr, $ident, );
};
($expr:expr, $ident:ident, $($ext:ident)?) => {
paste::paste! {
proptest::proptest! {
#[test]
fn [<shiftable_ $ident $(_ $ext)?>](shift in -100.0..100.0) {
let dist = $expr;
let shifted = dist.clone().shifted_unchecked(shift).$ident();
let manual = $crate::prelude::Shifted::new_unchecked(dist, shift).$ident();
match (shifted, manual) {
(Some(shifted), Some(manual)) => {
let shifted: f64 = shifted;
let manual: f64 = manual;
proptest::prop_assert!($crate::misc::eq_or_close(shifted, manual, 1e-10),
"{}s differ: {} vs {}", stringify!($ident), shifted, manual);
}
(None, None) => {},
_ => {
proptest::prop_assert!(false, "Shifting should not affect existence of {}",
stringify!($ident));
}
}
}
}
}
};
}
#[macro_export]
macro_rules! test_shiftable_density {
($expr:expr) => {
test_shiftable_density!($expr, );
};
($expr:expr, $($ext:ident)?) => {
paste::paste! {
proptest::proptest! {
#[test]
fn [<shiftable_density $(_ $ext)?>](y in -100.0..100.0, shift in -100.0..100.0) {
let dist = $expr;
let shifted: f64 = dist.clone().shifted_unchecked(shift).ln_f(&y);
let manual: f64 = $crate::prelude::Shifted::new_unchecked(dist, shift).ln_f(&y);
proptest::prop_assert!($crate::misc::eq_or_close(shifted, manual, 1e-10),
"densities differ: {} vs {}", shifted, manual);
}
}
}
};
}
#[macro_export]
macro_rules! test_shiftable_cdf {
($expr:expr) => {
test_shiftable_cdf!($expr, );
};
($expr:expr, $($ext:ident)?) => {
paste::paste! {
proptest::proptest! {
#[test]
fn [<shiftable_cdf $(_ $ext)?>](x in -100.0..100.0, shift in -100.0..100.0) {
let dist = $expr;
let shifted: f64 = dist.clone().shifted_unchecked(shift).cdf(&x);
let manual: f64 = $crate::prelude::Shifted::new_unchecked(dist, shift).cdf(&x);
proptest::prop_assert!($crate::misc::eq_or_close(shifted, manual, 1e-10),
"cdfs differ: {} vs {}", shifted, manual);
}
}
}
};
}
#[macro_export]
macro_rules! test_shiftable_invcdf {
($expr:expr) => {
test_shiftable_invcdf!($expr, );
};
($expr:expr, $($ext:ident)?) => {
paste::paste! {
proptest::proptest! {
#[test]
fn [<shiftable_invcdf $(_ $ext)?>](p in 0.0..1.0, shift in -100.0..100.0) {
let dist = $expr;
let shifted: f64 = dist.clone().shifted_unchecked(shift).invcdf(p);
let manual: f64 = $crate::prelude::Shifted::new_unchecked(dist, shift).invcdf(p);
proptest::prop_assert!($crate::misc::eq_or_close(shifted, manual, 1e-10),
"invcdfs differ: {} vs {}", shifted, manual);
}
}
}
};
}
#[macro_export]
macro_rules! test_shiftable_entropy {
($expr:expr) => {
test_shiftable_entropy!($expr, );
};
($expr:expr, $($ext:ident)?) => {
paste::paste! {
proptest::proptest! {
#[test]
fn [<shiftable_entropy $(_ $ext)?>](shift in -100.0..100.0) {
let dist = $expr;
let shifted: f64 = dist.clone().shifted_unchecked(shift).entropy();
let manual: f64 = $crate::prelude::Shifted::new_unchecked(dist, shift).entropy();
proptest::prop_assert!($crate::misc::eq_or_close(shifted, manual, 1e-10),
"entropies differ: {} vs {}", shifted, manual);
}
}
}
};
}
#[macro_export]
macro_rules! test_scalable_mean {
($expr:expr) => {
use proptest::prelude::*;
proptest! {
#[test]
fn scalable_mean(scale in -100.0..100.0) {
let dist = $expr;
let scaled = dist.clone().scaled(scale);
let manual = Scaled::new(dist, scale);
let mean_scaled = scaled.mean();
let mean_manual = manual.mean();
match (mean_scaled, mean_manual) {
(Some(mean_scaled), Some(mean_manual)) => {
let mean_scaled: f64 = mean_scaled;
let mean_manual: f64 = mean_manual;
prop_assert!($crate::misc::eq_or_close(mean_scaled, mean_manual, 1e-10), "means differ: {} vs {}", mean_scaled, mean_manual);
}
(None, None) => {},
_ => {
prop_assert!(false, "Shifting should not affect existence of mean");
}
}
}
}
};
}
#[macro_export]
macro_rules! test_scalable_method {
($expr:expr, $ident:ident) => {
test_scalable_method!($expr, $ident, );
};
($expr:expr, $ident:ident, $($ext:ident)?) => {
paste::paste! {
proptest::proptest! {
#[test]
fn [<scalable_ $ident $(_ $ext)?>](scale in 1e-10..100.0) {
let dist = $expr;
let scaled = dist.clone().scaled_unchecked(scale).$ident();
let manual = $crate::prelude::Scaled::new_unchecked(dist, scale).$ident();
match (scaled, manual) {
(Some(scaled), Some(manual)) => {
let scaled: f64 = scaled;
let manual: f64 = manual;
proptest::prop_assert!($crate::misc::eq_or_close(scaled, manual, 1e-10),
"{}s differ: {} vs {}", stringify!($ident), scaled, manual);
}
(None, None) => {},
_ => {
proptest::prop_assert!(false, "Scaling should not affect existence of {}",
stringify!($ident));
}
}
}
}
}
};
}
#[macro_export]
macro_rules! test_scalable_density {
($expr:expr) => {
test_scalable_density!($expr, );
};
($expr:expr, $($ext:ident)?) => {
paste::paste! {
proptest::proptest! {
#[test]
fn [<scalable_density $(_ $ext)?>](y in -100.0..100.0, scale in 1e-10..100.0) {
let dist = $expr;
let scaled: f64 = dist.clone().scaled_unchecked(scale).ln_f(&y);
let manual: f64 = $crate::prelude::Scaled::new_unchecked(dist, scale).ln_f(&y);
proptest::prop_assert!($crate::misc::eq_or_close(scaled, manual, 1e-10),
"densities differ: {} vs {}", scaled, manual);
}
}
}
};
}
#[macro_export]
macro_rules! test_scalable_cdf {
($expr:expr) => {
test_scalable_cdf!($expr, );
};
($expr:expr, $($ext:ident)?) => {
paste::paste! {
proptest::proptest! {
#[test]
fn [<scalable_cdf $(_ $ext)?>](x in -100.0..100.0, scale in 1e-10..100.0) {
let dist = $expr;
let scaled: f64 = dist.clone().scaled_unchecked(scale).cdf(&x);
let manual: f64 = $crate::prelude::Scaled::new_unchecked(dist, scale).cdf(&x);
proptest::prop_assert!($crate::misc::eq_or_close(scaled, manual, 1e-10),
"cdfs differ: {} vs {}", scaled, manual);
}
}
}
};
}
#[macro_export]
macro_rules! test_scalable_invcdf {
($expr:expr) => {
test_scalable_invcdf!($expr, );
};
($expr:expr, $($ext:ident)?) => {
paste::paste! {
proptest::proptest! {
#[test]
fn [<scalable_invcdf $(_ $ext)?>](p in 0.0..1.0, scale in 1e-10..100.0) {
let dist = $expr;
let scaled: f64 = dist.clone().scaled_unchecked(scale).invcdf(p);
let manual: f64 = $crate::prelude::Scaled::new_unchecked(dist, scale).invcdf(p);
proptest::prop_assert!($crate::misc::eq_or_close(scaled, manual, 1e-10),
"invcdfs differ: {} vs {}", scaled, manual);
}
}
}
};
}
#[macro_export]
macro_rules! test_scalable_entropy {
($expr:expr) => {
test_scalable_entropy!($expr, );
};
($expr:expr, $($ext:ident)?) => {
paste::paste! {
proptest::proptest! {
#[test]
fn [<scalable_entropy $(_ $ext)?>](scale in 1e-10..100.0) {
let dist = $expr;
let scaled: f64 = dist.clone().scaled_unchecked(scale).entropy();
let manual: f64 = $crate::prelude::Scaled::new_unchecked(dist, scale).entropy();
proptest::prop_assert!($crate::misc::eq_or_close(scaled, manual, 1e-10),
"entropies differ: {} vs {}", scaled, manual);
}
}
}
};
}
}