mini_mcmc/
distributions.rs

1/*!
2Traits for defining continuous and discrete proposal- and target distributions.
3Includes an implementations of commonly used distributions.
4
5This module is generic over the floating-point precision (e.g. `f32` or `f64`)
6using [`num_traits::Float`]. It also defines several traits:
7- [`Target`] for densities we want to sample from,
8- [`Proposal`] for proposal mechanisms,
9- [`Normalized`] for distributions that can compute a fully normalized log-prob,
10- [`Discrete`] for distributions over finite sets.
11
12# Examples
13
14### Continuous Distributions
15
16```rust
17use mini_mcmc::distributions::{
18    Gaussian2D, IsotropicGaussian, Proposal,
19    Target, Normalized
20};
21use ndarray::{arr1, arr2};
22
23// ----------------------
24// Example: Gaussian2D (2D with full covariance)
25// ----------------------
26let mean = arr1(&[0.0, 0.0]);
27let cov = arr2(&[[1.0, 0.0],
28                [0.0, 1.0]]);
29let gauss: Gaussian2D<f64> = Gaussian2D { mean, cov };
30
31// Compute the fully normalized log-prob at (0.5, -0.5):
32let logp = gauss.log_prob(&vec![0.5, -0.5]);
33println!("Normalized log-probability (2D Gaussian): {}", logp);
34
35// ----------------------
36// Example: IsotropicGaussian (any dimension)
37// ----------------------
38let mut proposal: IsotropicGaussian<f64> = IsotropicGaussian::new(1.0);
39let current = vec![0.0, 0.0];  // dimension = 2 in this example
40let candidate = proposal.sample(&current);
41println!("Candidate state: {:?}", candidate);
42*/
43
44use ndarray::{arr1, arr2, Array1, Array2, NdFloat};
45use num_traits::Float;
46use rand::rngs::SmallRng;
47use rand::{Rng, SeedableRng};
48use rand_distr::{Distribution, Normal};
49use std::f64::consts::PI;
50use std::ops::AddAssign;
51
52/// A trait for generating proposals Metropolis–Hastings-like algorithms.
53/// The state type `T` is typically a vector of continuous values.
54pub trait Proposal<T, F: Float> {
55    /// Samples a new point from q(x' | x).
56    fn sample(&mut self, current: &[T]) -> Vec<T>;
57
58    /// Evaluates log q(x' | x).
59    fn log_prob(&self, from: &[T], to: &[T]) -> F;
60
61    /// Returns a new instance of this proposal distribution seeded with `seed`.
62    fn set_seed(self, seed: u64) -> Self;
63}
64
65/// A trait for continuous target distributions from which we want to sample.
66/// The state type `T` is typically a vector of continuous values.
67pub trait Target<T, F: Float> {
68    /// Returns the log of the unnormalized density for state `theta`.
69    fn unnorm_log_prob(&self, theta: &[T]) -> F;
70}
71
72/// A trait for distributions that provide a normalized log-density (e.g. for diagnostics).
73pub trait Normalized<T, F: Float> {
74    /// Returns the normalized log-density for state `theta`.
75    fn log_prob(&self, theta: &[T]) -> F;
76}
77
78/** A trait for discrete distributions whose state is represented as an index.
79 ```rust
80 use mini_mcmc::distributions::{Categorical, Discrete};
81
82 // Create a categorical distribution over three categories.
83 let mut cat = Categorical::new(vec![0.2f64, 0.3, 0.5]);
84 let sample = cat.sample();
85 println!("Sampled category: {}", sample); // E.g. 1usize
86
87 let logp = cat.log_prob(sample);
88 println!("Log-probability of sampled category: {}", logp); // E.g. 0.3f64
89```
90*/
91pub trait Discrete<T: Float> {
92    /// Samples an index from the distribution.
93    fn sample(&mut self) -> usize;
94    /// Evaluates the log-probability of the given index.
95    fn log_prob(&self, index: usize) -> T;
96}
97
98/**
99A 2D Gaussian distribution parameterized by a mean vector and a 2Ă—2 covariance matrix.
100
101- The generic type `T` is typically `f32` or `f64`.
102- Implements both [`Target`] (for unnormalized log-prob) and
103  [`Normalized`] (for fully normalized log-prob).
104
105# Example
106
107```rust
108use mini_mcmc::distributions::{Gaussian2D, Normalized};
109use ndarray::{arr1, arr2};
110
111let mean = arr1(&[0.0, 0.0]);
112let cov = arr2(&[[1.0, 0.0],
113                [0.0, 1.0]]);
114let gauss: Gaussian2D<f64> = Gaussian2D { mean, cov };
115
116let lp = gauss.log_prob(&vec![0.5, -0.5]);
117println!("Normalized log probability: {}", lp);
118```
119*/
120#[derive(Debug, Clone, PartialEq, Eq)]
121pub struct Gaussian2D<T: Float> {
122    pub mean: Array1<T>,
123    pub cov: Array2<T>,
124}
125
126impl<T> Normalized<T, T> for Gaussian2D<T>
127where
128    T: NdFloat,
129{
130    /// Computes the fully normalized log-density of a 2D Gaussian.
131    fn log_prob(&self, theta: &[T]) -> T {
132        let term_1 = -(T::from(2.0).unwrap() * T::from(PI).unwrap()).ln();
133        let (a, b, c, d) = (
134            self.cov[(0, 0)],
135            self.cov[(0, 1)],
136            self.cov[(1, 0)],
137            self.cov[(1, 1)],
138        );
139        let det = a * d - b * c;
140        let half = T::from(0.5).unwrap();
141        let term_2 = -half * det.abs().ln();
142
143        let x = arr1(theta);
144        let diff = x - self.mean.clone();
145        let inv_cov = arr2(&[[d, -b], [-c, a]]) / det;
146        let term_3 = -half * diff.dot(&inv_cov).dot(&diff);
147        term_1 + term_2 + term_3
148    }
149}
150
151impl<T> Target<T, T> for Gaussian2D<T>
152where
153    T: NdFloat,
154{
155    fn unnorm_log_prob(&self, theta: &[T]) -> T {
156        let (a, b, c, d) = (
157            self.cov[(0, 0)],
158            self.cov[(0, 1)],
159            self.cov[(1, 0)],
160            self.cov[(1, 1)],
161        );
162        let det = a * d - b * c;
163        let x = arr1(theta);
164        let diff = x - self.mean.clone();
165        let inv_cov = arr2(&[[d, -b], [-c, a]]) / det;
166        -T::from(0.5).unwrap() * diff.dot(&inv_cov).dot(&diff)
167    }
168}
169
170/**
171An *isotropic* Gaussian distribution usable as either a target or a proposal
172in MCMC. It works for **any dimension** because it applies independent
173Gaussian noise (`mean = 0`, `std = self.std`) to each coordinate.
174
175- Implements [`Proposal`] so it can propose new states
176  from a current state.
177- Also implements [`Target`] for an unnormalized log-prob,
178  which might be useful if you want to treat it as a target distribution
179  in simplified scenarios.
180
181# Examples
182
183```rust
184use mini_mcmc::distributions::{IsotropicGaussian, Proposal};
185
186let mut proposal: IsotropicGaussian<f64> = IsotropicGaussian::new(1.0);
187let current = vec![0.0, 0.0, 0.0]; // dimension = 3
188let candidate = proposal.sample(&current);
189println!("Candidate state: {:?}", candidate);
190
191// Evaluate log q(candidate | current):
192let logq = proposal.log_prob(&current, &candidate);
193println!("Log of the proposal density: {}", logq);
194```
195*/
196#[derive(Debug, Clone, PartialEq, Eq)]
197pub struct IsotropicGaussian<T: Float> {
198    pub std: T,
199    rng: SmallRng,
200}
201
202impl<T: Float> IsotropicGaussian<T> {
203    /// Creates a new isotropic Gaussian proposal distribution with the specified standard deviation.
204    pub fn new(std: T) -> Self {
205        Self {
206            std,
207            rng: SmallRng::from_entropy(),
208        }
209    }
210}
211
212impl<T: Float + std::ops::AddAssign> Proposal<T, T> for IsotropicGaussian<T>
213where
214    rand_distr::StandardNormal: rand_distr::Distribution<T>,
215{
216    fn sample(&mut self, current: &[T]) -> Vec<T> {
217        let normal = Normal::new(T::zero(), self.std)
218            .expect("Expecting creation of normal distribution to succeed.");
219        normal
220            .sample_iter(&mut self.rng)
221            .zip(current)
222            .map(|(x, eps)| x + *eps)
223            .collect()
224    }
225
226    fn log_prob(&self, from: &[T], to: &[T]) -> T {
227        let mut lp = T::zero();
228        let d = T::from(from.len()).unwrap();
229        let two = T::from(2).unwrap();
230        let var = self.std * self.std;
231        for (&f, &t) in from.iter().zip(to.iter()) {
232            let diff = t - f;
233            let exponent = -(diff * diff) / (two * var);
234            lp += exponent;
235        }
236        lp += -d * T::from(0.5).unwrap() * (var * T::from(PI).unwrap() * self.std * self.std).ln();
237        lp
238    }
239
240    fn set_seed(mut self, seed: u64) -> Self {
241        self.rng = SmallRng::seed_from_u64(seed);
242        self
243    }
244}
245
246impl<T: Float> Target<T, T> for IsotropicGaussian<T> {
247    fn unnorm_log_prob(&self, theta: &[T]) -> T {
248        let mut sum = T::zero();
249        for &x in theta.iter() {
250            sum = sum + x * x
251        }
252        -T::from(0.5).unwrap() * sum / (self.std * self.std)
253    }
254}
255
256/**
257A categorical distribution represents a discrete probability distribution over a finite set of categories.
258
259The probabilities in `probs` should sum to 1 (or they will be normalized automatically).
260
261# Examples
262
263```rust
264use mini_mcmc::distributions::{Categorical, Discrete};
265
266let mut cat = Categorical::new(vec![0.2f64, 0.3, 0.5]);
267let sample = cat.sample();
268println!("Sampled category: {}", sample);
269let logp = cat.log_prob(sample);
270println!("Log probability of category {}: {}", sample, logp);
271```
272*/
273#[derive(Debug, Clone, PartialEq, Eq)]
274pub struct Categorical<T>
275where
276    T: Float + std::ops::AddAssign,
277{
278    pub probs: Vec<T>,
279    rng: SmallRng,
280}
281
282impl<T: Float + std::ops::AddAssign> Categorical<T> {
283    /// Creates a new categorical distribution from a vector of probabilities.
284    /// The probabilities will be normalized so that they sum to 1.
285    pub fn new(probs: Vec<T>) -> Self {
286        let sum: T = probs.iter().cloned().fold(T::zero(), |acc, x| acc + x);
287        let normalized: Vec<T> = probs.into_iter().map(|p| p / sum).collect();
288        Self {
289            probs: normalized,
290            rng: SmallRng::from_entropy(),
291        }
292    }
293}
294
295impl<T: Float + std::ops::AddAssign> Discrete<T> for Categorical<T>
296where
297    rand_distr::Standard: rand_distr::Distribution<T>,
298{
299    fn sample(&mut self) -> usize {
300        let r: T = self.rng.gen();
301        let mut cum: T = T::zero();
302        let mut k = self.probs.len() - 1;
303        for (i, &p) in self.probs.iter().enumerate() {
304            cum += p;
305            if r <= cum {
306                k = i;
307                break;
308            }
309        }
310        k
311    }
312
313    fn log_prob(&self, index: usize) -> T {
314        if index < self.probs.len() {
315            self.probs[index].ln()
316        } else {
317            T::neg_infinity()
318        }
319    }
320}
321
322impl<T: Float + AddAssign> Target<usize, T> for Categorical<T>
323where
324    rand_distr::Standard: rand_distr::Distribution<T>,
325{
326    fn unnorm_log_prob(&self, theta: &[usize]) -> T {
327        <Self as Discrete<T>>::log_prob(self, theta[0])
328    }
329}
330
331/**
332A trait for conditional distributions.
333
334This trait specifies how to sample a single coordinate of a state given the entire current state.
335It is primarily used in Gibbs sampling to update one coordinate at a time.
336*/
337pub trait Conditional<S> {
338    fn sample(&mut self, index: usize, given: &[S]) -> S;
339}
340
341#[cfg(test)]
342mod continuous_tests {
343    use super::*;
344
345    /**
346    A helper function to normalize the unnormalized log probability of an isotropic Gaussian
347    into a proper probability value (by applying the appropriate constant).
348
349    # Arguments
350
351    * `x` - The unnormalized log probability.
352    * `d` - The dimensionality of the state.
353    * `std` - The standard deviation used in the isotropic Gaussian.
354
355    # Returns
356
357    Returns the normalized probability as an `f64`.
358    */
359    fn normalize_isogauss(x: f64, d: usize, std: f64) -> f64 {
360        let log_normalizer = -((d as f64) / 2.0) * ((2.0_f64).ln() + PI.ln() + 2.0 * std.ln());
361        (x + log_normalizer).exp()
362    }
363
364    #[test]
365    fn iso_gauss_unnorm_log_prob_test_1() {
366        let distr = IsotropicGaussian::new(1.0);
367        let p = normalize_isogauss(distr.unnorm_log_prob(&[1.0]), 1, distr.std);
368        let true_p = 0.24197072451914337;
369        let diff = (p - true_p).abs();
370        assert!(
371            diff < 1e-7,
372            "Expected diff < 1e-7, got {diff} with p={p} (expected ~{true_p})."
373        );
374    }
375
376    #[test]
377    fn iso_gauss_unnorm_log_prob_test_2() {
378        let distr = IsotropicGaussian::new(2.0);
379        let p = normalize_isogauss(distr.unnorm_log_prob(&[0.42, 9.6]), 2, distr.std);
380        let true_p = 3.864661987252467e-7;
381        let diff = (p - true_p).abs();
382        assert!(
383            diff < 1e-15,
384            "Expected diff < 1e-15, got {diff} with p={p} (expected ~{true_p})"
385        );
386    }
387
388    #[test]
389    fn iso_gauss_unnorm_log_prob_test_3() {
390        let distr = IsotropicGaussian::new(3.0);
391        let p = normalize_isogauss(distr.unnorm_log_prob(&[1.0, 2.0, 3.0]), 3, distr.std);
392        let true_p = 0.001080393185560214;
393        let diff = (p - true_p).abs();
394        assert!(
395            diff < 1e-8,
396            "Expected diff < 1e-8, got {diff} with p={p} (expected ~{true_p})"
397        );
398    }
399}
400
401#[cfg(test)]
402mod categorical_tests {
403    use super::*;
404
405    /// A helper function to compare floating-point values with a given tolerance.
406    fn approx_eq(a: f64, b: f64, tol: f64) -> bool {
407        (a - b).abs() < tol
408    }
409
410    // ------------------------------------------------------
411    // 1) Test log_prob correctness for f64
412    // ------------------------------------------------------
413    #[test]
414    fn test_categorical_log_prob_f64() {
415        let probs = vec![0.2, 0.3, 0.5];
416        let cat = Categorical::<f64>::new(probs.clone());
417
418        // Check log probabilities for each index
419        let log_prob_0 = cat.log_prob(0);
420        let log_prob_1 = cat.log_prob(1);
421        let log_prob_2 = cat.log_prob(2);
422
423        // Expected values
424        let expected_0 = 0.2_f64.ln();
425        let expected_1 = 0.3_f64.ln();
426        let expected_2 = 0.5_f64.ln();
427
428        let tol = 1e-7;
429        assert!(
430            approx_eq(log_prob_0, expected_0, tol),
431            "Log prob mismatch at index 0: got {}, expected {}",
432            log_prob_0,
433            expected_0
434        );
435        assert!(
436            approx_eq(log_prob_1, expected_1, tol),
437            "Log prob mismatch at index 1: got {}, expected {}",
438            log_prob_1,
439            expected_1
440        );
441        assert!(
442            approx_eq(log_prob_2, expected_2, tol),
443            "Log prob mismatch at index 2: got {}, expected {}",
444            log_prob_2,
445            expected_2
446        );
447
448        // Out-of-bounds index should be NEG_INFINITY
449        let log_prob_out = cat.log_prob(3);
450        assert_eq!(
451            log_prob_out,
452            f64::NEG_INFINITY,
453            "Out-of-bounds index did not return NEG_INFINITY"
454        );
455    }
456
457    // ------------------------------------------------------
458    // 2) Test sampling frequencies for f64
459    // ------------------------------------------------------
460    #[test]
461    fn test_categorical_sampling_f64() {
462        let probs = vec![0.2, 0.3, 0.5];
463        let mut cat = Categorical::<f64>::new(probs.clone());
464
465        let num_samples = 100_000;
466        let mut counts = vec![0_usize; probs.len()];
467
468        // Draw samples and tally outcomes
469        for _ in 0..num_samples {
470            let sample = cat.sample();
471            counts[sample] += 1;
472        }
473
474        // Check empirical frequencies
475        let tol = 0.01; // 1% absolute tolerance
476        for (i, &count) in counts.iter().enumerate() {
477            let freq = count as f64 / num_samples as f64;
478            let expected = probs[i];
479            assert!(
480                approx_eq(freq, expected, tol),
481                "Empirical freq for index {} is off: got {:.3}, expected {:.3}",
482                i,
483                freq,
484                expected
485            );
486        }
487    }
488
489    // ------------------------------------------------------
490    // 3) Test log_prob correctness for f32
491    // ------------------------------------------------------
492    #[test]
493    fn test_categorical_log_prob_f32() {
494        let probs = vec![0.1_f32, 0.4, 0.5];
495        let cat = Categorical::<f32>::new(probs.clone());
496
497        let log_prob_0: f32 = cat.log_prob(0);
498        let log_prob_1 = cat.log_prob(1);
499        let log_prob_2 = cat.log_prob(2);
500
501        // For comparison, cast to f64
502        let expected_0 = (0.1_f64).ln();
503        let expected_1 = (0.4_f64).ln();
504        let expected_2 = (0.5_f64).ln();
505
506        let tol = 1e-6;
507        assert!(
508            approx_eq(log_prob_0.into(), expected_0, tol),
509            "Log prob mismatch at index 0 (f32 -> f64 cast)"
510        );
511        assert!(
512            approx_eq(log_prob_1.into(), expected_1, tol),
513            "Log prob mismatch at index 1"
514        );
515        assert!(
516            approx_eq(log_prob_2.into(), expected_2, tol),
517            "Log prob mismatch at index 2"
518        );
519
520        // Out-of-bounds
521        let log_prob_out = cat.log_prob(3);
522        assert_eq!(log_prob_out, f32::NEG_INFINITY);
523    }
524
525    // ------------------------------------------------------
526    // 4) Test sampling frequencies for f32
527    // ------------------------------------------------------
528    #[test]
529    fn test_categorical_sampling_f32() {
530        let probs = vec![0.1_f32, 0.4, 0.5];
531        let mut cat = Categorical::<f32>::new(probs.clone());
532
533        let num_samples = 100_000;
534        let mut counts = vec![0_usize; probs.len()];
535
536        for _ in 0..num_samples {
537            let sample = cat.sample();
538            counts[sample] += 1;
539        }
540
541        // Compare frequencies with expected probabilities
542        let tol = 0.02; // might relax tolerance for f32
543        for (i, &count) in counts.iter().enumerate() {
544            let freq = count as f32 / num_samples as f32;
545            let expected = probs[i];
546            assert!(
547                (freq - expected).abs() < tol,
548                "Empirical freq for index {} is off: got {:.3}, expected {:.3}",
549                i,
550                freq,
551                expected
552            );
553        }
554    }
555
556    #[test]
557    fn test_categorical_sample_single_value() {
558        let mut cat = Categorical {
559            probs: vec![1.0_f64],
560            rng: rand::rngs::SmallRng::from_seed(Default::default()),
561        };
562
563        let sampled_index = cat.sample();
564
565        assert_eq!(
566            sampled_index, 0,
567            "Should return the last index (0) for a single-element vector"
568        );
569    }
570
571    #[test]
572    fn test_target_for_categorical_in_range() {
573        // Create a categorical distribution with known probabilities.
574        let probs = vec![0.2_f64, 0.3, 0.5];
575        let cat = Categorical::new(probs.clone());
576        // Call unnorm_log_prob with a valid index (say, index 1).
577        let logp = cat.unnorm_log_prob(&[1]);
578        // The expected log probability is ln(0.3).
579        let expected = 0.3_f64.ln();
580        let tol = 1e-7;
581        assert!(
582            (logp - expected).abs() < tol,
583            "For index 1, expected ln(0.3) ~ {}, got {}",
584            expected,
585            logp
586        );
587    }
588
589    #[test]
590    fn test_target_for_categorical_out_of_range() {
591        let probs = vec![0.2_f64, 0.3, 0.5];
592        let cat = Categorical::new(probs);
593        // Calling unnorm_log_prob with an index that's out of bounds (e.g. 3)
594        // should return negative infinity.
595        let logp = cat.unnorm_log_prob(&[3]);
596        assert_eq!(
597            logp,
598            f64::NEG_INFINITY,
599            "Expected negative infinity for out-of-range index, got {}",
600            logp
601        );
602    }
603
604    #[test]
605    fn test_gaussian2d_log_prob() {
606        let mean = arr1(&[0.0, 0.0]);
607        let cov = arr2(&[[1.0, 0.0], [0.0, 1.0]]);
608        let gauss = Gaussian2D { mean, cov };
609
610        let theta = vec![0.5, -0.5];
611        let computed_logp = gauss.log_prob(&theta);
612
613        let expected_logp = -2.0878770664093453;
614
615        let tol = 1e-10;
616        assert!(
617            (computed_logp - expected_logp).abs() < tol,
618            "Computed log probability ({}) differs from expected ({}) by more than tolerance ({})",
619            computed_logp,
620            expected_logp,
621            tol
622        );
623    }
624}