use rand::Rng;
use rand::SeedableRng;
use rand_xoshiro::Xoshiro256Plus;
use rv::dist::Categorical;
use rv::dist::Gamma;
use rv::dist::Gaussian;
use rv::dist::NormalInvChiSquared;
use rv::dist::Poisson;
use rv::traits::*;
use thiserror::Error;
use crate::cc::feature::ColModel;
use crate::cc::feature::Column;
use crate::cc::feature::Feature;
use crate::cc::state::State;
use crate::codebook::ColType;
use crate::data::SparseContainer;
use crate::stats::prior::csd::CsdHyper;
use crate::stats::prior::nix::NixHyper;
use crate::stats::prior::pg::PgHyper;
use crate::stats::prior_process::Builder as AssignmentBuilder;
use crate::stats::prior_process::Process;
#[derive(Debug, Clone, Default)]
pub struct Builder {
pub n_rows: Option<usize>,
pub n_views: Option<usize>,
pub n_cats: Option<usize>,
pub col_configs: Option<Vec<ColType>>,
pub ftrs: Option<Vec<ColModel>>,
pub seed: Option<u64>,
pub prior_process: Option<Process>,
}
#[derive(Debug, Error, PartialEq)]
pub enum BuildStateError {
#[error("Supply either features or column configs; not both")]
BothColumnConfigsAndFeaturesPresent,
#[error("No column configs or features supplied")]
NeitherColumnConfigsAndFeaturesPresent,
}
impl Builder {
pub fn new() -> Self {
Builder::default()
}
#[must_use]
pub fn n_rows(mut self, n_rows: usize) -> Self {
self.n_rows = Some(n_rows);
self
}
#[must_use]
pub fn n_views(mut self, n_views: usize) -> Self {
self.n_views = Some(n_views);
self
}
#[must_use]
pub fn n_cats(mut self, n_cats: usize) -> Self {
self.n_cats = Some(n_cats);
self
}
#[must_use]
pub fn features(mut self, ftrs: Vec<ColModel>) -> Self {
self.ftrs = Some(ftrs);
self
}
#[must_use]
pub fn column_config(mut self, col_config: ColType) -> Self {
if let Some(ref mut col_configs) = self.col_configs {
col_configs.push(col_config);
} else {
self.col_configs = Some(vec![col_config]);
}
self
}
#[must_use]
pub fn column_configs(mut self, n: usize, col_config: ColType) -> Self {
if let Some(ref mut col_configs) = self.col_configs {
col_configs.append(&mut vec![col_config; n]);
} else {
self.col_configs = Some(vec![col_config; n]);
}
self
}
#[must_use]
pub fn seed_from_rng<R: Rng>(mut self, rng: &mut R) -> Self {
self.seed = Some(rng.next_u64());
self
}
#[must_use]
pub fn seed_from_u64(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
#[must_use]
pub fn prior_process(mut self, process: Process) -> Self {
self.prior_process = Some(process);
self
}
pub fn build(self) -> Result<State, BuildStateError> {
let mut rng = match self.seed {
Some(seed) => Xoshiro256Plus::seed_from_u64(seed),
None => Xoshiro256Plus::from_os_rng(),
};
let n_rows = self
.ftrs
.as_ref()
.map(|ftrs| ftrs[0].len())
.or(self.n_rows)
.unwrap_or(100);
let n_views = self.n_views.unwrap_or(1);
let n_cats = self.n_cats.unwrap_or(1);
if self.col_configs.is_some() && self.ftrs.is_some() {
return Err(BuildStateError::BothColumnConfigsAndFeaturesPresent);
} else if self.col_configs.is_none() && self.ftrs.is_none() {
return Err(
BuildStateError::NeitherColumnConfigsAndFeaturesPresent,
);
}
let mut ftrs = if self.col_configs.is_some() {
self.col_configs
.unwrap()
.iter()
.enumerate()
.map(|(id, col_config)| {
gen_feature(
id,
col_config.clone(),
n_rows,
n_cats,
&mut rng,
)
})
.collect()
} else {
self.ftrs.unwrap()
};
let mut col_asgn: Vec<usize> = vec![];
let mut col_counts: Vec<usize> = vec![];
let ftrs_per_view = ftrs.len() / n_views;
let views = (0..n_views)
.map(|view_ix| {
let ftrs_left = ftrs.len();
let to_drain = if view_ix == n_views - 1 {
ftrs_left
} else {
ftrs_per_view
};
col_asgn.append(&mut vec![view_ix; to_drain]);
col_counts.push(to_drain);
let ftrs_view = ftrs.drain(0..to_drain).collect();
let prior_process = AssignmentBuilder::new(n_rows)
.with_n_cats(n_cats)
.unwrap()
.seed_from_rng(&mut rng)
.build()
.unwrap();
crate::cc::view::Builder::from_prior_process(prior_process)
.features(ftrs_view)
.seed_from_rng(&mut rng)
.build()
})
.collect();
assert_eq!(ftrs.len(), 0);
let process = self.prior_process.unwrap_or_else(|| {
Process::Dirichlet(
crate::stats::prior_process::Dirichlet::from_prior(
crate::consts::state_alpha_prior(),
&mut rng,
),
)
});
let process = AssignmentBuilder::from_vec(col_asgn)
.seed_from_rng(&mut rng)
.with_process(process)
.build()
.unwrap();
Ok(State::new(views, process))
}
}
fn gen_feature<R: Rng>(
id: usize,
col_config: ColType,
n_rows: usize,
n_cats: usize,
rng: &mut R,
) -> ColModel {
match col_config {
ColType::Continuous { .. } => {
let hyper = NixHyper::default();
let prior = NormalInvChiSquared::new_unchecked(0.0, 1.0, 4.0, 4.0);
let g = Gaussian::standard();
let xs: Vec<f64> = g.sample(n_rows, rng);
let data = SparseContainer::from(xs);
let col = Column::new(id, data, prior, hyper);
ColModel::Continuous(col)
}
ColType::Count { .. } => {
let hyper = PgHyper::default();
let prior = Gamma::new_unchecked(1.0, 1.0);
let pois = Poisson::new_unchecked(1.0);
let xs: Vec<u32> = pois.sample(n_rows, rng);
let data = SparseContainer::from(xs);
let col = Column::new(id, data, prior, hyper);
ColModel::Count(col)
}
ColType::Categorical { k, .. } => {
let hyper = CsdHyper::vague(k);
let prior = hyper.draw(k, rng);
let components: Vec<Categorical> =
(0..n_cats).map(|_| prior.draw(rng)).collect();
let xs: Vec<u32> = (0..n_rows)
.map::<u32, _>(|i| components[i % n_cats].draw::<R>(rng))
.collect();
let data = SparseContainer::from(xs);
let col = Column::new(id, data, prior, hyper);
ColModel::Categorical(col)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cc::config::StateUpdateConfig;
#[test]
fn test_dimensions() {
let state = Builder::new()
.column_configs(
10,
ColType::Continuous {
hyper: None,
prior: None,
},
)
.n_rows(50)
.build()
.expect("Failed to build state");
assert_eq!(state.n_rows(), 50);
assert_eq!(state.n_cols(), 10);
}
#[test]
fn built_state_should_update() {
let mut rng = Xoshiro256Plus::from_os_rng();
let mut state = Builder::new()
.column_configs(
10,
ColType::Continuous {
hyper: None,
prior: None,
},
)
.n_rows(50)
.seed_from_rng(&mut rng)
.build()
.expect("Failed to build state");
let config = StateUpdateConfig {
n_iters: 5,
..Default::default()
};
state.update(config, &mut rng);
}
#[test]
fn seeding_state_works() {
let state_1 = {
let mut rng = Xoshiro256Plus::seed_from_u64(122_445);
Builder::new()
.column_configs(
10,
ColType::Continuous {
hyper: None,
prior: None,
},
)
.n_rows(50)
.seed_from_rng(&mut rng)
.build()
.expect("Failed to build state")
};
let state_2 = {
let mut rng = Xoshiro256Plus::seed_from_u64(122_445);
Builder::new()
.column_configs(
10,
ColType::Continuous {
hyper: None,
prior: None,
},
)
.n_rows(50)
.seed_from_rng(&mut rng)
.build()
.expect("Failed to build state")
};
assert_eq!(state_1.asgn().asgn, state_2.asgn().asgn);
for (view_1, view_2) in state_1.views.iter().zip(state_2.views.iter()) {
assert_eq!(view_1.asgn().asgn, view_2.asgn().asgn);
}
}
#[test]
fn n_rows_overridden_by_features() {
let n_cols = 5;
let col_models = {
let state = Builder::new()
.column_configs(
n_cols,
ColType::Continuous {
hyper: None,
prior: None,
},
)
.n_rows(11)
.build()
.unwrap();
(0..n_cols)
.map(|ix| state.feature(ix))
.cloned()
.collect::<Vec<_>>()
};
let state = Builder::new()
.features(col_models)
.n_rows(101)
.build()
.unwrap();
assert_eq!(state.n_rows(), 11);
}
}