pub struct NUTSChain<T, B, GTarget>where
B: AutodiffBackend,{
pub position: Tensor<B, 1>,
/* private fields */
}Expand description
Single-chain state and adaptation for NUTS.
Manages the dynamic trajectory building, dual-averaging adaptation of step size, and current position for one chain.
Fields§
§position: Tensor<B, 1>Current position in parameter space.
Implementations§
Source§impl<T, B, GTarget> NUTSChain<T, B, GTarget>where
T: Float + ElementConversion + Element + SampleUniform + FromPrimitive,
B: AutodiffBackend,
GTarget: GradientTarget<T, B> + Sync,
StandardNormal: Distribution<T>,
StandardUniform: Distribution<T>,
Exp1: Distribution<T>,
impl<T, B, GTarget> NUTSChain<T, B, GTarget>where
T: Float + ElementConversion + Element + SampleUniform + FromPrimitive,
B: AutodiffBackend,
GTarget: GradientTarget<T, B> + Sync,
StandardNormal: Distribution<T>,
StandardUniform: Distribution<T>,
Exp1: Distribution<T>,
Sourcepub fn new(
target: GTarget,
initial_position: Vec<T>,
target_accept_p: T,
) -> Self
pub fn new( target: GTarget, initial_position: Vec<T>, target_accept_p: T, ) -> Self
Constructs a new NUTSChain for a single chain with the given initial position.
§Parameters
target: The target distribution implementingGradientTarget.initial_position: Initial position vector of lengthD.target_accept_p: Desired average acceptance probability for adaptation.
§Returns
An initialized NUTSChain.
Sourcepub fn run(&mut self, n_collect: usize, n_discard: usize) -> Tensor<B, 2>
pub fn run(&mut self, n_collect: usize, n_discard: usize) -> Tensor<B, 2>
Runs the chain for n_collect + n_discard steps, adapting during burn-in and
returning collected samples.
§Parameters
n_collect: Number of samples to collect after adaptation.n_discard: Number of burn-in steps for adaptation.
§Returns
A 2D tensor of shape [n_collect, D] containing collected samples.
Trait Implementations§
Auto Trait Implementations§
impl<T, B, GTarget> Freeze for NUTSChain<T, B, GTarget>where
GTarget: Freeze,
T: Freeze,
<B as Backend>::FloatTensorPrimitive: Freeze,
<B as Backend>::QuantizedTensorPrimitive: Freeze,
impl<T, B, GTarget> RefUnwindSafe for NUTSChain<T, B, GTarget>where
GTarget: RefUnwindSafe,
T: RefUnwindSafe,
<B as Backend>::FloatTensorPrimitive: RefUnwindSafe,
<B as Backend>::QuantizedTensorPrimitive: RefUnwindSafe,
impl<T, B, GTarget> Send for NUTSChain<T, B, GTarget>
impl<T, B, GTarget> Sync for NUTSChain<T, B, GTarget>
impl<T, B, GTarget> Unpin for NUTSChain<T, B, GTarget>where
GTarget: Unpin,
T: Unpin,
<B as Backend>::FloatTensorPrimitive: Unpin,
<B as Backend>::QuantizedTensorPrimitive: Unpin,
impl<T, B, GTarget> UnwindSafe for NUTSChain<T, B, GTarget>where
GTarget: UnwindSafe,
T: UnwindSafe,
<B as Backend>::FloatTensorPrimitive: UnwindSafe,
<B as Backend>::QuantizedTensorPrimitive: UnwindSafe,
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more