pub trait LogProb: Send + Sync {
fn log_prob(&self, theta: &[f64]) -> f64;
}
pub struct LogProbFn<F: Fn(&[f64]) -> f64 + Send + Sync> {
f: F,
}
impl<F: Fn(&[f64]) -> f64 + Send + Sync> LogProbFn<F> {
pub fn new(f: F) -> Self {
Self { f }
}
}
impl<F: Fn(&[f64]) -> f64 + Send + Sync> LogProb for LogProbFn<F> {
fn log_prob(&self, theta: &[f64]) -> f64 {
(self.f)(theta)
}
}
pub trait Proposal: Send + Sync {
fn propose(&self, current: &[f64], rng: &mut McmcRng) -> Vec<f64>;
fn log_ratio(&self, proposed: &[f64], current: &[f64]) -> f64;
}
#[derive(Debug, Clone)]
pub struct GaussianProposal {
pub step_size: f64,
}
impl GaussianProposal {
pub fn new(step_size: f64) -> Self {
Self { step_size }
}
}
impl Proposal for GaussianProposal {
fn propose(&self, current: &[f64], rng: &mut McmcRng) -> Vec<f64> {
current
.iter()
.map(|&x| x + rng.next_normal_scaled(0.0, self.step_size))
.collect()
}
fn log_ratio(&self, _proposed: &[f64], _current: &[f64]) -> f64 {
0.0
}
}
#[derive(Debug, Clone)]
pub struct IndependentGaussianProposal {
pub mean: Vec<f64>,
pub std: Vec<f64>,
}
impl IndependentGaussianProposal {
pub fn new(mean: Vec<f64>, std: Vec<f64>) -> Self {
debug_assert_eq!(
mean.len(),
std.len(),
"mean and std must have the same length"
);
Self { mean, std }
}
}
#[inline]
fn log_normal_density(x: f64, mu: f64, sigma: f64) -> f64 {
let diff = x - mu;
-0.5 * (diff / sigma).powi(2) - sigma.ln()
}
impl Proposal for IndependentGaussianProposal {
fn propose(&self, _current: &[f64], rng: &mut McmcRng) -> Vec<f64> {
self.mean
.iter()
.zip(self.std.iter())
.map(|(&mu, &sigma)| rng.next_normal_scaled(mu, sigma))
.collect()
}
fn log_ratio(&self, proposed: &[f64], current: &[f64]) -> f64 {
let log_q_current: f64 = current
.iter()
.zip(self.mean.iter())
.zip(self.std.iter())
.map(|((&x, &mu), &sigma)| log_normal_density(x, mu, sigma))
.sum();
let log_q_proposed: f64 = proposed
.iter()
.zip(self.mean.iter())
.zip(self.std.iter())
.map(|((&x, &mu), &sigma)| log_normal_density(x, mu, sigma))
.sum();
log_q_current - log_q_proposed
}
}
#[derive(Debug, Clone)]
pub struct McmcRng {
state: u64,
}
impl McmcRng {
pub fn new(seed: u64) -> Self {
let state = seed.wrapping_add(6364136223846793005);
Self { state }
}
pub fn next_u64(&mut self) -> u64 {
self.state = self
.state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
self.state
}
pub fn next_f64(&mut self) -> f64 {
(self.next_u64() >> 11) as f64 * (1.0_f64 / (1u64 << 53) as f64)
}
pub fn next_normal(&mut self) -> f64 {
let u1 = self.next_f64().max(f64::MIN_POSITIVE); let u2 = self.next_f64();
let r = (-2.0 * u1.ln()).sqrt();
let theta = std::f64::consts::TAU * u2;
r * theta.cos()
}
pub fn next_normal_scaled(&mut self, mean: f64, std: f64) -> f64 {
mean + std * self.next_normal()
}
}
#[derive(Debug, Clone)]
pub struct McmcConfig {
pub n_samples: usize,
pub n_warmup: usize,
pub thin: usize,
pub seed: u64,
pub target_acceptance: f64,
}
impl Default for McmcConfig {
fn default() -> Self {
Self {
n_samples: 1000,
n_warmup: 500,
thin: 1,
seed: 42,
target_acceptance: 0.234,
}
}
}
impl McmcConfig {
pub fn new() -> Self {
Self::default()
}
pub fn n_samples(mut self, n: usize) -> Self {
self.n_samples = n;
self
}
pub fn n_warmup(mut self, n: usize) -> Self {
self.n_warmup = n;
self
}
pub fn thin(mut self, t: usize) -> Self {
self.thin = t;
self
}
pub fn seed(mut self, s: u64) -> Self {
self.seed = s;
self
}
}
#[derive(Debug, Clone)]
pub struct ChainDiagnostics {
pub n_samples: usize,
pub acceptance_rate: f64,
pub mean: Vec<f64>,
pub variance: Vec<f64>,
pub effective_sample_size: Vec<f64>,
pub r_hat: Option<Vec<f64>>,
}
#[derive(Debug, Clone)]
pub struct McmcResult {
pub samples: Vec<Vec<f64>>,
pub log_probs: Vec<f64>,
pub diagnostics: ChainDiagnostics,
}
impl McmcResult {
pub fn n_samples(&self) -> usize {
self.samples.len()
}
pub fn n_dims(&self) -> usize {
self.samples.first().map(|s| s.len()).unwrap_or(0)
}
pub fn marginal_samples(&self, dim: usize) -> Vec<f64> {
self.samples.iter().map(|s| s[dim]).collect()
}
pub fn posterior_mean(&self) -> Vec<f64> {
let n = self.n_samples();
if n == 0 {
return vec![];
}
let d = self.n_dims();
let mut mean = vec![0.0_f64; d];
for sample in &self.samples {
for (m, &v) in mean.iter_mut().zip(sample.iter()) {
*m += v;
}
}
mean.iter_mut().for_each(|m| *m /= n as f64);
mean
}
pub fn posterior_variance(&self) -> Vec<f64> {
let n = self.n_samples();
if n < 2 {
return vec![0.0; self.n_dims()];
}
let mean = self.posterior_mean();
let d = self.n_dims();
let mut var = vec![0.0_f64; d];
for sample in &self.samples {
for (v, (&x, &mu)) in var.iter_mut().zip(sample.iter().zip(mean.iter())) {
*v += (x - mu).powi(2);
}
}
var.iter_mut().for_each(|v| *v /= (n - 1) as f64);
var
}
pub fn credible_interval(&self, dim: usize, alpha: f64) -> (f64, f64) {
let mut marginal = self.marginal_samples(dim);
marginal.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let n = marginal.len();
if n == 0 {
return (f64::NAN, f64::NAN);
}
let lo_idx = ((alpha / 2.0) * n as f64) as usize;
let hi_idx = ((1.0 - alpha / 2.0) * n as f64) as usize;
let lo = marginal[lo_idx.min(n - 1)];
let hi = marginal[hi_idx.min(n - 1)];
(lo, hi)
}
}
#[derive(Debug)]
pub enum McmcError {
InvalidConfig(String),
DimensionMismatch,
NumericalError(String),
}
impl std::fmt::Display for McmcError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
McmcError::InvalidConfig(msg) => write!(f, "MCMC invalid configuration: {}", msg),
McmcError::DimensionMismatch => {
write!(f, "MCMC dimension mismatch between initial state and model")
}
McmcError::NumericalError(msg) => write!(f, "MCMC numerical error: {}", msg),
}
}
}
impl std::error::Error for McmcError {}
fn validate_config(config: &McmcConfig) -> Result<(), McmcError> {
if config.n_samples == 0 {
return Err(McmcError::InvalidConfig(
"n_samples must be > 0".to_string(),
));
}
if config.thin == 0 {
return Err(McmcError::InvalidConfig("thin must be > 0".to_string()));
}
Ok(())
}
fn slice_stats(data: &[f64]) -> (f64, f64) {
let n = data.len();
if n == 0 {
return (0.0, 0.0);
}
let mean = data.iter().sum::<f64>() / n as f64;
let var = if n < 2 {
0.0
} else {
data.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / (n - 1) as f64
};
(mean, var)
}
pub struct MetropolisHastings<P: LogProb, Q: Proposal> {
log_prob: P,
proposal: Q,
config: McmcConfig,
}
impl<P: LogProb, Q: Proposal> MetropolisHastings<P, Q> {
pub fn new(log_prob: P, proposal: Q, config: McmcConfig) -> Self {
Self {
log_prob,
proposal,
config,
}
}
pub fn sample(&self, initial: &[f64]) -> Result<McmcResult, McmcError> {
validate_config(&self.config)?;
if initial.is_empty() {
return Err(McmcError::InvalidConfig(
"initial state must be non-empty".to_string(),
));
}
let mut rng = McmcRng::new(self.config.seed);
let total_steps = self.config.n_warmup + self.config.n_samples * self.config.thin;
let mut current: Vec<f64> = initial.to_vec();
let mut current_lp = self.log_prob.log_prob(¤t);
if !current_lp.is_finite() {
return Err(McmcError::NumericalError(
"initial state has non-finite log probability".to_string(),
));
}
let mut samples: Vec<Vec<f64>> = Vec::with_capacity(self.config.n_samples);
let mut log_probs: Vec<f64> = Vec::with_capacity(self.config.n_samples);
let mut n_accepted: usize = 0;
let mut step_in_sample: usize = 0;
for step in 0..total_steps {
let proposed = self.proposal.propose(¤t, &mut rng);
let proposed_lp = self.log_prob.log_prob(&proposed);
let log_accept = if proposed_lp.is_finite() {
let log_alpha =
proposed_lp - current_lp + self.proposal.log_ratio(&proposed, ¤t);
log_alpha.min(0.0)
} else {
f64::NEG_INFINITY
};
let u = rng.next_f64();
let accepted = u.ln() < log_accept;
if accepted {
current = proposed;
current_lp = proposed_lp;
if step >= self.config.n_warmup {
n_accepted += 1;
}
}
if step >= self.config.n_warmup {
step_in_sample += 1;
if step_in_sample == self.config.thin {
samples.push(current.clone());
log_probs.push(current_lp);
step_in_sample = 0;
}
}
}
let n_post_warmup_steps = self.config.n_samples * self.config.thin;
let acceptance_rate = if n_post_warmup_steps > 0 {
n_accepted as f64 / n_post_warmup_steps as f64
} else {
0.0
};
let diagnostics = compute_diagnostics_with_acceptance(&samples, acceptance_rate);
Ok(McmcResult {
samples,
log_probs,
diagnostics,
})
}
}
pub struct HamiltonianMonteCarlo<P: LogProb> {
log_prob: P,
step_size: f64,
n_leapfrog_steps: usize,
config: McmcConfig,
}
impl<P: LogProb> HamiltonianMonteCarlo<P> {
pub fn new(log_prob: P, step_size: f64, n_leapfrog_steps: usize, config: McmcConfig) -> Self {
Self {
log_prob,
step_size,
n_leapfrog_steps,
config,
}
}
pub fn sample(&self, initial: &[f64]) -> Result<McmcResult, McmcError> {
validate_config(&self.config)?;
if initial.is_empty() {
return Err(McmcError::InvalidConfig(
"initial state must be non-empty".to_string(),
));
}
if self.step_size <= 0.0 {
return Err(McmcError::InvalidConfig(
"step_size must be positive".to_string(),
));
}
if self.n_leapfrog_steps == 0 {
return Err(McmcError::InvalidConfig(
"n_leapfrog_steps must be > 0".to_string(),
));
}
let mut rng = McmcRng::new(self.config.seed);
let total_steps = self.config.n_warmup + self.config.n_samples * self.config.thin;
let d = initial.len();
let mut current: Vec<f64> = initial.to_vec();
let mut current_lp = self.log_prob.log_prob(¤t);
if !current_lp.is_finite() {
return Err(McmcError::NumericalError(
"initial state has non-finite log probability".to_string(),
));
}
let mut samples: Vec<Vec<f64>> = Vec::with_capacity(self.config.n_samples);
let mut log_probs: Vec<f64> = Vec::with_capacity(self.config.n_samples);
let mut n_accepted: usize = 0;
let mut step_in_sample: usize = 0;
for step in 0..total_steps {
let momentum: Vec<f64> = (0..d).map(|_| rng.next_normal()).collect();
let ke_old: f64 = momentum.iter().map(|&r| 0.5 * r * r).sum();
let (proposed, new_momentum) = self.leapfrog(¤t, &momentum);
let proposed_lp = self.log_prob.log_prob(&proposed);
let ke_new: f64 = new_momentum.iter().map(|&r| 0.5 * r * r).sum();
let h_old = -current_lp + ke_old;
let h_new = -proposed_lp + ke_new;
let log_accept = if proposed_lp.is_finite() {
(h_old - h_new).min(0.0)
} else {
f64::NEG_INFINITY
};
let u = rng.next_f64();
let accepted = u.ln() < log_accept;
if accepted {
current = proposed;
current_lp = proposed_lp;
if step >= self.config.n_warmup {
n_accepted += 1;
}
}
if step >= self.config.n_warmup {
step_in_sample += 1;
if step_in_sample == self.config.thin {
samples.push(current.clone());
log_probs.push(current_lp);
step_in_sample = 0;
}
}
}
let n_post_warmup_steps = self.config.n_samples * self.config.thin;
let acceptance_rate = if n_post_warmup_steps > 0 {
n_accepted as f64 / n_post_warmup_steps as f64
} else {
0.0
};
let diagnostics = compute_diagnostics_with_acceptance(&samples, acceptance_rate);
Ok(McmcResult {
samples,
log_probs,
diagnostics,
})
}
fn grad_log_prob(&self, theta: &[f64], eps: f64) -> Vec<f64> {
let d = theta.len();
let mut grad = vec![0.0_f64; d];
let mut theta_plus = theta.to_vec();
let mut theta_minus = theta.to_vec();
for i in 0..d {
theta_plus[i] = theta[i] + eps;
theta_minus[i] = theta[i] - eps;
grad[i] = (self.log_prob.log_prob(&theta_plus) - self.log_prob.log_prob(&theta_minus))
/ (2.0 * eps);
theta_plus[i] = theta[i];
theta_minus[i] = theta[i];
}
grad
}
fn leapfrog(&self, theta: &[f64], momentum: &[f64]) -> (Vec<f64>, Vec<f64>) {
let eps = self.step_size;
let fd_eps = 1e-5_f64;
let mut q = theta.to_vec();
let mut p = momentum.to_vec();
let d = q.len();
let grad = self.grad_log_prob(&q, fd_eps);
for i in 0..d {
p[i] += 0.5 * eps * grad[i];
}
for step in 0..self.n_leapfrog_steps {
for i in 0..d {
q[i] += eps * p[i];
}
if step < self.n_leapfrog_steps - 1 {
let grad_q = self.grad_log_prob(&q, fd_eps);
for i in 0..d {
p[i] += eps * grad_q[i];
}
}
}
let grad_final = self.grad_log_prob(&q, fd_eps);
for i in 0..d {
p[i] += 0.5 * eps * grad_final[i];
}
for pi in p.iter_mut() {
*pi = -*pi;
}
(q, p)
}
}
pub fn effective_sample_size(samples: &[f64]) -> f64 {
let n = samples.len();
if n < 4 {
return n as f64;
}
let b = (n as f64).sqrt() as usize; let n_batches = n / b;
if n_batches < 2 {
return n as f64;
}
let overall_mean = samples.iter().sum::<f64>() / n as f64;
let chain_var = samples
.iter()
.map(|&x| (x - overall_mean).powi(2))
.sum::<f64>()
/ (n - 1) as f64;
if chain_var == 0.0 {
return 1.0;
}
let batch_mean_var: f64 = (0..n_batches)
.map(|k| {
let batch = &samples[k * b..(k + 1) * b];
let bm = batch.iter().sum::<f64>() / b as f64;
(bm - overall_mean).powi(2)
})
.sum::<f64>()
/ (n_batches - 1) as f64;
let ess = n as f64 * chain_var / (b as f64 * batch_mean_var);
ess.clamp(1.0, n as f64)
}
pub fn gelman_rubin(chains: &[Vec<f64>]) -> f64 {
let m = chains.len();
if m < 2 {
return f64::NAN;
}
let n = chains.iter().map(|c| c.len()).min().unwrap_or(0);
if n < 2 {
return f64::NAN;
}
let chain_means: Vec<f64> = chains
.iter()
.map(|c| c[..n].iter().sum::<f64>() / n as f64)
.collect();
let overall_mean = chain_means.iter().sum::<f64>() / m as f64;
let b = n as f64
* chain_means
.iter()
.map(|&mu| (mu - overall_mean).powi(2))
.sum::<f64>()
/ (m - 1) as f64;
let w = chains
.iter()
.zip(chain_means.iter())
.map(|(c, &mu)| c[..n].iter().map(|&x| (x - mu).powi(2)).sum::<f64>() / (n - 1) as f64)
.sum::<f64>()
/ m as f64;
if w == 0.0 {
return f64::NAN;
}
let var_hat = (n - 1) as f64 / n as f64 * w + b / n as f64;
(var_hat / w).sqrt()
}
pub fn autocorrelation(samples: &[f64], lag: usize) -> f64 {
let n = samples.len();
if n == 0 || lag >= n {
return 0.0;
}
if lag == 0 {
return 1.0;
}
let mean = samples.iter().sum::<f64>() / n as f64;
let variance = samples.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
if variance == 0.0 {
return 1.0;
}
let n_pairs = n - lag;
let cov: f64 = samples[..n_pairs]
.iter()
.zip(samples[lag..].iter())
.map(|(&a, &b)| (a - mean) * (b - mean))
.sum::<f64>()
/ n_pairs as f64;
cov / variance
}
pub fn compute_diagnostics(samples: &[Vec<f64>]) -> ChainDiagnostics {
compute_diagnostics_with_acceptance(samples, 0.0)
}
pub(crate) fn compute_diagnostics_with_acceptance(
samples: &[Vec<f64>],
acceptance_rate: f64,
) -> ChainDiagnostics {
let n = samples.len();
if n == 0 {
return ChainDiagnostics {
n_samples: 0,
acceptance_rate,
mean: vec![],
variance: vec![],
effective_sample_size: vec![],
r_hat: None,
};
}
let d = samples[0].len();
let mut mean = vec![0.0_f64; d];
let mut variance = vec![0.0_f64; d];
let mut ess = vec![0.0_f64; d];
for dim in 0..d {
let col: Vec<f64> = samples.iter().map(|s| s[dim]).collect();
let (m, v) = slice_stats(&col);
mean[dim] = m;
variance[dim] = v;
ess[dim] = effective_sample_size(&col);
}
ChainDiagnostics {
n_samples: n,
acceptance_rate,
mean,
variance,
effective_sample_size: ess,
r_hat: None,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rng_uniform_in_range() {
let mut rng = McmcRng::new(1234);
for _ in 0..10_000 {
let v = rng.next_f64();
assert!(v >= 0.0, "uniform sample below 0: {}", v);
assert!(v < 1.0, "uniform sample >= 1: {}", v);
}
}
#[test]
fn test_rng_normal_mean() {
let mut rng = McmcRng::new(42);
let samples: Vec<f64> = (0..1000).map(|_| rng.next_normal()).collect();
let mean = samples.iter().sum::<f64>() / samples.len() as f64;
assert!(
mean.abs() < 0.15,
"Box-Muller mean too far from 0: {}",
mean
);
}
#[test]
fn test_rng_normal_std() {
let mut rng = McmcRng::new(99);
let samples: Vec<f64> = (0..1000).map(|_| rng.next_normal()).collect();
let mean = samples.iter().sum::<f64>() / samples.len() as f64;
let var = samples.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / samples.len() as f64;
let std = var.sqrt();
assert!(
(std - 1.0).abs() < 0.15,
"Box-Muller std too far from 1: {}",
std
);
}
#[test]
fn test_gaussian_proposal_log_ratio_is_zero() {
let proposal = GaussianProposal::new(0.1);
let current = vec![1.0, 2.0, 3.0];
let proposed = vec![1.1, 2.2, 3.3];
assert_eq!(
proposal.log_ratio(&proposed, ¤t),
0.0,
"Gaussian RW should be symmetric"
);
}
#[test]
fn test_gaussian_proposal_changes_state() {
let proposal = GaussianProposal::new(1.0);
let mut rng = McmcRng::new(7);
let current = vec![0.0, 0.0, 0.0];
let proposed = proposal.propose(¤t, &mut rng);
assert_ne!(proposed, current, "proposal should change the state");
}
fn standard_normal_lp() -> LogProbFn<impl Fn(&[f64]) -> f64 + Send + Sync> {
LogProbFn::new(|theta: &[f64]| -0.5 * theta[0].powi(2))
}
#[test]
fn test_mh_standard_normal_mean() {
let lp = standard_normal_lp();
let proposal = GaussianProposal::new(1.0);
let config = McmcConfig::new().n_samples(2000).n_warmup(500).seed(123);
let sampler = MetropolisHastings::new(lp, proposal, config);
let result = sampler.sample(&[0.0]).expect("sampling failed");
let mean = result.posterior_mean()[0];
assert!(
mean.abs() < 0.3,
"MH posterior mean too far from 0: {}",
mean
);
}
#[test]
fn test_mh_standard_normal_variance() {
let lp = standard_normal_lp();
let proposal = GaussianProposal::new(1.0);
let config = McmcConfig::new().n_samples(2000).n_warmup(500).seed(77);
let sampler = MetropolisHastings::new(lp, proposal, config);
let result = sampler.sample(&[0.0]).expect("sampling failed");
let var = result.posterior_variance()[0];
assert!(
(var - 1.0).abs() < 0.5,
"MH posterior variance too far from 1: {}",
var
);
}
#[test]
fn test_mh_acceptance_rate_in_range() {
let lp = standard_normal_lp();
let proposal = GaussianProposal::new(1.0);
let config = McmcConfig::new().n_samples(1000).n_warmup(200).seed(55);
let sampler = MetropolisHastings::new(lp, proposal, config);
let result = sampler.sample(&[0.0]).expect("sampling failed");
let ar = result.diagnostics.acceptance_rate;
assert!(ar > 0.0, "acceptance rate should be > 0");
assert!(ar <= 1.0, "acceptance rate should be <= 1");
}
#[test]
fn test_mh_sample_count_matches_config() {
let lp = standard_normal_lp();
let proposal = GaussianProposal::new(1.0);
let n = 300;
let config = McmcConfig::new().n_samples(n).n_warmup(100).seed(11);
let sampler = MetropolisHastings::new(lp, proposal, config);
let result = sampler.sample(&[0.0]).expect("sampling failed");
assert_eq!(result.n_samples(), n, "sample count should match config");
}
#[test]
fn test_mh_warmup_discarded() {
let lp = standard_normal_lp();
let proposal = GaussianProposal::new(1.0);
let n_samples = 200;
let n_warmup = 100;
let config = McmcConfig::new()
.n_samples(n_samples)
.n_warmup(n_warmup)
.seed(42);
let sampler = MetropolisHastings::new(lp, proposal, config);
let result = sampler.sample(&[0.0]).expect("sampling failed");
assert_eq!(
result.n_samples(),
n_samples,
"warmup samples should not be included in result"
);
}
#[test]
fn test_marginal_samples_correct() {
let samples = vec![vec![1.0, 10.0], vec![2.0, 20.0], vec![3.0, 30.0]];
let result = McmcResult {
log_probs: vec![-1.0, -2.0, -3.0],
diagnostics: compute_diagnostics(&samples),
samples,
};
let m0 = result.marginal_samples(0);
assert_eq!(m0, vec![1.0, 2.0, 3.0]);
let m1 = result.marginal_samples(1);
assert_eq!(m1, vec![10.0, 20.0, 30.0]);
}
#[test]
fn test_credible_interval_contains_true_value() {
let lp = standard_normal_lp();
let proposal = GaussianProposal::new(1.0);
let config = McmcConfig::new().n_samples(2000).n_warmup(500).seed(88);
let sampler = MetropolisHastings::new(lp, proposal, config);
let result = sampler.sample(&[0.0]).expect("sampling failed");
let (lo, hi) = result.credible_interval(0, 0.05); assert!(
lo < 0.0 && 0.0 < hi,
"95% CI should contain the true mean 0.0; got ({}, {})",
lo,
hi
);
}
#[test]
fn test_hmc_standard_normal_mean() {
let lp = LogProbFn::new(|theta: &[f64]| -0.5 * theta[0].powi(2));
let config = McmcConfig::new().n_samples(1000).n_warmup(500).seed(321);
let sampler = HamiltonianMonteCarlo::new(lp, 0.3, 10, config);
let result = sampler.sample(&[0.0]).expect("HMC failed");
let mean = result.posterior_mean()[0];
assert!(
mean.abs() < 0.4,
"HMC posterior mean too far from 0: {}",
mean
);
}
#[test]
fn test_hmc_acceptance_rate_high() {
let lp = LogProbFn::new(|theta: &[f64]| -0.5 * theta[0].powi(2));
let config = McmcConfig::new().n_samples(500).n_warmup(200).seed(999);
let sampler = HamiltonianMonteCarlo::new(lp, 0.1, 5, config);
let result = sampler.sample(&[0.0]).expect("HMC failed");
let ar = result.diagnostics.acceptance_rate;
assert!(
ar > 0.5,
"HMC acceptance rate should be > 0.5 with small step size: {}",
ar
);
}
#[test]
fn test_hmc_gradient_finite_difference_accuracy() {
let hmc = HamiltonianMonteCarlo::new(
LogProbFn::new(|theta: &[f64]| -0.5 * theta[0].powi(2)),
0.1,
5,
McmcConfig::new(),
);
let grad = hmc.grad_log_prob(&[1.0], 1e-5);
assert!(
(grad[0] - (-1.0)).abs() < 1e-6,
"gradient inaccurate: expected -1, got {}",
grad[0]
);
}
#[test]
fn test_ess_positive_for_iid() {
let mut rng = McmcRng::new(1);
let samples: Vec<f64> = (0..200).map(|_| rng.next_normal()).collect();
let ess = effective_sample_size(&samples);
assert!(ess > 0.0, "ESS should be positive");
}
#[test]
fn test_ess_at_most_n_samples() {
let mut rng = McmcRng::new(2);
let samples: Vec<f64> = (0..200).map(|_| rng.next_normal()).collect();
let ess = effective_sample_size(&samples);
assert!(
ess <= samples.len() as f64,
"ESS should not exceed number of samples"
);
}
#[test]
fn test_autocorrelation_lag_zero() {
let samples: Vec<f64> = (0..100).map(|i| i as f64).collect();
let ac = autocorrelation(&samples, 0);
assert!(
(ac - 1.0).abs() < 1e-10,
"autocorrelation at lag 0 should be 1.0, got {}",
ac
);
}
#[test]
fn test_autocorrelation_large_lag_near_zero() {
let mut rng = McmcRng::new(3);
let samples: Vec<f64> = (0..500).map(|_| rng.next_normal()).collect();
let ac = autocorrelation(&samples, 100);
assert!(
ac.abs() < 0.2,
"autocorrelation at large lag should be near 0 for iid: {}",
ac
);
}
#[test]
fn test_gelman_rubin_converged_chains() {
let mut rng = McmcRng::new(5);
let chain1: Vec<f64> = (0..200).map(|_| rng.next_normal()).collect();
let chain2: Vec<f64> = (0..200).map(|_| rng.next_normal()).collect();
let r_hat = gelman_rubin(&[chain1, chain2]);
assert!(
!r_hat.is_nan(),
"R-hat should not be NaN for well-behaved chains"
);
assert!(
r_hat < 1.2,
"R-hat should be near 1.0 for converged chains, got {}",
r_hat
);
}
#[test]
fn test_gelman_rubin_non_converged_chains() {
let chain1: Vec<f64> = (0..200).map(|i| i as f64 * 0.01).collect(); let chain2: Vec<f64> = (0..200).map(|i| 100.0 + i as f64 * 0.01).collect(); let r_hat = gelman_rubin(&[chain1, chain2]);
assert!(
r_hat > 1.1,
"R-hat should be > 1.1 for non-converged chains, got {}",
r_hat
);
}
#[test]
fn test_mcmc_config_builder_pattern() {
let cfg = McmcConfig::new()
.n_samples(500)
.n_warmup(250)
.thin(2)
.seed(17);
assert_eq!(cfg.n_samples, 500);
assert_eq!(cfg.n_warmup, 250);
assert_eq!(cfg.thin, 2);
assert_eq!(cfg.seed, 17);
}
#[test]
fn test_mcmc_error_display() {
let e = McmcError::InvalidConfig("test error".to_string());
let s = e.to_string();
assert!(
s.contains("test error"),
"error Display should contain the message"
);
let e2 = McmcError::DimensionMismatch;
assert!(
e2.to_string().len() > 0,
"DimensionMismatch display should not be empty"
);
let e3 = McmcError::NumericalError("NaN".to_string());
assert!(
e3.to_string().contains("NaN"),
"NumericalError display should contain the message"
);
}
}