mini_mcmc/
distributions.rs

1/*!
2Traits for defining continuous and discrete proposal- and target distributions.
3Includes an implementations of commonly used distributions.
4
5This module is generic over the floating-point precision (e.g. `f32` or `f64`)
6using [`num_traits::Float`]. It also defines several traits:
7- [`Target`] for densities we want to sample from,
8- [`Proposal`] for proposal mechanisms,
9- [`Normalized`] for distributions that can compute a fully normalized log-prob,
10- [`Discrete`] for distributions over finite sets.
11
12# Examples
13
14### Continuous Distributions
15
16```rust
17use mini_mcmc::distributions::{
18    Gaussian2D, IsotropicGaussian, Proposal,
19    Target, Normalized
20};
21use ndarray::{arr1, arr2};
22
23// ----------------------
24// Example: Gaussian2D (2D with full covariance)
25// ----------------------
26let mean = arr1(&[0.0, 0.0]);
27let cov = arr2(&[[1.0, 0.0],
28                [0.0, 1.0]]);
29let gauss: Gaussian2D<f64> = Gaussian2D { mean, cov };
30
31// Compute the fully normalized log-prob at (0.5, -0.5):
32let logp = gauss.logp(&vec![0.5, -0.5]);
33println!("Normalized log-density (2D Gaussian): {}", logp);
34
35// ----------------------
36// Example: IsotropicGaussian (any dimension)
37// ----------------------
38let mut proposal: IsotropicGaussian<f64> = IsotropicGaussian::new(1.0);
39let current = vec![0.0, 0.0];  // dimension = 2 in this example
40let candidate = proposal.sample(&current);
41println!("Candidate state: {:?}", candidate);
42*/
43
44use burn::prelude::*;
45use burn::tensor::backend::AutodiffBackend;
46use ndarray::{arr1, arr2, Array1, Array2, NdFloat};
47use num_traits::Float;
48use rand::rngs::SmallRng;
49use rand::{Rng, SeedableRng};
50use rand_distr::{Distribution, Normal};
51use std::f64::consts::PI;
52use std::ops::AddAssign;
53
54/// A batched target trait for computing the unnormalized log density (and gradients) for a
55/// collection of positions.
56///
57/// Implement this trait for your target distribution to enable gradient-based sampling.
58///
59/// # Type Parameters
60///
61/// * `T`: The floating-point type (e.g., f32 or f64).
62/// * `B`: The autodiff backend from the `burn` crate.
63pub trait GradientTarget<T: Float, B: AutodiffBackend> {
64    /// Compute the log density for a batch of positions.
65    ///
66    /// # Parameters
67    ///
68    /// * `positions`: A tensor of shape `[n_chains, D]` representing the current positions for each chain.
69    ///
70    /// # Returns
71    ///
72    /// A 1D tensor of shape `[n_chains]` containing the log density for each chain.
73    fn unnorm_logp(&self, positions: Tensor<B, 2>) -> Tensor<B, 1>;
74}
75
76/// A trait for generating proposals Metropolis–Hastings-like algorithms.
77/// The state type `T` is typically a vector of continuous values.
78pub trait Proposal<T, F: Float> {
79    /// Samples a new point from q(x' | x).
80    fn sample(&mut self, current: &[T]) -> Vec<T>;
81
82    /// Evaluates log q(x' | x).
83    fn logp(&self, from: &[T], to: &[T]) -> F;
84
85    /// Returns a new instance of this proposal distribution seeded with `seed`.
86    fn set_seed(self, seed: u64) -> Self;
87}
88
89/// A trait for continuous target distributions from which we want to sample.
90/// The state type `T` is typically a vector of continuous values.
91pub trait Target<T, F: Float> {
92    /// Returns the log of the unnormalized density for state `theta`.
93    fn unnorm_logp(&self, theta: &[T]) -> F;
94}
95
96/// A trait for distributions that provide a normalized log-density (e.g. for diagnostics).
97pub trait Normalized<T, F: Float> {
98    /// Returns the normalized log-density for state `theta`.
99    fn logp(&self, theta: &[T]) -> F;
100}
101
102/** A trait for discrete distributions whose state is represented as an index.
103 ```rust
104 use mini_mcmc::distributions::{Categorical, Discrete};
105
106 // Create a categorical distribution over three categories.
107 let mut cat = Categorical::new(vec![0.2f64, 0.3, 0.5]);
108 let sample = cat.sample();
109 println!("Sampled category: {}", sample); // E.g. 1usize
110
111 let logp = cat.logp(sample);
112 println!("Log-probability of sampled category: {}", logp); // E.g. 0.3f64
113```
114*/
115pub trait Discrete<T: Float> {
116    /// Samples an index from the distribution.
117    fn sample(&mut self) -> usize;
118    /// Evaluates the log-probability of the given index.
119    fn logp(&self, index: usize) -> T;
120}
121
122/**
123A 2D Gaussian distribution parameterized by a mean vector and a 2Ă—2 covariance matrix.
124
125- The generic type `T` is typically `f32` or `f64`.
126- Implements both [`Target`] (for unnormalized log-prob) and
127  [`Normalized`] (for fully normalized log-prob).
128
129# Example
130
131```rust
132use mini_mcmc::distributions::{Gaussian2D, Normalized};
133use ndarray::{arr1, arr2};
134
135let mean = arr1(&[0.0, 0.0]);
136let cov = arr2(&[[1.0, 0.0],
137                [0.0, 1.0]]);
138let gauss: Gaussian2D<f64> = Gaussian2D { mean, cov };
139
140let lp = gauss.logp(&vec![0.5, -0.5]);
141println!("Normalized log probability: {}", lp);
142```
143*/
144#[derive(Debug, Clone, PartialEq, Eq)]
145pub struct Gaussian2D<T: Float> {
146    pub mean: Array1<T>,
147    pub cov: Array2<T>,
148}
149
150impl<T> Normalized<T, T> for Gaussian2D<T>
151where
152    T: NdFloat,
153{
154    /// Computes the fully normalized log-density of a 2D Gaussian.
155    fn logp(&self, theta: &[T]) -> T {
156        let term_1 = -(T::from(2.0).unwrap() * T::from(PI).unwrap()).ln();
157        let (a, b, c, d) = (
158            self.cov[(0, 0)],
159            self.cov[(0, 1)],
160            self.cov[(1, 0)],
161            self.cov[(1, 1)],
162        );
163        let det = a * d - b * c;
164        let half = T::from(0.5).unwrap();
165        let term_2 = -half * det.abs().ln();
166
167        let x = arr1(theta);
168        let diff = x - self.mean.clone();
169        let inv_cov = arr2(&[[d, -b], [-c, a]]) / det;
170        let term_3 = -half * diff.dot(&inv_cov).dot(&diff);
171        term_1 + term_2 + term_3
172    }
173}
174
175impl<T> Target<T, T> for Gaussian2D<T>
176where
177    T: NdFloat,
178{
179    fn unnorm_logp(&self, theta: &[T]) -> T {
180        let (a, b, c, d) = (
181            self.cov[(0, 0)],
182            self.cov[(0, 1)],
183            self.cov[(1, 0)],
184            self.cov[(1, 1)],
185        );
186        let det = a * d - b * c;
187        let x = arr1(theta);
188        let diff = x - self.mean.clone();
189        let inv_cov = arr2(&[[d, -b], [-c, a]]) / det;
190        -T::from(0.5).unwrap() * diff.dot(&inv_cov).dot(&diff)
191    }
192}
193
194/// A 2D Gaussian target distribution, parameterized by mean and covariance.
195///
196/// This struct also precomputes the inverse covariance and a log-normalization
197/// constant so we can quickly evaluate log-density and gradients in `unnorm_logp`.
198#[derive(Debug, Clone)]
199pub struct DiffableGaussian2D<T: Float> {
200    pub mean: [T; 2],
201    pub cov: [[T; 2]; 2],
202    pub inv_cov: [[T; 2]; 2],
203    pub logdet_cov: T,
204    pub norm_const: T,
205}
206
207impl<T> DiffableGaussian2D<T>
208where
209    T: Float + std::fmt::Debug + num_traits::FloatConst,
210{
211    /// Create a new 2D Gaussian with the specified mean and covariance.
212    /// We automatically compute the covariance inverse and log-determinant.
213    pub fn new(mean: [T; 2], cov: [[T; 2]; 2]) -> Self {
214        // Compute determinant
215        let det_cov = cov[0][0] * cov[1][1] - cov[0][1] * cov[1][0];
216        // Inverse of a 2x2:
217        // [a, b; c, d]^-1 = (1/det) [ d, -b; -c, a ]
218        let inv_det = T::one() / det_cov;
219        let inv_cov = [
220            [cov[1][1] * inv_det, -cov[0][1] * inv_det],
221            [-cov[1][0] * inv_det, cov[0][0] * inv_det],
222        ];
223        let logdet_cov = det_cov.ln(); // T must implement Float
224                                       // Normalization constant for log pdf in 2 dimensions:
225                                       //   - (1/2) * (dim * ln(2 pi) + ln(|Sigma|))
226                                       //   = -1/2 [ 2 * ln(2*pi) + ln(det_cov) ]
227        let two = T::one() + T::one();
228        let norm_const = -(two * (two * T::PI()).ln() + logdet_cov) / two;
229
230        Self {
231            mean,
232            cov,
233            inv_cov,
234            logdet_cov,
235            norm_const,
236        }
237    }
238}
239
240impl<T, B> GradientTarget<T, B> for DiffableGaussian2D<T>
241where
242    T: Float + burn::tensor::ElementConversion + std::fmt::Debug + burn::tensor::Element,
243    B: AutodiffBackend,
244{
245    /// Evaluate the log probability for a batch of positions: shape [n_chains, 2].
246    /// Return shape [n_chains].
247    /// Note: It is not necessary to return the log probability here but for easier debugging we do so anyways.
248    fn unnorm_logp(&self, positions: Tensor<B, 2>) -> Tensor<B, 1> {
249        let (n_chains, dim) = (positions.dims()[0], positions.dims()[1]);
250        assert_eq!(dim, 2, "Gaussian2D: expected dimension=2.");
251
252        // 1) Subtract mean => shape [n_chains, 2]
253        //    We'll broadcast self.mean onto all rows:
254        // Suppose self.mean = [T; 2] where T: Float, and we want a shape [1, 2].
255        let mean_tensor =
256            Tensor::<B, 2>::from_floats([[self.mean[0], self.mean[1]]], &B::Device::default())
257                .reshape([1, 2])
258                .expand([n_chains, 2]);
259
260        let delta = positions.clone() - mean_tensor;
261
262        // 2) We have inv_cov as a 2x2 matrix. Let's define it as a Tensor for matmul
263        let inv_cov_data = [
264            self.inv_cov[0][0],
265            self.inv_cov[0][1],
266            self.inv_cov[1][0],
267            self.inv_cov[1][1],
268        ];
269        let inv_cov_t =
270            Tensor::<B, 2>::from_floats([inv_cov_data], &B::Device::default()).reshape([2, 2]);
271
272        // 3) The quadratic form is: delta^T * inv_cov * delta
273        // For each chain, shape is [1,2] * [2,2] * [2,1] => scalar
274        // We'll do it in a batched style:
275        //   We can do: z = delta matmul inv_cov => shape [n_chains, 2]
276        //   Then z * delta => shape [n_chains, 2], sum dim=1 => shape [n_chains]
277        let z = delta.clone().matmul(inv_cov_t); // shape [n_chains, 2]
278        let quad = (z * delta).sum_dim(1).squeeze(1); // shape [n_chains]
279
280        // 4) The log density for each chain i is:
281        //     log p(x_i) = norm_const - 0.5 * quad[i]
282        // where norm_const is -0.5 * (2 ln(2 pi) + ln det(Sigma)).
283        // We'll broadcast that to shape [n_chains].
284        let shape = Shape::new([n_chains]);
285        let norm_c = Tensor::<B, 1>::ones(shape, &B::Device::default()).mul_scalar(self.norm_const);
286        let half = T::from(0.5).unwrap();
287        norm_c - quad.mul_scalar(half)
288    }
289}
290
291/**
292An *isotropic* Gaussian distribution usable as either a target or a proposal
293in MCMC. It works for **any dimension** because it applies independent
294Gaussian noise (`mean = 0`, `std = self.std`) to each coordinate.
295
296- Implements [`Proposal`] so it can propose new states
297  from a current state.
298- Also implements [`Target`] for an unnormalized log-prob,
299  which might be useful if you want to treat it as a target distribution
300  in simplified scenarios.
301
302# Examples
303
304```rust
305use mini_mcmc::distributions::{IsotropicGaussian, Proposal};
306
307let mut proposal: IsotropicGaussian<f64> = IsotropicGaussian::new(1.0);
308let current = vec![0.0, 0.0, 0.0]; // dimension = 3
309let candidate = proposal.sample(&current);
310println!("Candidate state: {:?}", candidate);
311
312// Evaluate log q(candidate | current):
313let logq = proposal.logp(&current, &candidate);
314println!("Log of the proposal density: {}", logq);
315```
316*/
317#[derive(Debug, Clone, PartialEq, Eq)]
318pub struct IsotropicGaussian<T: Float> {
319    pub std: T,
320    rng: SmallRng,
321}
322
323impl<T: Float> IsotropicGaussian<T> {
324    /// Creates a new isotropic Gaussian proposal distribution with the specified standard deviation.
325    pub fn new(std: T) -> Self {
326        Self {
327            std,
328            rng: SmallRng::from_entropy(),
329        }
330    }
331}
332
333impl<T: Float + std::ops::AddAssign> Proposal<T, T> for IsotropicGaussian<T>
334where
335    rand_distr::StandardNormal: rand_distr::Distribution<T>,
336{
337    fn sample(&mut self, current: &[T]) -> Vec<T> {
338        let normal = Normal::new(T::zero(), self.std)
339            .expect("Expecting creation of normal distribution to succeed.");
340        normal
341            .sample_iter(&mut self.rng)
342            .zip(current)
343            .map(|(x, eps)| x + *eps)
344            .collect()
345    }
346
347    fn logp(&self, from: &[T], to: &[T]) -> T {
348        let mut lp = T::zero();
349        let d = T::from(from.len()).unwrap();
350        let two = T::from(2).unwrap();
351        let var = self.std * self.std;
352        for (&f, &t) in from.iter().zip(to.iter()) {
353            let diff = t - f;
354            let exponent = -(diff * diff) / (two * var);
355            lp += exponent;
356        }
357        lp += -d * T::from(0.5).unwrap() * (var * T::from(PI).unwrap() * self.std * self.std).ln();
358        lp
359    }
360
361    fn set_seed(mut self, seed: u64) -> Self {
362        self.rng = SmallRng::seed_from_u64(seed);
363        self
364    }
365}
366
367impl<T: Float> Target<T, T> for IsotropicGaussian<T> {
368    fn unnorm_logp(&self, theta: &[T]) -> T {
369        let mut sum = T::zero();
370        for &x in theta.iter() {
371            sum = sum + x * x
372        }
373        -T::from(0.5).unwrap() * sum / (self.std * self.std)
374    }
375}
376
377/**
378A categorical distribution represents a discrete probability distribution over a finite set of categories.
379
380The probabilities in `probs` should sum to 1 (or they will be normalized automatically).
381
382# Examples
383
384```rust
385use mini_mcmc::distributions::{Categorical, Discrete};
386
387let mut cat = Categorical::new(vec![0.2f64, 0.3, 0.5]);
388let sample = cat.sample();
389println!("Sampled category: {}", sample);
390let logp = cat.logp(sample);
391println!("Log probability of category {}: {}", sample, logp);
392```
393*/
394#[derive(Debug, Clone, PartialEq, Eq)]
395pub struct Categorical<T>
396where
397    T: Float + std::ops::AddAssign,
398{
399    pub probs: Vec<T>,
400    rng: SmallRng,
401}
402
403impl<T: Float + std::ops::AddAssign> Categorical<T> {
404    /// Creates a new categorical distribution from a vector of probabilities.
405    /// The probabilities will be normalized so that they sum to 1.
406    pub fn new(probs: Vec<T>) -> Self {
407        let sum: T = probs.iter().cloned().fold(T::zero(), |acc, x| acc + x);
408        let normalized: Vec<T> = probs.into_iter().map(|p| p / sum).collect();
409        Self {
410            probs: normalized,
411            rng: SmallRng::from_entropy(),
412        }
413    }
414}
415
416impl<T: Float + std::ops::AddAssign> Discrete<T> for Categorical<T>
417where
418    rand_distr::Standard: rand_distr::Distribution<T>,
419{
420    fn sample(&mut self) -> usize {
421        let r: T = self.rng.gen();
422        let mut cum: T = T::zero();
423        let mut k = self.probs.len() - 1;
424        for (i, &p) in self.probs.iter().enumerate() {
425            cum += p;
426            if r <= cum {
427                k = i;
428                break;
429            }
430        }
431        k
432    }
433
434    fn logp(&self, index: usize) -> T {
435        if index < self.probs.len() {
436            self.probs[index].ln()
437        } else {
438            T::neg_infinity()
439        }
440    }
441}
442
443impl<T: Float + AddAssign> Target<usize, T> for Categorical<T>
444where
445    rand_distr::Standard: rand_distr::Distribution<T>,
446{
447    fn unnorm_logp(&self, theta: &[usize]) -> T {
448        <Self as Discrete<T>>::logp(self, theta[0])
449    }
450}
451
452/**
453A trait for conditional distributions.
454
455This trait specifies how to sample a single coordinate of a state given the entire current state.
456It is primarily used in Gibbs sampling to update one coordinate at a time.
457*/
458pub trait Conditional<S> {
459    fn sample(&mut self, index: usize, given: &[S]) -> S;
460}
461
462#[cfg(test)]
463mod continuous_tests {
464    use super::*;
465
466    /**
467    A helper function to normalize the unnormalized log probability of an isotropic Gaussian
468    into a proper probability value (by applying the appropriate constant).
469
470    # Arguments
471
472    * `x` - The unnormalized log probability.
473    * `d` - The dimensionality of the state.
474    * `std` - The standard deviation used in the isotropic Gaussian.
475
476    # Returns
477
478    Returns the normalized probability as an `f64`.
479    */
480    fn normalize_isogauss(x: f64, d: usize, std: f64) -> f64 {
481        let log_normalizer = -((d as f64) / 2.0) * ((2.0_f64).ln() + PI.ln() + 2.0 * std.ln());
482        (x + log_normalizer).exp()
483    }
484
485    #[test]
486    fn iso_gauss_unnorm_logp_test_1() {
487        let distr = IsotropicGaussian::new(1.0);
488        let p = normalize_isogauss(distr.unnorm_logp(&[1.0]), 1, distr.std);
489        let true_p = 0.24197072451914337;
490        let diff = (p - true_p).abs();
491        assert!(
492            diff < 1e-7,
493            "Expected diff < 1e-7, got {diff} with p={p} (expected ~{true_p})."
494        );
495    }
496
497    #[test]
498    fn iso_gauss_unnorm_logp_test_2() {
499        let distr = IsotropicGaussian::new(2.0);
500        let p = normalize_isogauss(distr.unnorm_logp(&[0.42, 9.6]), 2, distr.std);
501        let true_p = 3.864661987252467e-7;
502        let diff = (p - true_p).abs();
503        assert!(
504            diff < 1e-15,
505            "Expected diff < 1e-15, got {diff} with p={p} (expected ~{true_p})"
506        );
507    }
508
509    #[test]
510    fn iso_gauss_unnorm_logp_test_3() {
511        let distr = IsotropicGaussian::new(3.0);
512        let p = normalize_isogauss(distr.unnorm_logp(&[1.0, 2.0, 3.0]), 3, distr.std);
513        let true_p = 0.001080393185560214;
514        let diff = (p - true_p).abs();
515        assert!(
516            diff < 1e-8,
517            "Expected diff < 1e-8, got {diff} with p={p} (expected ~{true_p})"
518        );
519    }
520}
521
522#[cfg(test)]
523mod categorical_tests {
524    use super::*;
525
526    /// A helper function to compare floating-point values with a given tolerance.
527    fn approx_eq(a: f64, b: f64, tol: f64) -> bool {
528        (a - b).abs() < tol
529    }
530
531    // ------------------------------------------------------
532    // 1) Test logp correctness for f64
533    // ------------------------------------------------------
534    #[test]
535    fn test_categorical_logp_f64() {
536        let probs = vec![0.2, 0.3, 0.5];
537        let cat = Categorical::<f64>::new(probs.clone());
538
539        // Check log probabilities for each index
540        let logp_0 = cat.logp(0);
541        let logp_1 = cat.logp(1);
542        let logp_2 = cat.logp(2);
543
544        // Expected values
545        let expected_0 = 0.2_f64.ln();
546        let expected_1 = 0.3_f64.ln();
547        let expected_2 = 0.5_f64.ln();
548
549        let tol = 1e-7;
550        assert!(
551            approx_eq(logp_0, expected_0, tol),
552            "Log prob mismatch at index 0: got {}, expected {}",
553            logp_0,
554            expected_0
555        );
556        assert!(
557            approx_eq(logp_1, expected_1, tol),
558            "Log prob mismatch at index 1: got {}, expected {}",
559            logp_1,
560            expected_1
561        );
562        assert!(
563            approx_eq(logp_2, expected_2, tol),
564            "Log prob mismatch at index 2: got {}, expected {}",
565            logp_2,
566            expected_2
567        );
568
569        // Out-of-bounds index should be NEG_INFINITY
570        let logp_out = cat.logp(3);
571        assert_eq!(
572            logp_out,
573            f64::NEG_INFINITY,
574            "Out-of-bounds index did not return NEG_INFINITY"
575        );
576    }
577
578    // ------------------------------------------------------
579    // 2) Test sampling frequencies for f64
580    // ------------------------------------------------------
581    #[test]
582    fn test_categorical_sampling_f64() {
583        let probs = vec![0.2, 0.3, 0.5];
584        let mut cat = Categorical::<f64>::new(probs.clone());
585
586        let num_samples = 100_000;
587        let mut counts = vec![0_usize; probs.len()];
588
589        // Draw samples and tally outcomes
590        for _ in 0..num_samples {
591            let sample = cat.sample();
592            counts[sample] += 1;
593        }
594
595        // Check empirical frequencies
596        let tol = 0.01; // 1% absolute tolerance
597        for (i, &count) in counts.iter().enumerate() {
598            let freq = count as f64 / num_samples as f64;
599            let expected = probs[i];
600            assert!(
601                approx_eq(freq, expected, tol),
602                "Empirical freq for index {} is off: got {:.3}, expected {:.3}",
603                i,
604                freq,
605                expected
606            );
607        }
608    }
609
610    // ------------------------------------------------------
611    // 3) Test logp correctness for f32
612    // ------------------------------------------------------
613    #[test]
614    fn test_categorical_logp_f32() {
615        let probs = vec![0.1_f32, 0.4, 0.5];
616        let cat = Categorical::<f32>::new(probs.clone());
617
618        let logp_0: f32 = cat.logp(0);
619        let logp_1 = cat.logp(1);
620        let logp_2 = cat.logp(2);
621
622        // For comparison, cast to f64
623        let expected_0 = (0.1_f64).ln();
624        let expected_1 = (0.4_f64).ln();
625        let expected_2 = (0.5_f64).ln();
626
627        let tol = 1e-6;
628        assert!(
629            approx_eq(logp_0.into(), expected_0, tol),
630            "Log prob mismatch at index 0 (f32 -> f64 cast)"
631        );
632        assert!(
633            approx_eq(logp_1.into(), expected_1, tol),
634            "Log prob mismatch at index 1"
635        );
636        assert!(
637            approx_eq(logp_2.into(), expected_2, tol),
638            "Log prob mismatch at index 2"
639        );
640
641        // Out-of-bounds
642        let logp_out = cat.logp(3);
643        assert_eq!(logp_out, f32::NEG_INFINITY);
644    }
645
646    // ------------------------------------------------------
647    // 4) Test sampling frequencies for f32
648    // ------------------------------------------------------
649    #[test]
650    fn test_categorical_sampling_f32() {
651        let probs = vec![0.1_f32, 0.4, 0.5];
652        let mut cat = Categorical::<f32>::new(probs.clone());
653
654        let num_samples = 100_000;
655        let mut counts = vec![0_usize; probs.len()];
656
657        for _ in 0..num_samples {
658            let sample = cat.sample();
659            counts[sample] += 1;
660        }
661
662        // Compare frequencies with expected probabilities
663        let tol = 0.02; // might relax tolerance for f32
664        for (i, &count) in counts.iter().enumerate() {
665            let freq = count as f32 / num_samples as f32;
666            let expected = probs[i];
667            assert!(
668                (freq - expected).abs() < tol,
669                "Empirical freq for index {} is off: got {:.3}, expected {:.3}",
670                i,
671                freq,
672                expected
673            );
674        }
675    }
676
677    #[test]
678    fn test_categorical_sample_single_value() {
679        let mut cat = Categorical {
680            probs: vec![1.0_f64],
681            rng: rand::rngs::SmallRng::from_seed(Default::default()),
682        };
683
684        let sampled_index = cat.sample();
685
686        assert_eq!(
687            sampled_index, 0,
688            "Should return the last index (0) for a single-element vector"
689        );
690    }
691
692    #[test]
693    fn test_target_for_categorical_in_range() {
694        // Create a categorical distribution with known probabilities.
695        let probs = vec![0.2_f64, 0.3, 0.5];
696        let cat = Categorical::new(probs.clone());
697        // Call unnorm_logp with a valid index (say, index 1).
698        let logp = cat.unnorm_logp(&[1]);
699        // The expected log probability is ln(0.3).
700        let expected = 0.3_f64.ln();
701        let tol = 1e-7;
702        assert!(
703            (logp - expected).abs() < tol,
704            "For index 1, expected ln(0.3) ~ {}, got {}",
705            expected,
706            logp
707        );
708    }
709
710    #[test]
711    fn test_target_for_categorical_out_of_range() {
712        let probs = vec![0.2_f64, 0.3, 0.5];
713        let cat = Categorical::new(probs);
714        // Calling unnorm_logp with an index that's out of bounds (e.g. 3)
715        // should return negative infinity.
716        let logp = cat.unnorm_logp(&[3]);
717        assert_eq!(
718            logp,
719            f64::NEG_INFINITY,
720            "Expected negative infinity for out-of-range index, got {}",
721            logp
722        );
723    }
724
725    #[test]
726    fn test_gaussian2d_logp() {
727        let mean = arr1(&[0.0, 0.0]);
728        let cov = arr2(&[[1.0, 0.0], [0.0, 1.0]]);
729        let gauss = Gaussian2D { mean, cov };
730
731        let theta = vec![0.5, -0.5];
732        let computed_logp = gauss.logp(&theta);
733
734        let expected_logp = -2.0878770664093453;
735
736        let tol = 1e-10;
737        assert!(
738            (computed_logp - expected_logp).abs() < tol,
739            "Computed log density ({}) differs from expected ({}) by more than tolerance ({})",
740            computed_logp,
741            expected_logp,
742            tol
743        );
744    }
745}