use std::{fmt::Debug, sync::Arc};
use nuts_derive::Storable;
use rand::{
Rng, RngExt,
distr::{Distribution, StandardUniform},
};
use crate::{
Math, NutsError,
dynamics::{State, StatePool},
nuts::Collector,
sampler_stats::SamplerStats,
};
#[derive(Debug, Clone)]
pub struct DivergenceInfo {
pub start_momentum: Option<Box<[f64]>>,
pub start_location: Option<Box<[f64]>>,
pub start_gradient: Option<Box<[f64]>>,
pub end_location: Option<Box<[f64]>>,
pub energy_error: Option<f64>,
pub end_idx_in_trajectory: Option<i64>,
pub start_idx_in_trajectory: Option<i64>,
pub logp_function_error: Option<Arc<dyn std::error::Error + Send + Sync>>,
}
#[derive(Debug, Storable)]
pub struct DivergenceStats {
pub diverging: bool,
#[storable(event = "divergence")]
pub divergence_draw: Option<u64>,
#[storable(event = "divergence")]
pub divergence_message: Option<String>,
#[storable(event = "divergence", dims("unconstrained_parameter"))]
pub divergence_start: Option<Vec<f64>>,
#[storable(event = "divergence", dims("unconstrained_parameter"))]
pub divergence_start_gradient: Option<Vec<f64>>,
#[storable(event = "divergence", dims("unconstrained_parameter"))]
pub divergence_end: Option<Vec<f64>>,
#[storable(event = "divergence", dims("unconstrained_parameter"))]
pub divergence_momentum: Option<Vec<f64>>,
#[storable(event = "divergence")]
pub divergence_energy_error: Option<f64>,
}
#[derive(Debug, Clone, Copy)]
pub struct DivergenceStatsOptions {
pub store_divergences: bool,
}
impl From<(Option<&DivergenceInfo>, DivergenceStatsOptions, u64)> for DivergenceStats {
fn from((info, options, draw): (Option<&DivergenceInfo>, DivergenceStatsOptions, u64)) -> Self {
DivergenceStats {
diverging: info.is_some(),
divergence_draw: info.map(|_| draw),
divergence_start: if options.store_divergences {
info.and_then(|d| d.start_location.as_ref().map(|v| v.as_ref().to_vec()))
} else {
None
},
divergence_start_gradient: if options.store_divergences {
info.and_then(|d| d.start_gradient.as_ref().map(|v| v.as_ref().to_vec()))
} else {
None
},
divergence_end: if options.store_divergences {
info.and_then(|d| d.end_location.as_ref().map(|v| v.as_ref().to_vec()))
} else {
None
},
divergence_momentum: if options.store_divergences {
info.and_then(|d| d.start_momentum.as_ref().map(|v| v.as_ref().to_vec()))
} else {
None
},
divergence_message: info.map(|d| {
if let Some(err) = &d.logp_function_error {
err.to_string()
} else if let Some(energy_err) = d.energy_error {
if energy_err.is_nan() {
"Divergence due to NaN energy error".to_string()
} else {
format!("Divergence due to large energy error: {:.4}", energy_err)
}
} else {
"Divergence (unknown cause)".to_string()
}
}),
divergence_energy_error: info.and_then(|d| d.energy_error),
}
}
}
#[derive(Debug, Copy, Clone)]
pub enum Direction {
Forward,
Backward,
}
impl Distribution<Direction> for StandardUniform {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Direction {
if rng.random::<bool>() {
Direction::Forward
} else {
Direction::Backward
}
}
}
pub enum LeapfrogResult<M: Math, P: Point<M>> {
Ok(State<M, P>),
Divergence(DivergenceInfo),
Err(M::LogpErr),
}
pub trait Point<M: Math>: Sized + SamplerStats<M> + Debug {
fn position(&self) -> &M::Vector;
fn gradient(&self) -> &M::Vector;
fn index_in_trajectory(&self) -> i64;
fn energy(&self) -> f64;
fn logp(&self) -> f64;
fn energy_error(&self) -> f64 {
self.energy() - self.initial_energy()
}
fn initial_energy(&self) -> f64;
fn new(math: &mut M) -> Self;
fn copy_into(&self, math: &mut M, other: &mut Self);
}
pub trait Hamiltonian<M: Math>: SamplerStats<M> + Sized {
type Point: Point<M>;
fn leapfrog<C: Collector<M, Self::Point>>(
&mut self,
math: &mut M,
start: &State<M, Self::Point>,
dir: Direction,
step_size_factor: f64,
energy_baseline: f64,
max_energy_error: f64,
collector: &mut C,
) -> LeapfrogResult<M, Self::Point>;
fn is_turning(
&self,
math: &mut M,
state1: &State<M, Self::Point>,
state2: &State<M, Self::Point>,
) -> bool;
fn init_state(
&mut self,
math: &mut M,
init: &[f64],
) -> Result<State<M, Self::Point>, NutsError>;
fn init_state_untransformed(
&mut self,
math: &mut M,
init: &[f64],
) -> Result<State<M, Self::Point>, NutsError>;
fn initialize_trajectory<R: rand::Rng + ?Sized>(
&self,
math: &mut M,
state: &mut State<M, Self::Point>,
resaple_velocity: bool,
rng: &mut R,
) -> Result<(), NutsError>;
fn pool(&mut self) -> &mut StatePool<M, Self::Point>;
fn copy_state(&mut self, math: &mut M, state: &State<M, Self::Point>) -> State<M, Self::Point>;
fn step_size(&self) -> f64;
fn step_size_mut(&mut self) -> &mut f64;
fn update_stats_options(
&mut self,
_math: &mut M,
current: <Self as SamplerStats<M>>::StatsOptions,
) -> <Self as SamplerStats<M>>::StatsOptions {
current
}
fn momentum_decoherence_length(&self) -> Option<f64> {
None
}
fn momentum_decoherence_length_mut(&mut self) -> Option<&mut f64> {
None
}
fn partial_momentum_refresh<R: rand::Rng + ?Sized>(
&mut self,
math: &mut M,
state: &mut State<M, Self::Point>,
noise: &M::Vector,
rng: &mut R,
factor: f64,
) -> Result<(), NutsError> {
let _ = (math, state, noise, rng, factor);
Ok(())
}
}