use super::*;
use rand_distr;
use rand;
use crate::sim::*;
use std::default::Default;
pub type BernoulliFactor = UnivariateFactor<Beta>;
#[derive(Debug)]
pub struct Bernoulli {
theta : DVector<f64>,
eta : DVector<f64>,
factor : BernoulliFactor,
eta_traj : Option<EtaTrajectory>,
sampler : Vec<rand_distr::Bernoulli>,
log_part : DVector<f64>,
suf_theta : Option<DMatrix<f64>>
}
impl Bernoulli {
pub fn new(n : usize, theta : Option<f64>) -> Self {
let mut bern : Bernoulli = Default::default();
let theta = DVector::from_element(n, theta.unwrap_or(0.5));
bern.set_parameter(theta.rows(0,theta.nrows()), false);
bern
}
}
impl Conditional<Beta> for Bernoulli {
fn condition(mut self, b : Beta) -> Bernoulli {
self.factor = BernoulliFactor::Conjugate(b);
self.suf_theta = Some(Beta::sufficient_stat(self.theta.slice((0, 0), (self.theta.nrows(), 1))));
self
}
fn view_factor(&self) -> Option<&Beta> {
match &self.factor {
BernoulliFactor::Conjugate(d) => Some(d),
_ => None
}
}
fn take_factor(self) -> Option<Beta> {
match self.factor {
BernoulliFactor::Conjugate(b) => Some(b),
_ => None
}
}
fn factor_mut(&mut self) -> Option<&mut Beta> {
match &mut self.factor {
BernoulliFactor::Conjugate(b) => Some(b),
_ => None
}
}
}
impl Conditional<MultiNormal> for Bernoulli {
fn condition(mut self, m : MultiNormal) -> Bernoulli {
self.factor = BernoulliFactor::CondExpect(m);
self.suf_theta = None;
self
}
fn view_factor(&self) -> Option<&MultiNormal> {
match &self.factor {
BernoulliFactor::CondExpect(m) => Some(m),
_ => None
}
}
fn take_factor(self) -> Option<MultiNormal> {
match self.factor {
BernoulliFactor::CondExpect(m) => Some(m),
_ => None
}
}
fn factor_mut(&mut self) -> Option<&mut MultiNormal> {
match &mut self.factor {
BernoulliFactor::CondExpect(m) => Some(m),
_ => None
}
}
}
impl ExponentialFamily<U1> for Bernoulli
where
Self : Distribution
{
fn base_measure(y : DMatrixSlice<'_, f64>) -> DVector<f64> {
DVector::from_element(y.nrows(), 1.)
}
fn sufficient_stat(y : DMatrixSlice<'_, f64>) -> DMatrix<f64> {
assert!(y.ncols() == 1);
DMatrix::from_element(1, 1, y.sum() / (y.nrows() as f64) )
}
fn suf_log_prob(&self, t : DMatrixSlice<'_, f64>) -> f64 {
assert!(t.nrows() == 1 && t.ncols() == 1);
assert!(self.eta.nrows() == 1);
assert!(self.log_part.nrows() == 1);
self.eta[0] * t[0] - self.log_part[0]
}
fn update_log_partition<'a>(&'a mut self, eta : DVectorSlice<'_, f64>) {
self.log_part.iter_mut().zip(eta.iter()).for_each(|(l,e)| { *l = (1. + e.exp()).ln(); } );
}
fn log_partition<'a>(&'a self) -> &'a DVector<f64> {
&self.log_part
}
fn update_grad(&mut self, _eta : DVectorSlice<'_, f64>) {
unimplemented!()
}
fn grad(&self) -> &DVector<f64> {
unimplemented!()
}
fn link_inverse<S>(eta : &Matrix<f64, Dynamic, U1, S>) -> DVector<f64>
where S : Storage<f64, Dynamic, U1>
{
eta.map(|e| 1. / (1. + (-1.* e).exp() ) )
}
fn link<S>(theta : &Matrix<f64, Dynamic, U1, S>) -> DVector<f64>
where S : Storage<f64, Dynamic, U1>
{
theta.map(|p| (p / (1. - p)).ln() )
}
}
impl Likelihood<U1> for Bernoulli {
fn mean_mle(y : DMatrixSlice<'_, f64>) -> f64 {
assert!(y.ncols() == 1);
let mle = y.iter().fold(0.0, |ys, y| {
assert!(*y == 0. || *y == 1.); ys + y
}) / (y.nrows() as f64);
mle
}
fn var_mle(y : DMatrixSlice<'_, f64>) -> f64 {
let m = Self::mean_mle(y);
m * (1. - m)
}
}
impl Estimator<Beta> for Bernoulli {
fn fit<'a>(&'a mut self, y : DMatrix<f64>) -> Result<&'a Beta, &'static str> {
assert!(y.ncols() == 1);
match self.factor {
BernoulliFactor::Conjugate(ref mut beta) => {
let n = y.nrows() as f64;
let ys = y.column(0).sum();
let (a, b) = (beta.view_parameter(false)[0], beta.view_parameter(false)[1]);
let new_param = DVector::from_column_slice(&[a + ys, b + n - ys]);
beta.set_parameter(new_param.rows(0, new_param.nrows()), false);
Ok(&(*beta))
},
_ => Err("Distribution does not have a conjugate factor")
}
}
}
impl Distribution for Bernoulli
where Self : Sized
{
fn set_parameter(&mut self, p : DVectorSlice<'_, f64>, natural : bool) {
let eta = if natural {
p.clone_owned()
} else {
Self::link(&p)
};
self.theta = Self::link_inverse(&eta);
self.update_log_partition(eta.rows(0,eta.nrows()));
self.eta = eta;
if let Some(ref mut suf) = self.suf_theta {
*suf = Beta::sufficient_stat(self.theta.slice((0,0), (self.theta.nrows(),1)));
}
self.sampler.clear();
for t in self.theta.iter() {
self.sampler.push(rand_distr::Bernoulli::new(*t).unwrap());
}
}
fn view_parameter(&self, natural : bool) -> &DVector<f64> {
match natural {
true => &self.eta,
false => &self.theta
}
}
fn mean<'a>(&'a self) -> &'a DVector<f64> {
&self.theta
}
fn mode(&self) -> DVector<f64> {
self.theta.clone()
}
fn var(&self) -> DVector<f64> {
self.theta.component_mul(&self.theta.map(|p| 1. - p))
}
fn cov(&self) -> Option<DMatrix<f64>> {
None
}
fn log_prob(&self, y : DMatrixSlice<f64>) -> f64 {
assert!(y.ncols() == 1);
let eta = match self.current() {
Some(eta) => eta,
None => self.eta.rows(0, self.eta.nrows())
};
let factor_lp = match &self.factor {
BernoulliFactor::Conjugate(b) => {
b.log_prob(self.suf_theta.as_ref().unwrap().slice((0,0), (1,2)))
},
BernoulliFactor::CondExpect(m) => {
m.suf_log_prob(eta.slice((0,0), (eta.nrows(), 1)))
},
BernoulliFactor::Empty => 0.
};
eta.dot(&y) - self.log_part[0] + factor_lp
}
fn sample(&self) -> DMatrix<f64> {
use rand_distr::{Distribution};
let mut samples = DMatrix::zeros(self.theta.nrows(), 1);
for (i, _) in self.theta.iter().enumerate() {
samples[(i,0)] = (self.sampler[i].sample(&mut rand::thread_rng()) as i32) as f64;
}
samples
}
}
impl RandomWalk for Bernoulli {
fn current<'a>(&'a self) -> Option<DVectorSlice<'a, f64>> {
self.eta_traj.as_ref().and_then(|eta_traj| {
Some(eta_traj.traj.column(eta_traj.pos))
})
}
fn step_by<'a>(&'a mut self, diff_eta : DVectorSlice<'a, f64>, _update : bool) {
self.eta_traj.as_mut().unwrap().step_increment(diff_eta);
}
fn step_to<'a>(&'a mut self, new_eta : Option<DVectorSlice<'a, f64>>, update : bool) {
if let Some(ref mut eta_traj) = self.eta_traj {
eta_traj.step(new_eta)
} else {
self.eta_traj = Some(EtaTrajectory::new(new_eta.unwrap()));
}
if update {
self.set_parameter(new_eta.unwrap(), true);
}
}
fn marginal(&self) -> Option<Sample> {
self.eta_traj.as_ref().and_then(|eta_traj| {
let cols : Vec<DVector<f64>> = eta_traj.traj.clone()
.column_iter().take(eta_traj.pos)
.map(|col| Self::link(&col) ).collect();
let t_cols = DMatrix::from_columns(&cols[..]);
Some(Sample::new(t_cols))
})
}
}
impl Default for Bernoulli {
fn default() -> Self {
Bernoulli {
theta : DVector::from_element(1, 0.5),
eta : DVector::from_element(1, 0.0),
factor : BernoulliFactor::Empty,
eta_traj : None,
sampler : Vec::new(),
log_part : DVector::from_element(1, (2.).ln()),
suf_theta : None,
}
}
}