pub struct NUTS<T, B, GTarget>where
T: Float + ElementConversion + Element + SampleUniform + FromPrimitive,
B: AutodiffBackend,
GTarget: GradientTarget<T, B> + Sync,
StandardNormal: Distribution<T>,
StandardUniform: Distribution<T>,
Exp1: Distribution<T>,{ /* private fields */ }Expand description
No-U-Turn Sampler (NUTS).
Encapsulates multiple independent Markov chains using the NUTS algorithm. Utilizes dual-averaging step size adaptation and dynamic trajectory lengths to efficiently explore complex posterior geometries. Chains are executed concurrently via Rayon, each evolving independently.
§Type Parameters
T: Floating-point type for numerical calculations.B: Autodiff backend from theburncrate.GTarget: Target distribution type implementing theGradientTargettrait.
Implementations§
Source§impl<T, B, GTarget> NUTS<T, B, GTarget>where
T: Float + ElementConversion + Element + SampleUniform + FromPrimitive + Send,
B: AutodiffBackend + Send,
GTarget: GradientTarget<T, B> + Sync + Clone + Send,
StandardNormal: Distribution<T>,
StandardUniform: Distribution<T>,
Exp1: Distribution<T>,
impl<T, B, GTarget> NUTS<T, B, GTarget>where
T: Float + ElementConversion + Element + SampleUniform + FromPrimitive + Send,
B: AutodiffBackend + Send,
GTarget: GradientTarget<T, B> + Sync + Clone + Send,
StandardNormal: Distribution<T>,
StandardUniform: Distribution<T>,
Exp1: Distribution<T>,
Sourcepub fn new(
target: GTarget,
initial_positions: Vec<Vec<T>>,
target_accept_p: T,
) -> Self
pub fn new( target: GTarget, initial_positions: Vec<Vec<T>>, target_accept_p: T, ) -> Self
Creates a new NUTS sampler with the given target distribution and initial state for each chain.
§Parameters
target: The target distribution implementingGradientTarget.initial_positions: A vector of initial positions for each chain, shape[n_chains, D].target_accept_p: Desired average acceptance probability for the dual-averaging adaptation. Try values between 0.6 and 0.95.
§Returns
A newly initialized NUTS instance.
§Example
type B = Autodiff<NdArray>;
// Create a 2D Gaussian with mean [0,0] and identity covariance
let gauss = DiffableGaussian2D::new([0.0_f64, 0.0], [[1.0, 0.0], [0.0, 1.0]]);
// Initialize 3 chains in 2D at different starting points
let init_positions = vec![
vec![-1.0, -1.0],
vec![ 0.0, 0.0],
vec![ 1.0, 1.0],
];
// Build the sampler targeting 85% acceptance probability
let sampler: NUTS<f64, B, _> = NUTS::new(gauss, init_positions, 0.85);Sourcepub fn run(&mut self, n_collect: usize, n_discard: usize) -> Tensor<B, 3>
pub fn run(&mut self, n_collect: usize, n_discard: usize) -> Tensor<B, 3>
Runs all chains for a total of n_collect + n_discard steps and collects samples.
First discards n_discard warm-up steps for each chain (during which adaptation occurs),
then collects n_collect samples per chain.
§Parameters
n_collect: Number of samples to collect after warm-up per chain.n_discard: Number of warm-up (burn-in) steps to discard per chain.
§Returns
A 3D tensor of shape [n_chains, n_collect, D] containing the collected samples.
§Example
type B = Autodiff<NdArray>;
// As above, construct the sampler
let gauss = DiffableGaussian2D::new([0.0_f32, 0.0], [[1.0,0.0],[0.0,1.0]]);
let mut sampler = NUTS::new(gauss, init::<f32>(2, 2), 0.8);
// Discard 50 warm-up steps, then collect 150 observations per chain
let sample: Tensor<B, 3> = sampler.run(150, 50);
// sample.dims() == [2 chains, 150 observations, 2 dimensions]
assert_eq!(sample.dims(), [2, 150, 2]);Sourcepub fn run_progress(
&mut self,
n_collect: usize,
n_discard: usize,
) -> Result<(Tensor<B, 3>, RunStats), Box<dyn Error>>
pub fn run_progress( &mut self, n_collect: usize, n_discard: usize, ) -> Result<(Tensor<B, 3>, RunStats), Box<dyn Error>>
Run with live progress bars and collect summary stats.
Spawns a background thread to render per-chain and global bars,
then returns (samples, RunStats) when done.
§Example
use burn::backend::{Autodiff, NdArray};
use mini_mcmc::distributions::Rosenbrock2D;
use mini_mcmc::nuts::NUTS;
use mini_mcmc::core::init;
type B = Autodiff<NdArray>;
let target = Rosenbrock2D { a: 1.0, b: 100.0 };
let init = init::<f64>(4, 2); // 4 chains in 2D
let mut sampler = NUTS::<f64, B, Rosenbrock2D<f64>>::new(target, init, 0.9);
let (samples, stats) = sampler.run_progress(100, 20).unwrap();You can swap in any other GradientTarget just as easily.
Trait Implementations§
Source§impl<T, B, GTarget> Clone for NUTS<T, B, GTarget>where
T: Float + ElementConversion + Element + SampleUniform + FromPrimitive + Clone,
B: AutodiffBackend + Clone,
GTarget: GradientTarget<T, B> + Sync + Clone,
StandardNormal: Distribution<T>,
StandardUniform: Distribution<T>,
Exp1: Distribution<T>,
impl<T, B, GTarget> Clone for NUTS<T, B, GTarget>where
T: Float + ElementConversion + Element + SampleUniform + FromPrimitive + Clone,
B: AutodiffBackend + Clone,
GTarget: GradientTarget<T, B> + Sync + Clone,
StandardNormal: Distribution<T>,
StandardUniform: Distribution<T>,
Exp1: Distribution<T>,
Source§impl<T, B, GTarget> Debug for NUTS<T, B, GTarget>where
T: Float + ElementConversion + Element + SampleUniform + FromPrimitive + Debug,
B: AutodiffBackend + Debug,
GTarget: GradientTarget<T, B> + Sync + Debug,
StandardNormal: Distribution<T>,
StandardUniform: Distribution<T>,
Exp1: Distribution<T>,
impl<T, B, GTarget> Debug for NUTS<T, B, GTarget>where
T: Float + ElementConversion + Element + SampleUniform + FromPrimitive + Debug,
B: AutodiffBackend + Debug,
GTarget: GradientTarget<T, B> + Sync + Debug,
StandardNormal: Distribution<T>,
StandardUniform: Distribution<T>,
Exp1: Distribution<T>,
Auto Trait Implementations§
impl<T, B, GTarget> Freeze for NUTS<T, B, GTarget>
impl<T, B, GTarget> RefUnwindSafe for NUTS<T, B, GTarget>where
GTarget: RefUnwindSafe,
T: RefUnwindSafe,
<B as Backend>::FloatTensorPrimitive: RefUnwindSafe,
<B as Backend>::QuantizedTensorPrimitive: RefUnwindSafe,
impl<T, B, GTarget> Send for NUTS<T, B, GTarget>where
GTarget: Send,
impl<T, B, GTarget> Sync for NUTS<T, B, GTarget>
impl<T, B, GTarget> Unpin for NUTS<T, B, GTarget>where
GTarget: Unpin,
T: Unpin,
<B as Backend>::FloatTensorPrimitive: Unpin,
<B as Backend>::QuantizedTensorPrimitive: Unpin,
impl<T, B, GTarget> UnwindSafe for NUTS<T, B, GTarget>where
GTarget: UnwindSafe,
T: UnwindSafe,
<B as Backend>::FloatTensorPrimitive: UnwindSafe,
<B as Backend>::QuantizedTensorPrimitive: UnwindSafe,
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more