mini_mcmc/
distributions.rs

1/*!
2Traits for defining continuous and discrete proposal- and target distributions.
3Includes implementations of commonly used distributions.
4
5This module is generic over the floating-point precision (e.g. `f32` or `f64`)
6using [`num_traits::Float`]. It also defines several traits:
7- [`Target`] for densities or PMFs we want to sample from with **Metropolis-Hastings**,
8- [`GradientTarget`] for densities we want to sample from with **NUTS**,
9- [`BatchedGradientTarget`] for densities we want to sample from with **HMC**,
10- [`Conditional`] for densities or PMFs we want to sample from with **Gibbs Sampling**,
11- [`Proposal`] for proposal mechanisms,
12- [`Normalized`] for distributions that can compute a fully normalized log probability,
13- [`Discrete`] for distributions over finite sets.
14
15## Examples
16
17```rust
18use mini_mcmc::distributions::{
19    Gaussian2D, IsotropicGaussian, Proposal,
20    Target, Normalized
21};
22use ndarray::{arr1, arr2};
23
24// ----------------------
25// Example: Gaussian2D (2D with full covariance)
26// ----------------------
27let mean = arr1(&[0.0, 0.0]);
28let cov = arr2(&[[1.0, 0.0],
29                [0.0, 1.0]]);
30let gauss: Gaussian2D<f64> = Gaussian2D { mean, cov };
31
32// Compute the fully normalized log-prob at (0.5, -0.5):
33let logp = gauss.logp(&vec![0.5, -0.5]);
34println!("Normalized log-density (2D Gaussian): {}", logp);
35
36// ----------------------
37// Example: IsotropicGaussian (any dimension)
38// ----------------------
39let mut proposal: IsotropicGaussian<f64> = IsotropicGaussian::new(1.0);
40let current = vec![0.0, 0.0];  // dimension = 2 in this example
41let candidate = proposal.sample(&current);
42println!("Candidate state: {:?}", candidate);
43*/
44
45use burn::prelude::*;
46use burn::tensor::backend::AutodiffBackend;
47use burn::tensor::Element;
48use ndarray::{arr1, arr2, Array1, Array2, NdFloat};
49use num_traits::Float;
50use rand::rngs::SmallRng;
51use rand::{Rng, SeedableRng};
52use rand_distr::{Distribution, Normal, StandardUniform};
53use std::f64::consts::PI;
54use std::ops::AddAssign;
55
56/// A batched target trait for computing the unnormalized log density (and gradients) for a
57/// collection of positions.
58///
59/// Implement this trait for your target distribution to enable gradient-based sampling.
60///
61/// # Type Parameters
62///
63/// * `T`: The floating-point type (e.g., f32 or f64).
64/// * `B`: The autodiff backend from the `burn` crate.
65pub trait BatchedGradientTarget<T: Float, B: AutodiffBackend> {
66    /// Compute the log density for a batch of positions.
67    ///
68    /// # Parameters
69    ///
70    /// * `positions`: A tensor of shape `[n_chains, D]` representing the current positions for each chain.
71    ///
72    /// # Returns
73    ///
74    /// A 1D tensor of shape `[n_chains]` containing the log density for each chain.
75    fn unnorm_logp_batch(&self, positions: Tensor<B, 2>) -> Tensor<B, 1>;
76}
77
78pub trait GradientTarget<T: Float, B: AutodiffBackend> {
79    fn unnorm_logp(&self, position: Tensor<B, 1>) -> Tensor<B, 1>;
80
81    fn unnorm_logp_and_grad(&self, position: Tensor<B, 1>) -> (Tensor<B, 1>, Tensor<B, 1>) {
82        let pos = position.clone().detach().require_grad();
83        let ulogp = self.unnorm_logp(pos.clone());
84        let grad_inner = pos.grad(&ulogp.backward()).unwrap();
85        let grad = Tensor::<B, 1>::from_inner(grad_inner);
86        (ulogp, grad)
87    }
88}
89
90/// A trait for generating proposals Metropolis–Hastings-like algorithms.
91/// The state type `T` is typically a vector of continuous values.
92pub trait Proposal<T, F: Float> {
93    /// Samples a new point from q(x' | x).
94    fn sample(&mut self, current: &[T]) -> Vec<T>;
95
96    /// Evaluates log q(x' | x).
97    fn logp(&self, from: &[T], to: &[T]) -> F;
98
99    /// Returns a new instance of this proposal distribution seeded with `seed`.
100    fn set_seed(self, seed: u64) -> Self;
101}
102
103/// A trait for continuous target distributions from which we want to sample.
104/// The state type `T` is typically a vector of continuous values.
105pub trait Target<T, F: Float> {
106    /// Returns the log of the unnormalized density at `position`.
107    fn unnorm_logp(&self, position: &[T]) -> F;
108}
109
110/// A trait for distributions that provide a normalized log-density (e.g. for diagnostics).
111pub trait Normalized<T, F: Float> {
112    /// Returns the normalized log-density at `position`.
113    fn logp(&self, position: &[T]) -> F;
114}
115
116/** A trait for discrete distributions whose state is represented as an index.
117 ```rust
118 use mini_mcmc::distributions::{Categorical, Discrete};
119
120 // Create a categorical distribution over three categories.
121 let mut cat = Categorical::new(vec![0.2f64, 0.3, 0.5]);
122 let observation = cat.sample();
123 println!("Sampled category: {}", observation); // E.g. 1usize
124
125 let logp = cat.logp(observation);
126 println!("Log-probability of sampled category: {}", logp); // E.g. 0.3f64
127```
128*/
129pub trait Discrete<T: Float> {
130    /// Samples an index from the distribution.
131    fn sample(&mut self) -> usize;
132    /// Evaluates the log-probability of the given index.
133    fn logp(&self, index: usize) -> T;
134}
135
136/**
137A 2D Gaussian distribution parameterized by a mean vector and a 2Ă—2 covariance matrix.
138
139- The generic type `T` is typically `f32` or `f64`.
140- Implements both [`Target`] (for unnormalized log-prob) and
141  [`Normalized`] (for fully normalized log-prob).
142
143# Example
144
145```rust
146use mini_mcmc::distributions::{Gaussian2D, Normalized};
147use ndarray::{arr1, arr2};
148
149let mean = arr1(&[0.0, 0.0]);
150let cov = arr2(&[[1.0, 0.0],
151                [0.0, 1.0]]);
152let gauss: Gaussian2D<f64> = Gaussian2D { mean, cov };
153
154let lp = gauss.logp(&vec![0.5, -0.5]);
155println!("Normalized log probability: {}", lp);
156```
157*/
158#[derive(Debug, Clone, PartialEq, Eq)]
159pub struct Gaussian2D<T: Float> {
160    pub mean: Array1<T>,
161    pub cov: Array2<T>,
162}
163
164impl<T> Normalized<T, T> for Gaussian2D<T>
165where
166    T: NdFloat,
167{
168    /// Computes the fully normalized log-density of a 2D Gaussian.
169    fn logp(&self, position: &[T]) -> T {
170        let term_1 = -(T::from(2.0).unwrap() * T::from(PI).unwrap()).ln();
171        let (a, b, c, d) = (
172            self.cov[(0, 0)],
173            self.cov[(0, 1)],
174            self.cov[(1, 0)],
175            self.cov[(1, 1)],
176        );
177        let det = a * d - b * c;
178        let half = T::from(0.5).unwrap();
179        let term_2 = -half * det.abs().ln();
180
181        let x = arr1(position);
182        let diff = x - self.mean.clone();
183        let inv_cov = arr2(&[[d, -b], [-c, a]]) / det;
184        let term_3 = -half * diff.dot(&inv_cov).dot(&diff);
185        term_1 + term_2 + term_3
186    }
187}
188
189impl<T> Target<T, T> for Gaussian2D<T>
190where
191    T: NdFloat,
192{
193    fn unnorm_logp(&self, position: &[T]) -> T {
194        let (a, b, c, d) = (
195            self.cov[(0, 0)],
196            self.cov[(0, 1)],
197            self.cov[(1, 0)],
198            self.cov[(1, 1)],
199        );
200        let det = a * d - b * c;
201        let x = arr1(position);
202        let diff = x - self.mean.clone();
203        let inv_cov = arr2(&[[d, -b], [-c, a]]) / det;
204        -T::from(0.5).unwrap() * diff.dot(&inv_cov).dot(&diff)
205    }
206}
207
208/// A 2D Gaussian target distribution, parameterized by mean and covariance.
209///
210/// This struct also precomputes the inverse covariance and a log-normalization
211/// constant so we can quickly compute log-densities and gradients in `unnorm_logp`.
212#[derive(Debug, Clone)]
213pub struct DiffableGaussian2D<T: Float> {
214    pub mean: [T; 2],
215    pub cov: [[T; 2]; 2],
216    pub inv_cov: [[T; 2]; 2],
217    pub logdet_cov: T,
218    pub norm_const: T,
219}
220
221impl<T> DiffableGaussian2D<T>
222where
223    T: Float + std::fmt::Debug + num_traits::FloatConst,
224{
225    /// Create a new 2D Gaussian with the specified mean and covariance.
226    /// We automatically compute the covariance inverse and log-determinant.
227    pub fn new(mean: [T; 2], cov: [[T; 2]; 2]) -> Self {
228        // Compute determinant
229        let det_cov = cov[0][0] * cov[1][1] - cov[0][1] * cov[1][0];
230        // Inverse of a 2x2:
231        // [a, b; c, d]^-1 = (1/det) [ d, -b; -c, a ]
232        let inv_det = T::one() / det_cov;
233        let inv_cov = [
234            [cov[1][1] * inv_det, -cov[0][1] * inv_det],
235            [-cov[1][0] * inv_det, cov[0][0] * inv_det],
236        ];
237        let logdet_cov = det_cov.ln(); // T must implement Float
238                                       // Normalization constant for log pdf in 2 dimensions:
239                                       //   - (1/2) * (dim * ln(2 pi) + ln(|Sigma|))
240                                       //   = -1/2 [ 2 * ln(2*pi) + ln(det_cov) ]
241        let two = T::one() + T::one();
242        let norm_const = -(two * (two * T::PI()).ln() + logdet_cov) / two;
243
244        Self {
245            mean,
246            cov,
247            inv_cov,
248            logdet_cov,
249            norm_const,
250        }
251    }
252}
253
254impl<T, B> BatchedGradientTarget<T, B> for DiffableGaussian2D<T>
255where
256    T: Float + burn::tensor::ElementConversion + std::fmt::Debug + burn::tensor::Element,
257    B: AutodiffBackend,
258{
259    /// Evaluate the log probability for a batch of positions: shape [n_chains, 2].
260    /// Return shape [n_chains].
261    /// Note: It is not necessary to return the log probability here but for easier debugging we do so anyways.
262    fn unnorm_logp_batch(&self, positions: Tensor<B, 2>) -> Tensor<B, 1> {
263        let (n_chains, dim) = (positions.dims()[0], positions.dims()[1]);
264        assert_eq!(dim, 2, "Gaussian2D: expected dimension=2.");
265
266        let mean_tensor =
267            Tensor::<B, 2>::from_floats([[self.mean[0], self.mean[1]]], &B::Device::default())
268                .reshape([1, 2])
269                .expand([n_chains, 2]);
270
271        let delta = positions.clone() - mean_tensor;
272
273        let inv_cov_data = [
274            self.inv_cov[0][0],
275            self.inv_cov[0][1],
276            self.inv_cov[1][0],
277            self.inv_cov[1][1],
278        ];
279        let inv_cov_t =
280            Tensor::<B, 2>::from_floats([inv_cov_data], &B::Device::default()).reshape([2, 2]);
281
282        let z = delta.clone().matmul(inv_cov_t); // shape [n_chains, 2]
283        let quad = (z * delta).sum_dim(1).squeeze(1); // shape [n_chains]
284        let shape = Shape::new([n_chains]);
285        let norm_c = Tensor::<B, 1>::ones(shape, &B::Device::default()).mul_scalar(self.norm_const);
286        let half = T::from(0.5).unwrap();
287        norm_c - quad.mul_scalar(half)
288    }
289}
290
291impl<T, B> GradientTarget<T, B> for DiffableGaussian2D<T>
292where
293    T: Float + burn::tensor::ElementConversion + std::fmt::Debug + burn::tensor::Element,
294    B: AutodiffBackend,
295{
296    fn unnorm_logp(&self, position: Tensor<B, 1>) -> Tensor<B, 1> {
297        let dim = position.dims()[0];
298        assert_eq!(dim, 2, "Gaussian2D: expected dimension=2.");
299
300        let mean_tensor =
301            Tensor::<B, 1>::from_floats([self.mean[0], self.mean[1]], &B::Device::default());
302
303        let delta = position.clone() - mean_tensor;
304
305        let inv_cov_data = [
306            [self.inv_cov[0][0], self.inv_cov[0][1]],
307            [self.inv_cov[1][0], self.inv_cov[1][1]],
308        ];
309        let inv_cov_t = Tensor::<B, 2>::from_floats(inv_cov_data, &B::Device::default());
310
311        let z = delta.clone().reshape([1_i32, 2_i32]).matmul(inv_cov_t);
312        let quad = (z.reshape([2_i32]) * delta).sum();
313        let half = T::from(0.5).unwrap();
314        -quad.mul_scalar(half) + self.norm_const
315    }
316}
317
318/**
319An *isotropic* Gaussian distribution usable as either a target or a proposal
320in MCMC. It works for **any dimension** because it applies independent
321Gaussian noise (`mean = 0`, `std = self.std`) to each coordinate.
322
323- Implements [`Proposal`] so it can propose new states
324  from a current state.
325- Also implements [`Target`] for an unnormalized log-prob,
326  which might be useful if you want to treat it as a target distribution
327  in simplified scenarios.
328
329# Examples
330
331```rust
332use mini_mcmc::distributions::{IsotropicGaussian, Proposal};
333
334let mut proposal: IsotropicGaussian<f64> = IsotropicGaussian::new(1.0);
335let current = vec![0.0, 0.0, 0.0]; // dimension = 3
336let candidate = proposal.sample(&current);
337println!("Candidate state: {:?}", candidate);
338
339// Evaluate log q(candidate | current):
340let logq = proposal.logp(&current, &candidate);
341println!("Log of the proposal density: {}", logq);
342```
343*/
344#[derive(Debug, Clone, PartialEq, Eq)]
345pub struct IsotropicGaussian<T: Float> {
346    pub std: T,
347    rng: SmallRng,
348}
349
350impl<T: Float> IsotropicGaussian<T> {
351    /// Creates a new isotropic Gaussian proposal distribution with the specified standard deviation.
352    pub fn new(std: T) -> Self {
353        Self {
354            std,
355            rng: SmallRng::from_os_rng(),
356        }
357    }
358}
359
360impl<T: Float + std::ops::AddAssign> Proposal<T, T> for IsotropicGaussian<T>
361where
362    rand_distr::StandardNormal: rand_distr::Distribution<T>,
363{
364    fn sample(&mut self, current: &[T]) -> Vec<T> {
365        let normal = Normal::new(T::zero(), self.std)
366            .expect("Expecting creation of normal distribution to succeed.");
367        normal
368            .sample_iter(&mut self.rng)
369            .zip(current)
370            .map(|(x, eps)| x + *eps)
371            .collect()
372    }
373
374    fn logp(&self, from: &[T], to: &[T]) -> T {
375        let mut lp = T::zero();
376        let d = T::from(from.len()).unwrap();
377        let two = T::from(2).unwrap();
378        let var = self.std * self.std;
379        for (&f, &t) in from.iter().zip(to.iter()) {
380            let diff = t - f;
381            let exponent = -(diff * diff) / (two * var);
382            lp += exponent;
383        }
384        lp += -d * T::from(0.5).unwrap() * (var * T::from(PI).unwrap() * self.std * self.std).ln();
385        lp
386    }
387
388    fn set_seed(mut self, seed: u64) -> Self {
389        self.rng = SmallRng::seed_from_u64(seed);
390        self
391    }
392}
393
394impl<T: Float> Target<T, T> for IsotropicGaussian<T> {
395    fn unnorm_logp(&self, position: &[T]) -> T {
396        let mut sum = T::zero();
397        for &x in position.iter() {
398            sum = sum + x * x
399        }
400        -T::from(0.5).unwrap() * sum / (self.std * self.std)
401    }
402}
403
404/**
405A categorical distribution represents a discrete probability distribution over a finite set of categories.
406
407The probabilities in `probs` should sum to 1 (or they will be normalized automatically).
408
409# Examples
410
411```rust
412use mini_mcmc::distributions::{Categorical, Discrete};
413
414let mut cat = Categorical::new(vec![0.2f64, 0.3, 0.5]);
415let observation = cat.sample();
416println!("Sampled category: {}", observation);
417let logp = cat.logp(observation);
418println!("Log probability of category {}: {}", observation, logp);
419```
420*/
421#[derive(Debug, Clone, PartialEq, Eq)]
422pub struct Categorical<T>
423where
424    T: Float + std::ops::AddAssign,
425{
426    pub probs: Vec<T>,
427    rng: SmallRng,
428}
429
430impl<T: Float + std::ops::AddAssign> Categorical<T> {
431    /// Creates a new categorical distribution from a vector of probabilities.
432    /// The probabilities will be normalized so that they sum to 1.
433    pub fn new(probs: Vec<T>) -> Self {
434        let sum: T = probs.iter().cloned().fold(T::zero(), |acc, x| acc + x);
435        let normalized: Vec<T> = probs.into_iter().map(|p| p / sum).collect();
436        Self {
437            probs: normalized,
438            rng: SmallRng::from_os_rng(),
439        }
440    }
441}
442
443impl<T: Float + std::ops::AddAssign> Discrete<T> for Categorical<T>
444where
445    StandardUniform: rand::distr::Distribution<T>,
446{
447    fn sample(&mut self) -> usize {
448        let r: T = self.rng.random();
449        let mut cum: T = T::zero();
450        let mut k = self.probs.len() - 1;
451        for (i, &p) in self.probs.iter().enumerate() {
452            cum += p;
453            if r <= cum {
454                k = i;
455                break;
456            }
457        }
458        k
459    }
460
461    fn logp(&self, index: usize) -> T {
462        if index < self.probs.len() {
463            self.probs[index].ln()
464        } else {
465            T::neg_infinity()
466        }
467    }
468}
469
470impl<T: Float + AddAssign> Target<usize, T> for Categorical<T>
471where
472    rand_distr::StandardUniform: rand_distr::Distribution<T>,
473{
474    fn unnorm_logp(&self, position: &[usize]) -> T {
475        <Self as Discrete<T>>::logp(self, position[0])
476    }
477}
478
479/**
480A trait for conditional distributions.
481
482This trait specifies how to sample a single coordinate of a state given the entire current state.
483It is primarily used in Gibbs sampling to update one coordinate at a time.
484*/
485pub trait Conditional<S> {
486    fn sample(&mut self, index: usize, given: &[S]) -> S;
487}
488
489// Define the Rosenbrock distribution.
490#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
491pub struct Rosenbrock2D<T: Float> {
492    pub a: T,
493    pub b: T,
494}
495
496// For the batched version we need to implement BatchGradientTarget.
497impl<T, B> BatchedGradientTarget<T, B> for Rosenbrock2D<T>
498where
499    T: Float + Element,
500    B: AutodiffBackend,
501{
502    fn unnorm_logp_batch(&self, positions: Tensor<B, 2>) -> Tensor<B, 1> {
503        let n = positions.dims()[0];
504        let x = positions.clone().slice([0..n, 0..1]);
505        let y = positions.slice([0..n, 1..2]);
506        let term_1 = (-x.clone()).add_scalar(self.a).powi_scalar(2);
507        let term_2 = y.sub(x.powi_scalar(2)).powi_scalar(2).mul_scalar(self.b);
508        -(term_1 + term_2).flatten(0, 1)
509    }
510}
511
512impl<T, B> GradientTarget<T, B> for Rosenbrock2D<T>
513where
514    T: Float + Element,
515    B: AutodiffBackend,
516{
517    fn unnorm_logp(&self, position: Tensor<B, 1>) -> Tensor<B, 1> {
518        let x = position.clone().slice(s![0..1]);
519        let y = position.slice(s![1..2]);
520        let term_1 = (-x.clone()).add_scalar(self.a).powi_scalar(2);
521        let term_2 = y.sub(x.powi_scalar(2)).powi_scalar(2).mul_scalar(self.b);
522        -(term_1 + term_2)
523    }
524}
525
526// Define the Rosenbrock distribution.
527// From: https://arxiv.org/pdf/1903.09556.
528pub struct RosenbrockND {}
529
530// For the batched version we need to implement BatchGradientTarget.
531impl<T, B> BatchedGradientTarget<T, B> for RosenbrockND
532where
533    T: Float + Element,
534    B: AutodiffBackend,
535{
536    fn unnorm_logp_batch(&self, positions: Tensor<B, 2>) -> Tensor<B, 1> {
537        let k = positions.dims()[0];
538        let n = positions.dims()[1];
539        let low = positions.clone().slice([0..k, 0..(n - 1)]);
540        let high = positions.slice([0..k, 1..n]);
541        let term_1 = (high - low.clone().powi_scalar(2))
542            .powi_scalar(2)
543            .mul_scalar(100);
544        let term_2 = low.neg().add_scalar(1).powi_scalar(2);
545        -(term_1 + term_2).sum_dim(1).squeeze(1)
546    }
547}
548
549#[cfg(test)]
550mod continuous_tests {
551    use super::*;
552
553    /**
554    A helper function to normalize the unnormalized log probability of an isotropic Gaussian
555    into a proper probability value (by applying the appropriate constant).
556
557    # Arguments
558
559    * `x` - The unnormalized log probability.
560    * `d` - The dimensionality of the state.
561    * `std` - The standard deviation used in the isotropic Gaussian.
562
563    # Returns
564
565    Returns the normalized probability as an `f64`.
566    */
567    fn normalize_isogauss(x: f64, d: usize, std: f64) -> f64 {
568        let log_normalizer = -((d as f64) / 2.0) * ((2.0_f64).ln() + PI.ln() + 2.0 * std.ln());
569        (x + log_normalizer).exp()
570    }
571
572    #[test]
573    fn iso_gauss_unnorm_logp_test_1() {
574        let distr = IsotropicGaussian::new(1.0);
575        let p = normalize_isogauss(distr.unnorm_logp(&[1.0]), 1, distr.std);
576        let true_p = 0.24197072451914337;
577        let diff = (p - true_p).abs();
578        assert!(
579            diff < 1e-7,
580            "Expected diff < 1e-7, got {diff} with p={p} (expected ~{true_p})."
581        );
582    }
583
584    #[test]
585    fn iso_gauss_unnorm_logp_test_2() {
586        let distr = IsotropicGaussian::new(2.0);
587        let p = normalize_isogauss(distr.unnorm_logp(&[0.42, 9.6]), 2, distr.std);
588        let true_p = 3.864661987252467e-7;
589        let diff = (p - true_p).abs();
590        assert!(
591            diff < 1e-15,
592            "Expected diff < 1e-15, got {diff} with p={p} (expected ~{true_p})"
593        );
594    }
595
596    #[test]
597    fn iso_gauss_unnorm_logp_test_3() {
598        let distr = IsotropicGaussian::new(3.0);
599        let p = normalize_isogauss(distr.unnorm_logp(&[1.0, 2.0, 3.0]), 3, distr.std);
600        let true_p = 0.001080393185560214;
601        let diff = (p - true_p).abs();
602        assert!(
603            diff < 1e-8,
604            "Expected diff < 1e-8, got {diff} with p={p} (expected ~{true_p})"
605        );
606    }
607}
608
609#[cfg(test)]
610mod categorical_tests {
611    use super::*;
612
613    /// A helper function to compare floating-point values with a given tolerance.
614    fn approx_eq(a: f64, b: f64, tol: f64) -> bool {
615        (a - b).abs() < tol
616    }
617
618    // ------------------------------------------------------
619    // 1) Test logp correctness for f64
620    // ------------------------------------------------------
621    #[test]
622    fn test_categorical_logp_f64() {
623        let probs = vec![0.2, 0.3, 0.5];
624        let cat = Categorical::<f64>::new(probs.clone());
625
626        // Check log probabilities for each index
627        let logp_0 = cat.logp(0);
628        let logp_1 = cat.logp(1);
629        let logp_2 = cat.logp(2);
630
631        // Expected values
632        let expected_0 = 0.2_f64.ln();
633        let expected_1 = 0.3_f64.ln();
634        let expected_2 = 0.5_f64.ln();
635
636        let tol = 1e-7;
637        assert!(
638            approx_eq(logp_0, expected_0, tol),
639            "Log prob mismatch at index 0: got {}, expected {}",
640            logp_0,
641            expected_0
642        );
643        assert!(
644            approx_eq(logp_1, expected_1, tol),
645            "Log prob mismatch at index 1: got {}, expected {}",
646            logp_1,
647            expected_1
648        );
649        assert!(
650            approx_eq(logp_2, expected_2, tol),
651            "Log prob mismatch at index 2: got {}, expected {}",
652            logp_2,
653            expected_2
654        );
655
656        // Out-of-bounds index should be NEG_INFINITY
657        let logp_out = cat.logp(3);
658        assert_eq!(
659            logp_out,
660            f64::NEG_INFINITY,
661            "Out-of-bounds index did not return NEG_INFINITY"
662        );
663    }
664
665    // ------------------------------------------------------
666    // 2) Test sampling frequencies for f64
667    // ------------------------------------------------------
668    #[test]
669    fn test_categorical_sampling_f64() {
670        let probs = vec![0.2, 0.3, 0.5];
671        let mut cat = Categorical::<f64>::new(probs.clone());
672
673        let sample_size = 100_000;
674        let mut counts = vec![0_usize; probs.len()];
675
676        // Draw observations and tally outcomes
677        for _ in 0..sample_size {
678            let observation = cat.sample();
679            counts[observation] += 1;
680        }
681
682        // Check empirical frequencies
683        let tol = 0.01; // 1% absolute tolerance
684        for (i, &count) in counts.iter().enumerate() {
685            let freq = count as f64 / sample_size as f64;
686            let expected = probs[i];
687            assert!(
688                approx_eq(freq, expected, tol),
689                "Empirical freq for index {} is off: got {:.3}, expected {:.3}",
690                i,
691                freq,
692                expected
693            );
694        }
695    }
696
697    // ------------------------------------------------------
698    // 3) Test logp correctness for f32
699    // ------------------------------------------------------
700    #[test]
701    fn test_categorical_logp_f32() {
702        let probs = vec![0.1_f32, 0.4, 0.5];
703        let cat = Categorical::<f32>::new(probs.clone());
704
705        let logp_0: f32 = cat.logp(0);
706        let logp_1 = cat.logp(1);
707        let logp_2 = cat.logp(2);
708
709        // For comparison, cast to f64
710        let expected_0 = (0.1_f64).ln();
711        let expected_1 = (0.4_f64).ln();
712        let expected_2 = (0.5_f64).ln();
713
714        let tol = 1e-6;
715        assert!(
716            approx_eq(logp_0.into(), expected_0, tol),
717            "Log prob mismatch at index 0 (f32 -> f64 cast)"
718        );
719        assert!(
720            approx_eq(logp_1.into(), expected_1, tol),
721            "Log prob mismatch at index 1"
722        );
723        assert!(
724            approx_eq(logp_2.into(), expected_2, tol),
725            "Log prob mismatch at index 2"
726        );
727
728        // Out-of-bounds
729        let logp_out = cat.logp(3);
730        assert_eq!(logp_out, f32::NEG_INFINITY);
731    }
732
733    // ------------------------------------------------------
734    // 4) Test sampling frequencies for f32
735    // ------------------------------------------------------
736    #[test]
737    fn test_categorical_sampling_f32() {
738        let probs = vec![0.1_f32, 0.4, 0.5];
739        let mut cat = Categorical::<f32>::new(probs.clone());
740
741        let sample_size = 100_000;
742        let mut counts = vec![0_usize; probs.len()];
743
744        for _ in 0..sample_size {
745            let observation = cat.sample();
746            counts[observation] += 1;
747        }
748
749        // Compare frequencies with expected probabilities
750        let tol = 0.02; // might relax tolerance for f32
751        for (i, &count) in counts.iter().enumerate() {
752            let freq = count as f32 / sample_size as f32;
753            let expected = probs[i];
754            assert!(
755                (freq - expected).abs() < tol,
756                "Empirical freq for index {} is off: got {:.3}, expected {:.3}",
757                i,
758                freq,
759                expected
760            );
761        }
762    }
763
764    #[test]
765    fn test_categorical_sample_single_value() {
766        let mut cat = Categorical {
767            probs: vec![1.0_f64],
768            rng: rand::rngs::SmallRng::from_seed(Default::default()),
769        };
770
771        let sampled_index = cat.sample();
772
773        assert_eq!(
774            sampled_index, 0,
775            "Should return the last index (0) for a single-element vector"
776        );
777    }
778
779    #[test]
780    fn test_target_for_categorical_in_range() {
781        // Create a categorical distribution with known probabilities.
782        let probs = vec![0.2_f64, 0.3, 0.5];
783        let cat = Categorical::new(probs.clone());
784        // Call unnorm_logp with a valid index (say, index 1).
785        let logp = cat.unnorm_logp(&[1]);
786        // The expected log probability is ln(0.3).
787        let expected = 0.3_f64.ln();
788        let tol = 1e-7;
789        assert!(
790            (logp - expected).abs() < tol,
791            "For index 1, expected ln(0.3) ~ {}, got {}",
792            expected,
793            logp
794        );
795    }
796
797    #[test]
798    fn test_target_for_categorical_out_of_range() {
799        let probs = vec![0.2_f64, 0.3, 0.5];
800        let cat = Categorical::new(probs);
801        // Calling unnorm_logp with an index that's out of bounds (e.g. 3)
802        // should return negative infinity.
803        let logp = cat.unnorm_logp(&[3]);
804        assert_eq!(
805            logp,
806            f64::NEG_INFINITY,
807            "Expected negative infinity for out-of-range index, got {}",
808            logp
809        );
810    }
811
812    #[test]
813    fn test_gaussian2d_logp() {
814        let mean = arr1(&[0.0, 0.0]);
815        let cov = arr2(&[[1.0, 0.0], [0.0, 1.0]]);
816        let gauss = Gaussian2D { mean, cov };
817
818        let position = vec![0.5, -0.5];
819        let computed_logp = gauss.logp(&position);
820
821        let expected_logp = -2.0878770664093453;
822
823        let tol = 1e-10;
824        assert!(
825            (computed_logp - expected_logp).abs() < tol,
826            "Computed log density ({}) differs from expected ({}) by more than tolerance ({})",
827            computed_logp,
828            expected_logp,
829            tol
830        );
831    }
832}