use num_traits::Float;
use rand::distr::Distribution as RandDistribution;
use rand::prelude::*;
use rand_distr::StandardUniform;
use std::marker::{PhantomData, Send};
use crate::core::{HasChains, MarkovChain};
use crate::distributions::{Proposal, Target};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MetropolisHastings<S: Clone, T: Float, D: Clone, Q: Clone> {
pub target: D,
pub proposal: Q,
pub chains: Vec<MHMarkovChain<S, T, D, Q>>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MHMarkovChain<S, T, D, Q> {
pub target: D,
pub proposal: Q,
pub current_state: Vec<S>,
pub rng: SmallRng,
phantom: PhantomData<T>,
}
impl<S, T, D, Q> MetropolisHastings<S, T, D, Q>
where
D: Target<S, T> + std::clone::Clone + Send,
Q: Proposal<S, T> + std::clone::Clone + Send,
T: Float + Send,
S: Clone + std::cmp::PartialEq + Send + num_traits::Zero + std::fmt::Debug + 'static,
{
pub fn new(target: D, proposal: Q, initial_states: Vec<Vec<S>>) -> Self {
let chains = initial_states
.into_iter()
.map(|s| MHMarkovChain::new(target.clone(), proposal.clone(), s))
.collect();
Self {
target,
proposal,
chains,
}
}
pub fn seed(mut self, seed: u64) -> Self {
for (i, chain) in self.chains.iter_mut().enumerate() {
let chain_seed = 1 + seed + i as u64;
chain.rng = SmallRng::seed_from_u64(chain_seed);
let proposal_seed = chain_seed.wrapping_add(0x9E3779B97F4A7C15);
chain.proposal = chain.proposal.clone().set_seed(proposal_seed);
}
self
}
}
impl<S, T, D, Q> HasChains<Vec<S>> for MetropolisHastings<S, T, D, Q>
where
D: Target<S, T> + Clone + Send,
Q: Proposal<S, T> + Clone + Send,
T: Float + Send,
S: Clone + PartialEq + Send + num_traits::Zero + std::fmt::Debug + 'static,
StandardUniform: RandDistribution<T>,
{
type Chain = MHMarkovChain<S, T, D, Q>;
fn chains_mut(&mut self) -> &mut Vec<Self::Chain> {
&mut self.chains
}
}
impl<S, T, D, Q> MHMarkovChain<S, T, D, Q>
where
D: Target<S, T> + Clone,
Q: Proposal<S, T> + Clone,
S: Clone + std::cmp::PartialEq + num_traits::Zero,
T: Float,
{
pub fn new(target: D, proposal: Q, initial_state: Vec<S>) -> Self {
Self {
target,
proposal,
current_state: initial_state,
rng: SmallRng::seed_from_u64(rand::rng().random::<u64>()),
phantom: PhantomData,
}
}
}
impl<T, F, D, Q> MarkovChain<Vec<T>> for MHMarkovChain<T, F, D, Q>
where
D: Target<T, F> + Clone,
Q: Proposal<T, F> + Clone,
T: Clone + PartialEq + num_traits::Zero,
F: Float,
StandardUniform: RandDistribution<F>,
{
fn step(&mut self) -> &Vec<T> {
let proposed: Vec<T> = self.proposal.sample(&self.current_state);
let current_lp = self.target.unnorm_logp(&self.current_state);
let proposed_lp = self.target.unnorm_logp(&proposed);
let log_q_forward = self.proposal.logp(&self.current_state, &proposed);
let log_q_backward = self.proposal.logp(&proposed, &self.current_state);
let log_accept_ratio = (proposed_lp + log_q_backward) - (current_lp + log_q_forward);
let u: F = self.rng.random();
if log_accept_ratio > u.ln() {
self.current_state = proposed;
}
&self.current_state
}
fn current_state(&self) -> &Vec<T> {
&self.current_state
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::{ChainRunner, init_det}; use crate::distributions::{Gaussian2D, IsotropicGaussian};
use crate::stats::{RunStats, basic_stats, split_rhat_mean_ess}; use approx::assert_abs_diff_eq;
use ndarray::{Array3, Axis, arr1, arr2};
use rand::SeedableRng;
use rand::rngs::SmallRng;
fn run_gaussian_2d_test(sample_size: usize, n_chains: usize, use_progress: bool) {
const BURNIN: usize = 500;
const SEED: u64 = 42;
assert!(n_chains > 0 && sample_size > 0 && sample_size.is_multiple_of(n_chains));
let target = Gaussian2D {
mean: arr1(&[0.0, 1.0]),
cov: arr2(&[[4.0, 2.0], [2.0, 3.0]]),
};
let proposal = IsotropicGaussian::new(1.0).set_seed(SEED);
let mut mh =
MetropolisHastings::new(target.clone(), proposal, init_det(n_chains, 2)).seed(SEED);
let (sample, _stats) = if use_progress {
mh.run_progress(sample_size / n_chains, BURNIN).unwrap()
} else {
let sample = mh.run(sample_size / n_chains, BURNIN).unwrap();
let stats = RunStats::from(sample.view());
(sample, stats)
};
assert_eq!(sample.shape(), [n_chains, sample_size / n_chains, 2]);
if n_chains <= 1 {
return;
}
let stacked = sample
.into_shape_with_order((sample_size, 2))
.expect("Failed to reshape sample");
let mean = stacked.mean_axis(Axis(0)).unwrap();
let centered = &stacked - &mean;
let cov = centered.t().dot(¢ered) / (stacked.nrows() as f64 - 1.0);
assert_abs_diff_eq!(mean, target.mean, epsilon = 0.3);
assert_abs_diff_eq!(cov, target.cov, epsilon = 0.5);
}
#[test]
fn test_single_1_chain() {
run_gaussian_2d_test(100, 1, false);
}
#[test]
fn test_3_chains() {
run_gaussian_2d_test(6000, 3, false);
}
#[test]
fn test_progress_1_chain() {
run_gaussian_2d_test(100, 1, true);
}
#[test]
fn test_progress_3_chains() {
run_gaussian_2d_test(6000, 3, true);
}
#[test]
#[ignore = "Slow test: run only when explicitly requested"]
fn test_16_chains_long() {
run_gaussian_2d_test(80_000_000, 16, false);
}
#[test]
#[ignore = "Slow test: run only when explicitly requested"]
fn test_progress_16_chains_long() {
run_gaussian_2d_test(80_000_000, 16, true);
}
#[test]
#[ignore = "Benchmark test: run only when explicitly requested"]
fn test_mean_ess_2d_gaussian() {
let n_runs = 100;
let n_chains = 3;
let burn_in = 500_usize;
let sample_size_chain = 1500_usize; let collected = sample_size_chain - burn_in;
let mut ess_x1s = Vec::with_capacity(n_runs);
let mut ess_x2s = Vec::with_capacity(n_runs);
let mut outer_rng = SmallRng::seed_from_u64(42);
for _ in 0..n_runs {
let target = Gaussian2D {
mean: arr1(&[0.0, 1.0]),
cov: arr2(&[[4.0, 2.0], [2.0, 3.0]]),
};
let proposal = IsotropicGaussian::new(1.0);
let mut mh = MetropolisHastings::new(target, proposal, init_det(n_chains, 2));
let run_seed: u64 = outer_rng.random();
mh = mh.seed(run_seed);
let sample = mh.run(collected, burn_in).expect("MH run failed");
assert_eq!(sample.shape(), &[n_chains, collected, 2]);
let mut sample_f32 = Array3::<f32>::zeros((n_chains, collected, 2));
for c in 0..n_chains {
for t in 0..collected {
sample_f32[[c, t, 0]] = sample[[c, t, 0]] as f32;
sample_f32[[c, t, 1]] = sample[[c, t, 1]] as f32;
}
}
let (_, ess_vec) = split_rhat_mean_ess(sample_f32.view());
let ess_x1 = ess_vec[0];
let ess_x2 = ess_vec[1];
ess_x1s.push(ess_x1);
ess_x2s.push(ess_x2);
}
let ess_x1_array = ndarray::Array1::from_vec(ess_x1s);
let ess_x2_array = ndarray::Array1::from_vec(ess_x2s);
let stats_x1 = basic_stats("ESS(x1)", ess_x1_array);
let stats_x2 = basic_stats("ESS(x2)", ess_x2_array);
println!("{stats_x1}\n{stats_x2}");
assert!(
stats_x1.mean >= 65.0 && stats_x1.mean <= 125.0,
"Expected ESS(x1) to average in [65, 125]"
);
assert!(
stats_x2.mean >= 83.0 && stats_x1.mean <= 143.0,
"Expected ESS(x2) to average in [83, 143]"
);
assert!(
stats_x1.std >= 20.0 && stats_x1.std <= 40.0,
"Expected std(ESS(x1)) in [20, 40]"
);
assert!(
stats_x2.std >= 20.0 && stats_x1.std <= 40.0,
"Expected std(ESS(x2)) in [20, 40]"
);
}
#[test]
fn readme_test() {
let target = Gaussian2D {
mean: arr1(&[0.0, 0.0]),
cov: arr2(&[[1.0, 0.0], [0.0, 1.0]]),
};
let proposal = IsotropicGaussian::new(1.0);
let mut mh = MetropolisHastings::new(target, proposal, init_det(4, 2));
let sample = mh.run(1000, 100).unwrap();
assert_eq!(sample.shape()[0], 4);
assert_eq!(sample.shape()[1], 1000);
}
}