1use burn::prelude::*;
45use burn::tensor::backend::AutodiffBackend;
46use ndarray::{arr1, arr2, Array1, Array2, NdFloat};
47use num_traits::Float;
48use rand::rngs::SmallRng;
49use rand::{Rng, SeedableRng};
50use rand_distr::{Distribution, Normal};
51use std::f64::consts::PI;
52use std::ops::AddAssign;
53
54pub trait GradientTarget<T: Float, B: AutodiffBackend> {
64 fn unnorm_logp(&self, positions: Tensor<B, 2>) -> Tensor<B, 1>;
74}
75
76pub trait Proposal<T, F: Float> {
79 fn sample(&mut self, current: &[T]) -> Vec<T>;
81
82 fn logp(&self, from: &[T], to: &[T]) -> F;
84
85 fn set_seed(self, seed: u64) -> Self;
87}
88
89pub trait Target<T, F: Float> {
92 fn unnorm_logp(&self, theta: &[T]) -> F;
94}
95
96pub trait Normalized<T, F: Float> {
98 fn logp(&self, theta: &[T]) -> F;
100}
101
102pub trait Discrete<T: Float> {
116 fn sample(&mut self) -> usize;
118 fn logp(&self, index: usize) -> T;
120}
121
122#[derive(Debug, Clone, PartialEq, Eq)]
145pub struct Gaussian2D<T: Float> {
146 pub mean: Array1<T>,
147 pub cov: Array2<T>,
148}
149
150impl<T> Normalized<T, T> for Gaussian2D<T>
151where
152 T: NdFloat,
153{
154 fn logp(&self, theta: &[T]) -> T {
156 let term_1 = -(T::from(2.0).unwrap() * T::from(PI).unwrap()).ln();
157 let (a, b, c, d) = (
158 self.cov[(0, 0)],
159 self.cov[(0, 1)],
160 self.cov[(1, 0)],
161 self.cov[(1, 1)],
162 );
163 let det = a * d - b * c;
164 let half = T::from(0.5).unwrap();
165 let term_2 = -half * det.abs().ln();
166
167 let x = arr1(theta);
168 let diff = x - self.mean.clone();
169 let inv_cov = arr2(&[[d, -b], [-c, a]]) / det;
170 let term_3 = -half * diff.dot(&inv_cov).dot(&diff);
171 term_1 + term_2 + term_3
172 }
173}
174
175impl<T> Target<T, T> for Gaussian2D<T>
176where
177 T: NdFloat,
178{
179 fn unnorm_logp(&self, theta: &[T]) -> T {
180 let (a, b, c, d) = (
181 self.cov[(0, 0)],
182 self.cov[(0, 1)],
183 self.cov[(1, 0)],
184 self.cov[(1, 1)],
185 );
186 let det = a * d - b * c;
187 let x = arr1(theta);
188 let diff = x - self.mean.clone();
189 let inv_cov = arr2(&[[d, -b], [-c, a]]) / det;
190 -T::from(0.5).unwrap() * diff.dot(&inv_cov).dot(&diff)
191 }
192}
193
194#[derive(Debug, Clone)]
199pub struct DiffableGaussian2D<T: Float> {
200 pub mean: [T; 2],
201 pub cov: [[T; 2]; 2],
202 pub inv_cov: [[T; 2]; 2],
203 pub logdet_cov: T,
204 pub norm_const: T,
205}
206
207impl<T> DiffableGaussian2D<T>
208where
209 T: Float + std::fmt::Debug + num_traits::FloatConst,
210{
211 pub fn new(mean: [T; 2], cov: [[T; 2]; 2]) -> Self {
214 let det_cov = cov[0][0] * cov[1][1] - cov[0][1] * cov[1][0];
216 let inv_det = T::one() / det_cov;
219 let inv_cov = [
220 [cov[1][1] * inv_det, -cov[0][1] * inv_det],
221 [-cov[1][0] * inv_det, cov[0][0] * inv_det],
222 ];
223 let logdet_cov = det_cov.ln(); let two = T::one() + T::one();
228 let norm_const = -(two * (two * T::PI()).ln() + logdet_cov) / two;
229
230 Self {
231 mean,
232 cov,
233 inv_cov,
234 logdet_cov,
235 norm_const,
236 }
237 }
238}
239
240impl<T, B> GradientTarget<T, B> for DiffableGaussian2D<T>
241where
242 T: Float + burn::tensor::ElementConversion + std::fmt::Debug + burn::tensor::Element,
243 B: AutodiffBackend,
244{
245 fn unnorm_logp(&self, positions: Tensor<B, 2>) -> Tensor<B, 1> {
249 let (n_chains, dim) = (positions.dims()[0], positions.dims()[1]);
250 assert_eq!(dim, 2, "Gaussian2D: expected dimension=2.");
251
252 let mean_tensor =
256 Tensor::<B, 2>::from_floats([[self.mean[0], self.mean[1]]], &B::Device::default())
257 .reshape([1, 2])
258 .expand([n_chains, 2]);
259
260 let delta = positions.clone() - mean_tensor;
261
262 let inv_cov_data = [
264 self.inv_cov[0][0],
265 self.inv_cov[0][1],
266 self.inv_cov[1][0],
267 self.inv_cov[1][1],
268 ];
269 let inv_cov_t =
270 Tensor::<B, 2>::from_floats([inv_cov_data], &B::Device::default()).reshape([2, 2]);
271
272 let z = delta.clone().matmul(inv_cov_t); let quad = (z * delta).sum_dim(1).squeeze(1); let shape = Shape::new([n_chains]);
285 let norm_c = Tensor::<B, 1>::ones(shape, &B::Device::default()).mul_scalar(self.norm_const);
286 let half = T::from(0.5).unwrap();
287 norm_c - quad.mul_scalar(half)
288 }
289}
290
291#[derive(Debug, Clone, PartialEq, Eq)]
318pub struct IsotropicGaussian<T: Float> {
319 pub std: T,
320 rng: SmallRng,
321}
322
323impl<T: Float> IsotropicGaussian<T> {
324 pub fn new(std: T) -> Self {
326 Self {
327 std,
328 rng: SmallRng::from_entropy(),
329 }
330 }
331}
332
333impl<T: Float + std::ops::AddAssign> Proposal<T, T> for IsotropicGaussian<T>
334where
335 rand_distr::StandardNormal: rand_distr::Distribution<T>,
336{
337 fn sample(&mut self, current: &[T]) -> Vec<T> {
338 let normal = Normal::new(T::zero(), self.std)
339 .expect("Expecting creation of normal distribution to succeed.");
340 normal
341 .sample_iter(&mut self.rng)
342 .zip(current)
343 .map(|(x, eps)| x + *eps)
344 .collect()
345 }
346
347 fn logp(&self, from: &[T], to: &[T]) -> T {
348 let mut lp = T::zero();
349 let d = T::from(from.len()).unwrap();
350 let two = T::from(2).unwrap();
351 let var = self.std * self.std;
352 for (&f, &t) in from.iter().zip(to.iter()) {
353 let diff = t - f;
354 let exponent = -(diff * diff) / (two * var);
355 lp += exponent;
356 }
357 lp += -d * T::from(0.5).unwrap() * (var * T::from(PI).unwrap() * self.std * self.std).ln();
358 lp
359 }
360
361 fn set_seed(mut self, seed: u64) -> Self {
362 self.rng = SmallRng::seed_from_u64(seed);
363 self
364 }
365}
366
367impl<T: Float> Target<T, T> for IsotropicGaussian<T> {
368 fn unnorm_logp(&self, theta: &[T]) -> T {
369 let mut sum = T::zero();
370 for &x in theta.iter() {
371 sum = sum + x * x
372 }
373 -T::from(0.5).unwrap() * sum / (self.std * self.std)
374 }
375}
376
377#[derive(Debug, Clone, PartialEq, Eq)]
395pub struct Categorical<T>
396where
397 T: Float + std::ops::AddAssign,
398{
399 pub probs: Vec<T>,
400 rng: SmallRng,
401}
402
403impl<T: Float + std::ops::AddAssign> Categorical<T> {
404 pub fn new(probs: Vec<T>) -> Self {
407 let sum: T = probs.iter().cloned().fold(T::zero(), |acc, x| acc + x);
408 let normalized: Vec<T> = probs.into_iter().map(|p| p / sum).collect();
409 Self {
410 probs: normalized,
411 rng: SmallRng::from_entropy(),
412 }
413 }
414}
415
416impl<T: Float + std::ops::AddAssign> Discrete<T> for Categorical<T>
417where
418 rand_distr::Standard: rand_distr::Distribution<T>,
419{
420 fn sample(&mut self) -> usize {
421 let r: T = self.rng.gen();
422 let mut cum: T = T::zero();
423 let mut k = self.probs.len() - 1;
424 for (i, &p) in self.probs.iter().enumerate() {
425 cum += p;
426 if r <= cum {
427 k = i;
428 break;
429 }
430 }
431 k
432 }
433
434 fn logp(&self, index: usize) -> T {
435 if index < self.probs.len() {
436 self.probs[index].ln()
437 } else {
438 T::neg_infinity()
439 }
440 }
441}
442
443impl<T: Float + AddAssign> Target<usize, T> for Categorical<T>
444where
445 rand_distr::Standard: rand_distr::Distribution<T>,
446{
447 fn unnorm_logp(&self, theta: &[usize]) -> T {
448 <Self as Discrete<T>>::logp(self, theta[0])
449 }
450}
451
452pub trait Conditional<S> {
459 fn sample(&mut self, index: usize, given: &[S]) -> S;
460}
461
462#[cfg(test)]
463mod continuous_tests {
464 use super::*;
465
466 fn normalize_isogauss(x: f64, d: usize, std: f64) -> f64 {
481 let log_normalizer = -((d as f64) / 2.0) * ((2.0_f64).ln() + PI.ln() + 2.0 * std.ln());
482 (x + log_normalizer).exp()
483 }
484
485 #[test]
486 fn iso_gauss_unnorm_logp_test_1() {
487 let distr = IsotropicGaussian::new(1.0);
488 let p = normalize_isogauss(distr.unnorm_logp(&[1.0]), 1, distr.std);
489 let true_p = 0.24197072451914337;
490 let diff = (p - true_p).abs();
491 assert!(
492 diff < 1e-7,
493 "Expected diff < 1e-7, got {diff} with p={p} (expected ~{true_p})."
494 );
495 }
496
497 #[test]
498 fn iso_gauss_unnorm_logp_test_2() {
499 let distr = IsotropicGaussian::new(2.0);
500 let p = normalize_isogauss(distr.unnorm_logp(&[0.42, 9.6]), 2, distr.std);
501 let true_p = 3.864661987252467e-7;
502 let diff = (p - true_p).abs();
503 assert!(
504 diff < 1e-15,
505 "Expected diff < 1e-15, got {diff} with p={p} (expected ~{true_p})"
506 );
507 }
508
509 #[test]
510 fn iso_gauss_unnorm_logp_test_3() {
511 let distr = IsotropicGaussian::new(3.0);
512 let p = normalize_isogauss(distr.unnorm_logp(&[1.0, 2.0, 3.0]), 3, distr.std);
513 let true_p = 0.001080393185560214;
514 let diff = (p - true_p).abs();
515 assert!(
516 diff < 1e-8,
517 "Expected diff < 1e-8, got {diff} with p={p} (expected ~{true_p})"
518 );
519 }
520}
521
522#[cfg(test)]
523mod categorical_tests {
524 use super::*;
525
526 fn approx_eq(a: f64, b: f64, tol: f64) -> bool {
528 (a - b).abs() < tol
529 }
530
531 #[test]
535 fn test_categorical_logp_f64() {
536 let probs = vec![0.2, 0.3, 0.5];
537 let cat = Categorical::<f64>::new(probs.clone());
538
539 let logp_0 = cat.logp(0);
541 let logp_1 = cat.logp(1);
542 let logp_2 = cat.logp(2);
543
544 let expected_0 = 0.2_f64.ln();
546 let expected_1 = 0.3_f64.ln();
547 let expected_2 = 0.5_f64.ln();
548
549 let tol = 1e-7;
550 assert!(
551 approx_eq(logp_0, expected_0, tol),
552 "Log prob mismatch at index 0: got {}, expected {}",
553 logp_0,
554 expected_0
555 );
556 assert!(
557 approx_eq(logp_1, expected_1, tol),
558 "Log prob mismatch at index 1: got {}, expected {}",
559 logp_1,
560 expected_1
561 );
562 assert!(
563 approx_eq(logp_2, expected_2, tol),
564 "Log prob mismatch at index 2: got {}, expected {}",
565 logp_2,
566 expected_2
567 );
568
569 let logp_out = cat.logp(3);
571 assert_eq!(
572 logp_out,
573 f64::NEG_INFINITY,
574 "Out-of-bounds index did not return NEG_INFINITY"
575 );
576 }
577
578 #[test]
582 fn test_categorical_sampling_f64() {
583 let probs = vec![0.2, 0.3, 0.5];
584 let mut cat = Categorical::<f64>::new(probs.clone());
585
586 let num_samples = 100_000;
587 let mut counts = vec![0_usize; probs.len()];
588
589 for _ in 0..num_samples {
591 let sample = cat.sample();
592 counts[sample] += 1;
593 }
594
595 let tol = 0.01; for (i, &count) in counts.iter().enumerate() {
598 let freq = count as f64 / num_samples as f64;
599 let expected = probs[i];
600 assert!(
601 approx_eq(freq, expected, tol),
602 "Empirical freq for index {} is off: got {:.3}, expected {:.3}",
603 i,
604 freq,
605 expected
606 );
607 }
608 }
609
610 #[test]
614 fn test_categorical_logp_f32() {
615 let probs = vec![0.1_f32, 0.4, 0.5];
616 let cat = Categorical::<f32>::new(probs.clone());
617
618 let logp_0: f32 = cat.logp(0);
619 let logp_1 = cat.logp(1);
620 let logp_2 = cat.logp(2);
621
622 let expected_0 = (0.1_f64).ln();
624 let expected_1 = (0.4_f64).ln();
625 let expected_2 = (0.5_f64).ln();
626
627 let tol = 1e-6;
628 assert!(
629 approx_eq(logp_0.into(), expected_0, tol),
630 "Log prob mismatch at index 0 (f32 -> f64 cast)"
631 );
632 assert!(
633 approx_eq(logp_1.into(), expected_1, tol),
634 "Log prob mismatch at index 1"
635 );
636 assert!(
637 approx_eq(logp_2.into(), expected_2, tol),
638 "Log prob mismatch at index 2"
639 );
640
641 let logp_out = cat.logp(3);
643 assert_eq!(logp_out, f32::NEG_INFINITY);
644 }
645
646 #[test]
650 fn test_categorical_sampling_f32() {
651 let probs = vec![0.1_f32, 0.4, 0.5];
652 let mut cat = Categorical::<f32>::new(probs.clone());
653
654 let num_samples = 100_000;
655 let mut counts = vec![0_usize; probs.len()];
656
657 for _ in 0..num_samples {
658 let sample = cat.sample();
659 counts[sample] += 1;
660 }
661
662 let tol = 0.02; for (i, &count) in counts.iter().enumerate() {
665 let freq = count as f32 / num_samples as f32;
666 let expected = probs[i];
667 assert!(
668 (freq - expected).abs() < tol,
669 "Empirical freq for index {} is off: got {:.3}, expected {:.3}",
670 i,
671 freq,
672 expected
673 );
674 }
675 }
676
677 #[test]
678 fn test_categorical_sample_single_value() {
679 let mut cat = Categorical {
680 probs: vec![1.0_f64],
681 rng: rand::rngs::SmallRng::from_seed(Default::default()),
682 };
683
684 let sampled_index = cat.sample();
685
686 assert_eq!(
687 sampled_index, 0,
688 "Should return the last index (0) for a single-element vector"
689 );
690 }
691
692 #[test]
693 fn test_target_for_categorical_in_range() {
694 let probs = vec![0.2_f64, 0.3, 0.5];
696 let cat = Categorical::new(probs.clone());
697 let logp = cat.unnorm_logp(&[1]);
699 let expected = 0.3_f64.ln();
701 let tol = 1e-7;
702 assert!(
703 (logp - expected).abs() < tol,
704 "For index 1, expected ln(0.3) ~ {}, got {}",
705 expected,
706 logp
707 );
708 }
709
710 #[test]
711 fn test_target_for_categorical_out_of_range() {
712 let probs = vec![0.2_f64, 0.3, 0.5];
713 let cat = Categorical::new(probs);
714 let logp = cat.unnorm_logp(&[3]);
717 assert_eq!(
718 logp,
719 f64::NEG_INFINITY,
720 "Expected negative infinity for out-of-range index, got {}",
721 logp
722 );
723 }
724
725 #[test]
726 fn test_gaussian2d_logp() {
727 let mean = arr1(&[0.0, 0.0]);
728 let cov = arr2(&[[1.0, 0.0], [0.0, 1.0]]);
729 let gauss = Gaussian2D { mean, cov };
730
731 let theta = vec![0.5, -0.5];
732 let computed_logp = gauss.logp(&theta);
733
734 let expected_logp = -2.0878770664093453;
735
736 let tol = 1e-10;
737 assert!(
738 (computed_logp - expected_logp).abs() < tol,
739 "Computed log density ({}) differs from expected ({}) by more than tolerance ({})",
740 computed_logp,
741 expected_logp,
742 tol
743 );
744 }
745}