use burn::prelude::*;
use burn::tensor::backend::AutodiffBackend;
use burn::tensor::Element;
use ndarray::{arr1, arr2, Array1, Array2, NdFloat};
use num_traits::Float;
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use rand_distr::{Distribution, Normal, StandardUniform};
use std::f64::consts::PI;
use std::ops::AddAssign;
pub trait BatchedGradientTarget<T: Float, B: AutodiffBackend> {
fn unnorm_logp_batch(&self, positions: Tensor<B, 2>) -> Tensor<B, 1>;
}
pub trait GradientTarget<T: Float, B: AutodiffBackend> {
fn unnorm_logp(&self, position: Tensor<B, 1>) -> Tensor<B, 1>;
fn unnorm_logp_and_grad(&self, position: Tensor<B, 1>) -> (Tensor<B, 1>, Tensor<B, 1>) {
let pos = position.clone().detach().require_grad();
let ulogp = self.unnorm_logp(pos.clone());
let grad_inner = pos.grad(&ulogp.backward()).unwrap();
let grad = Tensor::<B, 1>::from_inner(grad_inner);
(ulogp, grad)
}
}
pub trait Proposal<T, F: Float> {
fn sample(&mut self, current: &[T]) -> Vec<T>;
fn logp(&self, from: &[T], to: &[T]) -> F;
fn set_seed(self, seed: u64) -> Self;
}
pub trait Target<T, F: Float> {
fn unnorm_logp(&self, position: &[T]) -> F;
}
pub trait Normalized<T, F: Float> {
fn logp(&self, position: &[T]) -> F;
}
pub trait Discrete<T: Float> {
fn sample(&mut self) -> usize;
fn logp(&self, index: usize) -> T;
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Gaussian2D<T: Float> {
pub mean: Array1<T>,
pub cov: Array2<T>,
}
impl<T> Normalized<T, T> for Gaussian2D<T>
where
T: NdFloat,
{
fn logp(&self, position: &[T]) -> T {
let term_1 = -(T::from(2.0).unwrap() * T::from(PI).unwrap()).ln();
let (a, b, c, d) = (
self.cov[(0, 0)],
self.cov[(0, 1)],
self.cov[(1, 0)],
self.cov[(1, 1)],
);
let det = a * d - b * c;
let half = T::from(0.5).unwrap();
let term_2 = -half * det.abs().ln();
let x = arr1(position);
let diff = x - self.mean.clone();
let inv_cov = arr2(&[[d, -b], [-c, a]]) / det;
let term_3 = -half * diff.dot(&inv_cov).dot(&diff);
term_1 + term_2 + term_3
}
}
impl<T> Target<T, T> for Gaussian2D<T>
where
T: NdFloat,
{
fn unnorm_logp(&self, position: &[T]) -> T {
let (a, b, c, d) = (
self.cov[(0, 0)],
self.cov[(0, 1)],
self.cov[(1, 0)],
self.cov[(1, 1)],
);
let det = a * d - b * c;
let x = arr1(position);
let diff = x - self.mean.clone();
let inv_cov = arr2(&[[d, -b], [-c, a]]) / det;
-T::from(0.5).unwrap() * diff.dot(&inv_cov).dot(&diff)
}
}
#[derive(Debug, Clone)]
pub struct DiffableGaussian2D<T: Float> {
pub mean: [T; 2],
pub cov: [[T; 2]; 2],
pub inv_cov: [[T; 2]; 2],
pub logdet_cov: T,
pub norm_const: T,
}
impl<T> DiffableGaussian2D<T>
where
T: Float + std::fmt::Debug + num_traits::FloatConst,
{
pub fn new(mean: [T; 2], cov: [[T; 2]; 2]) -> Self {
let det_cov = cov[0][0] * cov[1][1] - cov[0][1] * cov[1][0];
let inv_det = T::one() / det_cov;
let inv_cov = [
[cov[1][1] * inv_det, -cov[0][1] * inv_det],
[-cov[1][0] * inv_det, cov[0][0] * inv_det],
];
let logdet_cov = det_cov.ln(); let two = T::one() + T::one();
let norm_const = -(two * (two * T::PI()).ln() + logdet_cov) / two;
Self {
mean,
cov,
inv_cov,
logdet_cov,
norm_const,
}
}
}
impl<T, B> BatchedGradientTarget<T, B> for DiffableGaussian2D<T>
where
T: Float + burn::tensor::ElementConversion + std::fmt::Debug + burn::tensor::Element,
B: AutodiffBackend,
{
fn unnorm_logp_batch(&self, positions: Tensor<B, 2>) -> Tensor<B, 1> {
let (n_chains, dim) = (positions.dims()[0], positions.dims()[1]);
assert_eq!(dim, 2, "Gaussian2D: expected dimension=2.");
let mean_tensor =
Tensor::<B, 2>::from_floats([[self.mean[0], self.mean[1]]], &B::Device::default())
.reshape([1, 2])
.expand([n_chains, 2]);
let delta = positions.clone() - mean_tensor;
let inv_cov_data = [
self.inv_cov[0][0],
self.inv_cov[0][1],
self.inv_cov[1][0],
self.inv_cov[1][1],
];
let inv_cov_t =
Tensor::<B, 2>::from_floats([inv_cov_data], &B::Device::default()).reshape([2, 2]);
let z = delta.clone().matmul(inv_cov_t); let quad = (z * delta).sum_dim(1).squeeze(1); let shape = Shape::new([n_chains]);
let norm_c = Tensor::<B, 1>::ones(shape, &B::Device::default()).mul_scalar(self.norm_const);
let half = T::from(0.5).unwrap();
norm_c - quad.mul_scalar(half)
}
}
impl<T, B> GradientTarget<T, B> for DiffableGaussian2D<T>
where
T: Float + burn::tensor::ElementConversion + std::fmt::Debug + burn::tensor::Element,
B: AutodiffBackend,
{
fn unnorm_logp(&self, position: Tensor<B, 1>) -> Tensor<B, 1> {
let dim = position.dims()[0];
assert_eq!(dim, 2, "Gaussian2D: expected dimension=2.");
let mean_tensor =
Tensor::<B, 1>::from_floats([self.mean[0], self.mean[1]], &B::Device::default());
let delta = position.clone() - mean_tensor;
let inv_cov_data = [
[self.inv_cov[0][0], self.inv_cov[0][1]],
[self.inv_cov[1][0], self.inv_cov[1][1]],
];
let inv_cov_t = Tensor::<B, 2>::from_floats(inv_cov_data, &B::Device::default());
let z = delta.clone().reshape([1_i32, 2_i32]).matmul(inv_cov_t);
let quad = (z.reshape([2_i32]) * delta).sum();
let half = T::from(0.5).unwrap();
-quad.mul_scalar(half) + self.norm_const
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct IsotropicGaussian<T: Float> {
pub std: T,
rng: SmallRng,
}
impl<T: Float> IsotropicGaussian<T> {
pub fn new(std: T) -> Self {
Self {
std,
rng: SmallRng::from_os_rng(),
}
}
}
impl<T: Float + std::ops::AddAssign> Proposal<T, T> for IsotropicGaussian<T>
where
rand_distr::StandardNormal: rand_distr::Distribution<T>,
{
fn sample(&mut self, current: &[T]) -> Vec<T> {
let normal = Normal::new(T::zero(), self.std)
.expect("Expecting creation of normal distribution to succeed.");
normal
.sample_iter(&mut self.rng)
.zip(current)
.map(|(x, eps)| x + *eps)
.collect()
}
fn logp(&self, from: &[T], to: &[T]) -> T {
let mut lp = T::zero();
let d = T::from(from.len()).unwrap();
let two = T::from(2).unwrap();
let var = self.std * self.std;
for (&f, &t) in from.iter().zip(to.iter()) {
let diff = t - f;
let exponent = -(diff * diff) / (two * var);
lp += exponent;
}
lp += -d * T::from(0.5).unwrap() * (var * T::from(PI).unwrap() * self.std * self.std).ln();
lp
}
fn set_seed(mut self, seed: u64) -> Self {
self.rng = SmallRng::seed_from_u64(seed);
self
}
}
impl<T: Float> Target<T, T> for IsotropicGaussian<T> {
fn unnorm_logp(&self, position: &[T]) -> T {
let mut sum = T::zero();
for &x in position.iter() {
sum = sum + x * x
}
-T::from(0.5).unwrap() * sum / (self.std * self.std)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Categorical<T>
where
T: Float + std::ops::AddAssign,
{
pub probs: Vec<T>,
rng: SmallRng,
}
impl<T: Float + std::ops::AddAssign> Categorical<T> {
pub fn new(probs: Vec<T>) -> Self {
let sum: T = probs.iter().cloned().fold(T::zero(), |acc, x| acc + x);
let normalized: Vec<T> = probs.into_iter().map(|p| p / sum).collect();
Self {
probs: normalized,
rng: SmallRng::from_os_rng(),
}
}
}
impl<T: Float + std::ops::AddAssign> Discrete<T> for Categorical<T>
where
StandardUniform: rand::distr::Distribution<T>,
{
fn sample(&mut self) -> usize {
let r: T = self.rng.random();
let mut cum: T = T::zero();
let mut k = self.probs.len() - 1;
for (i, &p) in self.probs.iter().enumerate() {
cum += p;
if r <= cum {
k = i;
break;
}
}
k
}
fn logp(&self, index: usize) -> T {
if index < self.probs.len() {
self.probs[index].ln()
} else {
T::neg_infinity()
}
}
}
impl<T: Float + AddAssign> Target<usize, T> for Categorical<T>
where
rand_distr::StandardUniform: rand_distr::Distribution<T>,
{
fn unnorm_logp(&self, position: &[usize]) -> T {
<Self as Discrete<T>>::logp(self, position[0])
}
}
pub trait Conditional<S> {
fn sample(&mut self, index: usize, given: &[S]) -> S;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Rosenbrock2D<T: Float> {
pub a: T,
pub b: T,
}
impl<T, B> BatchedGradientTarget<T, B> for Rosenbrock2D<T>
where
T: Float + Element,
B: AutodiffBackend,
{
fn unnorm_logp_batch(&self, positions: Tensor<B, 2>) -> Tensor<B, 1> {
let n = positions.dims()[0];
let x = positions.clone().slice([0..n, 0..1]);
let y = positions.slice([0..n, 1..2]);
let term_1 = (-x.clone()).add_scalar(self.a).powi_scalar(2);
let term_2 = y.sub(x.powi_scalar(2)).powi_scalar(2).mul_scalar(self.b);
-(term_1 + term_2).flatten(0, 1)
}
}
impl<T, B> GradientTarget<T, B> for Rosenbrock2D<T>
where
T: Float + Element,
B: AutodiffBackend,
{
fn unnorm_logp(&self, position: Tensor<B, 1>) -> Tensor<B, 1> {
let x = position.clone().slice(s![0..1]);
let y = position.slice(s![1..2]);
let term_1 = (-x.clone()).add_scalar(self.a).powi_scalar(2);
let term_2 = y.sub(x.powi_scalar(2)).powi_scalar(2).mul_scalar(self.b);
-(term_1 + term_2)
}
}
pub struct RosenbrockND {}
impl<T, B> BatchedGradientTarget<T, B> for RosenbrockND
where
T: Float + Element,
B: AutodiffBackend,
{
fn unnorm_logp_batch(&self, positions: Tensor<B, 2>) -> Tensor<B, 1> {
let k = positions.dims()[0];
let n = positions.dims()[1];
let low = positions.clone().slice([0..k, 0..(n - 1)]);
let high = positions.slice([0..k, 1..n]);
let term_1 = (high - low.clone().powi_scalar(2))
.powi_scalar(2)
.mul_scalar(100);
let term_2 = low.neg().add_scalar(1).powi_scalar(2);
-(term_1 + term_2).sum_dim(1).squeeze(1)
}
}
#[cfg(test)]
mod continuous_tests {
use super::*;
fn normalize_isogauss(x: f64, d: usize, std: f64) -> f64 {
let log_normalizer = -((d as f64) / 2.0) * ((2.0_f64).ln() + PI.ln() + 2.0 * std.ln());
(x + log_normalizer).exp()
}
#[test]
fn iso_gauss_unnorm_logp_test_1() {
let distr = IsotropicGaussian::new(1.0);
let p = normalize_isogauss(distr.unnorm_logp(&[1.0]), 1, distr.std);
let true_p = 0.24197072451914337;
let diff = (p - true_p).abs();
assert!(
diff < 1e-7,
"Expected diff < 1e-7, got {diff} with p={p} (expected ~{true_p})."
);
}
#[test]
fn iso_gauss_unnorm_logp_test_2() {
let distr = IsotropicGaussian::new(2.0);
let p = normalize_isogauss(distr.unnorm_logp(&[0.42, 9.6]), 2, distr.std);
let true_p = 3.864661987252467e-7;
let diff = (p - true_p).abs();
assert!(
diff < 1e-15,
"Expected diff < 1e-15, got {diff} with p={p} (expected ~{true_p})"
);
}
#[test]
fn iso_gauss_unnorm_logp_test_3() {
let distr = IsotropicGaussian::new(3.0);
let p = normalize_isogauss(distr.unnorm_logp(&[1.0, 2.0, 3.0]), 3, distr.std);
let true_p = 0.001080393185560214;
let diff = (p - true_p).abs();
assert!(
diff < 1e-8,
"Expected diff < 1e-8, got {diff} with p={p} (expected ~{true_p})"
);
}
}
#[cfg(test)]
mod categorical_tests {
use super::*;
fn approx_eq(a: f64, b: f64, tol: f64) -> bool {
(a - b).abs() < tol
}
#[test]
fn test_categorical_logp_f64() {
let probs = vec![0.2, 0.3, 0.5];
let cat = Categorical::<f64>::new(probs.clone());
let logp_0 = cat.logp(0);
let logp_1 = cat.logp(1);
let logp_2 = cat.logp(2);
let expected_0 = 0.2_f64.ln();
let expected_1 = 0.3_f64.ln();
let expected_2 = 0.5_f64.ln();
let tol = 1e-7;
assert!(
approx_eq(logp_0, expected_0, tol),
"Log prob mismatch at index 0: got {}, expected {}",
logp_0,
expected_0
);
assert!(
approx_eq(logp_1, expected_1, tol),
"Log prob mismatch at index 1: got {}, expected {}",
logp_1,
expected_1
);
assert!(
approx_eq(logp_2, expected_2, tol),
"Log prob mismatch at index 2: got {}, expected {}",
logp_2,
expected_2
);
let logp_out = cat.logp(3);
assert_eq!(
logp_out,
f64::NEG_INFINITY,
"Out-of-bounds index did not return NEG_INFINITY"
);
}
#[test]
fn test_categorical_sampling_f64() {
let probs = vec![0.2, 0.3, 0.5];
let mut cat = Categorical::<f64>::new(probs.clone());
let sample_size = 100_000;
let mut counts = vec![0_usize; probs.len()];
for _ in 0..sample_size {
let observation = cat.sample();
counts[observation] += 1;
}
let tol = 0.01; for (i, &count) in counts.iter().enumerate() {
let freq = count as f64 / sample_size as f64;
let expected = probs[i];
assert!(
approx_eq(freq, expected, tol),
"Empirical freq for index {} is off: got {:.3}, expected {:.3}",
i,
freq,
expected
);
}
}
#[test]
fn test_categorical_logp_f32() {
let probs = vec![0.1_f32, 0.4, 0.5];
let cat = Categorical::<f32>::new(probs.clone());
let logp_0: f32 = cat.logp(0);
let logp_1 = cat.logp(1);
let logp_2 = cat.logp(2);
let expected_0 = (0.1_f64).ln();
let expected_1 = (0.4_f64).ln();
let expected_2 = (0.5_f64).ln();
let tol = 1e-6;
assert!(
approx_eq(logp_0.into(), expected_0, tol),
"Log prob mismatch at index 0 (f32 -> f64 cast)"
);
assert!(
approx_eq(logp_1.into(), expected_1, tol),
"Log prob mismatch at index 1"
);
assert!(
approx_eq(logp_2.into(), expected_2, tol),
"Log prob mismatch at index 2"
);
let logp_out = cat.logp(3);
assert_eq!(logp_out, f32::NEG_INFINITY);
}
#[test]
fn test_categorical_sampling_f32() {
let probs = vec![0.1_f32, 0.4, 0.5];
let mut cat = Categorical::<f32>::new(probs.clone());
let sample_size = 100_000;
let mut counts = vec![0_usize; probs.len()];
for _ in 0..sample_size {
let observation = cat.sample();
counts[observation] += 1;
}
let tol = 0.02; for (i, &count) in counts.iter().enumerate() {
let freq = count as f32 / sample_size as f32;
let expected = probs[i];
assert!(
(freq - expected).abs() < tol,
"Empirical freq for index {} is off: got {:.3}, expected {:.3}",
i,
freq,
expected
);
}
}
#[test]
fn test_categorical_sample_single_value() {
let mut cat = Categorical {
probs: vec![1.0_f64],
rng: rand::rngs::SmallRng::from_seed(Default::default()),
};
let sampled_index = cat.sample();
assert_eq!(
sampled_index, 0,
"Should return the last index (0) for a single-element vector"
);
}
#[test]
fn test_target_for_categorical_in_range() {
let probs = vec![0.2_f64, 0.3, 0.5];
let cat = Categorical::new(probs.clone());
let logp = cat.unnorm_logp(&[1]);
let expected = 0.3_f64.ln();
let tol = 1e-7;
assert!(
(logp - expected).abs() < tol,
"For index 1, expected ln(0.3) ~ {}, got {}",
expected,
logp
);
}
#[test]
fn test_target_for_categorical_out_of_range() {
let probs = vec![0.2_f64, 0.3, 0.5];
let cat = Categorical::new(probs);
let logp = cat.unnorm_logp(&[3]);
assert_eq!(
logp,
f64::NEG_INFINITY,
"Expected negative infinity for out-of-range index, got {}",
logp
);
}
#[test]
fn test_gaussian2d_logp() {
let mean = arr1(&[0.0, 0.0]);
let cov = arr2(&[[1.0, 0.0], [0.0, 1.0]]);
let gauss = Gaussian2D { mean, cov };
let position = vec![0.5, -0.5];
let computed_logp = gauss.logp(&position);
let expected_logp = -2.0878770664093453;
let tol = 1e-10;
assert!(
(computed_logp - expected_logp).abs() < tol,
"Computed log density ({}) differs from expected ({}) by more than tolerance ({})",
computed_logp,
expected_logp,
tol
);
}
}