use nalgebra::*;
use super::*;
use serde::{Serialize, Deserialize};
use std::f64::consts::PI;
#[derive(Debug, Clone, Serialize, Deserialize)]
enum CovFunction {
None,
Log,
Logit
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct LinearOp {
pub scale : DMatrix<f64>,
pub shift : DVector<f64>,
pub lin_mu : DVector<f64>,
pub lin_sigma_inv : DMatrix<f64>,
pub transf_sigma_inv : Option<DMatrix<f64>>,
pub cov_func : CovFunction
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MultiNormal {
mu : DVector<f64>,
scaled_mu : DVector<f64>,
sigma_inv : DMatrix<f64>,
op : Option<LinearOp>,
loc_factor : Option<Box<MultiNormal>>,
scale_factor : Option<Wishart>,
log_part : DVector<f64>,
}
impl MultiNormal {
pub fn invert_scale(s : &DMatrix<f64>) -> DMatrix<f64> {
let s_qr = QR::<f64, Dynamic, Dynamic>::new(s.clone());
s_qr.try_inverse().unwrap()
}
pub fn corr_from(mut cov : DMatrix<f64>) -> DMatrix<f64> {
assert!(cov.nrows() == cov.ncols());
let mut diag_m = DMatrix::zeros(cov.nrows(), cov.ncols());
let diag = cov.diagonal().map(|d| 1. / d.sqrt() );
diag_m.set_diagonal(&diag);
cov *= &diag_m;
diag_m *= cov;
diag_m
}
pub fn new(mu : DVector<f64>, sigma : DMatrix<f64>) -> Self {
let log_part = DVector::from_element(1, 0.0);
let sigma_inv = Self::invert_scale(&sigma);
let mut norm = Self {
mu : mu.clone(),
sigma_inv,
loc_factor: None,
scale_factor : None,
op : None,
log_part,
scaled_mu : mu.clone()
};
norm.update_log_partition(mu.rows(0, mu.nrows()));
norm
}
}
impl ExponentialFamily<Dynamic> for MultiNormal {
fn base_measure(y : DMatrixSlice<'_, f64>) -> DVector<f64> {
DVector::from_element(1, (2. * PI).powf( - (y.ncols() as f64) / 2. ) )
}
fn sufficient_stat(y : DMatrixSlice<'_, f64>) -> DMatrix<f64> {
let yt = y.clone_owned().transpose();
let mut t = DMatrix::zeros(y.ncols(), y.ncols() + 1);
let mut ss_ws = DMatrix::zeros(y.ncols(), y.ncols());
for (yr, yc) in y.row_iter().zip(yt.column_iter()) {
t.slice_mut((0, 0), (t.nrows(), 1)).add_assign(&yc);
yc.mul_to(&yr, &mut ss_ws);
t.slice_mut((0, 1), (t.ncols() - 1, t.ncols() - 1)).add_assign(&ss_ws);
}
t
}
fn suf_log_prob(&self, t : DMatrixSlice<'_, f64>) -> f64 {
let mut lp = 0.0;
lp += self.scaled_mu.dot(&t.column(0));
let t_cov = t.columns(1, t.ncols() - 1);
for (s_inv_row, tc_row) in self.sigma_inv.row_iter().zip(t_cov.row_iter()) {
lp += (-0.5) * s_inv_row.dot(&tc_row);
}
lp }
fn log_partition<'a>(&'a self) -> &'a DVector<f64> {
&self.log_part
}
fn update_log_partition<'a>(&'a mut self, eta : DVectorSlice<'_, f64>) {
let cov = Self::invert_scale(&self.sigma_inv) ;
let sigma_lu = LU::new(cov.clone());
let sigma_det = sigma_lu.determinant();
let p_eta_cov = -0.25 * eta.clone().transpose() * cov * η
self.log_part = DVector::from_element(1, p_eta_cov[0] - 0.5*sigma_det.ln())
}
fn link_inverse<S>(_eta : &Matrix<f64, Dynamic, U1, S>) -> DVector<f64>
where S : Storage<f64, Dynamic, U1>
{
unimplemented!()
}
fn link<S>(_theta : &Matrix<f64, Dynamic, U1, S>) -> DVector<f64>
where S : Storage<f64, Dynamic, U1>
{
unimplemented!()
}
fn update_grad(&mut self, _eta : DVectorSlice<'_, f64>) {
unimplemented!()
}
fn grad(&self) -> &DVector<f64> {
unimplemented!()
}
}
impl Distribution for MultiNormal {
fn view_parameter(&self, _natural : bool) -> &DVector<f64> {
unimplemented!()
}
fn set_parameter(&mut self, p : DVectorSlice<'_, f64>, _natural : bool) {
self.mu.copy_from(&p.column(0));
if let Some(ref mut _op) = self.op {
}
}
fn mean<'a>(&'a self) -> &'a DVector<f64> {
&self.mu
}
fn mode(&self) -> DVector<f64> {
self.mu.clone()
}
fn var(&self) -> DVector<f64> {
self.cov().unwrap().diagonal()
}
fn cov(&self) -> Option<DMatrix<f64>> {
Some(Self::invert_scale(&self.sigma_inv))
}
fn log_prob(&self, y : DMatrixSlice<f64>) -> f64 {
let t = Self::sufficient_stat(y);
let lp = self.suf_log_prob(t.slice((0, 0), (t.nrows(), t.ncols())));
let loc_lp = match &self.loc_factor {
Some(loc) => {
let mu_row : DMatrix<f64> = DMatrix::from_row_slice(self.mu.nrows(), 1, self.mu.data.as_slice());
loc.log_prob(mu_row.slice((0, 0), (0, self.mu.nrows())))
},
None => 0.0
};
let scale_lp = match &self.scale_factor {
Some(scale) => {
let sinv_diag : DVector<f64> = self.sigma_inv.diagonal().clone_owned();
scale.log_prob(sinv_diag.slice((0, 0), (sinv_diag.nrows(), 1)))
},
None => 0.0
};
lp + loc_lp + scale_lp
}
fn sample(&self) -> DMatrix<f64> {
unimplemented!()
}
}