1use ndarray::{arr1, arr2, Array1, Array2, NdFloat};
45use num_traits::Float;
46use rand::rngs::SmallRng;
47use rand::{Rng, SeedableRng};
48use rand_distr::{Distribution, Normal};
49use std::f64::consts::PI;
50use std::ops::AddAssign;
51
52pub trait Proposal<T, F: Float> {
55 fn sample(&mut self, current: &[T]) -> Vec<T>;
57
58 fn log_prob(&self, from: &[T], to: &[T]) -> F;
60
61 fn set_seed(self, seed: u64) -> Self;
63}
64
65pub trait Target<T, F: Float> {
68 fn unnorm_log_prob(&self, theta: &[T]) -> F;
70}
71
72pub trait Normalized<T, F: Float> {
74 fn log_prob(&self, theta: &[T]) -> F;
76}
77
78pub trait Discrete<T: Float> {
92 fn sample(&mut self) -> usize;
94 fn log_prob(&self, index: usize) -> T;
96}
97
98#[derive(Debug, Clone, PartialEq, Eq)]
121pub struct Gaussian2D<T: Float> {
122 pub mean: Array1<T>,
123 pub cov: Array2<T>,
124}
125
126impl<T> Normalized<T, T> for Gaussian2D<T>
127where
128 T: NdFloat,
129{
130 fn log_prob(&self, theta: &[T]) -> T {
132 let term_1 = -(T::from(2.0).unwrap() * T::from(PI).unwrap()).ln();
133 let (a, b, c, d) = (
134 self.cov[(0, 0)],
135 self.cov[(0, 1)],
136 self.cov[(1, 0)],
137 self.cov[(1, 1)],
138 );
139 let det = a * d - b * c;
140 let half = T::from(0.5).unwrap();
141 let term_2 = -half * det.abs().ln();
142
143 let x = arr1(theta);
144 let diff = x - self.mean.clone();
145 let inv_cov = arr2(&[[d, -b], [-c, a]]) / det;
146 let term_3 = -half * diff.dot(&inv_cov).dot(&diff);
147 term_1 + term_2 + term_3
148 }
149}
150
151impl<T> Target<T, T> for Gaussian2D<T>
152where
153 T: NdFloat,
154{
155 fn unnorm_log_prob(&self, theta: &[T]) -> T {
156 let (a, b, c, d) = (
157 self.cov[(0, 0)],
158 self.cov[(0, 1)],
159 self.cov[(1, 0)],
160 self.cov[(1, 1)],
161 );
162 let det = a * d - b * c;
163 let x = arr1(theta);
164 let diff = x - self.mean.clone();
165 let inv_cov = arr2(&[[d, -b], [-c, a]]) / det;
166 -T::from(0.5).unwrap() * diff.dot(&inv_cov).dot(&diff)
167 }
168}
169
170#[derive(Debug, Clone, PartialEq, Eq)]
197pub struct IsotropicGaussian<T: Float> {
198 pub std: T,
199 rng: SmallRng,
200}
201
202impl<T: Float> IsotropicGaussian<T> {
203 pub fn new(std: T) -> Self {
205 Self {
206 std,
207 rng: SmallRng::from_entropy(),
208 }
209 }
210}
211
212impl<T: Float + std::ops::AddAssign> Proposal<T, T> for IsotropicGaussian<T>
213where
214 rand_distr::StandardNormal: rand_distr::Distribution<T>,
215{
216 fn sample(&mut self, current: &[T]) -> Vec<T> {
217 let normal = Normal::new(T::zero(), self.std)
218 .expect("Expecting creation of normal distribution to succeed.");
219 normal
220 .sample_iter(&mut self.rng)
221 .zip(current)
222 .map(|(x, eps)| x + *eps)
223 .collect()
224 }
225
226 fn log_prob(&self, from: &[T], to: &[T]) -> T {
227 let mut lp = T::zero();
228 let d = T::from(from.len()).unwrap();
229 let two = T::from(2).unwrap();
230 let var = self.std * self.std;
231 for (&f, &t) in from.iter().zip(to.iter()) {
232 let diff = t - f;
233 let exponent = -(diff * diff) / (two * var);
234 lp += exponent;
235 }
236 lp += -d * T::from(0.5).unwrap() * (var * T::from(PI).unwrap() * self.std * self.std).ln();
237 lp
238 }
239
240 fn set_seed(mut self, seed: u64) -> Self {
241 self.rng = SmallRng::seed_from_u64(seed);
242 self
243 }
244}
245
246impl<T: Float> Target<T, T> for IsotropicGaussian<T> {
247 fn unnorm_log_prob(&self, theta: &[T]) -> T {
248 let mut sum = T::zero();
249 for &x in theta.iter() {
250 sum = sum + x * x
251 }
252 -T::from(0.5).unwrap() * sum / (self.std * self.std)
253 }
254}
255
256#[derive(Debug, Clone, PartialEq, Eq)]
274pub struct Categorical<T>
275where
276 T: Float + std::ops::AddAssign,
277{
278 pub probs: Vec<T>,
279 rng: SmallRng,
280}
281
282impl<T: Float + std::ops::AddAssign> Categorical<T> {
283 pub fn new(probs: Vec<T>) -> Self {
286 let sum: T = probs.iter().cloned().fold(T::zero(), |acc, x| acc + x);
287 let normalized: Vec<T> = probs.into_iter().map(|p| p / sum).collect();
288 Self {
289 probs: normalized,
290 rng: SmallRng::from_entropy(),
291 }
292 }
293}
294
295impl<T: Float + std::ops::AddAssign> Discrete<T> for Categorical<T>
296where
297 rand_distr::Standard: rand_distr::Distribution<T>,
298{
299 fn sample(&mut self) -> usize {
300 let r: T = self.rng.gen();
301 let mut cum: T = T::zero();
302 let mut k = self.probs.len() - 1;
303 for (i, &p) in self.probs.iter().enumerate() {
304 cum += p;
305 if r <= cum {
306 k = i;
307 break;
308 }
309 }
310 k
311 }
312
313 fn log_prob(&self, index: usize) -> T {
314 if index < self.probs.len() {
315 self.probs[index].ln()
316 } else {
317 T::neg_infinity()
318 }
319 }
320}
321
322impl<T: Float + AddAssign> Target<usize, T> for Categorical<T>
323where
324 rand_distr::Standard: rand_distr::Distribution<T>,
325{
326 fn unnorm_log_prob(&self, theta: &[usize]) -> T {
327 <Self as Discrete<T>>::log_prob(self, theta[0])
328 }
329}
330
331pub trait Conditional<S> {
338 fn sample(&mut self, index: usize, given: &[S]) -> S;
339}
340
341#[cfg(test)]
342mod continuous_tests {
343 use super::*;
344
345 fn normalize_isogauss(x: f64, d: usize, std: f64) -> f64 {
360 let log_normalizer = -((d as f64) / 2.0) * ((2.0_f64).ln() + PI.ln() + 2.0 * std.ln());
361 (x + log_normalizer).exp()
362 }
363
364 #[test]
365 fn iso_gauss_unnorm_log_prob_test_1() {
366 let distr = IsotropicGaussian::new(1.0);
367 let p = normalize_isogauss(distr.unnorm_log_prob(&[1.0]), 1, distr.std);
368 let true_p = 0.24197072451914337;
369 let diff = (p - true_p).abs();
370 assert!(
371 diff < 1e-7,
372 "Expected diff < 1e-7, got {diff} with p={p} (expected ~{true_p})."
373 );
374 }
375
376 #[test]
377 fn iso_gauss_unnorm_log_prob_test_2() {
378 let distr = IsotropicGaussian::new(2.0);
379 let p = normalize_isogauss(distr.unnorm_log_prob(&[0.42, 9.6]), 2, distr.std);
380 let true_p = 3.864661987252467e-7;
381 let diff = (p - true_p).abs();
382 assert!(
383 diff < 1e-15,
384 "Expected diff < 1e-15, got {diff} with p={p} (expected ~{true_p})"
385 );
386 }
387
388 #[test]
389 fn iso_gauss_unnorm_log_prob_test_3() {
390 let distr = IsotropicGaussian::new(3.0);
391 let p = normalize_isogauss(distr.unnorm_log_prob(&[1.0, 2.0, 3.0]), 3, distr.std);
392 let true_p = 0.001080393185560214;
393 let diff = (p - true_p).abs();
394 assert!(
395 diff < 1e-8,
396 "Expected diff < 1e-8, got {diff} with p={p} (expected ~{true_p})"
397 );
398 }
399}
400
401#[cfg(test)]
402mod categorical_tests {
403 use super::*;
404
405 fn approx_eq(a: f64, b: f64, tol: f64) -> bool {
407 (a - b).abs() < tol
408 }
409
410 #[test]
414 fn test_categorical_log_prob_f64() {
415 let probs = vec![0.2, 0.3, 0.5];
416 let cat = Categorical::<f64>::new(probs.clone());
417
418 let log_prob_0 = cat.log_prob(0);
420 let log_prob_1 = cat.log_prob(1);
421 let log_prob_2 = cat.log_prob(2);
422
423 let expected_0 = 0.2_f64.ln();
425 let expected_1 = 0.3_f64.ln();
426 let expected_2 = 0.5_f64.ln();
427
428 let tol = 1e-7;
429 assert!(
430 approx_eq(log_prob_0, expected_0, tol),
431 "Log prob mismatch at index 0: got {}, expected {}",
432 log_prob_0,
433 expected_0
434 );
435 assert!(
436 approx_eq(log_prob_1, expected_1, tol),
437 "Log prob mismatch at index 1: got {}, expected {}",
438 log_prob_1,
439 expected_1
440 );
441 assert!(
442 approx_eq(log_prob_2, expected_2, tol),
443 "Log prob mismatch at index 2: got {}, expected {}",
444 log_prob_2,
445 expected_2
446 );
447
448 let log_prob_out = cat.log_prob(3);
450 assert_eq!(
451 log_prob_out,
452 f64::NEG_INFINITY,
453 "Out-of-bounds index did not return NEG_INFINITY"
454 );
455 }
456
457 #[test]
461 fn test_categorical_sampling_f64() {
462 let probs = vec![0.2, 0.3, 0.5];
463 let mut cat = Categorical::<f64>::new(probs.clone());
464
465 let num_samples = 100_000;
466 let mut counts = vec![0_usize; probs.len()];
467
468 for _ in 0..num_samples {
470 let sample = cat.sample();
471 counts[sample] += 1;
472 }
473
474 let tol = 0.01; for (i, &count) in counts.iter().enumerate() {
477 let freq = count as f64 / num_samples as f64;
478 let expected = probs[i];
479 assert!(
480 approx_eq(freq, expected, tol),
481 "Empirical freq for index {} is off: got {:.3}, expected {:.3}",
482 i,
483 freq,
484 expected
485 );
486 }
487 }
488
489 #[test]
493 fn test_categorical_log_prob_f32() {
494 let probs = vec![0.1_f32, 0.4, 0.5];
495 let cat = Categorical::<f32>::new(probs.clone());
496
497 let log_prob_0: f32 = cat.log_prob(0);
498 let log_prob_1 = cat.log_prob(1);
499 let log_prob_2 = cat.log_prob(2);
500
501 let expected_0 = (0.1_f64).ln();
503 let expected_1 = (0.4_f64).ln();
504 let expected_2 = (0.5_f64).ln();
505
506 let tol = 1e-6;
507 assert!(
508 approx_eq(log_prob_0.into(), expected_0, tol),
509 "Log prob mismatch at index 0 (f32 -> f64 cast)"
510 );
511 assert!(
512 approx_eq(log_prob_1.into(), expected_1, tol),
513 "Log prob mismatch at index 1"
514 );
515 assert!(
516 approx_eq(log_prob_2.into(), expected_2, tol),
517 "Log prob mismatch at index 2"
518 );
519
520 let log_prob_out = cat.log_prob(3);
522 assert_eq!(log_prob_out, f32::NEG_INFINITY);
523 }
524
525 #[test]
529 fn test_categorical_sampling_f32() {
530 let probs = vec![0.1_f32, 0.4, 0.5];
531 let mut cat = Categorical::<f32>::new(probs.clone());
532
533 let num_samples = 100_000;
534 let mut counts = vec![0_usize; probs.len()];
535
536 for _ in 0..num_samples {
537 let sample = cat.sample();
538 counts[sample] += 1;
539 }
540
541 let tol = 0.02; for (i, &count) in counts.iter().enumerate() {
544 let freq = count as f32 / num_samples as f32;
545 let expected = probs[i];
546 assert!(
547 (freq - expected).abs() < tol,
548 "Empirical freq for index {} is off: got {:.3}, expected {:.3}",
549 i,
550 freq,
551 expected
552 );
553 }
554 }
555
556 #[test]
557 fn test_categorical_sample_single_value() {
558 let mut cat = Categorical {
559 probs: vec![1.0_f64],
560 rng: rand::rngs::SmallRng::from_seed(Default::default()),
561 };
562
563 let sampled_index = cat.sample();
564
565 assert_eq!(
566 sampled_index, 0,
567 "Should return the last index (0) for a single-element vector"
568 );
569 }
570
571 #[test]
572 fn test_target_for_categorical_in_range() {
573 let probs = vec![0.2_f64, 0.3, 0.5];
575 let cat = Categorical::new(probs.clone());
576 let logp = cat.unnorm_log_prob(&[1]);
578 let expected = 0.3_f64.ln();
580 let tol = 1e-7;
581 assert!(
582 (logp - expected).abs() < tol,
583 "For index 1, expected ln(0.3) ~ {}, got {}",
584 expected,
585 logp
586 );
587 }
588
589 #[test]
590 fn test_target_for_categorical_out_of_range() {
591 let probs = vec![0.2_f64, 0.3, 0.5];
592 let cat = Categorical::new(probs);
593 let logp = cat.unnorm_log_prob(&[3]);
596 assert_eq!(
597 logp,
598 f64::NEG_INFINITY,
599 "Expected negative infinity for out-of-range index, got {}",
600 logp
601 );
602 }
603
604 #[test]
605 fn test_gaussian2d_log_prob() {
606 let mean = arr1(&[0.0, 0.0]);
607 let cov = arr2(&[[1.0, 0.0], [0.0, 1.0]]);
608 let gauss = Gaussian2D { mean, cov };
609
610 let theta = vec![0.5, -0.5];
611 let computed_logp = gauss.log_prob(&theta);
612
613 let expected_logp = -2.0878770664093453;
614
615 let tol = 1e-10;
616 assert!(
617 (computed_logp - expected_logp).abs() < tol,
618 "Computed log probability ({}) differs from expected ({}) by more than tolerance ({})",
619 computed_logp,
620 expected_logp,
621 tol
622 );
623 }
624}