use ndarray::LinalgScalar;
use rand::rngs::SmallRng;
use rand::{rng, Rng, SeedableRng};
use crate::core::{HasChains, MarkovChain};
use crate::distributions::Conditional;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct GibbsMarkovChain<S, D>
where
D: Conditional<S>,
{
pub target: D,
pub current_state: Vec<S>,
pub seed: u64,
pub rng: SmallRng,
}
impl<T, D> GibbsMarkovChain<T, D>
where
D: Conditional<T> + Clone,
T: LinalgScalar,
{
pub fn new(target: D, initial_state: &[T]) -> Self {
let seed = rand::rng().random::<u64>();
Self {
target,
current_state: initial_state.to_vec(),
seed,
rng: SmallRng::seed_from_u64(seed),
}
}
}
impl<S, D: Conditional<S>> MarkovChain<S> for GibbsMarkovChain<S, D> {
fn step(&mut self) -> &Vec<S> {
(0..self.current_state.len())
.for_each(|i| self.current_state[i] = self.target.sample(i, &self.current_state));
&self.current_state
}
fn current_state(&self) -> &Vec<S> {
&self.current_state
}
}
pub struct GibbsSampler<S, D: Conditional<S>> {
pub target: D,
pub chains: Vec<GibbsMarkovChain<S, D>>,
pub seed: u64,
}
impl<T, D> GibbsSampler<T, D>
where
D: Conditional<T> + Clone + Send + Sync,
T: LinalgScalar,
{
pub fn new(target: D, initial_states: Vec<Vec<T>>) -> Self {
let seed = rng().random::<u64>();
Self {
target: target.clone(),
chains: initial_states
.into_iter()
.map(|x| GibbsMarkovChain::new(target.clone(), &x))
.collect(),
seed,
}
}
pub fn set_seed(mut self, seed: u64) -> Self {
self.seed = seed;
for (i, chain) in self.chains.iter_mut().enumerate() {
let chain_seed = seed + i as u64;
chain.seed = chain_seed;
chain.rng = SmallRng::seed_from_u64(chain_seed);
}
self
}
}
impl<S, D> HasChains<S> for GibbsSampler<S, D>
where
D: Conditional<S> + Clone + Send + Sync,
S: std::marker::Send,
{
type Chain = GibbsMarkovChain<S, D>;
fn chains_mut(&mut self) -> &mut Vec<Self::Chain> {
&mut self.chains
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::{init_det, ChainRunner};
use approx::assert_abs_diff_eq;
use ndarray::{Array3, Axis};
use rand_distr::Normal;
use std::f64::consts::PI;
#[derive(Clone)]
struct ConstantConditional {
c: f64,
}
impl Conditional<f64> for ConstantConditional {
fn sample(&mut self, _i: usize, _given: &[f64]) -> f64 {
self.c
}
}
#[derive(Clone)]
struct MixtureConditional {
mu0: f64,
sigma0: f64,
mu1: f64,
sigma1: f64,
pi0: f64, rng: SmallRng,
}
impl MixtureConditional {
fn normal_pdf(x: f64, mu: f64, sigma: f64) -> f64 {
let var = sigma * sigma;
let coeff = 1.0 / ((2.0 * PI * var).sqrt());
let exp_val = (-((x - mu).powi(2)) / (2.0 * var)).exp();
coeff * exp_val
}
}
impl Conditional<f64> for MixtureConditional {
fn sample(&mut self, i: usize, given: &[f64]) -> f64 {
if i == 0 {
let z = given[1];
if z < 0.5 {
let normal = Normal::new(self.mu0, self.sigma0).unwrap();
self.rng.sample(normal)
} else {
let normal = Normal::new(self.mu1, self.sigma1).unwrap();
self.rng.sample(normal)
}
} else if i == 1 {
let x = given[0];
let p0 = self.pi0 * MixtureConditional::normal_pdf(x, self.mu0, self.sigma0);
let p1 =
(1.0 - self.pi0) * MixtureConditional::normal_pdf(x, self.mu1, self.sigma1);
let total = p0 + p1;
let prob_z1 = if total > 0.0 { p1 / total } else { 0.5 };
if self.rng.random::<f64>() < prob_z1 {
1.0
} else {
0.0
}
} else {
panic!("Invalid coordinate index in MixtureConditional");
}
}
}
#[test]
fn test_gibbs_chain_step() {
let conditional = ConstantConditional { c: 7.0 };
let initial_state = [0.0, 0.0, 0.0];
let mut chain = GibbsMarkovChain::new(conditional, &initial_state);
chain.step();
for &x in chain.current_state().iter() {
assert!((x - 7.0).abs() < f64::EPSILON, "Expected 7.0, got {}", x);
}
}
#[test]
fn test_gibbs_sampler_run() {
let constant = 42.0;
let conditional = ConstantConditional { c: constant };
let mut sampler = GibbsSampler::new(conditional.clone(), init_det(4, 2)).set_seed(42);
let sample = sampler.run(10, 5).unwrap();
let shape = sample.shape();
assert_eq!(shape[0], 4);
assert_eq!(shape[1], 10);
assert_eq!(shape[2], 2);
assert_abs_diff_eq!(sample, Array3::<f64>::from_elem((4, 10, 2), 42.0));
}
#[test]
fn test_gibbs_sampler_run_progress() {
let constant = 42.0;
let conditional = ConstantConditional { c: constant };
let mut sampler = GibbsSampler::new(conditional, init_det(4, 2));
let (sample, stats) = sampler.run_progress(10, 5).unwrap();
let shape = sample.shape();
println!("{stats}");
assert_eq!(shape[0], 4);
assert_eq!(shape[1], 10);
assert_eq!(shape[2], 2);
assert_abs_diff_eq!(sample, Array3::<f64>::from_elem((4, 10, 2), 42.0));
}
#[allow(clippy::too_many_arguments)]
fn assert_mixture_simulation(
mu0: f64,
sigma0: f64,
mu1: f64,
sigma1: f64,
pi0: f64,
n_chains: usize,
n_collect: usize,
n_discard: usize,
seed: u64,
) {
let theo_mean = pi0 * mu0 + (1.0 - pi0) * mu1;
let theo_var = pi0 * (sigma0.powi(2) + (mu0 - theo_mean).powi(2))
+ (1.0 - pi0) * (sigma1.powi(2) + (mu1 - theo_mean).powi(2));
let conditional = MixtureConditional {
mu0,
sigma0,
mu1,
sigma1,
pi0,
rng: SmallRng::seed_from_u64(seed),
};
let mut sampler = GibbsSampler::new(conditional, init_det(n_chains, 2)).set_seed(seed);
let sample = sampler.run(n_collect, n_discard).unwrap();
let x = sample.index_axis(Axis(2), 0);
let x = x.flatten();
let sample_mean = x.mean().unwrap();
let sample_var = x.var(1.0);
assert!(
(sample_mean - theo_mean).abs() < theo_mean.abs() / 10.0,
"Empirical mean {} deviates too much from theoretical {}",
sample_mean,
theo_mean
);
assert!(
(sample_var - theo_var).abs() < theo_var.abs() / 10.0,
"Empirical variance {} deviates too much from theoretical {}",
sample_var,
theo_var
);
}
#[test]
fn test_gibbs_sampler_mixture_1() {
assert_mixture_simulation(
-2.0, 1.0, 3.0, 1.5, 0.5, 4, 100_000, 10_000, 42, );
}
#[test]
fn test_gibbs_sampler_mixture_2() {
assert_mixture_simulation(
-42.0, 69.0, 1.0, 2.0, 0.123, 4, 100_000, 10_000, 42, );
}
#[test]
fn test_chain_step_return_value() {
let conditional = ConstantConditional { c: 42.0 };
let initial_state = [0.0, 0.0, 0.0];
let mut chain = GibbsMarkovChain::new(conditional, &initial_state);
let returned_ref = chain.step();
for &val in returned_ref.iter() {
assert!(
(val - 42.0).abs() < f64::EPSILON,
"Expected 42.0 after step, got {}",
val
);
}
assert!(
std::ptr::eq(returned_ref, chain.current_state()),
"step() should return a reference to the chain's internal state"
);
}
#[test]
fn test_chain_current_state_return_value() {
let conditional = ConstantConditional { c: 13.0 };
let initial_state = [1.0, 2.0, 3.0];
let chain = GibbsMarkovChain::new(conditional, &initial_state);
let state_ref = chain.current_state();
assert_eq!(
state_ref.len(),
initial_state.len(),
"Expected the current_state() to have length {}",
initial_state.len()
);
for (i, &val) in state_ref.iter().enumerate() {
assert!(
(val - initial_state[i]).abs() < f64::EPSILON,
"Expected coordinate {} to be {}, got {}",
i,
initial_state[i],
val
);
}
}
#[test]
fn test_has_chains_for_gibbs_sampler() {
let conditional = ConstantConditional { c: 42.0 };
let mut sampler = GibbsSampler::new(conditional.clone(), init_det(1, 2)).set_seed(42);
let original_len = sampler.chains.len();
{
let chains_mut = sampler.chains_mut();
if let Some(first_chain) = chains_mut.first_mut() {
first_chain.current_state[0] = 100.0;
}
}
assert_eq!(
sampler.chains[0].current_state[0], 100.0,
"Expected the first coordinate of the first chain to be updated to 100.0"
);
{
let chains_mut = sampler.chains_mut();
chains_mut.push(GibbsMarkovChain::new(conditional, &[1.0, 1.0]));
}
assert_eq!(
sampler.chains.len(),
original_len + 1,
"Expected chains length to increase by 1"
);
}
}