use lace_data::SparseContainer;
use lace_stats::rv::data::DataOrSuffStat;
use lace_stats::rv::traits::*;
use once_cell::sync::OnceCell;
use rand::Rng;
use serde::{Deserialize, Serialize};
use crate::feature::Component;
use crate::traits::AccumScore;
use crate::traits::{LaceDatum, LaceLikelihood, LaceStat};
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub struct ConjugateComponent<X, Fx, Pr>
where
X: LaceDatum,
Fx: LaceLikelihood<X>,
Fx::Stat: LaceStat,
Pr: ConjugatePrior<X, Fx>,
Pr::LnPpCache: Send + Sync + Clone + std::fmt::Debug,
{
#[serde(bound(deserialize = "Fx: serde::de::DeserializeOwned"))]
pub fx: Fx,
#[serde(bound(deserialize = "Fx: serde::de::DeserializeOwned"))]
pub stat: Fx::Stat,
#[serde(skip)]
pub ln_pp_cache: OnceCell<Pr::LnPpCache>,
}
impl<X, Fx, Pr> AccumScore<X> for ConjugateComponent<X, Fx, Pr>
where
X: LaceDatum,
Fx: LaceLikelihood<X>,
Fx::Stat: LaceStat,
Pr: ConjugatePrior<X, Fx>,
Pr::LnPpCache: Send + Sync + Clone + std::fmt::Debug,
{
fn accum_score(&self, scores: &mut [f64], container: &SparseContainer<X>) {
self.fx.accum_score(scores, container)
}
}
impl<X, Fx, Pr> ConjugateComponent<X, Fx, Pr>
where
X: LaceDatum,
Fx: LaceLikelihood<X>,
Fx::Stat: LaceStat,
Pr: ConjugatePrior<X, Fx>,
Pr::LnPpCache: Send + Sync + Clone + std::fmt::Debug,
{
#[inline]
pub fn new(fx: Fx) -> Self {
let stat = fx.empty_suffstat();
ConjugateComponent {
fx,
stat,
ln_pp_cache: OnceCell::new(),
}
}
#[inline]
pub fn obs(&self) -> DataOrSuffStat<X, Fx> {
DataOrSuffStat::SuffStat(&self.stat)
}
#[inline]
pub fn reset_ln_pp_cache(&mut self) {
self.ln_pp_cache = OnceCell::new()
}
#[inline]
pub fn ln_pp_cache(&self, prior: &Pr) -> &Pr::LnPpCache {
self.ln_pp_cache
.get_or_init(|| prior.ln_pp_cache(&self.obs()))
}
}
impl<X, Fx, Pr> Rv<X> for ConjugateComponent<X, Fx, Pr>
where
X: LaceDatum,
Fx: LaceLikelihood<X>,
Fx::Stat: LaceStat,
Pr: ConjugatePrior<X, Fx>,
Pr::LnPpCache: Send + Sync + Clone + std::fmt::Debug,
{
fn ln_f(&self, x: &X) -> f64 {
self.fx.ln_f(x)
}
fn draw<R: Rng>(&self, mut rng: &mut R) -> X {
self.fx.draw(&mut rng)
}
fn sample<R: Rng>(&self, n: usize, mut rng: &mut R) -> Vec<X> {
self.fx.sample(n, &mut rng)
}
}
impl<X, Fx, Pr> Mode<X> for ConjugateComponent<X, Fx, Pr>
where
X: LaceDatum,
Fx: LaceLikelihood<X> + Mode<X>,
Fx::Stat: LaceStat,
Pr: ConjugatePrior<X, Fx>,
Pr::LnPpCache: Send + Sync + Clone + std::fmt::Debug,
{
fn mode(&self) -> Option<X> {
self.fx.mode()
}
}
impl<X, Fx, Pr> Entropy for ConjugateComponent<X, Fx, Pr>
where
X: LaceDatum,
Fx: LaceLikelihood<X> + Entropy,
Fx::Stat: LaceStat,
Pr: ConjugatePrior<X, Fx>,
Pr::LnPpCache: Send + Sync + Clone + std::fmt::Debug,
{
fn entropy(&self) -> f64 {
self.fx.entropy()
}
}
impl<X, Fx, Pr> SuffStat<X> for ConjugateComponent<X, Fx, Pr>
where
X: LaceDatum,
Fx: LaceLikelihood<X>,
Fx::Stat: LaceStat,
Pr: ConjugatePrior<X, Fx>,
Pr::LnPpCache: Send + Sync + Clone + std::fmt::Debug,
{
fn n(&self) -> usize {
self.stat.n()
}
fn observe(&mut self, x: &X) {
self.reset_ln_pp_cache();
self.stat.observe(x);
}
fn forget(&mut self, x: &X) {
self.reset_ln_pp_cache();
self.stat.forget(x);
}
}
impl<X, Fx, Pr> From<ConjugateComponent<X, Fx, Pr>> for Component
where
X: LaceDatum,
Fx: LaceLikelihood<X>,
Fx::Stat: LaceStat,
Pr: ConjugatePrior<X, Fx>,
Pr::LnPpCache: Send + Sync + Clone + std::fmt::Debug,
{
fn from(cpnt: ConjugateComponent<X, Fx, Pr>) -> Component {
cpnt.fx.into()
}
}