mod enum_test;
use std::collections::BTreeMap;
use enum_test::build_features;
use enum_test::normalize_assignment;
use enum_test::partition_to_ix;
use enum_test::Partition;
use lace::cc::alg::RowAssignAlg;
use lace::cc::feature::ColModel;
use lace::cc::feature::FType;
use lace::cc::feature::Feature;
use lace::cc::transition::ViewTransition;
use lace::cc::view::Builder;
use lace::cc::view::View;
use lace::stats::prior_process::Builder as PriorProcessBuilder;
use lace::stats::prior_process::Dirichlet;
use lace::stats::prior_process::PitmanYor;
use lace::stats::prior_process::Process;
use rand::Rng;
use rv::dist::Beta;
use rv::dist::Gamma;
use rv::misc::LogSumExp;
const N_TRIES: u32 = 5;
#[derive(Clone, Copy, Debug)]
pub enum ProcessType {
Dirichlet,
PitmanYor,
}
impl From<ProcessType> for Process {
fn from(proc: ProcessType) -> Self {
match proc {
ProcessType::Dirichlet => Process::Dirichlet(Dirichlet {
alpha: 1.0,
alpha_prior: Gamma::default(),
}),
ProcessType::PitmanYor => Process::PitmanYor(PitmanYor {
alpha: 1.2,
d: 0.2,
alpha_prior: Gamma::default(),
d_prior: Beta::jeffreys(),
}),
}
}
}
#[allow(clippy::ptr_arg)]
fn calc_partition_ln_posterior<R: Rng>(
features: &Vec<ColModel>,
proc_type: ProcessType,
mut rng: &mut R,
) -> BTreeMap<u64, f64> {
let n = features[0].len();
let mut ln_posterior: BTreeMap<u64, f64> = BTreeMap::new();
Partition::new(n).for_each(|z| {
let ix = partition_to_ix(&z);
let prior_process = PriorProcessBuilder::from_vec(z)
.with_process(proc_type.into())
.seed_from_rng(&mut rng)
.build()
.unwrap();
let ln_pz = prior_process.ln_f_partition(&prior_process.asgn);
let view: View = Builder::from_prior_process(prior_process)
.features(features.clone())
.seed_from_rng(&mut rng)
.build();
ln_posterior.insert(ix, view.score() + ln_pz);
});
ln_posterior
}
fn norm_posterior(ln_posterior: &BTreeMap<u64, f64>) -> BTreeMap<u64, f64> {
let logps: Vec<f64> = ln_posterior.values().copied().collect();
let z = logps.iter().logsumexp();
let mut normed: BTreeMap<u64, f64> = BTreeMap::new();
for (key, lp) in ln_posterior {
normed.insert(*key, (lp - z).exp());
}
normed
}
pub fn view_enum_test(
n_rows: usize,
n_cols: usize,
n_runs: usize,
n_iters: usize,
ftype: FType,
row_alg: RowAssignAlg,
proc_type: ProcessType,
) -> f64 {
let mut rng = rand::rng();
let features = build_features(n_rows, n_cols, ftype, &mut rng);
let ln_posterior =
calc_partition_ln_posterior(&features, proc_type, &mut rng);
let posterior = norm_posterior(&ln_posterior);
let transitions: Vec<ViewTransition> = vec![
ViewTransition::RowAssignment(row_alg),
ViewTransition::ComponentParams,
];
let mut est_posterior: BTreeMap<u64, f64> = BTreeMap::new();
let inc: f64 = ((n_runs * n_iters) as f64).recip();
for _ in 0..n_runs {
let prior_process = PriorProcessBuilder::new(n_rows)
.with_process(proc_type.into())
.seed_from_rng(&mut rng)
.build()
.unwrap();
let mut view = Builder::from_prior_process(prior_process)
.features(features.clone())
.seed_from_rng(&mut rng)
.build();
for _ in 0..n_iters {
view.update(10, &transitions, &mut rng);
let normed = normalize_assignment(view.asgn().asgn.clone());
let ix = partition_to_ix(&normed);
if !posterior.contains_key(&ix) {
panic!("invalid index!\n{:?}\n{:?}", view.asgn().asgn, normed);
}
*est_posterior.entry(ix).or_insert(0.0) += inc;
}
}
assert!(!est_posterior.keys().any(|k| !posterior.contains_key(k)));
let mut cdf = 0.0;
let mut est_cdf = 0.0;
posterior.iter().fold(0.0, |err, (key, &p)| {
cdf += p;
if est_posterior.contains_key(key) {
est_cdf += est_posterior[key];
}
err + (cdf - est_cdf).abs()
}) / posterior.len() as f64
}
fn flaky_test_passes<F>(n_tries: u32, test_fn: F) -> bool
where
F: Fn() -> bool,
{
for _ in 0..n_tries {
if test_fn() {
return true;
}
}
false
}
macro_rules! view_enum_test {
($ftype: ident, $proctype: ident, $row_alg: ident) => {
#[test]
fn $row_alg() {
fn test_fn() -> bool {
let err = view_enum_test(
4,
1,
1,
5_000,
FType::$ftype,
RowAssignAlg::$row_alg,
ProcessType::$proctype,
);
eprintln!("err: {}", err);
err < 0.01
}
assert!(flaky_test_passes(N_TRIES, test_fn));
}
};
($modname: ident, $ftype: ident, $proctype: ident, [$($row_alg: ident),+]) => {
#[allow(non_snake_case)]
mod $modname {
use super::*;
$(
view_enum_test!($ftype, $proctype, $row_alg);
)+
}
};
($(($modname: ident, $ftype: ident, $proctype: ident, $row_algs: tt)),+) => {
$(
view_enum_test!($modname, $ftype, $proctype, $row_algs);
)+
};
}
view_enum_test!(
(
ve_continuous_dp,
Continuous,
Dirichlet,
[Gibbs, Slice, Sams]
),
(
ve_continuous_pyp,
Continuous,
PitmanYor,
[Gibbs, Slice, Sams]
),
(
ve_categorical_dp,
Categorical,
Dirichlet,
[Gibbs, Slice, Sams]
),
(
ve_categorical_pyp,
Categorical,
PitmanYor,
[Gibbs, Slice, Sams]
),
(ve_count_dp, Count, Dirichlet, [Gibbs, Slice, Sams]),
(ve_count_pyp, Count, PitmanYor, [Gibbs, Slice, Sams])
);