1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
use crate::traits::*;
use rand::Rng;
use std::marker::PhantomData;
use std::sync::Arc;
/// A wrapper for a complete conjugate model
///
/// # Parameters
///
/// `X`: The type of the data/observations to be modeled
/// `Fx`: The type of the likelihood, *f(x|θ)*
/// `Pr`: The type of the prior on the parameters of `Fx`, π(θ)
#[derive(Clone, Debug, PartialEq, PartialOrd)]
pub struct ConjugateModel<X, Fx, Pr>
where
Fx: Rv<X> + HasSuffStat<X>,
Pr: ConjugatePrior<X, Fx>,
{
/// Pointer to an `Rv` implementing `ConjugatePrior` for `Fx`
prior: Arc<Pr>,
/// A `SuffStat` for `Fx`
suffstat: Fx::Stat,
_phantom: PhantomData<X>,
}
impl<X, Fx, Pr> ConjugateModel<X, Fx, Pr>
where
Fx: Rv<X> + HasSuffStat<X>,
Pr: ConjugatePrior<X, Fx>,
{
/// Create a new conjugate model
///
/// # Arguments
///
/// *fx*:
///
/// # Example
///
/// ```
/// use std::sync::Arc;
/// use rv::prelude::*;
/// use rv::ConjugateModel;
///
/// let pr = Arc::new(Beta::jeffreys());
/// let fx = Bernoulli::uniform();
/// let model = ConjugateModel::<bool, Bernoulli, Beta>::new(&fx, pr);
/// ```
pub fn new(fx: &Fx, pr: Arc<Pr>) -> Self {
ConjugateModel {
prior: pr,
suffstat: fx.empty_suffstat(),
_phantom: PhantomData,
}
}
/// Log marginal likelihood, *f(obs)*
pub fn ln_m(&self) -> f64 {
self.prior.ln_m(&self.obs())
}
/// Log posterior predictive, *f(y|obs)*
pub fn ln_pp(&self, y: &X) -> f64 {
self.prior.ln_pp(y, &self.obs())
}
/// Return the posterior distribution
/// # Example
///
/// ```
/// use std::sync::Arc;
/// use rv::prelude::*;
/// use rv::ConjugateModel;
///
/// let flips: Vec<bool> = vec![true, false, true, false, false, false];
///
/// let pr = Arc::new(Beta::new(1.0, 1.0).unwrap());
/// let fx = Bernoulli::uniform();
/// let mut model = ConjugateModel::<bool, Bernoulli, Beta>::new(&fx, pr);
///
/// model.observe_many(&flips);
///
/// let post = model.posterior();
///
/// assert_eq!(post, Beta::new(3.0, 5.0).unwrap());
/// ```
pub fn posterior(&self) -> Pr::Posterior {
self.prior.posterior(&self.obs())
}
/// Return the observations
fn obs(&self) -> DataOrSuffStat<X, Fx> {
DataOrSuffStat::SuffStat(&self.suffstat)
}
}
impl<X, Fx, Pr> SuffStat<X> for ConjugateModel<X, Fx, Pr>
where
Fx: Rv<X> + HasSuffStat<X>,
Pr: ConjugatePrior<X, Fx>,
{
fn n(&self) -> usize {
self.suffstat.n()
}
fn observe(&mut self, x: &X) {
self.suffstat.observe(x);
}
fn forget(&mut self, x: &X) {
self.suffstat.forget(x);
}
}
impl<X, Fx, Pr> HasDensity<X> for ConjugateModel<X, Fx, Pr>
where
Fx: Rv<X> + HasSuffStat<X>,
Pr: ConjugatePrior<X, Fx>,
{
fn ln_f(&self, x: &X) -> f64 {
self.prior.ln_pp(x, &self.obs())
}
}
impl<X, Fx, Pr> Sampleable<X> for ConjugateModel<X, Fx, Pr>
where
Fx: Rv<X> + HasSuffStat<X>,
Pr: ConjugatePrior<X, Fx>,
{
fn draw<R: Rng>(&self, mut rng: &mut R) -> X {
let post = self.posterior();
let fx: Fx = post.draw(&mut rng);
fx.draw(&mut rng)
}
fn sample<R: Rng>(&self, n: usize, mut rng: &mut R) -> Vec<X> {
let post = self.posterior();
(0..n)
.map(|_| {
let fx: Fx = post.draw(&mut rng);
fx.draw(&mut rng)
})
.collect()
}
}