use enum_dispatch::enum_dispatch;
use rand::Rng;
use rv::dist::Categorical;
use rv::dist::Gamma;
use rv::dist::Gaussian;
use rv::dist::NormalInvChiSquared;
use rv::dist::Poisson;
use rv::dist::SymmetricDirichlet;
use super::Component;
use crate::cc::feature::ColModel;
use crate::cc::feature::Column;
use crate::cc::feature::FType;
use crate::data::Datum;
use crate::data::FeatureData;
use crate::stats::assignment::Assignment;
use crate::stats::prior::csd::CsdHyper;
use crate::stats::prior::nix::NixHyper;
use crate::stats::prior::pg::PgHyper;
use crate::stats::MixtureType;
#[enum_dispatch(ColModel)]
pub trait Feature {
fn id(&self) -> usize;
fn set_id(&mut self, id: usize);
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
fn k(&self) -> usize;
fn accum_score(&self, scores: &mut [f64], k: usize);
fn init_components(&mut self, k: usize, rng: &mut impl Rng);
fn update_components(&mut self, rng: &mut impl Rng);
fn reassign(&mut self, asgn: &Assignment, rng: &mut impl Rng);
fn score(&self) -> f64;
fn asgn_score(&self, asgn: &Assignment) -> f64;
fn update_prior_params(&mut self, rng: &mut impl Rng) -> f64;
fn append_empty_component(&mut self, rng: &mut impl Rng);
fn drop_component(&mut self, k: usize);
fn predictive_score_at(&self, row_ix: usize, k: usize) -> f64;
fn singleton_score(&self, row_ix: usize) -> f64;
fn logm(&self, k: usize) -> f64;
fn observe_datum(&mut self, row_ix: usize, k: usize);
fn forget_datum(&mut self, row_ix: usize, k: usize);
fn append_datum(&mut self, x: Datum);
fn insert_datum(&mut self, row_ix: usize, x: Datum);
fn is_missing(&self, ix: usize) -> bool;
fn is_present(&self, ix: usize) -> bool {
!self.is_missing(ix)
}
fn datum(&self, ix: usize) -> Datum;
fn take_data(&mut self) -> FeatureData;
fn take_datum(&mut self, row_ix: usize, k: usize) -> Option<Datum>;
fn clone_data(&self) -> FeatureData;
fn draw(&self, k: usize, rng: &mut impl Rng) -> Datum;
fn repop_data(&mut self, data: FeatureData);
#[allow(clippy::ptr_arg)]
fn accum_weights(
&self,
datum: &Datum,
weights: &mut Vec<f64>,
scaled: bool,
);
#[allow(clippy::ptr_arg)]
fn accum_exp_weights(&self, datum: &Datum, weights: &mut Vec<f64>);
fn cpnt_logp(&self, datum: &Datum, k: usize) -> f64;
fn cpnt_likelihood(&self, datum: &Datum, k: usize) -> f64;
fn ftype(&self) -> FType;
fn component(&self, k: usize) -> Component;
fn to_mixture(&self, weights: Vec<f64>) -> MixtureType;
fn geweke_init<R: Rng>(&mut self, asgn: &Assignment, rng: &mut R);
}
#[enum_dispatch(ColModel)]
pub(crate) trait FeatureHelper: Feature {
fn del_datum(&mut self, ix: usize);
}
#[cfg(test)]
mod tests {
use approx::*;
use rv::dist::Gaussian;
use rv::traits::Sampleable;
use super::*;
use crate::data::SparseContainer;
use crate::stats::prior_process::Builder as PriorProcessBuilder;
#[test]
fn score_and_asgn_score_equivalency() {
let n_rows = 100;
let mut rng = rand::rng();
let g = Gaussian::standard();
let hyper = NixHyper::default();
let prior = NormalInvChiSquared::new_unchecked(0.0, 1.0, 1.0, 1.0);
for _ in 0..100 {
let asgn = PriorProcessBuilder::new(n_rows).build().unwrap().asgn;
let xs: Vec<f64> = g.sample(n_rows, &mut rng);
let data = SparseContainer::from(xs);
let mut feature =
Column::new(0, data, prior.clone(), hyper.clone());
feature.reassign(&asgn, &mut rng);
assert_relative_eq!(
feature.score(),
feature.asgn_score(&asgn),
epsilon = 1E-8
);
}
}
}