use crate::stats::{collect_rhat, ChainStats, ChainTracker, RunStats};
use indicatif::ProgressBar;
use indicatif::{MultiProgress, ProgressStyle};
use ndarray::stack;
use ndarray::{prelude::*, LinalgScalar, ShapeError};
use ndarray_stats::QuantileExt;
use num_traits::{Float, FromPrimitive};
use rand::rngs::SmallRng;
use rand::SeedableRng;
use rand_distr::{Distribution, StandardNormal};
use rayon::prelude::*;
use std::cmp::PartialEq;
use std::error::Error;
use std::marker::Send;
use std::sync::mpsc::{self, Receiver, Sender};
use std::thread::{self};
use std::time::{Duration, Instant};
pub trait MarkovChain<T> {
fn step(&mut self) -> &Vec<T>;
fn current_state(&self) -> &Vec<T>;
}
pub fn run_chain<T, M>(chain: &mut M, n_collect: usize, n_discard: usize) -> Array2<T>
where
M: MarkovChain<T>,
T: LinalgScalar,
{
let dim = chain.current_state().len();
let mut out = Array2::<T>::zeros((n_collect, dim));
let total = n_collect + n_discard;
for i in 0..total {
let state = chain.step();
if i >= n_discard {
let state_arr = ArrayView::from_shape(state.len(), state.as_slice()).unwrap();
out.row_mut(i - n_discard).assign(&state_arr);
}
}
out
}
pub fn run_chain_progress<T, M>(
chain: &mut M,
n_collect: usize,
n_discard: usize,
tx: Sender<ChainStats>,
) -> Result<Array2<T>, String>
where
M: MarkovChain<T>,
T: LinalgScalar + PartialEq + num_traits::ToPrimitive,
{
let n_params = chain.current_state().len();
let mut out = Array2::<T>::zeros((n_collect, n_params));
let mut tracker = ChainTracker::new(n_params, chain.current_state());
let mut last = Instant::now();
let freq = Duration::from_secs(1);
let total = n_discard + n_collect;
for i in 0..total {
let current_state = chain.step();
tracker.step(current_state).map_err(|e| {
let msg = format!(
"Chain statistics tracker caused error: {}.\nAborting generation of further observations.",
e
);
println!("{}", msg);
msg
})?;
let now = Instant::now();
if (now >= last + freq) | (i == total - 1) {
if let Err(e) = tx.send(tracker.stats()) {
eprintln!("Sending chain statistics failed: {e}");
}
last = now;
}
if i >= n_discard {
out.row_mut(i - n_discard).assign(
&ArrayView1::from_shape(current_state.len(), current_state.as_slice()).unwrap(),
);
}
}
Ok(out)
}
pub trait HasChains<S> {
type Chain: MarkovChain<S> + Send;
fn chains_mut(&mut self) -> &mut Vec<Self::Chain>;
}
pub trait ChainRunner<T>: HasChains<T>
where
T: LinalgScalar + PartialEq + Send + num_traits::ToPrimitive,
{
fn run(&mut self, n_collect: usize, n_discard: usize) -> Result<Array3<T>, ShapeError> {
let results: Vec<Array2<T>> = self
.chains_mut()
.par_iter_mut()
.map(|chain| run_chain(chain, n_collect, n_discard))
.collect();
let views: Vec<ArrayView2<T>> = results.iter().map(|x| x.view()).collect();
let out: Array3<T> = stack(Axis(0), &views)?;
Ok(out)
}
fn run_progress(
&mut self,
n_collect: usize,
n_discard: usize,
) -> Result<(Array3<T>, RunStats), Box<dyn Error>> {
let chains = self.chains_mut();
let mut rxs: Vec<Receiver<ChainStats>> = vec![];
let mut txs: Vec<Sender<ChainStats>> = vec![];
(0..chains.len()).for_each(|_| {
let (tx, rx) = mpsc::channel();
rxs.push(rx);
txs.push(tx);
});
let progress_handle = thread::spawn(move || {
let sleep_ms = Duration::from_millis(250);
let timeout_ms = Duration::from_millis(0);
let multi = MultiProgress::new();
let pb_style = ProgressStyle::default_bar()
.template("{prefix:8} {bar:40.cyan/blue} {pos}/{len} ({eta}) | {msg}")
.unwrap()
.progress_chars("=>-");
let total: u64 = (n_collect + n_discard).try_into().unwrap();
let global_pb = multi.add(ProgressBar::new((rxs.len() as u64) * total));
global_pb.set_style(pb_style.clone());
global_pb.set_prefix("Global");
let mut active: Vec<(usize, ProgressBar)> = (0..rxs.len().min(5))
.map(|chain_idx| {
let pb = multi.add(ProgressBar::new(total));
pb.set_style(pb_style.clone());
pb.set_prefix(format!("Chain {chain_idx}"));
(chain_idx, pb)
})
.collect();
let mut next_active = active.len();
let mut n_finished = 0;
let mut most_recent = vec![None; rxs.len()];
let mut total_progress;
loop {
for (i, rx) in rxs.iter().enumerate() {
while let Ok(stats) = rx.recv_timeout(timeout_ms) {
most_recent[i] = Some(stats)
}
}
let mut to_replace = vec![false; active.len()];
let mut avg_p_accept = 0.0;
let mut n_available_stats = 0.0;
for (vec_idx, (i, pb)) in active.iter().enumerate() {
if let Some(stats) = &most_recent[*i] {
pb.set_position(stats.n);
pb.set_message(format!("p(accept)≈{:.2}", stats.p_accept));
avg_p_accept += stats.p_accept;
n_available_stats += 1.0;
if stats.n == total {
to_replace[vec_idx] = true;
n_finished += 1;
}
}
}
avg_p_accept /= n_available_stats;
total_progress = 0;
for stats in most_recent.iter().flatten() {
total_progress += stats.n;
}
global_pb.set_position(total_progress);
let valid: Vec<&ChainStats> = most_recent.iter().flatten().collect();
if valid.len() >= 2 {
let rhats = collect_rhat(valid.as_slice());
let max = rhats.max_skipnan();
global_pb.set_message(format!(
"p(accept)≈{:.2} max(rhat)≈{:.2}",
avg_p_accept, max
))
}
let mut to_remove = vec![];
for (i, replace) in to_replace.iter().enumerate() {
if *replace && next_active < most_recent.len() {
let pb = multi.add(ProgressBar::new(total));
pb.set_style(pb_style.clone());
pb.set_prefix(format!("Chain {next_active}"));
active[i] = (next_active, pb);
next_active += 1;
} else if *replace {
to_remove.push(i);
}
}
to_remove.sort();
for i in to_remove.iter().rev() {
active.remove(*i);
}
if n_finished >= most_recent.len() {
break;
}
std::thread::sleep(sleep_ms);
}
});
let chain_sample: Vec<Array2<T>> = thread::scope(|s| {
let handles: Vec<thread::ScopedJoinHandle<Array2<T>>> = chains
.iter_mut()
.zip(txs)
.map(|(chain, tx)| {
s.spawn(|| {
run_chain_progress(chain, n_collect, n_discard, tx)
.expect("Expected running chain to succeed.")
})
})
.collect();
handles
.into_iter()
.map(|h| {
h.join()
.expect("Expected thread to succeed in generating observation.")
})
.collect()
});
let sample: Array3<T> = stack(
Axis(0),
&chain_sample
.iter()
.map(|x| x.view())
.collect::<Vec<ArrayView2<T>>>(),
)?;
if let Err(e) = progress_handle.join() {
eprintln!("Progress bar thread emitted error message: {:?}", e);
}
let run_stats = RunStats::from(sample.view());
Ok((sample, run_stats))
}
}
impl<T: LinalgScalar + Send + PartialEq + num_traits::ToPrimitive, R: HasChains<T>> ChainRunner<T>
for R
{
}
pub fn init<T>(n: usize, d: usize) -> Vec<Vec<T>>
where
T: Float + FromPrimitive,
{
let rng = SmallRng::from_os_rng();
_init(n, d, rng)
}
pub fn init_det<T>(n: usize, d: usize) -> Vec<Vec<T>>
where
T: Float + FromPrimitive,
{
init_with_seed(n, d, 42)
}
pub fn init_with_seed<T>(n: usize, d: usize, seed: u64) -> Vec<Vec<T>>
where
T: Float + FromPrimitive,
{
let rng = SmallRng::seed_from_u64(seed);
_init(n, d, rng)
}
fn _init<T>(n: usize, d: usize, mut rng: SmallRng) -> Vec<Vec<T>>
where
T: Float + FromPrimitive,
{
(0..n)
.map(|_| {
(0..d)
.map(|_| {
let obs: f64 = StandardNormal.sample(&mut rng);
T::from_f64(obs).unwrap()
})
.collect()
})
.collect()
}