use crate::euclidean::BatchVector;
use ndarray::Array3;
use num_traits::{Float, FromPrimitive, ToPrimitive, Zero};
use rand::distr::Distribution as RandDistribution;
use rand::rngs::SmallRng;
use rand::SeedableRng;
use rand_distr::StandardNormal;
pub trait BatchedHamiltonianTarget<V: BatchVector> {
fn logp_and_grad(&self, position: &V, grad: &mut V) -> V::Energy;
}
#[derive(Debug)]
pub struct BatchedGenericHMC<V, Target>
where
V: BatchVector,
Target: BatchedHamiltonianTarget<V>,
{
target: Target,
step_size: V::Scalar,
n_leapfrog: usize,
position: V,
momentum: V,
grad: V,
proposal_pos: V,
proposal_mom: V,
rng: SmallRng,
n_chains: usize,
dim: usize,
}
impl<V, Target> BatchedGenericHMC<V, Target>
where
V: BatchVector,
V::Scalar: Float + FromPrimitive + ToPrimitive + Zero,
Target: BatchedHamiltonianTarget<V>,
StandardNormal: RandDistribution<V::Scalar>,
{
pub fn new(
target: Target,
initial_position: V,
step_size: V::Scalar,
n_leapfrog: usize,
) -> Self {
let n_chains = initial_position.n_chains();
let dim = initial_position.dim_per_chain();
Self {
target,
step_size,
n_leapfrog,
momentum: initial_position.zeros_like(),
grad: initial_position.zeros_like(),
proposal_pos: initial_position.zeros_like(),
proposal_mom: initial_position.zeros_like(),
position: initial_position,
rng: SmallRng::from_rng(&mut rand::rng()),
n_chains,
dim,
}
}
pub fn set_seed(mut self, seed: u64) -> Self {
self.rng = SmallRng::seed_from_u64(seed);
self
}
pub fn run(&mut self, n_collect: usize, n_discard: usize) -> Array3<V::Scalar> {
(0..n_discard).for_each(|_| self.step());
let mut out = Array3::<V::Scalar>::zeros((self.n_chains, n_collect, self.dim));
let mut scratch = vec![V::Scalar::zero(); self.n_chains * self.dim];
for step_idx in 0..n_collect {
self.step();
self.position.write_to_slice(&mut scratch);
for chain_idx in 0..self.n_chains {
for d in 0..self.dim {
out[[chain_idx, step_idx, d]] = scratch[chain_idx * self.dim + d];
}
}
}
out
}
pub fn run_positions(&mut self, n_collect: usize, n_discard: usize) -> Vec<V> {
(0..n_discard).for_each(|_| self.step());
let mut samples = Vec::with_capacity(n_collect);
for _ in 0..n_collect {
self.step();
samples.push(self.position.clone());
}
samples
}
pub fn step(&mut self) {
self.momentum.fill_random_normal(&mut self.rng);
let ke_current = self.momentum.kinetic_energy();
self.grad.fill_zero();
let logp_current = self.target.logp_and_grad(&self.position, &mut self.grad);
self.proposal_pos.assign(&self.position);
self.proposal_mom.assign(&self.momentum);
let logp_proposed = self.leapfrog();
let ke_proposed = self.proposal_mom.kinetic_energy();
let delta_logp = V::energy_sub(&logp_proposed, &logp_current);
let delta_ke = V::energy_sub(&ke_current, &ke_proposed);
let log_accept = V::energy_add(&delta_logp, &delta_ke);
let u = self.position.sample_uniform(&mut self.rng);
let ln_u = V::energy_ln(&u);
let mask = V::accept_mask(&log_accept, &ln_u);
self.position.masked_assign(&self.proposal_pos, &mask);
}
fn leapfrog(&mut self) -> V::Energy {
let half = V::Scalar::from_f64(0.5).unwrap() * self.step_size;
let mut logp = self
.target
.logp_and_grad(&self.proposal_pos, &mut self.grad);
for _ in 0..self.n_leapfrog {
self.proposal_mom.add_scaled_assign(&self.grad, half);
self.proposal_pos
.add_scaled_assign(&self.proposal_mom, self.step_size);
logp = self
.target
.logp_and_grad(&self.proposal_pos, &mut self.grad);
self.proposal_mom.add_scaled_assign(&self.grad, half);
}
logp
}
pub fn positions(&self) -> &V {
&self.position
}
pub fn target(&self) -> &Target {
&self.target
}
pub fn step_size(&self) -> &V::Scalar {
&self.step_size
}
pub fn n_leapfrog(&self) -> usize {
self.n_leapfrog
}
pub fn rng_clone(&self) -> SmallRng {
self.rng.clone()
}
}