Trait GradientTarget

Source
pub trait GradientTarget<T: Float, B: AutodiffBackend> {
    // Required method
    fn log_prob_batch(&self, positions: &Tensor<B, 2>) -> Tensor<B, 1>;
}
Expand description

A batched target trait for computing the unnormalized log probability (and gradients) for a collection of positions.

Implement this trait for your target distribution to enable gradient-based sampling.

§Type Parameters

  • T: The floating-point type (e.g., f32 or f64).
  • B: The autodiff backend from the burn crate.

Required Methods§

Source

fn log_prob_batch(&self, positions: &Tensor<B, 2>) -> Tensor<B, 1>

Compute the log probability for a batch of positions.

§Parameters
  • positions: A tensor of shape [n_chains, D] representing the current positions for each chain.
§Returns

A 1D tensor of shape [n_chains] containing the log probabilities for each chain.

Implementors§