use rand::Rng;
use rv::dist::Gamma;
use rv::dist::Gaussian;
use rv::dist::InvGamma;
use rv::dist::NormalInvChiSquared;
use rv::traits::*;
use serde::Deserialize;
use serde::Serialize;
use crate::stats::UpdatePrior;
use crate::utils::mean_var;
pub fn geweke() -> NormalInvChiSquared {
NormalInvChiSquared::new_unchecked(0.0, 1.0, 10.0, 1.0)
}
pub fn from_data(xs: &[f64]) -> NormalInvChiSquared {
let (m, s2) = mean_var(xs);
NormalInvChiSquared::new_unchecked(m, 1.0, 1.0, s2)
}
pub fn from_hyper(
hyper: NixHyper,
mut rng: &mut impl Rng,
) -> NormalInvChiSquared {
hyper.draw(&mut rng)
}
impl UpdatePrior<f64, Gaussian, NixHyper> for NormalInvChiSquared {
fn update_prior<R: Rng>(
&mut self,
components: &[&Gaussian],
hyper: &NixHyper,
rng: &mut R,
) -> f64 {
let loglike = |nix: &NormalInvChiSquared| {
components.iter().map(|cpnt| nix.ln_f(cpnt)).sum::<f64>()
};
use crate::stats::mh::mh_prior;
let mh_result = mh_prior(
self.clone(),
loglike,
|mut rng| hyper.draw(&mut rng),
200,
rng,
);
*self = mh_result.x;
hyper.pr_m.ln_f(&self.m())
+ hyper.pr_k.ln_f(&self.k())
+ hyper.pr_v.ln_f(&self.v())
+ hyper.pr_s2.ln_f(&self.s2())
}
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub struct NixHyper {
pub pr_m: Gaussian,
pub pr_k: Gamma,
pub pr_v: InvGamma,
pub pr_s2: InvGamma,
}
impl Default for NixHyper {
fn default() -> Self {
NixHyper {
pr_m: Gaussian::new_unchecked(0.0, 1.0),
pr_k: Gamma::new_unchecked(2.0, 1.0),
pr_v: InvGamma::new_unchecked(2.0, 2.0),
pr_s2: InvGamma::new_unchecked(2.0, 2.0),
}
}
}
impl NixHyper {
pub fn new(
pr_m: Gaussian,
pr_k: Gamma,
pr_v: InvGamma,
pr_s2: InvGamma,
) -> Self {
NixHyper {
pr_m,
pr_k,
pr_v,
pr_s2,
}
}
pub fn geweke() -> Self {
NixHyper {
pr_m: Gaussian::new(0.0, 0.1).unwrap(),
pr_k: Gamma::new(40.0, 40.0).unwrap(),
pr_v: InvGamma::new(21.0, 120.0).unwrap(),
pr_s2: InvGamma::new(40.0, 40.0).unwrap(),
}
}
pub fn from_data(xs: &[f64]) -> Self {
let (m, v) = mean_var(xs);
let s = v.sqrt();
let logn = (xs.len() as f64).ln();
NixHyper {
pr_m: Gaussian::new(m, s).unwrap(),
pr_k: Gamma::new(1.0, 1.0).unwrap(),
pr_v: InvGamma::new(logn, logn).unwrap(),
pr_s2: InvGamma::new(logn, v).unwrap(),
}
}
pub fn draw(&self, mut rng: &mut impl Rng) -> NormalInvChiSquared {
NormalInvChiSquared::new_unchecked(
self.pr_m.draw(&mut rng),
self.pr_k.draw(&mut rng),
self.pr_v.draw(&mut rng),
self.pr_s2.draw(&mut rng),
)
}
}