nuts-rs 0.18.0

Sample from unnormalized densities using Hamiltonian MCMC
Documentation
use std::{
    sync::Arc,
    time::{Duration, Instant},
};

use anyhow::Context;
use nuts_rs::{
    CpuLogpFunc, CpuMath, DiagAdaptExpSettings, DiagNutsSettings, EuclideanAdaptOptions, LogpError,
    LowRankNutsSettings, Model, Sampler, SamplerWaitResult, ZarrConfig,
};
use nuts_storable::HasDims;
use rand::prelude::Rng;
use rand_distr::{Distribution, StandardNormal};
use thiserror::Error;
use zarrs::{
    array::{Array, ArraySubset},
    storage::{ReadableListableStorageTraits, store::MemoryStore},
};

struct NormalLogp<'a> {
    dim: usize,
    mu: &'a [f64],
}

#[derive(Error, Debug)]
enum NormalLogpError {}

impl LogpError for NormalLogpError {
    fn is_recoverable(&self) -> bool {
        true
    }
}

impl HasDims for NormalLogp<'_> {
    fn dim_sizes(&self) -> std::collections::HashMap<String, u64> {
        std::collections::HashMap::from([
            ("unconstrained_parameter".to_string(), self.dim as u64),
            ("dim".to_string(), self.dim as u64),
        ])
    }
}

impl<'a> CpuLogpFunc for NormalLogp<'a> {
    type LogpError = NormalLogpError;
    type FlowParameters = ();
    type ExpandedVector = Vec<f64>;

    fn dim(&self) -> usize {
        self.dim
    }

    fn logp(&mut self, position: &[f64], grad: &mut [f64]) -> Result<f64, Self::LogpError> {
        let n = position.len();
        assert!(grad.len() == n);
        let mut logp = 0f64;

        position
            .iter()
            .zip(self.mu.iter())
            .zip(grad.iter_mut())
            .for_each(|((&p, &mu), grad)| {
                let diff = p - mu;
                logp -= diff * diff / 2.;
                *grad = -diff;
            });
        Ok(logp)
    }

    fn expand_vector<R>(
        &mut self,
        _rng: &mut R,
        array: &[f64],
    ) -> Result<Self::ExpandedVector, nuts_rs::CpuMathError>
    where
        R: rand::Rng + ?Sized,
    {
        Ok(array.to_vec())
    }
}

struct NormalModel {
    mu: Box<[f64]>,
}

impl NormalModel {
    fn new(mu: Box<[f64]>) -> Self {
        NormalModel { mu }
    }
}

impl Model for NormalModel {
    type Math<'model>
        = CpuMath<NormalLogp<'model>>
    where
        Self: 'model;

    fn math<R: Rng + ?Sized>(&self, _rng: &mut R) -> anyhow::Result<Self::Math<'_>> {
        Ok(CpuMath::new(NormalLogp {
            dim: self.mu.len(),
            mu: &self.mu,
        }))
    }

    fn init_position<R: Rng + ?Sized>(
        &self,
        rng: &mut R,
        position: &mut [f64],
    ) -> anyhow::Result<()> {
        let normal = StandardNormal;
        position.iter_mut().for_each(|x| *x = normal.sample(rng));
        Ok(())
    }
}

fn sample() -> anyhow::Result<Arc<MemoryStore>> {
    let mu = vec![0.5; 100];
    let model = NormalModel::new(mu.into());
    let settings = DiagNutsSettings {
        seed: 42,
        num_chains: 6,
        ..Default::default()
    };

    let store = Arc::new(MemoryStore::new());
    let trace_config = ZarrConfig::new(store.clone());
    let mut sampler = Sampler::new(model, settings, trace_config, 6, None)?;

    let _ = loop {
        match sampler.wait_timeout(Duration::from_secs(1)) {
            SamplerWaitResult::Trace(trace) => break trace,
            SamplerWaitResult::Timeout(new_sampler) => sampler = new_sampler,
            SamplerWaitResult::Err(err, _trace) => return Err(err),
        };
    };
    Ok(store)
}

fn sample_debug_stats() -> anyhow::Result<Arc<dyn ReadableListableStorageTraits>> {
    let mu = vec![0.5; 100];
    let model = NormalModel::new(mu.into());
    let settings = DiagNutsSettings {
        seed: 42,
        num_chains: 6,
        store_gradient: true,
        store_divergences: true,
        store_unconstrained: true,
        adapt_options: EuclideanAdaptOptions {
            mass_matrix_options: DiagAdaptExpSettings {
                store_mass_matrix: true,
                use_grad_based_estimate: true,
                ..Default::default()
            },
            ..Default::default()
        },
        ..Default::default()
    };

    let store = Arc::new(MemoryStore::new());
    let trace_config = ZarrConfig::new(store.clone());
    let mut sampler = Sampler::new(model, settings, trace_config, 6, None)?;

    let _ = loop {
        match sampler.wait_timeout(Duration::from_secs(1)) {
            SamplerWaitResult::Trace(trace) => break trace,
            SamplerWaitResult::Timeout(new_sampler) => sampler = new_sampler,
            SamplerWaitResult::Err(err, _trace) => return Err(err),
        };
    };

    let store_dyn: Arc<dyn ReadableListableStorageTraits> = store.clone();

    let diverging = Array::open(store_dyn.clone(), "/sample_stats/diverging")
        .context("Could not read diverging array")?;
    assert!(
        diverging.dimension_names().as_ref().unwrap()
            == &[Some("chain".to_string()), Some("draw".to_string())]
    );

    let _: Vec<bool> = diverging
        .retrieve_array_subset(&ArraySubset::new_with_shape(diverging.shape().to_vec()))?;

    let logp = Array::open(store_dyn.clone(), "/sample_stats/logp")
        .context("Could not read logp array")?;
    assert!(
        logp.dimension_names().as_ref().unwrap()
            == &[Some("chain".to_string()), Some("draw".to_string())]
    );
    let _: Vec<f64> =
        logp.retrieve_array_subset(&ArraySubset::new_with_shape(logp.shape().to_vec()))?;

    Ok(store)
}

fn sample_eigs_debug_stats() -> anyhow::Result<Arc<MemoryStore>> {
    let mu = vec![0.5; 10];
    let model = NormalModel::new(mu.into());
    let settings = LowRankNutsSettings {
        seed: 42,
        num_chains: 1,
        num_tune: 200,
        num_draws: 10,
        store_gradient: true,
        store_divergences: true,
        store_unconstrained: true,
        adapt_options: EuclideanAdaptOptions {
            mass_matrix_options: nuts_rs::LowRankSettings {
                store_mass_matrix: false,
                ..Default::default()
            },
            ..Default::default()
        },
        ..Default::default()
    };

    let store = Arc::new(MemoryStore::new());
    let trace_config = ZarrConfig::new(store.clone());
    let mut sampler = Sampler::new(model, settings, trace_config, 1, None)?;

    let _trace = loop {
        match sampler.wait_timeout(Duration::from_secs(1)) {
            SamplerWaitResult::Trace(trace) => break trace,
            SamplerWaitResult::Timeout(new_sampler) => sampler = new_sampler,
            SamplerWaitResult::Err(err, _trace) => return Err(err),
        };
    };

    Ok(store)
}

#[test]
fn run() -> anyhow::Result<()> {
    let start = Instant::now();
    let _ = sample()?;
    dbg!(start.elapsed());
    Ok(())
}

#[test]
fn run_debug_stats() -> anyhow::Result<()> {
    let start = Instant::now();
    let _ = sample_debug_stats()?;
    dbg!(start.elapsed());
    Ok(())
}

#[test]
fn run_debug_stats_eigs() -> anyhow::Result<()> {
    let start = Instant::now();
    let _ = sample_eigs_debug_stats()?;
    dbg!(start.elapsed());
    Ok(())
}