nuts_rs/lib.rs
1//! Sample from posterior distributions using the No U-turn Sampler (NUTS).
2//! For details see the original [NUTS paper](https://arxiv.org/abs/1111.4246)
3//! and the more recent [introduction](https://arxiv.org/abs/1701.02434).
4//!
5//! This crate was developed as a faster replacement of the sampler in PyMC,
6//! to be used with the new numba backend of PyTensor. The python wrapper
7//! for this sampler is [nutpie](https://github.com/pymc-devs/nutpie).
8//!
9//! ## Usage
10//!
11//! ```
12//! use nuts_rs::{CpuLogpFunc, CpuMath, LogpError, DiagGradNutsSettings, Chain, Progress,
13//! Settings};
14//! use thiserror::Error;
15//! use rand::thread_rng;
16//!
17//! // Define a function that computes the unnormalized posterior density
18//! // and its gradient.
19//! #[derive(Debug)]
20//! struct PosteriorDensity {}
21//!
22//! // The density might fail in a recoverable or non-recoverable manner...
23//! #[derive(Debug, Error)]
24//! enum PosteriorLogpError {}
25//! impl LogpError for PosteriorLogpError {
26//! fn is_recoverable(&self) -> bool { false }
27//! }
28//!
29//! impl CpuLogpFunc for PosteriorDensity {
30//! type LogpError = PosteriorLogpError;
31//!
32//! // Only used for transforming adaptation.
33//! type TransformParams = ();
34//!
35//! // We define a 10 dimensional normal distribution
36//! fn dim(&self) -> usize { 10 }
37//!
38//! // The normal likelihood with mean 3 and its gradient.
39//! fn logp(&mut self, position: &[f64], grad: &mut [f64]) -> Result<f64, Self::LogpError> {
40//! let mu = 3f64;
41//! let logp = position
42//! .iter()
43//! .copied()
44//! .zip(grad.iter_mut())
45//! .map(|(x, grad)| {
46//! let diff = x - mu;
47//! *grad = -diff;
48//! -diff * diff / 2f64
49//! })
50//! .sum();
51//! return Ok(logp)
52//! }
53//! }
54//!
55//! // We get the default sampler arguments
56//! let mut settings = DiagGradNutsSettings::default();
57//!
58//! // and modify as we like
59//! settings.num_tune = 1000;
60//! settings.maxdepth = 3; // small value just for testing...
61//!
62//! // We instanciate our posterior density function
63//! let logp_func = PosteriorDensity {};
64//! let math = CpuMath::new(logp_func);
65//!
66//! let chain = 0;
67//! let mut rng = thread_rng();
68//! let mut sampler = settings.new_chain(0, math, &mut rng);
69//!
70//! // Set to some initial position and start drawing samples.
71//! sampler.set_position(&vec![0f64; 10]).expect("Unrecoverable error during init");
72//! let mut trace = vec![]; // Collection of all draws
73//! for _ in 0..2000 {
74//! let (draw, info) = sampler.draw().expect("Unrecoverable error during sampling");
75//! trace.push(draw);
76//! }
77//! ```
78//!
79//! Users can also implement the `Model` trait for more control and parallel sampling.
80//!
81//! ## Implementation details
82//!
83//! This crate mostly follows the implementation of NUTS in [Stan](https://mc-stan.org) and
84//! [PyMC](https://docs.pymc.io/en/v3/), only tuning of mass matrix and step size differs
85//! somewhat.
86
87mod adapt_strategy;
88mod chain;
89mod cpu_math;
90mod euclidean_hamiltonian;
91mod hamiltonian;
92mod low_rank_mass_matrix;
93mod mass_matrix;
94mod mass_matrix_adapt;
95mod math;
96mod math_base;
97mod nuts;
98mod sampler;
99mod sampler_stats;
100mod state;
101mod stepsize;
102mod stepsize_adapt;
103mod transform_adapt_strategy;
104mod transformed_hamiltonian;
105
106pub use adapt_strategy::EuclideanAdaptOptions;
107pub use chain::Chain;
108pub use cpu_math::{CpuLogpFunc, CpuMath};
109pub use hamiltonian::DivergenceInfo;
110pub use math_base::{LogpError, Math};
111pub use nuts::NutsError;
112pub use sampler::{
113 sample_sequentially, ChainOutput, ChainProgress, DiagGradNutsSettings, DrawStorage,
114 LowRankNutsSettings, Model, NutsSettings, Progress, ProgressCallback, Sampler,
115 SamplerWaitResult, Settings, Trace, TransformedNutsSettings,
116};
117
118pub use low_rank_mass_matrix::LowRankSettings;
119pub use mass_matrix_adapt::DiagAdaptExpSettings;
120pub use stepsize_adapt::DualAverageSettings;
121pub use transform_adapt_strategy::TransformedSettings;