use std::fmt::Debug;
use rv::data::BernoulliSuffStat;
use rv::data::CategoricalDatum;
use rv::data::CategoricalSuffStat;
use rv::data::GaussianSuffStat;
use rv::data::PoissonSuffStat;
use rv::dist::Bernoulli;
use rv::dist::Beta;
use rv::dist::Categorical;
use rv::dist::Gamma;
use rv::dist::Gaussian;
use rv::dist::NormalInvChiSquared;
use rv::dist::Poisson;
use rv::dist::SymmetricDirichlet;
use rv::traits::ConjugatePrior;
use rv::traits::HasDensity;
use rv::traits::HasSuffStat;
use rv::traits::Mode;
use rv::traits::Rv;
use rv::traits::Sampleable;
use serde::de::DeserializeOwned;
use serde::Serialize;
use crate::cc::feature::Component;
use crate::cc::feature::FType;
use crate::data::SparseContainer;
use crate::data::TranslateContainer;
use crate::data::TranslateDatum;
use crate::stats::prior::csd::CsdHyper;
use crate::stats::prior::nix::NixHyper;
use crate::stats::prior::pg::PgHyper;
use crate::stats::UpdatePrior;
pub trait AccumScore<X: Clone + Default>: Rv<X> + Sync {
fn accum_score(&self, scores: &mut [f64], container: &SparseContainer<X>) {
use crate::data::AccumScore;
container.accum_score(scores, &|x| self.ln_f(x))
}
}
impl<X: CategoricalDatum + Default> AccumScore<X> for Categorical {}
impl AccumScore<u32> for Poisson {}
impl AccumScore<f64> for Gaussian {}
impl AccumScore<bool> for Bernoulli {}
pub trait HasFType {
fn ftype() -> FType;
}
macro_rules! impl_ftype {
($Fx: ty, $ftype: ident) => {
impl $crate::cc::traits::HasFType for $Fx {
fn ftype() -> $crate::cc::feature::FType {
$crate::cc::feature::FType::$ftype
}
}
};
}
impl_ftype!(Poisson, Count);
impl_ftype!(Gaussian, Continuous);
impl_ftype!(Bernoulli, Binary);
impl_ftype!(Categorical, Categorical);
pub trait LaceDatum:
Sync + Serialize + DeserializeOwned + Default + Clone + Debug
{
}
impl<X> LaceDatum for X where
X: Sync + Serialize + DeserializeOwned + Default + Clone + Debug
{
}
pub trait LaceStat:
Sync + Serialize + DeserializeOwned + Debug + Clone + PartialEq
{
}
impl<X> LaceStat for X where
X: Sync + Serialize + DeserializeOwned + Debug + Clone + PartialEq
{
}
pub trait LaceLikelihood<X: LaceDatum>:
Sampleable<X>
+ HasDensity<X>
+ HasFType
+ TranslateDatum<X>
+ TranslateContainer<X>
+ Mode<X>
+ AccumScore<X>
+ HasSuffStat<X>
+ Serialize
+ DeserializeOwned
+ Sync
+ Into<Component>
+ Clone
+ Debug
+ PartialEq
{
fn ln_f_max(&self) -> Option<f64> {
self.mode().map(|x| self.ln_f(&x))
}
}
impl<X, Fx> LaceLikelihood<X> for Fx
where
X: LaceDatum,
Fx: Sampleable<X>
+ HasDensity<X>
+ HasFType
+ TranslateDatum<X>
+ TranslateContainer<X>
+ Mode<X>
+ AccumScore<X>
+ HasSuffStat<X>
+ Serialize
+ DeserializeOwned
+ Sync
+ Into<Component>
+ Clone
+ Debug
+ PartialEq,
Fx::Stat: Sync + Serialize + DeserializeOwned + Clone + Debug,
{
}
pub trait LacePrior<X: LaceDatum, Fx: LaceLikelihood<X>, H>:
ConjugatePrior<X, Fx>
+ HasDensity<Fx>
+ UpdatePrior<X, Fx, H>
+ Serialize
+ DeserializeOwned
+ Sync
+ Clone
+ Debug
{
fn empty_suffstat(&self) -> Fx::Stat;
fn invalid_temp_component(&self) -> Fx;
fn score_column<I: Iterator<Item = Fx::Stat>>(&self, stats: I) -> f64;
}
impl LacePrior<u32, Categorical, CsdHyper> for SymmetricDirichlet {
fn empty_suffstat(&self) -> CategoricalSuffStat {
CategoricalSuffStat::new(self.k())
}
fn invalid_temp_component(&self) -> Categorical {
Categorical::new_unchecked(vec![0.0; self.k()])
}
fn score_column<I: Iterator<Item = CategoricalSuffStat>>(
&self,
stats: I,
) -> f64 {
let sum_alpha = self.alpha() * self.k() as f64;
let a = ::special::Gamma::ln_gamma(sum_alpha).0;
let d = ::special::Gamma::ln_gamma(self.alpha()).0 * self.k() as f64;
stats
.map(|stat| {
let b =
::special::Gamma::ln_gamma(sum_alpha + stat.n() as f64).0;
let c = stat.counts().iter().fold(0.0, |acc, &ct| {
acc + ::special::Gamma::ln_gamma(self.alpha() + ct).0
});
a - b + c - d
})
.sum::<f64>()
}
}
#[inline]
fn poisson_zn(shape: f64, rate: f64, stat: &PoissonSuffStat) -> f64 {
let shape_n = shape + stat.sum();
let rate_n = rate + stat.n() as f64;
let ln_gamma_shape = ::special::Gamma::ln_gamma(shape_n).0;
let ln_rate = rate_n.ln();
shape_n.mul_add(-ln_rate, ln_gamma_shape)
}
impl LacePrior<u32, Poisson, PgHyper> for Gamma {
fn empty_suffstat(&self) -> PoissonSuffStat {
PoissonSuffStat::new()
}
fn invalid_temp_component(&self) -> Poisson {
Poisson::new_unchecked(1.0)
}
fn score_column<I: Iterator<Item = PoissonSuffStat>>(
&self,
stats: I,
) -> f64 {
let shape = self.shape();
let rate = self.rate();
let z0 = {
let ln_gamma_shape = ::special::Gamma::ln_gamma(shape).0;
let ln_rate = rate.ln();
shape.mul_add(-ln_rate, ln_gamma_shape)
};
stats
.map(|stat| {
let zn = poisson_zn(shape, rate, &stat);
zn - z0 - stat.sum_ln_fact()
})
.sum::<f64>()
}
}
impl LacePrior<bool, Bernoulli, ()> for Beta {
fn empty_suffstat(&self) -> BernoulliSuffStat {
BernoulliSuffStat::new()
}
fn invalid_temp_component(&self) -> Bernoulli {
Bernoulli::uniform()
}
fn score_column<I: Iterator<Item = BernoulliSuffStat>>(
&self,
stats: I,
) -> f64 {
use rv::data::DataOrSuffStat;
let cache = <Beta as ConjugatePrior<bool, Bernoulli>>::ln_m_cache(self);
stats
.map(|stat| {
let x = DataOrSuffStat::SuffStat::<bool, Bernoulli>(&stat);
self.ln_m_with_cache(&cache, &x)
})
.sum::<f64>()
}
}
impl LacePrior<f64, Gaussian, NixHyper> for NormalInvChiSquared {
fn empty_suffstat(&self) -> GaussianSuffStat {
GaussianSuffStat::new()
}
fn invalid_temp_component(&self) -> Gaussian {
Gaussian::standard()
}
fn score_column<I: Iterator<Item = GaussianSuffStat>>(
&self,
stats: I,
) -> f64 {
use rv::data::DataOrSuffStat;
let cache =
<NormalInvChiSquared as ConjugatePrior<f64, Gaussian>>::ln_m_cache(
self,
);
stats
.map(|stat| {
let x = DataOrSuffStat::SuffStat(&stat);
<NormalInvChiSquared as ConjugatePrior<f64, Gaussian>>::ln_m_with_cache(self, &cache, &x)
})
.sum::<f64>()
}
}