use anyhow::Result;
use nuts_storable::{HasDims, Storable, Value};
use rand::{Rng, SeedableRng, rngs::ChaCha8Rng};
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use std::{collections::HashMap, fmt::Debug, time::Duration};
#[cfg(feature = "parallel")]
use anyhow::{Context, bail};
#[cfg(feature = "parallel")]
use itertools::Itertools;
#[cfg(feature = "parallel")]
use std::ops::Deref;
#[cfg(feature = "parallel")]
use rayon::{ScopeFifo, ThreadPoolBuilder};
#[cfg(feature = "parallel")]
use std::{
sync::{
Arc, Mutex,
mpsc::{
Receiver, RecvTimeoutError, Sender, SyncSender, TryRecvError, channel, sync_channel,
},
},
thread::{JoinHandle, spawn},
time::Instant,
};
use crate::{
DiagAdaptExpSettings, Math, StepSizeAdaptMethod,
adapt_strategy::{EuclideanAdaptOptions, GlobalStrategy, GlobalStrategyStatsOptions},
chain::{AdaptStrategy, Chain, NutsChain, StatOptions},
dynamics::{KineticEnergyKind, TransformedHamiltonian, TransformedPointStatsOptions},
external_adapt_strategy::{ExternalTransformAdaptation, FlowSettings},
mclmc::MclmcTrajectoryKind,
nuts::NutsOptions,
sampler_stats::{SamplerStats, StatsDims},
transform::{
DiagAdaptStrategy, DiagMassMatrix, ExternalTransformation, LowRankMassMatrix,
LowRankMassMatrixStrategy, LowRankSettings,
},
};
#[cfg(feature = "parallel")]
use crate::{model::Model, storage::{ChainStorage, StorageConfig, TraceStorage}};
pub trait Settings:
private::Sealed + Clone + Copy + Default + Sync + Send + Serialize + DeserializeOwned + 'static
{
type Chain<M: Math>: Chain<M>;
fn new_chain<M: Math, R: Rng + ?Sized>(
&self,
chain: u64,
math: M,
rng: &mut R,
) -> Self::Chain<M>;
fn hint_num_tune(&self) -> usize;
fn hint_num_draws(&self) -> usize;
fn num_chains(&self) -> usize;
fn seed(&self) -> u64;
fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions;
fn sampler_name(&self) -> &'static str;
fn adaptation_name(&self) -> &'static str;
fn stat_names<M: Math>(&self, math: &M) -> Vec<String> {
let dims = StatsDims::from(math);
<<Self::Chain<M> as SamplerStats<M>>::Stats as Storable<_>>::names(&dims)
.into_iter()
.map(String::from)
.collect()
}
fn data_names<M: Math>(&self, math: &M) -> Vec<String> {
<M::ExpandedVector as Storable<_>>::names(math)
.into_iter()
.map(String::from)
.collect()
}
fn stat_types<M: Math>(&self, math: &M) -> Vec<(String, nuts_storable::ItemType)> {
self.stat_names(math)
.into_iter()
.map(|name| (name.clone(), self.stat_type::<M>(math, &name)))
.collect()
}
fn stat_type<M: Math>(&self, math: &M, name: &str) -> nuts_storable::ItemType {
let dims = StatsDims::from(math);
<<Self::Chain<M> as SamplerStats<M>>::Stats as Storable<_>>::item_type(&dims, name)
}
fn data_types<M: Math>(&self, math: &M) -> Vec<(String, nuts_storable::ItemType)> {
self.data_names(math)
.into_iter()
.map(|name| (name.clone(), self.data_type(math, &name)))
.collect()
}
fn data_type<M: Math>(&self, math: &M, name: &str) -> nuts_storable::ItemType {
<M::ExpandedVector as Storable<_>>::item_type(math, name)
}
fn stat_dims_all<M: Math>(&self, math: &M) -> Vec<(String, Vec<String>)> {
self.stat_names(math)
.into_iter()
.map(|name| (name.clone(), self.stat_dims::<M>(math, &name)))
.collect()
}
fn stat_dims<M: Math>(&self, math: &M, name: &str) -> Vec<String> {
let dims = StatsDims::from(math);
<<Self::Chain<M> as SamplerStats<M>>::Stats as Storable<_>>::dims(&dims, name)
.into_iter()
.map(String::from)
.collect()
}
fn stat_dim_sizes<M: Math>(&self, math: &M) -> HashMap<String, u64> {
let dims = StatsDims::from(math);
dims.dim_sizes()
}
fn data_dims_all<M: Math>(&self, math: &M) -> Vec<(String, Vec<String>)> {
self.data_names(math)
.into_iter()
.map(|name| (name.clone(), self.data_dims(math, &name)))
.collect()
}
fn data_dims<M: Math>(&self, math: &M, name: &str) -> Vec<String> {
<M::ExpandedVector as Storable<_>>::dims(math, name)
.into_iter()
.map(String::from)
.collect()
}
fn stat_coords<M: Math>(&self, math: &M) -> HashMap<String, Value> {
let dims = StatsDims::from(math);
dims.coords()
}
fn stat_event_dims<M: Math>(&self, math: &M) -> Vec<(String, Option<String>)> {
let dims = StatsDims::from(math);
self.stat_names(math)
.into_iter()
.map(|name| {
let event_dim =
<<Self::Chain<M> as SamplerStats<M>>::Stats as Storable<_>>::event_dim(
&dims, &name,
)
.map(String::from);
(name, event_dim)
})
.collect()
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct Progress {
pub draw: u64,
pub chain: u64,
pub diverging: bool,
pub tuning: bool,
pub step_size: f64,
pub num_steps: u64,
}
mod private {
use super::{
DiagMclmcSettings, DiagNutsSettings, FlowMclmcSettings, FlowNutsSettings,
LowRankMclmcSettings, LowRankNutsSettings,
};
pub trait Sealed {}
impl Sealed for DiagNutsSettings {}
impl Sealed for LowRankNutsSettings {}
impl Sealed for FlowNutsSettings {}
impl Sealed for DiagMclmcSettings {}
impl Sealed for LowRankMclmcSettings {}
impl Sealed for FlowMclmcSettings {}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct NutsSettings<A: Debug + Copy + Default + Serialize> {
pub num_tune: u64,
pub num_draws: u64,
pub maxdepth: u64,
pub mindepth: u64,
pub store_gradient: bool,
pub store_unconstrained: bool,
pub store_transformed: bool,
pub max_energy_error: f64,
pub store_divergences: bool,
pub adapt_options: A,
pub check_turning: bool,
pub target_integration_time: Option<f64>,
pub trajectory_kind: KineticEnergyKind,
pub num_chains: usize,
pub seed: u64,
pub extra_doublings: u64,
}
pub type DiagNutsSettings = NutsSettings<EuclideanAdaptOptions<DiagAdaptExpSettings>>;
#[deprecated(since = "0.0.0", note = "Use DiagNutsSettings instead")]
pub type DiagGradNutsSettings = DiagNutsSettings;
pub type LowRankNutsSettings = NutsSettings<EuclideanAdaptOptions<LowRankSettings>>;
pub type FlowNutsSettings = NutsSettings<FlowSettings>;
#[deprecated(since = "0.0.0", note = "Use FlowNutsSettings instead")]
pub type TransformedNutsSettings = FlowNutsSettings;
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct MclmcSettings<A: Debug + Copy + Default + Serialize> {
pub step_size: f64,
pub momentum_decoherence_length: f64,
pub num_tune: u64,
pub num_draws: u64,
pub num_chains: usize,
pub seed: u64,
pub max_energy_error: f64,
pub store_unconstrained: bool,
pub store_gradient: bool,
pub store_transformed: bool,
pub store_divergences: bool,
pub adapt_options: A,
pub subsample_frequency: f64,
pub dynamic_step_size: bool,
pub trajectory_kind: MclmcTrajectoryKind,
pub trajectory_switch_fraction: f64,
}
pub type DiagMclmcSettings = MclmcSettings<EuclideanAdaptOptions<DiagAdaptExpSettings>>;
pub type LowRankMclmcSettings = MclmcSettings<EuclideanAdaptOptions<LowRankSettings>>;
pub type FlowMclmcSettings = MclmcSettings<FlowSettings>;
#[deprecated(since = "0.0.0", note = "Use FlowMclmcSettings instead")]
pub type TransformedMclmcSettings = FlowMclmcSettings;
fn usize_hint(value: u64, field: &str) -> usize {
value
.try_into()
.unwrap_or_else(|_| panic!("{field} must be smaller than usize::MAX"))
}
fn default_mclmc_settings<A: Debug + Copy + Default + Serialize>(
adapt_options: A,
num_tune: u64,
num_chains: usize,
max_energy_error: f64,
) -> MclmcSettings<A> {
MclmcSettings {
step_size: 0.5,
momentum_decoherence_length: 3.0,
num_tune,
num_draws: 1000,
num_chains,
seed: 0,
max_energy_error,
store_unconstrained: false,
store_gradient: false,
store_divergences: false,
store_transformed: false,
adapt_options,
subsample_frequency: 1.0,
dynamic_step_size: true,
trajectory_kind: MclmcTrajectoryKind::EuclideanEarlyThenMicrocanonical,
trajectory_switch_fraction: 0.3,
}
}
impl Default for DiagMclmcSettings {
fn default() -> Self {
let mut adapt_options = EuclideanAdaptOptions::default();
adapt_options.step_size_settings.adapt_options.method = StepSizeAdaptMethod::Fixed(0.5);
default_mclmc_settings(adapt_options, 400, 6, 1000.0)
}
}
impl Default for LowRankMclmcSettings {
fn default() -> Self {
let mut adapt_options = EuclideanAdaptOptions::default();
adapt_options.early_mass_matrix_switch_freq = 20;
adapt_options.step_size_settings.adapt_options.method = StepSizeAdaptMethod::Fixed(0.5);
default_mclmc_settings(adapt_options, 800, 6, 1000.0)
}
}
impl Default for FlowMclmcSettings {
fn default() -> Self {
default_mclmc_settings(FlowSettings::default(), 1500, 1, 20.0)
}
}
type DiagMclmcChain<M> = crate::mclmc::MclmcChain<
M,
ChaCha8Rng,
GlobalStrategy<M, DiagAdaptStrategy<M>>,
DiagMassMatrix<M>,
>;
type LowRankMclmcChain<M> = crate::mclmc::MclmcChain<
M,
ChaCha8Rng,
GlobalStrategy<M, LowRankMassMatrixStrategy>,
LowRankMassMatrix<M>,
>;
impl Settings for DiagMclmcSettings {
type Chain<M: Math> = DiagMclmcChain<M>;
fn new_chain<M: Math, R: Rng + ?Sized>(
&self,
chain: u64,
mut math: M,
rng: &mut R,
) -> Self::Chain<M> {
use crate::dynamics::KineticEnergyKind;
use crate::mclmc::MclmcChain;
use crate::stepsize::StepSizeAdaptMethod;
let num_tune = self.num_tune;
let mut adapt_options = self.adapt_options;
adapt_options.step_size_settings.adapt_options.method =
StepSizeAdaptMethod::Fixed(self.step_size);
let strategy = GlobalStrategy::<M, DiagAdaptStrategy<M>>::new(
&mut math,
adapt_options,
num_tune,
chain,
);
let mass_matrix = DiagMassMatrix::new(
&mut math,
self.adapt_options.mass_matrix_options.store_mass_matrix,
);
let initial_kind = match self.trajectory_kind {
MclmcTrajectoryKind::Microcanonical => KineticEnergyKind::Microcanonical,
MclmcTrajectoryKind::Euclidean
| MclmcTrajectoryKind::EuclideanEarlyThenMicrocanonical => KineticEnergyKind::Euclidean,
};
let mut hamiltonian = TransformedHamiltonian::new(&mut math, mass_matrix, initial_kind);
hamiltonian.set_momentum_decoherence_length(Some(self.momentum_decoherence_length));
let switch_draw = (self.trajectory_switch_fraction * self.num_tune as f64) as u64;
let rng = ChaCha8Rng::try_from_rng(rng).expect("Could not seed rng");
let stats_options = self.stats_options::<M>();
MclmcChain::new(
math,
hamiltonian,
strategy,
rng,
chain,
self.subsample_frequency,
self.dynamic_step_size,
self.trajectory_kind,
switch_draw,
self.max_energy_error,
stats_options,
)
}
fn hint_num_tune(&self) -> usize {
usize_hint(self.num_tune, "num_tune")
}
fn hint_num_draws(&self) -> usize {
usize_hint(self.num_draws, "num_draws")
}
fn num_chains(&self) -> usize {
self.num_chains
}
fn seed(&self) -> u64 {
self.seed
}
fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
StatOptions {
adapt: GlobalStrategyStatsOptions {
step_size: (),
mass_matrix: (),
},
hamiltonian: -1,
point: {
let store_gradient = self.store_gradient;
let store_unconstrained = self.store_unconstrained;
let store_transformed = self.store_transformed;
TransformedPointStatsOptions {
store_gradient,
store_unconstrained,
store_transformed,
}
},
divergence: crate::dynamics::DivergenceStatsOptions {
store_divergences: self.store_divergences,
},
}
}
fn sampler_name(&self) -> &'static str {
"mclmc"
}
fn adaptation_name(&self) -> &'static str {
"diagonal"
}
}
fn default_nuts_settings<A: Debug + Copy + Default + Serialize>(
adapt_options: A,
num_tune: u64,
num_chains: usize,
max_energy_error: f64,
) -> NutsSettings<A> {
NutsSettings {
num_tune,
num_draws: 1000,
maxdepth: 10,
mindepth: 0,
max_energy_error,
store_gradient: false,
store_unconstrained: false,
store_transformed: false,
store_divergences: false,
adapt_options,
check_turning: true,
seed: 0,
num_chains,
target_integration_time: None,
trajectory_kind: KineticEnergyKind::Euclidean,
extra_doublings: 0,
}
}
impl Settings for LowRankMclmcSettings {
type Chain<M: Math> = LowRankMclmcChain<M>;
fn new_chain<M: Math, R: Rng + ?Sized>(
&self,
chain: u64,
mut math: M,
rng: &mut R,
) -> Self::Chain<M> {
use crate::dynamics::KineticEnergyKind;
use crate::mclmc::MclmcChain;
use crate::stepsize::StepSizeAdaptMethod;
let num_tune = self.num_tune;
let mut adapt_options = self.adapt_options;
adapt_options.step_size_settings.adapt_options.method =
StepSizeAdaptMethod::Fixed(self.step_size);
let strategy = GlobalStrategy::<M, LowRankMassMatrixStrategy>::new(
&mut math,
adapt_options,
num_tune,
chain,
);
let mass_matrix = LowRankMassMatrix::new(&mut math, self.adapt_options.mass_matrix_options);
let initial_kind = match self.trajectory_kind {
MclmcTrajectoryKind::Microcanonical => KineticEnergyKind::Microcanonical,
MclmcTrajectoryKind::Euclidean
| MclmcTrajectoryKind::EuclideanEarlyThenMicrocanonical => KineticEnergyKind::Euclidean,
};
let mut hamiltonian = TransformedHamiltonian::new(&mut math, mass_matrix, initial_kind);
hamiltonian.set_momentum_decoherence_length(Some(self.momentum_decoherence_length));
let switch_draw = (self.trajectory_switch_fraction * self.num_tune as f64) as u64;
let rng = ChaCha8Rng::try_from_rng(rng).expect("Could not seed rng");
let stats_options = self.stats_options::<M>();
MclmcChain::new(
math,
hamiltonian,
strategy,
rng,
chain,
self.subsample_frequency,
self.dynamic_step_size,
self.trajectory_kind,
switch_draw,
self.max_energy_error,
stats_options,
)
}
fn hint_num_tune(&self) -> usize {
usize_hint(self.num_tune, "num_tune")
}
fn hint_num_draws(&self) -> usize {
usize_hint(self.num_draws, "num_draws")
}
fn num_chains(&self) -> usize {
self.num_chains
}
fn seed(&self) -> u64 {
self.seed
}
fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
StatOptions {
adapt: GlobalStrategyStatsOptions {
step_size: (),
mass_matrix: (),
},
hamiltonian: -1,
point: {
let store_gradient = self.store_gradient;
let store_unconstrained = self.store_unconstrained;
let store_transformed = self.store_transformed;
TransformedPointStatsOptions {
store_gradient,
store_unconstrained,
store_transformed,
}
},
divergence: crate::dynamics::DivergenceStatsOptions {
store_divergences: self.store_divergences,
},
}
}
fn sampler_name(&self) -> &'static str {
"mclmc"
}
fn adaptation_name(&self) -> &'static str {
"low_rank"
}
}
impl Default for DiagNutsSettings {
fn default() -> Self {
default_nuts_settings(EuclideanAdaptOptions::default(), 400, 6, 1000.0)
}
}
impl Default for LowRankNutsSettings {
fn default() -> Self {
let mut vals = default_nuts_settings(EuclideanAdaptOptions::default(), 800, 6, 1000.0);
vals.adapt_options.mass_matrix_update_freq = 20;
vals
}
}
impl Default for FlowNutsSettings {
fn default() -> Self {
default_nuts_settings(FlowSettings::default(), 1500, 1, 20.0)
}
}
type DiagNutsChain<M> = NutsChain<M, ChaCha8Rng, GlobalStrategy<M, DiagAdaptStrategy<M>>>;
type LowRankNutsChain<M> = NutsChain<M, ChaCha8Rng, GlobalStrategy<M, LowRankMassMatrixStrategy>>;
fn nuts_options(settings: &NutsSettings<impl Debug + Copy + Default + Serialize>) -> NutsOptions {
NutsOptions {
maxdepth: settings.maxdepth,
mindepth: settings.mindepth,
store_divergences: settings.store_divergences,
check_turning: settings.check_turning,
target_integration_time: settings.target_integration_time,
extra_doublings: settings.extra_doublings,
max_energy_error: settings.max_energy_error,
}
}
impl Settings for LowRankNutsSettings {
type Chain<M: Math> = LowRankNutsChain<M>;
fn new_chain<M: Math, R: Rng + ?Sized>(
&self,
chain: u64,
mut math: M,
mut rng: &mut R,
) -> Self::Chain<M> {
let num_tune = self.num_tune;
let strategy = GlobalStrategy::new(&mut math, self.adapt_options, num_tune, chain);
let mass_matrix = LowRankMassMatrix::new(&mut math, self.adapt_options.mass_matrix_options);
let hamiltonian = TransformedHamiltonian::new(&mut math, mass_matrix, self.trajectory_kind);
let options = nuts_options(self);
let rng = ChaCha8Rng::try_from_rng(&mut rng).expect("Could not seed rng");
NutsChain::new(
math,
hamiltonian,
strategy,
options,
rng,
chain,
self.stats_options(),
)
}
fn hint_num_tune(&self) -> usize {
usize_hint(self.num_tune, "num_tune")
}
fn hint_num_draws(&self) -> usize {
usize_hint(self.num_draws, "num_draws")
}
fn num_chains(&self) -> usize {
self.num_chains
}
fn seed(&self) -> u64 {
self.seed
}
fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
StatOptions {
adapt: GlobalStrategyStatsOptions {
mass_matrix: (),
step_size: (),
},
hamiltonian: -1,
point: {
let store_gradient = self.store_gradient;
let store_unconstrained = self.store_unconstrained;
let store_transformed = self.store_transformed;
TransformedPointStatsOptions {
store_gradient,
store_unconstrained,
store_transformed,
}
},
divergence: crate::dynamics::DivergenceStatsOptions {
store_divergences: self.store_divergences,
},
}
}
fn sampler_name(&self) -> &'static str {
"nuts"
}
fn adaptation_name(&self) -> &'static str {
"low_rank"
}
}
impl Settings for DiagNutsSettings {
type Chain<M: Math> = DiagNutsChain<M>;
fn new_chain<M: Math, R: Rng + ?Sized>(
&self,
chain: u64,
mut math: M,
mut rng: &mut R,
) -> Self::Chain<M> {
let num_tune = self.num_tune;
let strategy = GlobalStrategy::new(&mut math, self.adapt_options, num_tune, chain);
let mass_matrix = DiagMassMatrix::new(
&mut math,
self.adapt_options.mass_matrix_options.store_mass_matrix,
);
let potential = TransformedHamiltonian::new(&mut math, mass_matrix, self.trajectory_kind);
let options = nuts_options(self);
let rng = ChaCha8Rng::try_from_rng(&mut rng).expect("Could not seed rng");
NutsChain::new(
math,
potential,
strategy,
options,
rng,
chain,
self.stats_options(),
)
}
fn hint_num_tune(&self) -> usize {
usize_hint(self.num_tune, "num_tune")
}
fn hint_num_draws(&self) -> usize {
usize_hint(self.num_draws, "num_draws")
}
fn num_chains(&self) -> usize {
self.num_chains
}
fn seed(&self) -> u64 {
self.seed
}
fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
StatOptions {
adapt: GlobalStrategyStatsOptions {
mass_matrix: (),
step_size: (),
},
hamiltonian: -1,
point: {
let store_gradient = self.store_gradient;
let store_unconstrained = self.store_unconstrained;
let store_transformed = self.store_transformed;
TransformedPointStatsOptions {
store_gradient,
store_unconstrained,
store_transformed,
}
},
divergence: crate::dynamics::DivergenceStatsOptions {
store_divergences: self.store_divergences,
},
}
}
fn sampler_name(&self) -> &'static str {
"nuts"
}
fn adaptation_name(&self) -> &'static str {
"diagonal"
}
}
impl Settings for FlowNutsSettings {
type Chain<M: Math> = NutsChain<M, ChaCha8Rng, ExternalTransformAdaptation>;
fn new_chain<M: Math, R: Rng + ?Sized>(
&self,
chain: u64,
mut math: M,
mut rng: &mut R,
) -> Self::Chain<M> {
let num_tune = self.num_tune;
let strategy =
ExternalTransformAdaptation::new(&mut math, self.adapt_options, num_tune, chain);
let params = math
.new_transformation(rng, math.dim(), chain)
.expect("Failed to create external transformation");
let transform = ExternalTransformation::new(params);
let hamiltonian = TransformedHamiltonian::new(&mut math, transform, self.trajectory_kind);
let options = nuts_options(self);
let rng = ChaCha8Rng::try_from_rng(&mut rng).expect("Could not seed rng");
NutsChain::new(
math,
hamiltonian,
strategy,
options,
rng,
chain,
self.stats_options(),
)
}
fn hint_num_tune(&self) -> usize {
usize_hint(self.num_tune, "num_tune")
}
fn hint_num_draws(&self) -> usize {
usize_hint(self.num_draws, "num_draws")
}
fn num_chains(&self) -> usize {
self.num_chains
}
fn seed(&self) -> u64 {
self.seed
}
fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
StatOptions {
adapt: (),
hamiltonian: (),
point: {
let store_gradient = self.store_gradient;
let store_unconstrained = self.store_unconstrained;
let store_transformed = self.store_transformed;
TransformedPointStatsOptions {
store_gradient,
store_unconstrained,
store_transformed,
}
},
divergence: crate::dynamics::DivergenceStatsOptions {
store_divergences: self.store_divergences,
},
}
}
fn sampler_name(&self) -> &'static str {
"nuts"
}
fn adaptation_name(&self) -> &'static str {
"flow"
}
}
impl Settings for FlowMclmcSettings {
type Chain<M: Math> = crate::mclmc::MclmcChain<
M,
ChaCha8Rng,
ExternalTransformAdaptation,
ExternalTransformation<M>,
>;
fn new_chain<M: Math, R: Rng + ?Sized>(
&self,
chain: u64,
mut math: M,
rng: &mut R,
) -> Self::Chain<M> {
use crate::dynamics::KineticEnergyKind;
use crate::mclmc::MclmcChain;
let num_tune = self.num_tune;
let strategy =
ExternalTransformAdaptation::new(&mut math, self.adapt_options, num_tune, chain);
let params = math
.new_transformation(rng, math.dim(), chain)
.expect("Failed to create external transformation");
let transform = ExternalTransformation::new(params);
let initial_kind = match self.trajectory_kind {
MclmcTrajectoryKind::Microcanonical => KineticEnergyKind::Microcanonical,
MclmcTrajectoryKind::Euclidean
| MclmcTrajectoryKind::EuclideanEarlyThenMicrocanonical => KineticEnergyKind::Euclidean,
};
let mut hamiltonian = TransformedHamiltonian::new(&mut math, transform, initial_kind);
hamiltonian.set_momentum_decoherence_length(Some(self.momentum_decoherence_length));
let switch_draw = (self.trajectory_switch_fraction * self.num_tune as f64) as u64;
let rng = ChaCha8Rng::try_from_rng(rng).expect("Could not seed rng");
let stats_options = self.stats_options::<M>();
MclmcChain::new(
math,
hamiltonian,
strategy,
rng,
chain,
self.subsample_frequency,
self.dynamic_step_size,
self.trajectory_kind,
switch_draw,
self.max_energy_error,
stats_options,
)
}
fn hint_num_tune(&self) -> usize {
usize_hint(self.num_tune, "num_tune")
}
fn hint_num_draws(&self) -> usize {
usize_hint(self.num_draws, "num_draws")
}
fn num_chains(&self) -> usize {
self.num_chains
}
fn seed(&self) -> u64 {
self.seed
}
fn stats_options<M: Math>(&self) -> <Self::Chain<M> as SamplerStats<M>>::StatsOptions {
StatOptions {
adapt: (),
hamiltonian: (),
point: {
let store_gradient = self.store_gradient;
let store_unconstrained = self.store_unconstrained;
let store_transformed = self.store_transformed;
TransformedPointStatsOptions {
store_gradient,
store_unconstrained,
store_transformed,
}
},
divergence: crate::dynamics::DivergenceStatsOptions {
store_divergences: self.store_divergences,
},
}
}
fn sampler_name(&self) -> &'static str {
"mclmc"
}
fn adaptation_name(&self) -> &'static str {
"flow"
}
}
pub fn sample_sequentially<'math, M: Math + 'math, R: Rng + ?Sized>(
math: M,
settings: DiagNutsSettings,
start: &[f64],
draws: u64,
chain: u64,
rng: &mut R,
) -> Result<impl Iterator<Item = Result<(Box<[f64]>, Progress)>> + 'math> {
let mut sampler = settings.new_chain(chain, math, rng);
sampler.set_position(start)?;
Ok((0..draws).map(move |_| sampler.draw()))
}
#[non_exhaustive]
#[derive(Clone, Debug)]
pub struct ChainProgress {
pub finished_draws: usize,
pub total_draws: usize,
pub divergences: usize,
pub tuning: bool,
pub started: bool,
pub latest_num_steps: usize,
pub total_num_steps: usize,
pub step_size: f64,
pub runtime: Duration,
pub divergent_draws: Vec<usize>,
}
impl ChainProgress {
fn new(total: usize) -> Self {
Self {
finished_draws: 0,
total_draws: total,
divergences: 0,
tuning: true,
started: false,
latest_num_steps: 0,
step_size: 0f64,
total_num_steps: 0,
runtime: Duration::ZERO,
divergent_draws: Vec::new(),
}
}
fn update(&mut self, stats: &Progress, draw_duration: Duration) {
if stats.diverging & !stats.tuning {
self.divergences += 1;
self.divergent_draws.push(self.finished_draws);
}
self.finished_draws += 1;
self.tuning = stats.tuning;
self.latest_num_steps = stats.num_steps as usize;
self.total_num_steps += stats.num_steps as usize;
self.step_size = stats.step_size;
self.runtime += draw_duration;
}
}
#[cfg(feature = "parallel")]
enum ChainCommand {
Resume,
Pause,
}
#[cfg(feature = "parallel")]
struct ChainProcess<T>
where
T: TraceStorage,
{
stop_marker: Sender<ChainCommand>,
trace: Arc<Mutex<Option<T::ChainStorage>>>,
progress: Arc<Mutex<ChainProgress>>,
}
#[cfg(feature = "parallel")]
impl<T: TraceStorage> ChainProcess<T> {
fn finalize_many(trace: T, chains: Vec<Self>) -> Result<(Option<anyhow::Error>, T::Finalized)> {
let finalized_chain_traces = chains
.into_iter()
.filter_map(|chain| chain.trace.lock().expect("Poisoned lock").take())
.map(|chain| chain.finalize())
.collect_vec();
trace.finalize(finalized_chain_traces)
}
fn progress(&self) -> ChainProgress {
self.progress.lock().expect("Poisoned lock").clone()
}
fn resume(&self) -> Result<()> {
self.stop_marker.send(ChainCommand::Resume)?;
Ok(())
}
fn pause(&self) -> Result<()> {
self.stop_marker.send(ChainCommand::Pause)?;
Ok(())
}
fn start<'model, M: Model, S: Settings>(
model: &'model M,
chain_trace: T::ChainStorage,
chain_id: u64,
seed: u64,
settings: &'model S,
scope: &ScopeFifo<'model>,
results: Sender<Result<()>>,
) -> Result<Self> {
let (stop_marker_tx, stop_marker_rx) = channel();
let mut rng = ChaCha8Rng::seed_from_u64(seed);
rng.set_stream(chain_id + 1);
let chain_trace = Arc::new(Mutex::new(Some(chain_trace)));
let progress = Arc::new(Mutex::new(ChainProgress::new(
settings.hint_num_draws() + settings.hint_num_tune(),
)));
let trace_inner = chain_trace.clone();
let progress_inner = progress.clone();
scope.spawn_fifo(move |_| {
let chain_trace = trace_inner;
let progress = progress_inner;
let mut sample = move || {
let logp = model
.math(&mut rng)
.context("Failed to create model density")?;
let dim = logp.dim();
let mut sampler = settings.new_chain(chain_id, logp, &mut rng);
progress.lock().expect("Poisoned mutex").started = true;
let mut initval = vec![0f64; dim];
let mut error = None;
for _ in 0..500 {
model
.init_position(&mut rng, &mut initval)
.context("Failed to generate a new initial position")?;
if let Err(err) = sampler.set_position(&initval) {
error = Some(err);
continue;
}
error = None;
break;
}
if let Some(error) = error {
return Err(error.context("All initialization points failed"));
}
let draws = settings.hint_num_tune() + settings.hint_num_draws();
let mut msg = stop_marker_rx.try_recv();
let mut draw = 0;
loop {
match msg {
Err(TryRecvError::Disconnected) => {
break;
}
Err(TryRecvError::Empty) => {}
Ok(ChainCommand::Pause) => {
msg = stop_marker_rx.recv().map_err(|e| e.into());
continue;
}
Ok(ChainCommand::Resume) => {}
}
let now = Instant::now();
let (_point, mut draw_data, mut stats, info) = sampler.expanded_draw().unwrap();
let mut guard = chain_trace
.lock()
.expect("Could not unlock trace lock. Poisoned mutex");
let Some(trace_val) = guard.as_mut() else {
break;
};
progress
.lock()
.expect("Poisoned mutex")
.update(&info, now.elapsed());
let math = sampler.math();
let dims = StatsDims::from(math.deref());
trace_val.record_sample(
settings,
stats.get_all(&dims),
draw_data.get_all(math.deref()),
&info,
)?;
draw += 1;
if draw == draws {
break;
}
msg = stop_marker_rx.try_recv();
}
Ok(())
};
let result = sample();
let _ = results.send(result);
drop(results);
});
Ok(Self {
trace: chain_trace,
stop_marker: stop_marker_tx,
progress,
})
}
fn flush(&self) -> Result<()> {
self.trace
.lock()
.map_err(|_| anyhow::anyhow!("Could not lock trace mutex"))
.context("Could not flush trace")?
.as_mut()
.map(|v| v.flush())
.transpose()?;
Ok(())
}
}
#[cfg(feature = "parallel")]
#[derive(Debug)]
enum SamplerCommand {
Pause,
Continue,
Progress,
Flush,
Inspect,
}
#[cfg(feature = "parallel")]
enum SamplerResponse<T: Send + 'static> {
Ok(),
Progress(Box<[ChainProgress]>),
Inspect(T),
}
#[cfg(feature = "parallel")]
pub enum SamplerWaitResult<F: Send + 'static> {
Trace(F),
Timeout(Sampler<F>),
Err(anyhow::Error, Option<F>),
}
#[cfg(feature = "parallel")]
pub struct Sampler<F: Send + 'static> {
main_thread: JoinHandle<Result<(Option<anyhow::Error>, F)>>,
commands: SyncSender<SamplerCommand>,
responses: Receiver<SamplerResponse<(Option<anyhow::Error>, F)>>,
results: Receiver<Result<()>>,
}
#[cfg(feature = "parallel")]
pub struct ProgressCallback {
pub callback: Box<dyn FnMut(Duration, Box<[ChainProgress]>) + Send>,
pub rate: Duration,
}
#[cfg(feature = "parallel")]
impl<F: Send + 'static> Sampler<F> {
pub fn new<M, S, C, T>(
model: M,
settings: S,
trace_config: C,
num_cores: usize,
callback: Option<ProgressCallback>,
) -> Result<Self>
where
S: Settings,
C: StorageConfig<Storage = T>,
M: Model,
T: TraceStorage<Finalized = F>,
{
let (commands_tx, commands_rx) = sync_channel(0);
let (responses_tx, responses_rx) = sync_channel(0);
let (results_tx, results_rx) = channel();
let main_thread = spawn(move || {
let pool = ThreadPoolBuilder::new()
.num_threads(num_cores + 1) .thread_name(|i| format!("nutpie-worker-{i}"))
.build()
.context("Could not start thread pool")?;
let settings_ref = &settings;
let model_ref = &model;
let mut callback = callback;
pool.scope_fifo(move |scope| {
let results = results_tx;
let mut chains = Vec::with_capacity(settings.num_chains());
let mut rng = ChaCha8Rng::seed_from_u64(settings.seed());
rng.set_stream(0);
let math = model_ref
.math(&mut rng)
.context("Could not create model density")?;
let trace = trace_config
.new_trace(settings_ref, &math)
.context("Could not create trace object")?;
drop(math);
for chain_id in 0..settings.num_chains() {
let chain_trace_val = trace
.initialize_trace_for_chain(chain_id as u64)
.context("Failed to create trace object")?;
let chain = ChainProcess::start(
model_ref,
chain_trace_val,
chain_id as u64,
settings.seed(),
settings_ref,
scope,
results.clone(),
);
chains.push(chain);
}
drop(results);
let (chains, errors): (Vec<_>, Vec<_>) = chains.into_iter().partition_result();
if let Some(error) = errors.into_iter().next() {
let _ = ChainProcess::finalize_many(trace, chains);
return Err(error).context("Could not start chains");
}
let mut main_loop = || {
let start_time = Instant::now();
let mut pause_start = Instant::now();
let mut pause_time = Duration::ZERO;
let mut progress_rate = Duration::MAX;
if let Some(ProgressCallback { callback, rate }) = &mut callback {
let progress = chains.iter().map(|chain| chain.progress()).collect_vec();
callback(start_time.elapsed(), progress.into());
progress_rate = *rate;
}
let mut last_progress = Instant::now();
let mut is_paused = false;
loop {
let timeout = progress_rate.checked_sub(last_progress.elapsed());
let timeout = timeout.unwrap_or_else(|| {
if let Some(ProgressCallback { callback, .. }) = &mut callback {
let progress =
chains.iter().map(|chain| chain.progress()).collect_vec();
let mut elapsed = start_time.elapsed().saturating_sub(pause_time);
if is_paused {
elapsed = elapsed.saturating_sub(pause_start.elapsed());
}
callback(elapsed, progress.into());
}
last_progress = Instant::now();
progress_rate
});
match commands_rx.recv_timeout(timeout) {
Ok(SamplerCommand::Pause) => {
for chain in chains.iter() {
let _ = chain.pause();
}
if !is_paused {
pause_start = Instant::now();
}
is_paused = true;
responses_tx.send(SamplerResponse::Ok()).map_err(|e| {
anyhow::anyhow!(
"Could not send pause response to controller thread: {e}"
)
})?;
}
Ok(SamplerCommand::Continue) => {
for chain in chains.iter() {
let _ = chain.resume();
}
pause_time += pause_start.elapsed();
is_paused = false;
responses_tx.send(SamplerResponse::Ok()).map_err(|e| {
anyhow::anyhow!(
"Could not send continue response to controller thread: {e}"
)
})?;
}
Ok(SamplerCommand::Progress) => {
let progress =
chains.iter().map(|chain| chain.progress()).collect_vec();
responses_tx.send(SamplerResponse::Progress(progress.into())).map_err(|e| {
anyhow::anyhow!(
"Could not send progress response to controller thread: {e}"
)
})?;
}
Ok(SamplerCommand::Inspect) => {
let traces = chains
.iter()
.filter_map(|chain| {
chain
.trace
.lock()
.expect("Poisoned lock")
.as_ref()
.map(|v| v.inspect())
})
.collect_vec();
let finalized_trace = trace.inspect(traces)?;
responses_tx.send(SamplerResponse::Inspect(finalized_trace)).map_err(|e| {
anyhow::anyhow!(
"Could not send inspect response to controller thread: {e}"
)
})?;
}
Ok(SamplerCommand::Flush) => {
for chain in chains.iter() {
chain.flush()?;
}
responses_tx.send(SamplerResponse::Ok()).map_err(|e| {
anyhow::anyhow!(
"Could not send flush response to controller thread: {e}"
)
})?;
}
Err(RecvTimeoutError::Timeout) => {}
Err(RecvTimeoutError::Disconnected) => {
if let Some(ProgressCallback { callback, .. }) = &mut callback {
let progress =
chains.iter().map(|chain| chain.progress()).collect_vec();
let mut elapsed =
start_time.elapsed().saturating_sub(pause_time);
if is_paused {
elapsed = elapsed.saturating_sub(pause_start.elapsed());
}
callback(elapsed, progress.into());
}
return Ok(());
}
};
}
};
let result: Result<()> = main_loop();
let output = ChainProcess::finalize_many(trace, chains)?;
result?;
Ok(output)
})
});
Ok(Self {
main_thread,
commands: commands_tx,
responses: responses_rx,
results: results_rx,
})
}
pub fn pause(&mut self) -> Result<()> {
self.commands
.send(SamplerCommand::Pause)
.context("Could not send pause command to controller thread")?;
let response = self
.responses
.recv()
.context("Could not recieve pause response from controller thread")?;
let SamplerResponse::Ok() = response else {
bail!("Got invalid response from sample controller thread");
};
Ok(())
}
pub fn resume(&mut self) -> Result<()> {
self.commands.send(SamplerCommand::Continue)?;
let response = self.responses.recv()?;
let SamplerResponse::Ok() = response else {
bail!("Got invalid response from sample controller thread");
};
Ok(())
}
pub fn flush(&mut self) -> Result<()> {
self.commands.send(SamplerCommand::Flush)?;
let response = self
.responses
.recv()
.context("Could not recieve flush response from controller thread")?;
let SamplerResponse::Ok() = response else {
bail!("Got invalid response from sample controller thread");
};
Ok(())
}
pub fn inspect(&mut self) -> Result<(Option<anyhow::Error>, F)> {
self.commands.send(SamplerCommand::Inspect)?;
let response = self
.responses
.recv()
.context("Could not recieve inspect response from controller thread")?;
let SamplerResponse::Inspect(trace) = response else {
bail!("Got invalid response from sample controller thread");
};
Ok(trace)
}
pub fn abort(self) -> Result<(Option<anyhow::Error>, F)> {
drop(self.commands);
let result = self.main_thread.join();
match result {
Err(payload) => std::panic::resume_unwind(payload),
Ok(Ok(val)) => Ok(val),
Ok(Err(err)) => Err(err),
}
}
pub fn wait_timeout(self, timeout: Duration) -> SamplerWaitResult<F> {
let start = Instant::now();
let mut remaining = Some(timeout);
while remaining.is_some() {
match self.results.recv_timeout(timeout) {
Ok(Ok(_)) => remaining = timeout.checked_sub(start.elapsed()),
Ok(Err(e)) => return SamplerWaitResult::Err(e, None),
Err(RecvTimeoutError::Disconnected) => match self.abort() {
Ok((Some(err), trace)) => return SamplerWaitResult::Err(err, Some(trace)),
Ok((None, trace)) => return SamplerWaitResult::Trace(trace),
Err(err) => return SamplerWaitResult::Err(err, None),
},
Err(RecvTimeoutError::Timeout) => break,
}
}
SamplerWaitResult::Timeout(self)
}
pub fn progress(&mut self) -> Result<Box<[ChainProgress]>> {
self.commands.send(SamplerCommand::Progress)?;
let response = self.responses.recv()?;
let SamplerResponse::Progress(progress) = response else {
bail!("Got invalid response from sample controller thread");
};
Ok(progress)
}
}
#[cfg(test)]
pub mod test_logps {
#[cfg(feature = "zarr")]
use crate::{Model, math::CpuLogpFunc, math::CpuMath};
#[cfg(feature = "zarr")]
use anyhow::Result;
#[cfg(feature = "zarr")]
use rand::Rng;
#[cfg(feature = "zarr")]
pub struct CpuModel<F> {
logp: F,
}
#[cfg(feature = "zarr")]
impl<F> CpuModel<F> {
pub fn new(logp: F) -> Self {
Self { logp }
}
}
#[cfg(feature = "zarr")]
impl<F> Model for CpuModel<F>
where
F: Send + Sync + 'static,
for<'a> &'a F: CpuLogpFunc,
{
type Math<'model> = CpuMath<&'model F>;
fn math<R: Rng + ?Sized>(&self, _rng: &mut R) -> Result<Self::Math<'_>> {
Ok(CpuMath::new(&self.logp))
}
fn init_position<R: rand::prelude::Rng + ?Sized>(
&self,
_rng: &mut R,
position: &mut [f64],
) -> Result<()> {
position.iter_mut().for_each(|x| *x = 0.);
Ok(())
}
}
}
#[cfg(test)]
mod tests {
use crate::math::test_logps::NormalLogp;
use crate::{
Chain, math::CpuMath, sample_sequentially, sampler::DiagMclmcSettings,
sampler::DiagNutsSettings, sampler::LowRankMclmcSettings, sampler::LowRankNutsSettings,
sampler::Settings,
};
#[cfg(feature = "zarr")]
use super::test_logps::CpuModel;
use anyhow::Result;
use itertools::Itertools;
use pretty_assertions::assert_eq;
use rand::{SeedableRng, rngs::StdRng};
#[cfg(feature = "zarr")]
use std::{
sync::Arc,
time::{Duration, Instant},
};
#[cfg(feature = "zarr")]
use crate::{Sampler, ZarrConfig};
#[cfg(feature = "zarr")]
use zarrs::storage::store::MemoryStore;
fn assert_settings_smoke<S: Settings>(settings: S) -> Result<()> {
let logp = NormalLogp { dim: 4, mu: 0.1 };
let math = CpuMath::new(&logp);
let mut rng = StdRng::seed_from_u64(42);
let stat_names = settings.stat_names(&math);
let stat_types = settings.stat_types(&math);
assert!(!stat_names.is_empty());
assert_eq!(stat_names.len(), stat_types.len());
let mut chain = settings.new_chain(0, math, &mut rng);
chain.set_position(&vec![0.2; 4])?;
let (_draw, _info) = chain.draw()?;
Ok(())
}
#[test]
fn all_settings_smoke() -> Result<()> {
assert_settings_smoke(DiagNutsSettings {
num_tune: 10,
num_draws: 10,
..Default::default()
})?;
assert_settings_smoke(LowRankNutsSettings {
num_tune: 10,
num_draws: 10,
..Default::default()
})?;
assert_settings_smoke(DiagMclmcSettings {
num_tune: 10,
num_draws: 10,
..Default::default()
})?;
assert_settings_smoke(LowRankMclmcSettings {
num_tune: 10,
num_draws: 10,
..Default::default()
})?;
Ok(())
}
#[test]
fn sample_chain() -> Result<()> {
let logp = NormalLogp { dim: 10, mu: 0.1 };
let math = CpuMath::new(&logp);
let settings = DiagNutsSettings {
num_tune: 100,
num_draws: 100,
..Default::default()
};
let start = vec![0.2; 10];
let mut rng = StdRng::seed_from_u64(42);
let mut chain = settings.new_chain(0, math, &mut rng);
let (_draw, info) = chain.draw()?;
assert!(info.tuning);
assert_eq!(info.draw, 0);
let math = CpuMath::new(&logp);
let chain = sample_sequentially(math, settings, &start, 200, 1, &mut rng).unwrap();
let mut draws = chain.collect_vec();
assert_eq!(draws.len(), 200);
let draw0 = draws.remove(100).unwrap();
let (vals, stats) = draw0;
assert_eq!(vals.len(), 10);
assert_eq!(stats.chain, 1);
assert_eq!(stats.draw, 100);
Ok(())
}
#[cfg(feature = "zarr")]
#[test]
fn sample_parallel() -> Result<()> {
let logp = NormalLogp { dim: 100, mu: 0.1 };
let settings = DiagNutsSettings {
num_tune: 100,
num_draws: 100,
seed: 10,
..Default::default()
};
let model = CpuModel::new(logp.clone());
let store = MemoryStore::new();
let zarr_config = ZarrConfig::new(Arc::new(store));
let mut sampler = Sampler::new(model, settings, zarr_config, 4, None)?;
sampler.pause()?;
sampler.pause()?;
sampler.resume()?;
let (ok, _) = sampler.abort()?;
if let Some(err) = ok {
Err(err)?;
}
let store = MemoryStore::new();
let zarr_config = ZarrConfig::new(Arc::new(store));
let model = CpuModel::new(logp.clone());
let mut sampler = Sampler::new(model, settings, zarr_config, 4, None)?;
sampler.pause()?;
if let (Some(err), _) = sampler.abort()? {
Err(err)?;
}
let store = MemoryStore::new();
let zarr_config = ZarrConfig::new(Arc::new(store));
let model = CpuModel::new(logp.clone());
let start = Instant::now();
let sampler = Sampler::new(model, settings, zarr_config, 4, None)?;
let mut sampler = match sampler.wait_timeout(Duration::from_nanos(100)) {
super::SamplerWaitResult::Trace(_) => {
dbg!(start.elapsed());
panic!("finished");
}
super::SamplerWaitResult::Timeout(sampler) => sampler,
super::SamplerWaitResult::Err(_, _) => {
panic!("error")
}
};
for _ in 0..30 {
sampler.progress()?;
}
match sampler.wait_timeout(Duration::from_secs(1)) {
super::SamplerWaitResult::Trace(_) => {
dbg!(start.elapsed());
}
super::SamplerWaitResult::Timeout(_) => {
panic!("timeout")
}
super::SamplerWaitResult::Err(err, _) => Err(err)?,
};
Ok(())
}
#[test]
fn sample_seq() {
let logp = NormalLogp { dim: 10, mu: 0.1 };
let math = CpuMath::new(&logp);
let settings = DiagNutsSettings {
num_tune: 100,
num_draws: 100,
..Default::default()
};
let start = vec![0.2; 10];
let mut rng = StdRng::seed_from_u64(42);
let chain = sample_sequentially(math, settings, &start, 200, 1, &mut rng).unwrap();
let mut draws = chain.collect_vec();
assert_eq!(draws.len(), 200);
let draw0 = draws.remove(100).unwrap();
let (vals, stats) = draw0;
assert_eq!(vals.len(), 10);
assert_eq!(stats.chain, 1);
assert_eq!(stats.draw, 100);
}
}