use std::{
sync::Arc,
time::{Duration, Instant},
};
use anyhow::Context;
use nuts_rs::{
Chain, CpuLogpFunc, CpuMath, DiagAdaptExpSettings, DiagNutsSettings, EuclideanAdaptOptions,
LogpError, LowRankNutsSettings, Model, Sampler, SamplerWaitResult, Settings, ZarrConfig,
};
use nuts_storable::HasDims;
use rand::SeedableRng;
use rand::prelude::Rng;
use rand_distr::{Distribution, StandardNormal};
use thiserror::Error;
use zarrs::{
array::{Array, ArraySubset},
storage::{ReadableListableStorageTraits, store::MemoryStore},
};
struct CorrelatedNormalLogp {
dim: usize,
mu: Vec<f64>,
prec_rank1_coeff: f64,
}
impl CorrelatedNormalLogp {
fn new(dim: usize, rank1_scale: f64) -> Self {
let prec_rank1_coeff = rank1_scale / (1.0 + rank1_scale * dim as f64);
Self {
dim,
mu: vec![0.0; dim],
prec_rank1_coeff,
}
}
}
#[derive(Error, Debug)]
enum CorrelatedNormalLogpError {}
impl LogpError for CorrelatedNormalLogpError {
fn is_recoverable(&self) -> bool {
true
}
}
impl HasDims for CorrelatedNormalLogp {
fn dim_sizes(&self) -> std::collections::HashMap<String, u64> {
std::collections::HashMap::from([
("unconstrained_parameter".to_string(), self.dim as u64),
("dim".to_string(), self.dim as u64),
])
}
}
impl CpuLogpFunc for CorrelatedNormalLogp {
type LogpError = CorrelatedNormalLogpError;
type FlowParameters = ();
type ExpandedVector = Vec<f64>;
fn dim(&self) -> usize {
self.dim
}
fn logp(&mut self, position: &[f64], grad: &mut [f64]) -> Result<f64, Self::LogpError> {
let n = position.len();
assert_eq!(grad.len(), n);
let sum_diff: f64 = position
.iter()
.zip(self.mu.iter())
.map(|(p, m)| p - m)
.sum();
let rank1_term = self.prec_rank1_coeff * sum_diff;
let mut logp = 0f64;
for i in 0..n {
let diff = position[i] - self.mu[i];
let prec_times_diff = diff - rank1_term;
grad[i] = -prec_times_diff;
logp -= 0.5 * diff * prec_times_diff;
}
Ok(logp)
}
fn expand_vector<R>(
&mut self,
_rng: &mut R,
array: &[f64],
) -> Result<Self::ExpandedVector, nuts_rs::CpuMathError>
where
R: rand::Rng + ?Sized,
{
Ok(array.to_vec())
}
}
struct NormalLogp<'a> {
dim: usize,
mu: &'a [f64],
}
#[derive(Error, Debug)]
enum NormalLogpError {}
impl LogpError for NormalLogpError {
fn is_recoverable(&self) -> bool {
true
}
}
impl HasDims for NormalLogp<'_> {
fn dim_sizes(&self) -> std::collections::HashMap<String, u64> {
std::collections::HashMap::from([
("unconstrained_parameter".to_string(), self.dim as u64),
("dim".to_string(), self.dim as u64),
])
}
}
impl<'a> CpuLogpFunc for NormalLogp<'a> {
type LogpError = NormalLogpError;
type FlowParameters = ();
type ExpandedVector = Vec<f64>;
fn dim(&self) -> usize {
self.dim
}
fn logp(&mut self, position: &[f64], grad: &mut [f64]) -> Result<f64, Self::LogpError> {
let n = position.len();
assert!(grad.len() == n);
let mut logp = 0f64;
position
.iter()
.zip(self.mu.iter())
.zip(grad.iter_mut())
.for_each(|((&p, &mu), grad)| {
let diff = p - mu;
logp -= diff * diff / 2.;
*grad = -diff;
});
Ok(logp)
}
fn expand_vector<R>(
&mut self,
_rng: &mut R,
array: &[f64],
) -> Result<Self::ExpandedVector, nuts_rs::CpuMathError>
where
R: rand::Rng + ?Sized,
{
Ok(array.to_vec())
}
}
struct NormalModel {
mu: Box<[f64]>,
}
impl NormalModel {
fn new(mu: Box<[f64]>) -> Self {
NormalModel { mu }
}
}
impl Model for NormalModel {
type Math<'model>
= CpuMath<NormalLogp<'model>>
where
Self: 'model;
fn math<R: Rng + ?Sized>(&self, _rng: &mut R) -> anyhow::Result<Self::Math<'_>> {
Ok(CpuMath::new(NormalLogp {
dim: self.mu.len(),
mu: &self.mu,
}))
}
fn init_position<R: Rng + ?Sized>(
&self,
rng: &mut R,
position: &mut [f64],
) -> anyhow::Result<()> {
let normal = StandardNormal;
position.iter_mut().for_each(|x| *x = normal.sample(rng));
Ok(())
}
}
fn sample() -> anyhow::Result<Arc<MemoryStore>> {
let mu = vec![0.5; 100];
let model = NormalModel::new(mu.into());
let settings = DiagNutsSettings {
seed: 42,
num_chains: 6,
..Default::default()
};
let store = Arc::new(MemoryStore::new());
let trace_config = ZarrConfig::new(store.clone());
let mut sampler = Sampler::new(model, settings, trace_config, 6, None)?;
let _ = loop {
match sampler.wait_timeout(Duration::from_secs(1)) {
SamplerWaitResult::Trace(trace) => break trace,
SamplerWaitResult::Timeout(new_sampler) => sampler = new_sampler,
SamplerWaitResult::Err(err, _trace) => return Err(err),
};
};
Ok(store)
}
fn sample_debug_stats() -> anyhow::Result<Arc<dyn ReadableListableStorageTraits>> {
let mu = vec![0.5; 100];
let model = NormalModel::new(mu.into());
let settings = DiagNutsSettings {
seed: 42,
num_chains: 6,
store_gradient: true,
store_divergences: true,
store_unconstrained: true,
adapt_options: EuclideanAdaptOptions {
mass_matrix_options: DiagAdaptExpSettings {
store_mass_matrix: true,
use_grad_based_estimate: true,
..Default::default()
},
..Default::default()
},
..Default::default()
};
let store = Arc::new(MemoryStore::new());
let trace_config = ZarrConfig::new(store.clone());
let mut sampler = Sampler::new(model, settings, trace_config, 6, None)?;
let _ = loop {
match sampler.wait_timeout(Duration::from_secs(1)) {
SamplerWaitResult::Trace(trace) => break trace,
SamplerWaitResult::Timeout(new_sampler) => sampler = new_sampler,
SamplerWaitResult::Err(err, _trace) => return Err(err),
};
};
let store_dyn: Arc<dyn ReadableListableStorageTraits> = store.clone();
let diverging = Array::open(store_dyn.clone(), "/sample_stats/diverging")
.context("Could not read diverging array")?;
assert!(
diverging.dimension_names().as_ref().unwrap()
== &[Some("chain".to_string()), Some("draw".to_string())]
);
let _: Vec<bool> = diverging
.retrieve_array_subset(&ArraySubset::new_with_shape(diverging.shape().to_vec()))?;
let logp = Array::open(store_dyn.clone(), "/sample_stats/logp")
.context("Could not read logp array")?;
assert!(
logp.dimension_names().as_ref().unwrap()
== &[Some("chain".to_string()), Some("draw".to_string())]
);
let _: Vec<f64> =
logp.retrieve_array_subset(&ArraySubset::new_with_shape(logp.shape().to_vec()))?;
Ok(store)
}
fn sample_eigs_debug_stats() -> anyhow::Result<Arc<MemoryStore>> {
let mu = vec![0.5; 10];
let model = NormalModel::new(mu.into());
let settings = LowRankNutsSettings {
seed: 42,
num_chains: 1,
num_tune: 200,
num_draws: 10,
store_gradient: true,
store_divergences: true,
store_unconstrained: true,
adapt_options: EuclideanAdaptOptions {
mass_matrix_options: nuts_rs::LowRankSettings {
store_mass_matrix: false,
..Default::default()
},
..Default::default()
},
..Default::default()
};
let store = Arc::new(MemoryStore::new());
let trace_config = ZarrConfig::new(store.clone());
let mut sampler = Sampler::new(model, settings, trace_config, 1, None)?;
let _trace = loop {
match sampler.wait_timeout(Duration::from_secs(1)) {
SamplerWaitResult::Trace(trace) => break trace,
SamplerWaitResult::Timeout(new_sampler) => sampler = new_sampler,
SamplerWaitResult::Err(err, _trace) => return Err(err),
};
};
Ok(store)
}
#[test]
fn low_rank_exact_gaussian() -> anyhow::Result<()> {
{
let dim = 10;
let logp = CorrelatedNormalLogp::new(dim, 0.5);
let math = CpuMath::new(logp);
let mut settings = LowRankNutsSettings {
num_tune: 500,
num_draws: 100,
num_chains: 1,
seed: 42,
..Default::default()
};
settings.adapt_options.mass_matrix_options.eigval_cutoff = 1.00001;
let mut rng = rand::rngs::StdRng::seed_from_u64(42);
let mut chain = settings.new_chain(0, math, &mut rng);
chain.set_position(&vec![1.0f64; dim])?;
for _ in 0..(settings.num_tune + settings.num_draws) {
let (_, _, stats, progress) = chain.expanded_draw()?;
if !progress.tuning {
assert!(
stats.point.fisher_distance < 1e-10,
"fisher_distance = {} should be ~0 after low-rank adaptation converges",
stats.point.fisher_distance,
);
}
}
Ok(())
}
}
#[test]
fn run() -> anyhow::Result<()> {
let start = Instant::now();
let _ = sample()?;
dbg!(start.elapsed());
Ok(())
}
#[test]
fn run_debug_stats() -> anyhow::Result<()> {
let start = Instant::now();
let _ = sample_debug_stats()?;
dbg!(start.elapsed());
Ok(())
}
#[test]
fn run_debug_stats_eigs() -> anyhow::Result<()> {
let start = Instant::now();
let _ = sample_eigs_debug_stats()?;
dbg!(start.elapsed());
Ok(())
}