use std::path::PathBuf;
use rand::Rng;
use serde::Serialize;
use thiserror::Error;
use crate::cc::alg::ColAssignAlg;
use crate::cc::alg::RowAssignAlg;
use crate::cc::config::StateUpdateConfig;
use crate::cc::state::BuildStateError;
use crate::cc::state::Builder;
use crate::cc::state::State;
use crate::cc::transition::StateTransition;
use crate::codebook::Codebook;
use crate::defaults;
#[derive(Debug, Clone)]
enum BencherSetup {
Csv {
codebook: Box<Codebook>,
path: PathBuf,
},
Builder(Builder),
}
#[derive(Debug, Error)]
pub enum GenerateStateError {
#[error("error parsing csv: {0}")]
Parse(#[from] crate::error::DataParseError),
#[error("csv error: {0}")]
Read(#[from] crate::codebook::ReadError),
#[error("error building state: {0}")]
BuildState(#[from] BuildStateError),
}
fn emit_prior_process<R: rand::Rng>(
prior_process: crate::codebook::PriorProcess,
rng: &mut R,
) -> crate::stats::prior_process::Process {
use crate::stats::prior_process::Dirichlet;
use crate::stats::prior_process::PitmanYor;
use crate::stats::prior_process::Process;
match prior_process {
crate::codebook::PriorProcess::Dirichlet { alpha_prior } => {
let inner = Dirichlet::from_prior(alpha_prior, rng);
Process::Dirichlet(inner)
}
crate::codebook::PriorProcess::PitmanYor {
alpha_prior,
d_prior,
} => {
let inner = PitmanYor::from_prior(alpha_prior, d_prior, rng);
Process::PitmanYor(inner)
}
}
}
impl BencherSetup {
fn gen_state(
&mut self,
mut rng: &mut impl Rng,
) -> Result<State, GenerateStateError> {
match self {
BencherSetup::Csv {
ref mut codebook,
path,
} => crate::codebook::data::read_csv(path)
.map_err(GenerateStateError::Read)
.and_then(|df| {
let state_prior_process = {
let prior_process = codebook
.state_prior_process
.clone()
.unwrap_or_default();
emit_prior_process(prior_process, rng)
};
let view_prior_process = {
let prior_process = codebook
.view_prior_process
.clone()
.unwrap_or_default();
emit_prior_process(prior_process, rng)
};
let mut codebook_tmp = Box::<Codebook>::default();
std::mem::swap(codebook, &mut codebook_tmp);
crate::data::df_to_col_models(*codebook_tmp, df, &mut rng)
.map(|(cb, features)| {
std::mem::swap(codebook, &mut Box::new(cb));
State::from_prior(
features,
state_prior_process,
view_prior_process,
&mut rng,
)
})
.map_err(GenerateStateError::Parse)
}),
BencherSetup::Builder(state_builder) => state_builder
.clone()
.seed_from_u64(rng.next_u64())
.build()
.map_err(GenerateStateError::BuildState),
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct BencherResult {
pub time_sec: Vec<f64>,
pub score: Vec<f64>,
}
#[derive(Debug, Clone)]
pub struct Bencher {
setup: BencherSetup,
n_runs: usize,
n_iters: usize,
col_asgn_alg: ColAssignAlg,
row_asgn_alg: RowAssignAlg,
config: Option<StateUpdateConfig>,
}
impl Bencher {
#[must_use]
pub fn from_csv(codebook: Codebook, path: PathBuf) -> Self {
Self {
setup: BencherSetup::Csv {
codebook: Box::new(codebook),
path,
},
n_runs: 1,
n_iters: 100,
col_asgn_alg: defaults::COL_ASSIGN_ALG,
row_asgn_alg: defaults::ROW_ASSIGN_ALG,
config: None,
}
}
#[must_use]
pub fn from_builder(state_builder: Builder) -> Self {
Self {
setup: BencherSetup::Builder(state_builder),
n_runs: 1,
n_iters: 100,
col_asgn_alg: defaults::COL_ASSIGN_ALG,
row_asgn_alg: defaults::ROW_ASSIGN_ALG,
config: None,
}
}
#[must_use]
pub fn n_runs(mut self, n_runs: usize) -> Self {
self.n_runs = n_runs;
self
}
#[must_use]
pub fn n_iters(mut self, n_iters: usize) -> Self {
self.n_iters = n_iters;
self
}
#[must_use]
pub fn row_assign_alg(mut self, alg: RowAssignAlg) -> Self {
self.row_asgn_alg = alg;
self
}
#[must_use]
pub fn col_assign_alg(mut self, alg: ColAssignAlg) -> Self {
self.col_asgn_alg = alg;
self
}
#[must_use]
pub fn update_config(mut self, config: StateUpdateConfig) -> Self {
config.transitions.iter().for_each(|&t| {
if let StateTransition::ColumnAssignment(alg) = t {
self.col_asgn_alg = alg;
}
});
config.transitions.iter().for_each(|&t| {
if let StateTransition::RowAssignment(alg) = t {
self.row_asgn_alg = alg;
}
});
self.config = Some(config);
self
}
fn state_config(&self) -> StateUpdateConfig {
let mut config = match self.config {
Some(ref config) => config.clone(),
None => StateUpdateConfig {
n_iters: 1,
..Default::default()
},
};
let transitions = config
.transitions
.iter()
.map(|t| match t {
StateTransition::RowAssignment(_) => {
StateTransition::RowAssignment(self.row_asgn_alg)
}
StateTransition::ColumnAssignment(_) => {
StateTransition::ColumnAssignment(self.col_asgn_alg)
}
_ => *t,
})
.collect::<Vec<_>>();
config.transitions = transitions;
config
}
pub fn run_once(&mut self, mut rng: &mut impl Rng) -> BencherResult {
use std::time::Instant;
let mut state: State = self.setup.gen_state(&mut rng).unwrap();
let config = self.state_config();
let time_sec: Vec<f64> = (0..self.n_iters)
.map(|_| {
let start = Instant::now();
state.update(config.clone(), &mut rng);
start.elapsed().as_secs_f64()
})
.collect();
BencherResult {
time_sec,
score: state.diagnostics.loglike,
}
}
pub fn run(&mut self, mut rng: &mut impl Rng) -> Vec<BencherResult> {
(0..self.n_runs).map(|_| self.run_once(&mut rng)).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::codebook::ColType;
fn quick_bencher() -> Bencher {
let builder = Builder::new()
.column_configs(
5,
ColType::Continuous {
hyper: None,
prior: None,
},
)
.n_rows(50);
Bencher::from_builder(builder).n_runs(5).n_iters(17)
}
#[test]
fn bencher_from_state_builder_should_return_properly_sized_result() {
let mut bencher = quick_bencher();
let mut rng = rand::rng();
let results = bencher.run(&mut rng);
assert_eq!(results.len(), 5);
assert_eq!(results[0].time_sec.len(), 17);
assert_eq!(results[1].time_sec.len(), 17);
assert_eq!(results[2].time_sec.len(), 17);
assert_eq!(results[3].time_sec.len(), 17);
assert_eq!(results[4].time_sec.len(), 17);
assert_eq!(results[0].score.len(), 17);
assert_eq!(results[1].score.len(), 17);
assert_eq!(results[2].score.len(), 17);
assert_eq!(results[3].score.len(), 17);
assert_eq!(results[4].score.len(), 17);
}
}