use crate::diag_mass::{DiagMass, RunningVariance};
use crate::euclidean::BatchVector;
use ndarray::Array3;
use num_traits::{Float, FromPrimitive, One, ToPrimitive, Zero};
use rand::SeedableRng;
use rand::distr::Distribution as RandDistribution;
use rand::rngs::SmallRng;
use rand_distr::StandardNormal;
pub trait BatchedHamiltonianTarget<V: BatchVector> {
fn logp_and_grad(&self, position: &V, grad: &mut V) -> V::Energy;
}
#[derive(Debug)]
pub struct BatchedGenericHMC<V, Target>
where
V: BatchVector,
Target: BatchedHamiltonianTarget<V>,
{
target: Target,
step_size: V::Scalar,
step_size_bar: V::Scalar,
n_leapfrog: usize,
target_accept_p: V::Scalar,
gamma: V::Scalar,
t_0: usize,
kappa: V::Scalar,
mu: V::Scalar,
h_bar: V::Scalar,
mass: DiagMass<V::Scalar>,
position: V,
momentum: V,
grad: V,
proposal_pos: V,
proposal_mom: V,
rng: SmallRng,
n_chains: usize,
dim: usize,
}
impl<V, Target> BatchedGenericHMC<V, Target>
where
V: BatchVector,
V::Scalar: Float + FromPrimitive + ToPrimitive + Zero,
Target: BatchedHamiltonianTarget<V>,
StandardNormal: RandDistribution<V::Scalar>,
{
pub fn new(
target: Target,
initial_position: V,
step_size: V::Scalar,
n_leapfrog: usize,
) -> Self {
let n_chains = initial_position.n_chains();
let dim = initial_position.dim_per_chain();
Self {
target,
step_size,
step_size_bar: step_size,
n_leapfrog,
target_accept_p: V::Scalar::from_f64(0.8).unwrap(),
gamma: V::Scalar::from_f64(0.05).unwrap(),
t_0: 10,
kappa: V::Scalar::from_f64(0.75).unwrap(),
mu: (V::Scalar::from_f64(10.0).unwrap() * step_size).ln(),
h_bar: V::Scalar::zero(),
mass: DiagMass::identity(dim),
momentum: initial_position.zeros_like(),
grad: initial_position.zeros_like(),
proposal_pos: initial_position.zeros_like(),
proposal_mom: initial_position.zeros_like(),
position: initial_position,
rng: SmallRng::from_rng(&mut rand::rng()),
n_chains,
dim,
}
}
pub fn set_seed(mut self, seed: u64) -> Self {
self.rng = SmallRng::seed_from_u64(seed);
self
}
pub fn set_target_accept(mut self, target_accept_p: V::Scalar) -> Self {
assert!(
target_accept_p > V::Scalar::zero() && target_accept_p < V::Scalar::one(),
"target_accept must be in (0, 1)"
);
self.target_accept_p = target_accept_p;
self
}
pub fn run(&mut self, n_collect: usize, n_discard: usize) -> Array3<V::Scalar> {
self.warmup(n_discard);
let mut out = Array3::<V::Scalar>::zeros((self.n_chains, n_collect, self.dim));
let mut scratch = vec![V::Scalar::zero(); self.n_chains * self.dim];
for step_idx in 0..n_collect {
let _ = self.step();
self.position.write_to_slice(&mut scratch);
for chain_idx in 0..self.n_chains {
for d in 0..self.dim {
out[[chain_idx, step_idx, d]] = scratch[chain_idx * self.dim + d];
}
}
}
out
}
pub fn run_positions(&mut self, n_collect: usize, n_discard: usize) -> Vec<V> {
self.warmup(n_discard);
let mut samples = Vec::with_capacity(n_collect);
for _ in 0..n_collect {
let _ = self.step();
samples.push(self.position.clone());
}
samples
}
pub(crate) fn warmup(&mut self, n_discard: usize) {
if n_discard == 0 {
return;
}
self.step_size = self.find_reasonable_step_size(self.step_size);
self.reset_step_size_adaptation();
let mass_update_iter = Self::mass_adaptation_iter(n_discard);
let mut adapt_iter = 0;
let mut running = RunningVariance::new(self.dim);
let mut scratch = vec![V::Scalar::zero(); self.n_chains * self.dim];
for iter in 1..=n_discard {
let accept_p = self.step();
adapt_iter += 1;
self.adapt_step_size(accept_p, adapt_iter);
if iter <= mass_update_iter {
self.position.write_to_slice(&mut scratch);
running.update_batch(&scratch);
}
if iter == mass_update_iter && self.update_mass_from_running(&running) {
self.step_size = self.find_reasonable_step_size(self.step_size_bar);
self.reset_step_size_adaptation();
adapt_iter = 0;
}
}
self.step_size = self.step_size_bar;
}
fn mass_adaptation_iter(n_discard: usize) -> usize {
let iter = n_discard / 2;
if iter < 5 || iter >= n_discard {
0
} else {
iter
}
}
fn update_mass_from_running(&mut self, running: &RunningVariance<V::Scalar>) -> bool {
if running.sample_count() < 5 {
return false;
}
let regularize = V::Scalar::from_f64(0.05).unwrap();
let jitter = V::Scalar::from_f64(1e-6).unwrap();
let Some(var) = running.regularized_variance(regularize, jitter) else {
return false;
};
self.mass = DiagMass::from_variance(var, jitter);
true
}
fn find_reasonable_step_size(&mut self, mut step_size: V::Scalar) -> V::Scalar {
let two = V::Scalar::from_f64(2.0).unwrap();
let half = V::Scalar::from_f64(0.5).unwrap();
let min_step = V::Scalar::epsilon();
let max_iters = 32;
let mut accept_p = self.mean_acceptance_for_step_size(step_size);
while (!accept_p.is_finite() || accept_p <= V::Scalar::zero()) && step_size > min_step {
step_size = (step_size / two).max(min_step);
accept_p = self.mean_acceptance_for_step_size(step_size);
}
if !accept_p.is_finite() {
return min_step;
}
let grow = accept_p > half;
for _ in 0..max_iters {
let candidate = if grow {
step_size * two
} else {
step_size / two
};
if candidate <= min_step {
break;
}
let candidate_accept = self.mean_acceptance_for_step_size(candidate);
if !candidate_accept.is_finite() || candidate_accept <= V::Scalar::zero() {
if grow {
break;
}
step_size = candidate.max(min_step);
continue;
}
if (candidate_accept > half) != grow {
break;
}
step_size = candidate;
}
step_size.max(min_step)
}
fn reset_step_size_adaptation(&mut self) {
self.step_size_bar = self.step_size;
self.mu = (V::Scalar::from_f64(10.0).unwrap() * self.step_size).ln();
self.h_bar = V::Scalar::zero();
}
fn adapt_step_size(&mut self, accept_p: V::Scalar, iter: usize) {
let m = V::Scalar::from_usize(iter).unwrap();
let eta = V::Scalar::one()
/ V::Scalar::from_usize(iter + self.t_0).expect("iteration converts to scalar");
self.h_bar =
(V::Scalar::one() - eta) * self.h_bar + eta * (self.target_accept_p - accept_p);
self.step_size = (self.mu - m.sqrt() / self.gamma * self.h_bar).exp();
let eta_bar = m.powf(-self.kappa);
self.step_size_bar = ((V::Scalar::one() - eta_bar) * self.step_size_bar.ln()
+ eta_bar * self.step_size.ln())
.exp();
}
pub fn step(&mut self) -> V::Scalar {
self.step_with_step_size(self.step_size, true)
}
fn mean_acceptance_for_step_size(&mut self, step_size: V::Scalar) -> V::Scalar {
self.step_with_step_size(step_size, false)
}
fn step_with_step_size(&mut self, step_size: V::Scalar, apply: bool) -> V::Scalar {
let mass_inv = self.mass.inv();
let mass_sqrt = self.mass.sqrt();
self.momentum.fill_random_normal(&mut self.rng);
self.momentum.scale_diag_assign(mass_sqrt);
let ke_current = self.momentum.kinetic_energy_diag(mass_inv);
self.grad.fill_zero();
let logp_current = self.target.logp_and_grad(&self.position, &mut self.grad);
self.proposal_pos.assign(&self.position);
self.proposal_mom.assign(&self.momentum);
let logp_proposed = Self::leapfrog(
&self.target,
&mut self.proposal_pos,
&mut self.proposal_mom,
&mut self.grad,
step_size,
self.n_leapfrog,
mass_inv,
);
let ke_proposed = self.proposal_mom.kinetic_energy_diag(mass_inv);
let delta_logp = V::energy_sub(&logp_proposed, &logp_current);
let delta_ke = V::energy_sub(&ke_current, &ke_proposed);
let log_accept = V::energy_add(&delta_logp, &delta_ke);
let mean_accept = V::mean_acceptance(&log_accept);
if apply {
let u = self.position.sample_uniform(&mut self.rng);
let ln_u = V::energy_ln(&u);
let mask = V::accept_mask(&log_accept, &ln_u);
self.position.masked_assign(&self.proposal_pos, &mask);
}
mean_accept
}
fn leapfrog(
target: &Target,
proposal_pos: &mut V,
proposal_mom: &mut V,
grad: &mut V,
step_size: V::Scalar,
n_leapfrog: usize,
inv_mass: &[V::Scalar],
) -> V::Energy {
let half = V::Scalar::from_f64(0.5).unwrap() * step_size;
let mut logp = target.logp_and_grad(proposal_pos, grad);
for _ in 0..n_leapfrog {
proposal_mom.add_scaled_assign(grad, half);
proposal_pos.add_diag_scaled_assign(proposal_mom, inv_mass, step_size);
logp = target.logp_and_grad(proposal_pos, grad);
proposal_mom.add_scaled_assign(grad, half);
}
logp
}
pub fn positions(&self) -> &V {
&self.position
}
pub fn target(&self) -> &Target {
&self.target
}
pub fn step_size(&self) -> &V::Scalar {
&self.step_size
}
pub fn n_leapfrog(&self) -> usize {
self.n_leapfrog
}
pub fn rng_clone(&self) -> SmallRng {
self.rng.clone()
}
#[cfg(test)]
pub(crate) fn mass_diag(&self) -> Vec<V::Scalar> {
self.mass
.inv()
.iter()
.map(|&inv| V::Scalar::one() / inv)
.collect()
}
}