
Sample from posterior distributions using the No U-turn Sampler (NUTS).
For details see the original NUTS paper
and the more recent introduction.
This crate was developed as a faster replacement of the sampler in PyMC,
to be used with the new numba backend of PyTensor. The python wrapper
for this sampler is nutpie.
Usage
use nuts_rs::{CpuLogpFunc, CpuMath, LogpError, DiagGradNutsSettings, Chain, SampleStats,
Settings};
use thiserror::Error;
use rand::thread_rng;
#[derive(Debug)]
struct PosteriorDensity {}
#[derive(Debug, Error)]
enum PosteriorLogpError {}
impl LogpError for PosteriorLogpError {
fn is_recoverable(&self) -> bool { false }
}
impl CpuLogpFunc for PosteriorDensity {
type LogpError = PosteriorLogpError;
type TransformParams = ();
fn dim(&self) -> usize { 10 }
fn logp(&mut self, position: &[f64], grad: &mut [f64]) -> Result<f64, Self::LogpError> {
let mu = 3f64;
let logp = position
.iter()
.copied()
.zip(grad.iter_mut())
.map(|(x, grad)| {
let diff = x - mu;
*grad = -diff;
-diff * diff / 2f64
})
.sum();
return Ok(logp)
}
}
fn main() {
let mut settings = DiagGradNutsSettings::default();
settings.num_tune = 1000;
settings.maxdepth = 3;
let logp_func = PosteriorDensity {};
let math = CpuMath::new(logp_func);
let chain = 0;
let mut rng = thread_rng();
let mut sampler = settings.new_chain(0, math, &mut rng);
sampler
.set_position(&vec![0f64; 10])
.expect("Unrecoverable error during init");
let mut trace = vec![]; for _ in 0..2000 {
let (draw, info) = sampler.draw().expect("Unrecoverable error during sampling");
trace.push(draw.clone());
println!("Draw: {:?}", draw);
}
}
Users can also implement the Model
trait for more control and parallel sampling.
Implementation details
This crate mostly follows the implementation of NUTS in Stan and
PyMC, only tuning of mass matrix and step size differs
somewhat.