1use burn::prelude::*;
46use burn::tensor::backend::AutodiffBackend;
47use burn::tensor::Element;
48use ndarray::{arr1, arr2, Array1, Array2, NdFloat};
49use num_traits::Float;
50use rand::rngs::SmallRng;
51use rand::{Rng, SeedableRng};
52use rand_distr::{Distribution, Normal, StandardUniform};
53use std::f64::consts::PI;
54use std::ops::AddAssign;
55
56pub trait BatchedGradientTarget<T: Float, B: AutodiffBackend> {
66 fn unnorm_logp_batch(&self, positions: Tensor<B, 2>) -> Tensor<B, 1>;
76}
77
78pub trait GradientTarget<T: Float, B: AutodiffBackend> {
79 fn unnorm_logp(&self, position: Tensor<B, 1>) -> Tensor<B, 1>;
80
81 fn unnorm_logp_and_grad(&self, position: Tensor<B, 1>) -> (Tensor<B, 1>, Tensor<B, 1>) {
82 let pos = position.clone().detach().require_grad();
83 let ulogp = self.unnorm_logp(pos.clone());
84 let grad_inner = pos.grad(&ulogp.backward()).unwrap();
85 let grad = Tensor::<B, 1>::from_inner(grad_inner);
86 (ulogp, grad)
87 }
88}
89
90pub trait Proposal<T, F: Float> {
93 fn sample(&mut self, current: &[T]) -> Vec<T>;
95
96 fn logp(&self, from: &[T], to: &[T]) -> F;
98
99 fn set_seed(self, seed: u64) -> Self;
101}
102
103pub trait Target<T, F: Float> {
106 fn unnorm_logp(&self, position: &[T]) -> F;
108}
109
110pub trait Normalized<T, F: Float> {
112 fn logp(&self, position: &[T]) -> F;
114}
115
116pub trait Discrete<T: Float> {
130 fn sample(&mut self) -> usize;
132 fn logp(&self, index: usize) -> T;
134}
135
136#[derive(Debug, Clone, PartialEq, Eq)]
159pub struct Gaussian2D<T: Float> {
160 pub mean: Array1<T>,
161 pub cov: Array2<T>,
162}
163
164impl<T> Normalized<T, T> for Gaussian2D<T>
165where
166 T: NdFloat,
167{
168 fn logp(&self, position: &[T]) -> T {
170 let term_1 = -(T::from(2.0).unwrap() * T::from(PI).unwrap()).ln();
171 let (a, b, c, d) = (
172 self.cov[(0, 0)],
173 self.cov[(0, 1)],
174 self.cov[(1, 0)],
175 self.cov[(1, 1)],
176 );
177 let det = a * d - b * c;
178 let half = T::from(0.5).unwrap();
179 let term_2 = -half * det.abs().ln();
180
181 let x = arr1(position);
182 let diff = x - self.mean.clone();
183 let inv_cov = arr2(&[[d, -b], [-c, a]]) / det;
184 let term_3 = -half * diff.dot(&inv_cov).dot(&diff);
185 term_1 + term_2 + term_3
186 }
187}
188
189impl<T> Target<T, T> for Gaussian2D<T>
190where
191 T: NdFloat,
192{
193 fn unnorm_logp(&self, position: &[T]) -> T {
194 let (a, b, c, d) = (
195 self.cov[(0, 0)],
196 self.cov[(0, 1)],
197 self.cov[(1, 0)],
198 self.cov[(1, 1)],
199 );
200 let det = a * d - b * c;
201 let x = arr1(position);
202 let diff = x - self.mean.clone();
203 let inv_cov = arr2(&[[d, -b], [-c, a]]) / det;
204 -T::from(0.5).unwrap() * diff.dot(&inv_cov).dot(&diff)
205 }
206}
207
208#[derive(Debug, Clone)]
213pub struct DiffableGaussian2D<T: Float> {
214 pub mean: [T; 2],
215 pub cov: [[T; 2]; 2],
216 pub inv_cov: [[T; 2]; 2],
217 pub logdet_cov: T,
218 pub norm_const: T,
219}
220
221impl<T> DiffableGaussian2D<T>
222where
223 T: Float + std::fmt::Debug + num_traits::FloatConst,
224{
225 pub fn new(mean: [T; 2], cov: [[T; 2]; 2]) -> Self {
228 let det_cov = cov[0][0] * cov[1][1] - cov[0][1] * cov[1][0];
230 let inv_det = T::one() / det_cov;
233 let inv_cov = [
234 [cov[1][1] * inv_det, -cov[0][1] * inv_det],
235 [-cov[1][0] * inv_det, cov[0][0] * inv_det],
236 ];
237 let logdet_cov = det_cov.ln(); let two = T::one() + T::one();
242 let norm_const = -(two * (two * T::PI()).ln() + logdet_cov) / two;
243
244 Self {
245 mean,
246 cov,
247 inv_cov,
248 logdet_cov,
249 norm_const,
250 }
251 }
252}
253
254impl<T, B> BatchedGradientTarget<T, B> for DiffableGaussian2D<T>
255where
256 T: Float + burn::tensor::ElementConversion + std::fmt::Debug + burn::tensor::Element,
257 B: AutodiffBackend,
258{
259 fn unnorm_logp_batch(&self, positions: Tensor<B, 2>) -> Tensor<B, 1> {
263 let (n_chains, dim) = (positions.dims()[0], positions.dims()[1]);
264 assert_eq!(dim, 2, "Gaussian2D: expected dimension=2.");
265
266 let mean_tensor =
267 Tensor::<B, 2>::from_floats([[self.mean[0], self.mean[1]]], &B::Device::default())
268 .reshape([1, 2])
269 .expand([n_chains, 2]);
270
271 let delta = positions.clone() - mean_tensor;
272
273 let inv_cov_data = [
274 self.inv_cov[0][0],
275 self.inv_cov[0][1],
276 self.inv_cov[1][0],
277 self.inv_cov[1][1],
278 ];
279 let inv_cov_t =
280 Tensor::<B, 2>::from_floats([inv_cov_data], &B::Device::default()).reshape([2, 2]);
281
282 let z = delta.clone().matmul(inv_cov_t); let quad = (z * delta).sum_dim(1).squeeze(1); let shape = Shape::new([n_chains]);
285 let norm_c = Tensor::<B, 1>::ones(shape, &B::Device::default()).mul_scalar(self.norm_const);
286 let half = T::from(0.5).unwrap();
287 norm_c - quad.mul_scalar(half)
288 }
289}
290
291impl<T, B> GradientTarget<T, B> for DiffableGaussian2D<T>
292where
293 T: Float + burn::tensor::ElementConversion + std::fmt::Debug + burn::tensor::Element,
294 B: AutodiffBackend,
295{
296 fn unnorm_logp(&self, position: Tensor<B, 1>) -> Tensor<B, 1> {
297 let dim = position.dims()[0];
298 assert_eq!(dim, 2, "Gaussian2D: expected dimension=2.");
299
300 let mean_tensor =
301 Tensor::<B, 1>::from_floats([self.mean[0], self.mean[1]], &B::Device::default());
302
303 let delta = position.clone() - mean_tensor;
304
305 let inv_cov_data = [
306 [self.inv_cov[0][0], self.inv_cov[0][1]],
307 [self.inv_cov[1][0], self.inv_cov[1][1]],
308 ];
309 let inv_cov_t = Tensor::<B, 2>::from_floats(inv_cov_data, &B::Device::default());
310
311 let z = delta.clone().reshape([1_i32, 2_i32]).matmul(inv_cov_t);
312 let quad = (z.reshape([2_i32]) * delta).sum();
313 let half = T::from(0.5).unwrap();
314 -quad.mul_scalar(half) + self.norm_const
315 }
316}
317
318#[derive(Debug, Clone, PartialEq, Eq)]
345pub struct IsotropicGaussian<T: Float> {
346 pub std: T,
347 rng: SmallRng,
348}
349
350impl<T: Float> IsotropicGaussian<T> {
351 pub fn new(std: T) -> Self {
353 Self {
354 std,
355 rng: SmallRng::from_os_rng(),
356 }
357 }
358}
359
360impl<T: Float + std::ops::AddAssign> Proposal<T, T> for IsotropicGaussian<T>
361where
362 rand_distr::StandardNormal: rand_distr::Distribution<T>,
363{
364 fn sample(&mut self, current: &[T]) -> Vec<T> {
365 let normal = Normal::new(T::zero(), self.std)
366 .expect("Expecting creation of normal distribution to succeed.");
367 normal
368 .sample_iter(&mut self.rng)
369 .zip(current)
370 .map(|(x, eps)| x + *eps)
371 .collect()
372 }
373
374 fn logp(&self, from: &[T], to: &[T]) -> T {
375 let mut lp = T::zero();
376 let d = T::from(from.len()).unwrap();
377 let two = T::from(2).unwrap();
378 let var = self.std * self.std;
379 for (&f, &t) in from.iter().zip(to.iter()) {
380 let diff = t - f;
381 let exponent = -(diff * diff) / (two * var);
382 lp += exponent;
383 }
384 lp += -d * T::from(0.5).unwrap() * (var * T::from(PI).unwrap() * self.std * self.std).ln();
385 lp
386 }
387
388 fn set_seed(mut self, seed: u64) -> Self {
389 self.rng = SmallRng::seed_from_u64(seed);
390 self
391 }
392}
393
394impl<T: Float> Target<T, T> for IsotropicGaussian<T> {
395 fn unnorm_logp(&self, position: &[T]) -> T {
396 let mut sum = T::zero();
397 for &x in position.iter() {
398 sum = sum + x * x
399 }
400 -T::from(0.5).unwrap() * sum / (self.std * self.std)
401 }
402}
403
404#[derive(Debug, Clone, PartialEq, Eq)]
422pub struct Categorical<T>
423where
424 T: Float + std::ops::AddAssign,
425{
426 pub probs: Vec<T>,
427 rng: SmallRng,
428}
429
430impl<T: Float + std::ops::AddAssign> Categorical<T> {
431 pub fn new(probs: Vec<T>) -> Self {
434 let sum: T = probs.iter().cloned().fold(T::zero(), |acc, x| acc + x);
435 let normalized: Vec<T> = probs.into_iter().map(|p| p / sum).collect();
436 Self {
437 probs: normalized,
438 rng: SmallRng::from_os_rng(),
439 }
440 }
441}
442
443impl<T: Float + std::ops::AddAssign> Discrete<T> for Categorical<T>
444where
445 StandardUniform: rand::distr::Distribution<T>,
446{
447 fn sample(&mut self) -> usize {
448 let r: T = self.rng.random();
449 let mut cum: T = T::zero();
450 let mut k = self.probs.len() - 1;
451 for (i, &p) in self.probs.iter().enumerate() {
452 cum += p;
453 if r <= cum {
454 k = i;
455 break;
456 }
457 }
458 k
459 }
460
461 fn logp(&self, index: usize) -> T {
462 if index < self.probs.len() {
463 self.probs[index].ln()
464 } else {
465 T::neg_infinity()
466 }
467 }
468}
469
470impl<T: Float + AddAssign> Target<usize, T> for Categorical<T>
471where
472 rand_distr::StandardUniform: rand_distr::Distribution<T>,
473{
474 fn unnorm_logp(&self, position: &[usize]) -> T {
475 <Self as Discrete<T>>::logp(self, position[0])
476 }
477}
478
479pub trait Conditional<S> {
486 fn sample(&mut self, index: usize, given: &[S]) -> S;
487}
488
489#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
491pub struct Rosenbrock2D<T: Float> {
492 pub a: T,
493 pub b: T,
494}
495
496impl<T, B> BatchedGradientTarget<T, B> for Rosenbrock2D<T>
498where
499 T: Float + Element,
500 B: AutodiffBackend,
501{
502 fn unnorm_logp_batch(&self, positions: Tensor<B, 2>) -> Tensor<B, 1> {
503 let n = positions.dims()[0];
504 let x = positions.clone().slice([0..n, 0..1]);
505 let y = positions.slice([0..n, 1..2]);
506 let term_1 = (-x.clone()).add_scalar(self.a).powi_scalar(2);
507 let term_2 = y.sub(x.powi_scalar(2)).powi_scalar(2).mul_scalar(self.b);
508 -(term_1 + term_2).flatten(0, 1)
509 }
510}
511
512impl<T, B> GradientTarget<T, B> for Rosenbrock2D<T>
513where
514 T: Float + Element,
515 B: AutodiffBackend,
516{
517 fn unnorm_logp(&self, position: Tensor<B, 1>) -> Tensor<B, 1> {
518 let x = position.clone().slice(s![0..1]);
519 let y = position.slice(s![1..2]);
520 let term_1 = (-x.clone()).add_scalar(self.a).powi_scalar(2);
521 let term_2 = y.sub(x.powi_scalar(2)).powi_scalar(2).mul_scalar(self.b);
522 -(term_1 + term_2)
523 }
524}
525
526pub struct RosenbrockND {}
529
530impl<T, B> BatchedGradientTarget<T, B> for RosenbrockND
532where
533 T: Float + Element,
534 B: AutodiffBackend,
535{
536 fn unnorm_logp_batch(&self, positions: Tensor<B, 2>) -> Tensor<B, 1> {
537 let k = positions.dims()[0];
538 let n = positions.dims()[1];
539 let low = positions.clone().slice([0..k, 0..(n - 1)]);
540 let high = positions.slice([0..k, 1..n]);
541 let term_1 = (high - low.clone().powi_scalar(2))
542 .powi_scalar(2)
543 .mul_scalar(100);
544 let term_2 = low.neg().add_scalar(1).powi_scalar(2);
545 -(term_1 + term_2).sum_dim(1).squeeze(1)
546 }
547}
548
549#[cfg(test)]
550mod continuous_tests {
551 use super::*;
552
553 fn normalize_isogauss(x: f64, d: usize, std: f64) -> f64 {
568 let log_normalizer = -((d as f64) / 2.0) * ((2.0_f64).ln() + PI.ln() + 2.0 * std.ln());
569 (x + log_normalizer).exp()
570 }
571
572 #[test]
573 fn iso_gauss_unnorm_logp_test_1() {
574 let distr = IsotropicGaussian::new(1.0);
575 let p = normalize_isogauss(distr.unnorm_logp(&[1.0]), 1, distr.std);
576 let true_p = 0.24197072451914337;
577 let diff = (p - true_p).abs();
578 assert!(
579 diff < 1e-7,
580 "Expected diff < 1e-7, got {diff} with p={p} (expected ~{true_p})."
581 );
582 }
583
584 #[test]
585 fn iso_gauss_unnorm_logp_test_2() {
586 let distr = IsotropicGaussian::new(2.0);
587 let p = normalize_isogauss(distr.unnorm_logp(&[0.42, 9.6]), 2, distr.std);
588 let true_p = 3.864661987252467e-7;
589 let diff = (p - true_p).abs();
590 assert!(
591 diff < 1e-15,
592 "Expected diff < 1e-15, got {diff} with p={p} (expected ~{true_p})"
593 );
594 }
595
596 #[test]
597 fn iso_gauss_unnorm_logp_test_3() {
598 let distr = IsotropicGaussian::new(3.0);
599 let p = normalize_isogauss(distr.unnorm_logp(&[1.0, 2.0, 3.0]), 3, distr.std);
600 let true_p = 0.001080393185560214;
601 let diff = (p - true_p).abs();
602 assert!(
603 diff < 1e-8,
604 "Expected diff < 1e-8, got {diff} with p={p} (expected ~{true_p})"
605 );
606 }
607}
608
609#[cfg(test)]
610mod categorical_tests {
611 use super::*;
612
613 fn approx_eq(a: f64, b: f64, tol: f64) -> bool {
615 (a - b).abs() < tol
616 }
617
618 #[test]
622 fn test_categorical_logp_f64() {
623 let probs = vec![0.2, 0.3, 0.5];
624 let cat = Categorical::<f64>::new(probs.clone());
625
626 let logp_0 = cat.logp(0);
628 let logp_1 = cat.logp(1);
629 let logp_2 = cat.logp(2);
630
631 let expected_0 = 0.2_f64.ln();
633 let expected_1 = 0.3_f64.ln();
634 let expected_2 = 0.5_f64.ln();
635
636 let tol = 1e-7;
637 assert!(
638 approx_eq(logp_0, expected_0, tol),
639 "Log prob mismatch at index 0: got {}, expected {}",
640 logp_0,
641 expected_0
642 );
643 assert!(
644 approx_eq(logp_1, expected_1, tol),
645 "Log prob mismatch at index 1: got {}, expected {}",
646 logp_1,
647 expected_1
648 );
649 assert!(
650 approx_eq(logp_2, expected_2, tol),
651 "Log prob mismatch at index 2: got {}, expected {}",
652 logp_2,
653 expected_2
654 );
655
656 let logp_out = cat.logp(3);
658 assert_eq!(
659 logp_out,
660 f64::NEG_INFINITY,
661 "Out-of-bounds index did not return NEG_INFINITY"
662 );
663 }
664
665 #[test]
669 fn test_categorical_sampling_f64() {
670 let probs = vec![0.2, 0.3, 0.5];
671 let mut cat = Categorical::<f64>::new(probs.clone());
672
673 let sample_size = 100_000;
674 let mut counts = vec![0_usize; probs.len()];
675
676 for _ in 0..sample_size {
678 let observation = cat.sample();
679 counts[observation] += 1;
680 }
681
682 let tol = 0.01; for (i, &count) in counts.iter().enumerate() {
685 let freq = count as f64 / sample_size as f64;
686 let expected = probs[i];
687 assert!(
688 approx_eq(freq, expected, tol),
689 "Empirical freq for index {} is off: got {:.3}, expected {:.3}",
690 i,
691 freq,
692 expected
693 );
694 }
695 }
696
697 #[test]
701 fn test_categorical_logp_f32() {
702 let probs = vec![0.1_f32, 0.4, 0.5];
703 let cat = Categorical::<f32>::new(probs.clone());
704
705 let logp_0: f32 = cat.logp(0);
706 let logp_1 = cat.logp(1);
707 let logp_2 = cat.logp(2);
708
709 let expected_0 = (0.1_f64).ln();
711 let expected_1 = (0.4_f64).ln();
712 let expected_2 = (0.5_f64).ln();
713
714 let tol = 1e-6;
715 assert!(
716 approx_eq(logp_0.into(), expected_0, tol),
717 "Log prob mismatch at index 0 (f32 -> f64 cast)"
718 );
719 assert!(
720 approx_eq(logp_1.into(), expected_1, tol),
721 "Log prob mismatch at index 1"
722 );
723 assert!(
724 approx_eq(logp_2.into(), expected_2, tol),
725 "Log prob mismatch at index 2"
726 );
727
728 let logp_out = cat.logp(3);
730 assert_eq!(logp_out, f32::NEG_INFINITY);
731 }
732
733 #[test]
737 fn test_categorical_sampling_f32() {
738 let probs = vec![0.1_f32, 0.4, 0.5];
739 let mut cat = Categorical::<f32>::new(probs.clone());
740
741 let sample_size = 100_000;
742 let mut counts = vec![0_usize; probs.len()];
743
744 for _ in 0..sample_size {
745 let observation = cat.sample();
746 counts[observation] += 1;
747 }
748
749 let tol = 0.02; for (i, &count) in counts.iter().enumerate() {
752 let freq = count as f32 / sample_size as f32;
753 let expected = probs[i];
754 assert!(
755 (freq - expected).abs() < tol,
756 "Empirical freq for index {} is off: got {:.3}, expected {:.3}",
757 i,
758 freq,
759 expected
760 );
761 }
762 }
763
764 #[test]
765 fn test_categorical_sample_single_value() {
766 let mut cat = Categorical {
767 probs: vec![1.0_f64],
768 rng: rand::rngs::SmallRng::from_seed(Default::default()),
769 };
770
771 let sampled_index = cat.sample();
772
773 assert_eq!(
774 sampled_index, 0,
775 "Should return the last index (0) for a single-element vector"
776 );
777 }
778
779 #[test]
780 fn test_target_for_categorical_in_range() {
781 let probs = vec![0.2_f64, 0.3, 0.5];
783 let cat = Categorical::new(probs.clone());
784 let logp = cat.unnorm_logp(&[1]);
786 let expected = 0.3_f64.ln();
788 let tol = 1e-7;
789 assert!(
790 (logp - expected).abs() < tol,
791 "For index 1, expected ln(0.3) ~ {}, got {}",
792 expected,
793 logp
794 );
795 }
796
797 #[test]
798 fn test_target_for_categorical_out_of_range() {
799 let probs = vec![0.2_f64, 0.3, 0.5];
800 let cat = Categorical::new(probs);
801 let logp = cat.unnorm_logp(&[3]);
804 assert_eq!(
805 logp,
806 f64::NEG_INFINITY,
807 "Expected negative infinity for out-of-range index, got {}",
808 logp
809 );
810 }
811
812 #[test]
813 fn test_gaussian2d_logp() {
814 let mean = arr1(&[0.0, 0.0]);
815 let cov = arr2(&[[1.0, 0.0], [0.0, 1.0]]);
816 let gauss = Gaussian2D { mean, cov };
817
818 let position = vec![0.5, -0.5];
819 let computed_logp = gauss.logp(&position);
820
821 let expected_logp = -2.0878770664093453;
822
823 let tol = 1e-10;
824 assert!(
825 (computed_logp - expected_logp).abs() < tol,
826 "Computed log density ({}) differs from expected ({}) by more than tolerance ({})",
827 computed_logp,
828 expected_logp,
829 tol
830 );
831 }
832}