#[cfg(test)]
use approx::RelativeEq;
use crate::traits::Rv;
use std::collections::BTreeMap;
#[macro_export]
macro_rules! test_basic_impls {
([continuous] $fx: expr) => {
test_basic_impls!($fx, 0.5_f64);
};
([categorical] $fx: expr) => {
test_basic_impls!($fx, 0_usize);
};
([count] $fx: expr) => {
test_basic_impls!($fx, 3_u32);
};
([binary] $fx: expr) => {
test_basic_impls!($fx, true);
};
($fx: expr, $x: expr) => {
#[test]
fn should_impl_debug_clone_and_partialeq() {
let fx = $fx;
let fx2 = fx.clone();
assert_eq!($fx, fx2);
let y1 = fx.ln_f(&$x);
let y2 = fx.ln_f(&$x);
assert!((y1 - y2).abs() < std::f64::EPSILON);
assert_eq!(fx2, fx);
let _s1 = format!("{:?}", fx);
}
};
}
#[cfg(test)]
#[allow(dead_code)]
pub fn relative_eq<T, I>(
left: I,
right: I,
epsilon: T::Epsilon,
max_relative: T::Epsilon,
) -> bool
where
T: RelativeEq,
T::Epsilon: Copy,
I: IntoIterator<Item = T>,
<I as IntoIterator>::IntoIter: ExactSizeIterator,
{
let a = left.into_iter();
let b = right.into_iter();
if a.len() != b.len() {
return false;
}
a.zip(b)
.all(|(a, b)| a.relative_eq(&b, epsilon, max_relative))
}
pub trait GewekeTestable<Fx, X> {
fn prior_draw<R: rand::Rng>(&self, rng: &mut R) -> Fx;
fn update_params<R: rand::Rng>(&self, data: &Vec<X>, rng: &mut R) -> Fx;
fn geweke_stats(&self, fx: &Fx, xs: &Vec<X>) -> BTreeMap<String, f64>;
}
pub struct GewekeTester<Pr, Fx, X>
where
Pr: GewekeTestable<Fx, X>,
Pr: Rv<Fx>,
Fx: Rv<X>,
{
pub pr: Pr,
pub nx: usize,
pub xs: Vec<X>,
pub prior_chain_stats: BTreeMap<String, Vec<f64>>,
pub posterior_chain_stats: BTreeMap<String, Vec<f64>>,
_phantom: std::marker::PhantomData<Fx>,
}
fn append_stats(
n: usize,
src: &BTreeMap<String, f64>,
sink: &mut BTreeMap<String, Vec<f64>>,
) {
if sink.is_empty() {
for k in src.keys() {
sink.insert(k.clone(), Vec::with_capacity(n));
}
}
for (k, v) in src.iter() {
sink.get_mut(k)
.map(|vals| vals.push(*v))
.expect("failed to push")
}
}
impl<Pr, Fx, X> GewekeTester<Pr, Fx, X>
where
Pr: GewekeTestable<Fx, X>,
Pr: Rv<Fx>,
Fx: Rv<X>,
{
pub fn new(pr: Pr, nx: usize) -> Self {
GewekeTester {
pr,
nx,
xs: Vec::new(),
prior_chain_stats: BTreeMap::new(),
posterior_chain_stats: BTreeMap::new(),
_phantom: std::marker::PhantomData,
}
}
pub fn eval(&self, max_err: f64) -> Result<(), String> {
let errors = self.errs();
errors.iter().try_for_each(|(name, err)| {
if *err > max_err {
Err(format!(
"P-P Error {} ({}) exceeds max ({})",
name, err, max_err
))
} else {
Ok(())
}
})
}
pub fn errs(&self) -> Vec<(String, f64)> {
use crate::dist::Empirical;
let mut errors: Vec<(String, f64)> = Vec::new();
for (stat_name, prior_stats) in self.prior_chain_stats.iter() {
let post_stats = &self.posterior_chain_stats[stat_name];
let emp_prior = Empirical::new(prior_stats.clone());
let emp_post = Empirical::new(post_stats.clone());
let err = emp_prior.err(&emp_post);
errors.push((stat_name.clone(), err));
}
errors
}
pub fn run_chains<R: rand::Rng>(
&mut self,
n: usize,
thinning: usize,
rng: &mut R,
) {
self.run_prior_chain(n, rng);
self.run_posterior_chain(n, thinning, rng);
}
pub fn run_prior_chain<R: rand::Rng>(&mut self, n: usize, rng: &mut R) {
(0..n).for_each(|_| {
let fx = self.pr.prior_draw(rng);
let xs: Vec<X> = fx.sample(self.nx, rng);
let stats = self.pr.geweke_stats(&fx, &xs);
append_stats(n, &stats, &mut self.prior_chain_stats)
})
}
pub fn run_posterior_chain<R: rand::Rng>(
&mut self,
n: usize,
thinning: usize,
rng: &mut R,
) {
let mut fx = self.pr.prior_draw(rng);
let mut xs = fx.sample(self.nx, rng);
(0..n).for_each(|_| {
(0..thinning).for_each(|_| {
fx = self.pr.update_params(&xs, rng);
xs = fx.sample(self.nx, rng);
});
let stats = self.pr.geweke_stats(&fx, &xs);
append_stats(n, &stats, &mut self.posterior_chain_stats)
})
}
}
#[macro_export]
macro_rules! gaussian_prior_geweke_testable {
($prior: ty, $fx: ty) => {
impl GewekeTestable<Gaussian, f64> for $prior {
fn prior_draw<R: rand::Rng>(&self, rng: &mut R) -> Gaussian {
self.draw(rng)
}
fn update_params<R: rand::Rng>(
&self,
data: &Vec<f64>,
rng: &mut R,
) -> Gaussian {
let post = <$prior as ConjugatePrior<f64, $fx>>::posterior(
&self,
&DataOrSuffStat::from(data),
);
post.draw(rng)
}
fn geweke_stats(
&self,
fx: &Gaussian,
xs: &Vec<f64>,
) -> BTreeMap<String, f64> {
let mut stats: BTreeMap<String, f64> = BTreeMap::new();
stats.insert(String::from("mu"), fx.mu());
stats.insert(String::from("sigma"), fx.sigma());
let mean = xs.iter().map(|&x| x).sum::<f64>() / xs.len() as f64;
let mse = xs.iter().map(|&x| (x - mean).powi(2)).sum::<f64>()
/ xs.len() as f64;
stats.insert(String::from("x_mean"), mean);
stats.insert(String::from("x_mse"), mse);
stats
}
}
};
}