use std::{
cell::{Ref, RefCell},
fmt::Debug,
marker::PhantomData,
ops::DerefMut,
};
use anyhow::{Result, bail};
use nuts_derive::Storable;
use nuts_storable::{HasDims, Storable};
use rand_distr::num_traits::ToPrimitive;
use serde::{Deserialize, Serialize};
use crate::{
Math, NutsError,
chain::{AdaptStrategy, Chain, StatOptions},
dynamics::{
Direction, DivergenceInfo, DivergenceStats, Hamiltonian, KineticEnergyKind, Point, State,
TransformedHamiltonian, TransformedPoint,
},
nuts::{Collector, NutsOptions},
sampler::Progress,
sampler_stats::{SamplerStats, StatsDims},
transform::Transformation,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub enum MclmcTrajectoryKind {
#[default]
Microcanonical,
Euclidean,
EuclideanEarlyThenMicrocanonical,
}
#[derive(Debug, Clone)]
pub struct MclmcInfo {
pub energy_change: f64,
pub diverging: bool,
pub divergence_info: Option<DivergenceInfo>,
pub num_steps: u64,
pub average_step_size: f64,
}
#[derive(Debug, Storable)]
pub struct MclmcStats<P: HasDims, H: Storable<P>, A: Storable<P>, Pt: Storable<P>> {
pub chain: u64,
pub draw: u64,
pub num_steps: u64,
pub energy_change: f64,
pub log_weight: f64,
pub tuning: bool,
#[storable(flatten)]
pub hamiltonian: H,
#[storable(flatten)]
pub adapt: A,
#[storable(flatten)]
pub point: Pt,
pub average_step_size: f64,
#[storable(flatten)]
pub divergence: DivergenceStats,
#[storable(ignore)]
_phantom: PhantomData<fn() -> P>,
}
pub struct MclmcChain<M, R, A, T>
where
M: Math,
R: rand::Rng,
T: Transformation<M>,
A: AdaptStrategy<M, Hamiltonian = TransformedHamiltonian<M, T>>,
{
hamiltonian: TransformedHamiltonian<M, T>,
collector: A::Collector,
adapt: A,
state: State<M, TransformedPoint<M>>,
rng: R,
chain: u64,
draw_count: u64,
subsample_frequency: f64,
dynamic_step_size: bool,
trajectory_kind: MclmcTrajectoryKind,
switch_draw: u64,
max_energy_error: f64,
nuts_options: NutsOptions,
math: RefCell<M>,
stats_options: StatOptions<M, A>,
last_info: Option<MclmcInfo>,
tmp_velocity: M::Vector,
}
impl<M, R, A, T> MclmcChain<M, R, A, T>
where
M: Math,
R: rand::Rng,
T: Transformation<M>,
A: AdaptStrategy<M, Hamiltonian = TransformedHamiltonian<M, T>>,
{
pub fn new(
mut math: M,
mut hamiltonian: TransformedHamiltonian<M, T>,
adapt: A,
rng: R,
chain: u64,
subsample_frequency: f64,
dynamic_step_size: bool,
trajectory_kind: MclmcTrajectoryKind,
switch_draw: u64,
max_energy_error: f64,
stats_options: StatOptions<M, A>,
) -> Self {
let state = hamiltonian.pool().new_state(&mut math);
let collector = adapt.new_collector(&mut math);
let tmp_velocity = math.new_array();
Self {
hamiltonian,
collector,
adapt,
state,
rng,
chain,
draw_count: 0,
subsample_frequency,
dynamic_step_size,
trajectory_kind,
switch_draw,
nuts_options: NutsOptions::default(),
math: math.into(),
stats_options,
last_info: None,
tmp_velocity,
max_energy_error,
}
}
fn mclmc_kernel(
&mut self,
resample_velocity: bool,
) -> Result<(State<M, TransformedPoint<M>>, MclmcInfo)> {
let math = self.math.get_mut();
let base_step_size = self.hamiltonian.step_size();
let num_base_steps: u64 = self
.hamiltonian
.momentum_decoherence_length()
.map(|length| {
let num_steps = (self.subsample_frequency * length / base_step_size)
.round()
.max(1.0)
.min(1e6);
if !num_steps.is_finite() {
bail!("Invalid number of integration steps");
}
Ok(num_steps as u64)
})
.unwrap_or(Ok(1))?;
let max_halvings: u64 = if self.dynamic_step_size { 10 } else { 0 };
use crate::dynamics::LeapfrogResult;
let mut current = self.hamiltonian.copy_state(math, &self.state);
self.hamiltonian.initialize_trajectory(
math,
&mut current,
resample_velocity,
&mut self.rng,
)?;
let ones = {
let mut ones = math.new_array();
math.fill_array(&mut ones, 1.0);
ones
};
let mut momentum_noise = math.new_array();
math.array_gaussian(&mut self.rng, &mut momentum_noise, &ones);
let draw_start_energy = current.point().energy();
let mut divergence_info: Option<DivergenceInfo> = None;
let mut steps_taken = 0u64;
let mut factor = 1.0_f64;
let mut remaining_stack: Vec<u64> = Vec::with_capacity(max_halvings.try_into().unwrap());
let mut remaining = num_base_steps;
let mut time = 0.0;
while remaining > 0 {
math.copy_into(¤t.point().velocity, &mut self.tmp_velocity);
self.hamiltonian.partial_momentum_refresh(
math,
&mut current,
&momentum_noise,
&mut self.rng,
factor,
)?;
let step_baseline = current.point().energy();
match self.hamiltonian.leapfrog(
math,
¤t,
Direction::Forward,
factor,
step_baseline,
self.max_energy_error * factor / num_base_steps.to_f64().unwrap(),
&mut self.collector,
) {
LeapfrogResult::Ok(mut next) => {
math.array_gaussian(&mut self.rng, &mut momentum_noise, &ones);
self.hamiltonian.partial_momentum_refresh(
math,
&mut next,
&momentum_noise,
&mut self.rng,
factor,
)?;
math.array_gaussian(&mut self.rng, &mut momentum_noise, &ones);
current = next;
steps_taken += 1;
remaining -= 1;
time += factor * base_step_size;
while remaining == 0 {
if let Some(prev_remaining) = remaining_stack.pop() {
remaining = prev_remaining - 1;
factor *= 2.0;
} else {
break;
}
}
}
LeapfrogResult::Divergence(info) => {
if remaining_stack.len() >= max_halvings.try_into().unwrap() {
divergence_info = Some(info);
break;
}
factor *= 0.5;
remaining_stack.push(remaining);
remaining = 2;
math.copy_into(
&self.tmp_velocity,
&mut current.try_point_mut().unwrap().velocity,
);
}
LeapfrogResult::Err(e) => {
return Err(NutsError::LogpFailure(e.into()).into());
}
}
}
if divergence_info.is_some() {
let mut next_state = self.hamiltonian.copy_state(math, &self.state);
self.hamiltonian
.initialize_trajectory(math, &mut next_state, true, &mut self.rng)?;
let energy_change = current.point().energy() - draw_start_energy;
let info = MclmcInfo {
energy_change,
diverging: true,
divergence_info: divergence_info.clone(),
num_steps: steps_taken,
average_step_size: time / steps_taken.to_f64().unwrap(),
};
let sample_info = crate::nuts::SampleInfo {
depth: steps_taken,
divergence_info,
reached_maxdepth: false,
};
self.collector.register_draw(math, ¤t, &sample_info);
return Ok((next_state, info));
}
assert!(steps_taken >= num_base_steps);
let sample_info = crate::nuts::SampleInfo {
depth: steps_taken,
divergence_info: None,
reached_maxdepth: false,
};
self.collector.register_draw(math, ¤t, &sample_info);
let energy_change = current.point().energy_error();
let info = MclmcInfo {
energy_change,
diverging: false,
divergence_info: None,
num_steps: steps_taken,
average_step_size: time / steps_taken.to_f64().unwrap(),
};
Ok((current, info))
}
}
impl<M, R, A, T> SamplerStats<M> for MclmcChain<M, R, A, T>
where
M: Math,
R: rand::Rng,
T: Transformation<M>,
A: AdaptStrategy<M, Hamiltonian = TransformedHamiltonian<M, T>>,
{
type Stats = MclmcStats<
StatsDims,
<TransformedHamiltonian<M, T> as SamplerStats<M>>::Stats,
A::Stats,
<TransformedPoint<M> as SamplerStats<M>>::Stats,
>;
type StatsOptions = StatOptions<M, A>;
fn extract_stats(&self, math: &mut M, options: Self::StatsOptions) -> Self::Stats {
let info = self
.last_info
.as_ref()
.expect("Sampler has not started yet");
let hamiltonian_stats = self.hamiltonian.extract_stats(math, options.hamiltonian);
let adapt_stats = self.adapt.extract_stats(math, options.adapt);
let point_stats = self.state.point().extract_stats(math, options.point);
MclmcStats {
chain: self.chain,
draw: self.draw_count,
num_steps: info.num_steps,
energy_change: info.energy_change,
log_weight: info.energy_change,
tuning: self.adapt.is_tuning(),
hamiltonian: hamiltonian_stats,
adapt: adapt_stats,
point: point_stats,
average_step_size: info.average_step_size,
divergence: (
info.divergence_info.as_ref(),
options.divergence,
self.draw_count,
)
.into(),
_phantom: PhantomData,
}
}
}
impl<M, R, A, T> Chain<M> for MclmcChain<M, R, A, T>
where
M: Math,
R: rand::Rng,
T: Transformation<M>,
A: AdaptStrategy<M, Hamiltonian = TransformedHamiltonian<M, T>>,
{
type AdaptStrategy = A;
fn set_position(&mut self, position: &[f64]) -> Result<()> {
let mut math_ = self.math.borrow_mut();
let math = math_.deref_mut();
self.adapt.init(
math,
&mut self.nuts_options,
&mut self.hamiltonian,
position,
&mut self.rng,
)?;
self.state = self.hamiltonian.init_state(math, position)?;
self.hamiltonian
.initialize_trajectory(math, &mut self.state, true, &mut self.rng)?;
Ok(())
}
fn draw(&mut self) -> Result<(Box<[f64]>, Progress)> {
let resample_velocity = if self.trajectory_kind
== MclmcTrajectoryKind::EuclideanEarlyThenMicrocanonical
&& self.draw_count == self.switch_draw
&& self.hamiltonian.kinetic_energy_kind != KineticEnergyKind::Microcanonical
{
self.hamiltonian
.set_kinetic_energy_kind(KineticEnergyKind::Microcanonical);
true
} else {
false
};
let (state, info) = self.mclmc_kernel(resample_velocity)?;
let position: Box<[f64]> = {
let mut math_ = self.math.borrow_mut();
let math = math_.deref_mut();
let mut pos = vec![0f64; math.dim()];
state.write_position(math, &mut pos);
pos.into()
};
let progress = Progress {
draw: self.draw_count,
chain: self.chain,
diverging: info.diverging,
tuning: self.adapt.is_tuning(),
step_size: self.hamiltonian.step_size(),
num_steps: info.num_steps,
};
{
let mut math_ = self.math.borrow_mut();
let math = math_.deref_mut();
self.adapt.adapt(
math,
&mut self.nuts_options,
&mut self.hamiltonian,
self.draw_count,
&self.collector,
&state,
&mut self.rng,
)?;
self.collector = self.adapt.new_collector(math);
}
self.draw_count += 1;
self.state = state;
self.last_info = Some(info);
Ok((position, progress))
}
fn dim(&self) -> usize {
self.math.borrow().dim()
}
fn expanded_draw(&mut self) -> Result<(Box<[f64]>, M::ExpandedVector, Self::Stats, Progress)> {
let (position, progress) = self.draw()?;
let mut math_ = self.math.borrow_mut();
let math = math_.deref_mut();
let stats = self.extract_stats(math, self.stats_options);
self.stats_options.hamiltonian = self
.hamiltonian
.update_stats_options(math, self.stats_options.hamiltonian);
let expanded = math.expand_vector(&mut self.rng, self.state.point().position())?;
Ok((position, expanded, stats, progress))
}
fn math(&self) -> Ref<'_, M> {
self.math.borrow()
}
}
#[cfg(test)]
mod tests {
use rand::rng;
use crate::{
Chain, DiagMclmcSettings, MclmcSettings, adapt_strategy::test_logps::NormalLogp,
math::CpuMath, sampler::Settings,
};
#[test]
fn mclmc_draws_normal() {
let ndim = 10;
let func = NormalLogp::new(ndim, 3.0);
let math = CpuMath::new(func);
let settings = DiagMclmcSettings {
step_size: 0.5,
momentum_decoherence_length: 3.0,
num_tune: 200,
num_draws: 500,
..MclmcSettings::default()
};
let mut rng = rng();
let mut chain = settings.new_chain(0, math, &mut rng);
chain.set_position(&vec![0.0f64; ndim]).unwrap();
let mut last_pos = vec![0.0f64; ndim];
for _ in 0..500 {
let (draw, progress) = chain.draw().unwrap();
assert!(!progress.diverging, "unexpected divergence");
last_pos.copy_from_slice(&draw);
}
let mean: f64 = last_pos.iter().sum::<f64>() / ndim as f64;
assert!(
(mean - 3.0).abs() < 3.0,
"mean {mean} too far from expected 3.0"
);
}
#[test]
fn mclmc_euclidean_trajectory() {
use crate::mclmc::MclmcTrajectoryKind;
let ndim = 10;
let func = NormalLogp::new(ndim, 3.0);
let math = CpuMath::new(func);
let settings = DiagMclmcSettings {
step_size: 0.3,
momentum_decoherence_length: 3.0,
num_tune: 200,
num_draws: 500,
trajectory_kind: MclmcTrajectoryKind::Euclidean,
..MclmcSettings::default()
};
let mut rng = rng();
let mut chain = settings.new_chain(0, math, &mut rng);
chain.set_position(&vec![0.0f64; ndim]).unwrap();
let mut last_pos = vec![0.0f64; ndim];
for _ in 0..500 {
let (draw, progress) = chain.draw().unwrap();
assert!(!progress.diverging, "unexpected divergence");
last_pos.copy_from_slice(&draw);
}
let mean: f64 = last_pos.iter().sum::<f64>() / ndim as f64;
assert!(
(mean - 3.0).abs() < 3.0,
"mean {mean} too far from expected 3.0"
);
}
#[test]
fn mclmc_euclidean_early_then_microcanonical() {
use crate::mclmc::MclmcTrajectoryKind;
let ndim = 10;
let func = NormalLogp::new(ndim, 3.0);
let math = CpuMath::new(func);
let settings = DiagMclmcSettings {
step_size: 0.5,
momentum_decoherence_length: 3.0,
num_tune: 200,
num_draws: 500,
trajectory_kind: MclmcTrajectoryKind::EuclideanEarlyThenMicrocanonical,
trajectory_switch_fraction: 0.3,
..MclmcSettings::default()
};
let mut rng = rng();
let mut chain = settings.new_chain(0, math, &mut rng);
chain.set_position(&vec![0.0f64; ndim]).unwrap();
let mut last_pos = vec![0.0f64; ndim];
for _ in 0..500 {
let (draw, progress) = chain.draw().unwrap();
assert!(!progress.diverging, "unexpected divergence");
last_pos.copy_from_slice(&draw);
}
let mean: f64 = last_pos.iter().sum::<f64>() / ndim as f64;
assert!(
(mean - 3.0).abs() < 3.0,
"mean {mean} too far from expected 3.0"
);
}
}