1use rand::Rng;
2use std::f64;
3
4pub struct MhResult<T> {
6 pub x: T,
8 pub score_x: f64,
11}
12
13impl<T> From<(T, f64)> for MhResult<T> {
14 fn from(tuple: (T, f64)) -> MhResult<T> {
15 MhResult {
16 x: tuple.0,
17 score_x: tuple.1,
18 }
19 }
20}
21
22pub fn mh_prior<T, F, D, R: Rng>(
31 x_start: T,
32 loglike: F,
33 prior_draw: D,
34 n_iters: usize,
35 rng: &mut R,
36) -> MhResult<T>
37where
38 F: Fn(&T) -> f64,
39 D: Fn(&mut R) -> T,
40{
41 let x = x_start;
42 let fx = loglike(&x);
43 (0..n_iters)
44 .fold((x, fx), |(x, fx), _| {
45 let y = prior_draw(rng);
46 let fy = loglike(&y);
47
48 assert!(fy.is_finite(), "Non finite proposal likelihood");
49
50 let r: f64 = rng.gen::<f64>();
51 if r.ln() < fy - fx {
52 (y, fy)
53 } else {
54 (x, fx)
55 }
56 })
57 .into()
58}
59
60pub fn mh_importance<T, Fx, Dq, Fq, R: Rng>(
73 x_start: T,
74 ln_f: Fx,
75 q_draw: Dq,
76 q_ln_f: Fq,
77 n_iters: usize,
78 rng: &mut R,
79) -> MhResult<T>
80where
81 Fx: Fn(&T) -> f64,
82 Dq: Fn(&mut R) -> T,
83 Fq: Fn(&T) -> f64,
84{
85 let x = x_start;
86 let fx = ln_f(&x) - q_ln_f(&x);
87 (0..n_iters)
88 .fold((x, fx), |(x, fx), _| {
89 let y = q_draw(rng);
90 let fy = ln_f(&y) - q_ln_f(&y);
91
92 assert!(fy.is_finite(), "Non finite proposal likelihood");
93
94 let r: f64 = rng.gen::<f64>();
95 if r.ln() < fy - fx {
96 (y, fy)
97 } else {
98 (x, fx)
99 }
100 })
101 .into()
102}
103
104pub fn mh_symrw<T, F, Q, R>(
114 x_start: T,
115 score_fn: F,
116 walk_fn: Q,
117 n_iters: usize,
118 rng: &mut R,
119) -> MhResult<T>
120where
121 F: Fn(&T) -> f64,
122 Q: Fn(&T, &mut R) -> T,
123 R: Rng,
124{
125 let score_x = score_fn(&x_start);
126 let x = x_start;
127 (0..n_iters)
128 .fold((x, score_x), |(x, fx), _| {
129 let y = walk_fn(&x, rng);
130 let fy = score_fn(&y);
131
132 assert!(fy.is_finite(), "Non finite proposal likelihood");
133
134 let r: f64 = rng.gen::<f64>();
135 if r.ln() < fy - fx {
136 (y, fy)
137 } else {
138 (x, fx)
139 }
140 })
141 .into()
142}
143
144fn slice_stepping_out<F>(
145 ln_height: f64,
146 x: f64,
147 step_size: f64,
148 score_fn: &F,
149 r: f64,
150 bounds: (f64, f64),
151) -> (f64, f64)
152where
153 F: Fn(f64) -> f64,
154{
155 let step_limit = 15_usize;
156
157 let x_left = {
158 let mut x_left = r.mul_add(-step_size, x);
159 let mut loop_counter: usize = 0;
160 let mut step = step_size;
161 loop {
162 let ln_fx_left = score_fn(x_left);
163 if x_left < bounds.0 {
164 break bounds.0;
165 } else if ln_fx_left < ln_height {
166 break x_left;
167 }
168
169 x_left -= step;
170 step *= 2.0;
171
172 if loop_counter == step_limit {
173 panic!(
174 "x_left step ({}/{}) limit ({}) hit. x = {}, height = {}, fx = {}",
175 step_size, step, step_limit, x, ln_height, ln_fx_left,
176 )
177 }
178 loop_counter += 1;
179 }
180 };
181
182 let x_right = {
183 let mut x_right = (1.0 - r).mul_add(step_size, x);
184 let mut loop_counter: usize = 0;
185 let mut step = step_size;
186 loop {
187 let ln_fx_right = score_fn(x_right);
188 if x_right > bounds.1 {
189 break bounds.1;
190 } else if ln_fx_right < ln_height {
191 break x_right;
192 }
193
194 x_right += step;
195 step *= 2.0;
196
197 if loop_counter == step_limit {
198 panic!("x_right step limit ({}) hit", step_limit)
199 }
200 loop_counter += 1;
201 }
202 };
203
204 (x_left, x_right)
205}
206
207fn mh_slice_step<F, R>(
208 x_start: f64,
209 step_size: f64,
210 score_fn: &F,
211 bounds: (f64, f64),
212 mut rng: &mut R,
213) -> MhResult<f64>
214where
215 F: Fn(f64) -> f64,
216 R: Rng,
217{
218 use crate::rv::dist::Uniform;
219 use crate::rv::traits::Rv;
220
221 let ln_fx = score_fn(x_start);
222 let ln_u = rng.gen::<f64>().ln() + ln_fx;
223 let (mut x_left, mut x_right) = slice_stepping_out(
224 ln_u,
225 x_start,
226 step_size,
227 &score_fn,
228 rng.gen::<f64>(),
229 bounds,
230 );
231
232 let step_limit = 50;
233 let mut loop_counter = 0;
234 loop {
235 let x: f64 = Uniform::new_unchecked(x_left, x_right).draw(&mut rng);
236 let ln_fx = score_fn(x);
237 if ln_fx > ln_u {
239 break MhResult { x, score_x: ln_fx };
240 }
241
242 if loop_counter == step_limit {
243 panic!("Slice interval tuning limit ({}) hit", step_limit)
244 }
245
246 if x > x_start {
247 x_right = x;
248 } else {
249 x_left = x;
250 };
251
252 loop_counter += 1;
253 }
254}
255
256pub fn mh_slice<F, R>(
264 x_start: f64,
265 step_size: f64,
266 n_iters: usize,
267 score_fn: F,
268 bounds: (f64, f64),
269 mut rng: &mut R,
270) -> MhResult<f64>
271where
272 F: Fn(f64) -> f64,
273 R: Rng,
274{
275 (0..n_iters).fold(
276 mh_slice_step(x_start, step_size, &score_fn, bounds, &mut rng),
277 |acc, _| mh_slice_step(acc.x, step_size, &score_fn, bounds, &mut rng),
278 )
279}
280
281pub fn mh_symrw_adaptive<F, R>(
282 x_start: f64,
283 mut mu_guess: f64,
284 mut var_guess: f64,
285 n_steps: usize,
286 score_fn: F,
287 bounds: (f64, f64),
288 mut rng: &mut R,
289) -> MhResult<f64>
290where
291 F: Fn(f64) -> f64,
292 R: Rng,
293{
294 use crate::rv::dist::Gaussian;
295 use crate::rv::traits::Rv;
296
297 let gamma_init = 0.9;
299
300 let mut x = x_start;
301 let mut fx = score_fn(x);
302 let mut x_sum = x;
303 let lambda: f64 = 2.38 * 2.38;
304
305 for n in 0..n_steps {
306 let y: f64 = Gaussian::new_unchecked(x, (lambda * var_guess).sqrt())
307 .draw(&mut rng);
308 if bounds.0 < x || x < bounds.1 {
309 let fy = score_fn(y);
310
311 assert!(fy.is_finite(), "Non finite proposal likelihood");
312
313 if rng.gen::<f64>().ln() < fy - fx {
314 x = y;
315 fx = fy;
316 }
317 }
318 x_sum += x;
319 let x_bar = x_sum / (n + 1) as f64;
320 let gamma = gamma_init / (n + 1) as f64;
321 let mu_next = (x_bar - mu_guess).mul_add(gamma, mu_guess);
322 var_guess = (x - mu_guess)
323 .mul_add(x - mu_guess, -var_guess)
324 .mul_add(gamma, var_guess);
325 mu_guess = mu_next;
326 }
327
328 MhResult { x, score_x: fx }
331}
332
333use crate::mat::{MeanVector, ScaleMatrix, SquareT};
334use std::ops::Mul;
335
336pub fn mh_symrw_adaptive_mv<F, R, M, S>(
345 x_start: M,
346 mut mu_guess: M,
347 mut var_guess: S,
348 n_steps: usize,
349 score_fn: F,
350 bounds: &[(f64, f64)],
351 mut rng: &mut R,
352) -> MhResult<Vec<f64>>
353where
354 F: Fn(&[f64]) -> f64,
355 R: Rng,
356 M: MeanVector + SquareT<Output = S> + Mul<f64, Output = M>,
357 S: ScaleMatrix + Mul<f64, Output = S>,
358{
359 use crate::rv::dist::MvGaussian;
360 use crate::rv::nalgebra::{DMatrix, DVector};
361 use crate::rv::traits::Rv;
362
363 let gamma = 0.5;
366
367 let mut x = x_start;
368 let mut fx = score_fn(x.values());
369 let mut x_sum = M::zeros().mv_add(&x);
370 let mut ln_lambda: f64 = (2.38 * 2.38 / x.len() as f64).ln();
371
372 let n_rows = x.len();
373
374 for n in 0..n_steps {
375 var_guess.diagonalize();
376 let cov = DMatrix::from_row_slice(n_rows, n_rows, var_guess.values());
377 let mu = DVector::from_row_slice(x.values());
378
379 let y: DVector<f64> =
380 MvGaussian::new_unchecked(mu, ln_lambda.exp() * cov).draw(&mut rng);
381 let y = M::from_dvector(y);
382
383 let in_bounds = y
384 .values()
385 .iter()
386 .zip(bounds.iter())
387 .all(|(&y_i, bounds_i)| bounds_i.0 < y_i && y_i < bounds_i.1);
388
389 let alpha = if in_bounds {
390 let fy = score_fn(y.values());
391
392 assert!(fy.is_finite(), "Non finite proposal likelihood");
393
394 let ln_alpha = (fy - fx).min(0.0);
395 if rng.gen::<f64>().ln() < ln_alpha {
396 x = y;
397 fx = fy;
398 }
399 ln_alpha.exp()
400 } else {
401 0.0
402 };
403
404 x_sum = x_sum.mv_add(&x);
405 ln_lambda += gamma * (alpha - 0.234);
406
407 let x_bar = M::zeros().mv_add(&x_sum) * (n as f64 + 1.0).recip();
408 let mu_next = (x_bar.mv_sub(&mu_guess) * gamma).mv_add(&mu_guess);
409 var_guess = (x.clone().mv_sub(&mu_guess).square_t().mv_sub(&var_guess)
410 * gamma)
411 .mv_add(&var_guess);
412 mu_guess = mu_next;
413 }
414
415 MhResult {
416 x: Vec::from(x.values()),
417 score_x: fx,
418 }
419}
420
421#[cfg(test)]
422mod tests {
423 use super::*;
424 use crate::rv::dist::{Bernoulli, Beta, Gaussian};
425 use crate::rv::misc::ks_test;
426 use crate::rv::traits::{Cdf, Rv};
427 use rand_distr::Normal;
428
429 const KS_PVAL: f64 = 0.2;
430 const N_FLAKY_TEST: usize = 10;
431
432 fn mh_chain<F, X, R>(
433 x_start: X,
434 mh_fn: F,
435 n_steps: usize,
436 rng: &mut R,
437 ) -> Vec<X>
438 where
439 X: Clone,
440 F: Fn(&X, &mut R) -> X,
441 R: Rng,
442 {
443 let mut x = x_start;
444 let mut samples: Vec<X> = Vec::with_capacity(n_steps);
445 for _ in 0..n_steps {
446 let y = mh_fn(&x, rng);
447 samples.push(y.clone());
448 x = y
449 }
450
451 samples
452 }
453
454 #[test]
455 fn test_mh_prior_uniform() {
456 let loglike = |_x: &f64| 0.0;
457 fn prior_draw<R: Rng>(r: &mut R) -> f64 {
458 r.gen()
459 }
460
461 let mut rng = rand::thread_rng();
462 let n_passes = (0..N_FLAKY_TEST).fold(0, |acc, _| {
463 let xs = mh_chain(
464 0.5,
465 |&x, mut rng| mh_prior(x, loglike, prior_draw, 1, &mut rng).x,
466 500,
467 &mut rng,
468 );
469 let (_, p) = ks_test(&xs, |x| x);
470
471 if p > KS_PVAL {
472 acc + 1
473 } else {
474 acc
475 }
476 });
477
478 assert!(n_passes > 0);
479 }
480
481 #[test]
482 fn test_mh_prior_gaussian() {
483 let gauss = Gaussian::standard();
484 let loglike = |_x: &f64| 0.0;
485 fn prior_draw<R: Rng>(r: &mut R) -> f64 {
486 let norm = Normal::new(0.0, 1.0).unwrap();
487 r.sample(norm)
488 }
489
490 let mut rng = rand::thread_rng();
491 let n_passes = (0..N_FLAKY_TEST).fold(0, |acc, _| {
492 let xs = mh_chain(
493 0.5,
494 |&x, mut rng| mh_prior(x, loglike, prior_draw, 1, &mut rng).x,
495 500,
496 &mut rng,
497 );
498 let (_, p) = ks_test(&xs, |x| gauss.cdf(&x));
499
500 if p > KS_PVAL {
501 acc + 1
502 } else {
503 acc
504 }
505 });
506
507 assert!(n_passes > 0);
508 }
509
510 #[test]
511 fn test_mh_importance_beta() {
512 let xs: Vec<u8> = vec![0, 0, 1, 1, 1, 1];
513 let prior = Beta::new(2.0, 2.0).unwrap();
514
515 let ln_fn = |theta: &f64| {
517 let likelihood = Bernoulli::new(*theta).unwrap();
518 let f: f64 = xs.iter().map(|x| likelihood.ln_f(x)).sum();
519 f + prior.ln_f(theta)
520 };
521
522 fn q_draw<R: Rng>(mut rng: &mut R) -> f64 {
523 let q = Beta::new(2.0, 1.0).unwrap();
524 q.draw(&mut rng)
525 }
526
527 fn q_ln_f(theta: &f64) -> f64 {
528 let q = Beta::new(2.0, 1.0).unwrap();
529 q.ln_f(theta)
530 }
531
532 let mut rng = rand::thread_rng();
533 let n_passes = (0..N_FLAKY_TEST).fold(0, |acc, _| {
534 let xs = mh_chain(
535 0.5,
536 |&x, mut rng| {
537 mh_importance(x, ln_fn, q_draw, q_ln_f, 2, &mut rng).x
538 },
539 250,
540 &mut rng,
541 );
542
543 let true_posterior = Beta::new(2.0 + 4.0, 2.0 + 2.0).unwrap();
544
545 let (_, p) = ks_test(&xs, |x| true_posterior.cdf(&x));
546
547 if p > KS_PVAL {
548 acc + 1
549 } else {
550 acc
551 }
552 });
553
554 assert!(n_passes > 0);
555 }
556
557 #[test]
558 fn test_symrw_uniform() {
559 let score_fn = |_x: &f64| 0.0;
560 fn walk_fn<R: Rng>(x: &f64, r: &mut R) -> f64 {
561 let norm = Normal::new(*x, 0.2).unwrap();
562
563 r.sample(norm).rem_euclid(1.0)
564 }
565
566 let mut rng = rand::thread_rng();
567 let n_passes = (0..N_FLAKY_TEST).fold(0, |acc, _| {
568 let xs = mh_chain(
569 0.5,
570 |&x, mut rng| mh_symrw(x, score_fn, walk_fn, 1, &mut rng).x,
571 500,
572 &mut rng,
573 );
574 let (_, p) = ks_test(&xs, |x| x);
575
576 if p > KS_PVAL {
577 acc + 1
578 } else {
579 acc
580 }
581 });
582
583 assert!(n_passes > 0);
584 }
585
586 #[test]
587 fn test_symrw_gaussian() {
588 let gauss = Gaussian::new(1.0, 1.5).unwrap();
589
590 let score_fn = |x: &f64| gauss.ln_f(x);
591 fn walk_fn<R: Rng>(x: &f64, r: &mut R) -> f64 {
592 let norm = Normal::new(*x, 0.5).unwrap();
593 r.sample(norm)
594 }
595
596 let mut rng = rand::thread_rng();
597 let n_passes = (0..N_FLAKY_TEST).fold(0, |acc, _| {
598 let xs = mh_chain(
599 1.0,
600 |&x, mut rng| mh_symrw(x, score_fn, walk_fn, 10, &mut rng).x,
601 250,
602 &mut rng,
603 );
604 let (_, p) = ks_test(&xs, |x| gauss.cdf(&x));
605
606 if p > KS_PVAL {
607 acc + 1
608 } else {
609 acc
610 }
611 });
612
613 assert!(n_passes > 0);
614 }
615
616 #[test]
617 fn test_mh_slice_uniform() {
618 let loglike = |x: f64| {
619 if 0.0 < x && x < 1.0 {
620 0.0
621 } else {
622 std::f64::NEG_INFINITY
623 }
624 };
625
626 let mut rng = rand::thread_rng();
627 let n_passes = (0..N_FLAKY_TEST).fold(0, |acc, _| {
628 let xs = mh_chain(
629 0.5,
630 |&x, mut rng| {
631 mh_slice(x, 0.2, 1, loglike, (0.0, 1.0), &mut rng).x
632 },
633 500,
634 &mut rng,
635 );
636 let (_, p) = ks_test(&xs, |x| x);
637
638 if p > KS_PVAL {
639 acc + 1
640 } else {
641 acc
642 }
643 });
644
645 assert!(n_passes > 0);
646 }
647
648 #[test]
649 fn test_mh_slice_gaussian() {
650 use std::f64::{INFINITY, NEG_INFINITY};
651
652 let gauss = Gaussian::new(1.0, 1.5).unwrap();
653
654 let score_fn = |x: f64| gauss.ln_f(&x);
655
656 let mut rng = rand::thread_rng();
657 let n_passes = (0..N_FLAKY_TEST).fold(0, |acc, _| {
658 let xs = mh_chain(
659 1.0,
660 |&x, mut rng| {
661 mh_slice(
662 x,
663 1.0,
664 1,
665 score_fn,
666 (NEG_INFINITY, INFINITY),
667 &mut rng,
668 )
669 .x
670 },
671 250,
672 &mut rng,
673 );
674 let (_, p) = ks_test(&xs, |x| gauss.cdf(&x));
675
676 if p > KS_PVAL {
677 acc + 1
678 } else {
679 acc
680 }
681 });
682
683 assert!(n_passes > 0);
684 }
685
686 #[test]
687 fn test_mh_symrw_adaptive_gaussian() {
688 use std::f64::{INFINITY, NEG_INFINITY};
689
690 let gauss = Gaussian::new(1.0, 1.5).unwrap();
691
692 let score_fn = |x: f64| gauss.ln_f(&x);
693
694 let mut rng = rand::thread_rng();
695 let n_passes = (0..N_FLAKY_TEST).fold(0, |acc, _| {
696 let xs = mh_chain(
697 1.0,
698 |&x, mut rng| {
699 mh_symrw_adaptive(
700 x,
701 0.1,
702 0.1,
703 100,
704 score_fn,
705 (NEG_INFINITY, INFINITY),
706 &mut rng,
707 )
708 .x
709 },
710 250,
711 &mut rng,
712 );
713 let (_, p) = ks_test(&xs, |x| gauss.cdf(&x));
714
715 if p > KS_PVAL {
716 acc + 1
717 } else {
718 acc
719 }
720 });
721
722 assert!(n_passes > 0);
723 }
724
725 #[test]
726 fn test_mh_symrw_adaptive_normal_gamma() {
727 use std::f64::{INFINITY, NEG_INFINITY};
728
729 let mut rng = rand::thread_rng();
730 let sigma: f64 = 1.5;
731 let m0: f64 = 0.0;
732 let s0: f64 = 0.5;
733 let gauss = Gaussian::new(1.0, sigma).unwrap();
734 let prior = Gaussian::new(m0, s0).unwrap();
735
736 let xs: Vec<f64> = gauss.sample(20, &mut rng);
737 let sum_x = xs.iter().sum::<f64>();
738
739 let score_fn = |mu: f64| {
740 let g = Gaussian::new_unchecked(mu, sigma);
741 let fx: f64 = xs.iter().map(|x| g.ln_f(x)).sum();
742 fx + prior.ln_f(&mu)
743 };
744
745 let posterior = {
746 let nf = xs.len() as f64;
747 let s2 = sigma * sigma;
748 let s02 = s0 * s0;
749 let sn = ((nf / s2) + s02.recip()).recip();
750 let mn = sn * (m0 / s02 + sum_x / s2);
751 Gaussian::new(mn, sn.sqrt()).unwrap()
752 };
753
754 let n_passes = (0..N_FLAKY_TEST).fold(0, |acc, _| {
755 let ys = mh_chain(
756 1.0,
757 |&x, mut rng| {
758 mh_symrw_adaptive(
759 x,
760 0.1,
761 0.1,
762 100,
763 score_fn,
764 (NEG_INFINITY, INFINITY),
765 &mut rng,
766 )
767 .x
768 },
769 250,
770 &mut rng,
771 );
772 let (_, p) = ks_test(&ys, |y| posterior.cdf(&y));
773
774 if p > KS_PVAL {
775 acc + 1
776 } else {
777 acc
778 }
779 });
780
781 assert!(n_passes > 0);
782 }
783
784 #[test]
785 fn test_mh_symrw_mv_normal_gamma_known_var() {
786 let mut rng = rand::thread_rng();
787 let sigma: f64 = 1.5;
788 let m0: f64 = 0.0;
789 let s0: f64 = 0.5;
790 let gauss = Gaussian::new(1.0, sigma).unwrap();
791 let prior = Gaussian::new(m0, s0).unwrap();
792
793 let xs: Vec<f64> = gauss.sample(20, &mut rng);
794 let sum_x = xs.iter().sum::<f64>();
795
796 let score_fn = |mu: &f64| {
797 let g = Gaussian::new_unchecked(*mu, sigma);
798 let fx: f64 = xs.iter().map(|x| g.ln_f(x)).sum();
799 fx + prior.ln_f(mu)
800 };
801
802 fn walk_fn<R: Rng>(x: &f64, mut r: &mut R) -> f64 {
803 Gaussian::new_unchecked(*x, 0.2).draw(&mut r)
804 }
805
806 let posterior = {
807 let nf = xs.len() as f64;
808 let s2 = sigma * sigma;
809 let s02 = s0 * s0;
810
811 let sn = ((nf / s2) + s02.recip()).recip();
812
813 let mn = sn * (m0 / s02 + sum_x / s2);
814 println!("Posterior mean: {}", mn);
815
816 Gaussian::new(mn, sn.sqrt()).unwrap()
817 };
818
819 let n_passes = (0..N_FLAKY_TEST).fold(0, |acc, _| {
820 let ys = mh_chain(
821 1.0,
822 |x, mut rng| mh_symrw(*x, score_fn, walk_fn, 50, &mut rng).x,
823 250,
824 &mut rng,
825 );
826 let (_, p) = ks_test(&ys, |y| posterior.cdf(&y));
827 println!("p: {}, m: {}", p, lace_utils::mean(&ys));
828
829 if p > KS_PVAL {
830 acc + 1
831 } else {
832 acc
833 }
834 });
835
836 assert!(n_passes > 0);
837 }
838
839 #[test]
840 fn test_mh_symrw_adaptive_mv_normal_gamma_known_var() {
841 use crate::mat::Matrix1x1;
842 use std::f64::{INFINITY, NEG_INFINITY};
843
844 let mut rng = rand::thread_rng();
845 let sigma: f64 = 1.5;
846 let m0: f64 = 0.0;
847 let s0: f64 = 0.5;
848 let gauss = Gaussian::new(1.0, sigma).unwrap();
849 let prior = Gaussian::new(m0, s0).unwrap();
850
851 let xs: Vec<f64> = gauss.sample(20, &mut rng);
852 let sum_x = xs.iter().sum::<f64>();
853
854 let score_fn = |mu: &[f64]| {
855 let g = Gaussian::new_unchecked(mu[0], sigma);
856 let fx: f64 = xs.iter().map(|x| g.ln_f(x)).sum();
857 fx + prior.ln_f(&mu[0])
858 };
859
860 let posterior = {
861 let nf = xs.len() as f64;
862 let s2 = sigma * sigma;
863 let s02 = s0 * s0;
864
865 let sn = ((nf / s2) + s02.recip()).recip();
866
867 let mn = sn * (m0 / s02 + sum_x / s2);
868 println!("Posterior mean: {}", mn);
869
870 Gaussian::new(mn, sn.sqrt()).unwrap()
871 };
872
873 let n_passes = (0..N_FLAKY_TEST).fold(0, |acc, _| {
874 let ys = mh_chain(
875 1.0,
876 |x, mut rng| {
877 mh_symrw_adaptive_mv(
878 Matrix1x1([*x]),
879 Matrix1x1([0.1]),
880 Matrix1x1([0.1]),
881 10,
882 score_fn,
883 &[(NEG_INFINITY, INFINITY)],
884 &mut rng,
885 )
886 .x[0]
887 },
888 250,
889 &mut rng,
890 );
891 let (_, p) = ks_test(&ys, |y| posterior.cdf(&y));
892 println!("p: {}, m: {}", p, lace_utils::mean(&ys));
893
894 if p > KS_PVAL {
895 acc + 1
896 } else {
897 acc
898 }
899 });
900
901 assert!(n_passes > 0);
902 }
903
904 #[test]
905 fn test_mh_symrw_adaptive_mv_normal_gamma_unknown() {
906 use crate::mat::{Matrix2x2, Vector2};
907 use crate::rv::dist::InvGamma;
908 use crate::test::gauss_perm_test;
909 use std::f64::{INFINITY, NEG_INFINITY};
910
911 let n = 20;
912 let n_samples = 250;
913
914 let mut rng = rand::thread_rng();
915
916 let m0: f64 = 0.0;
918 let v0: f64 = 0.5;
919 let a0: f64 = 1.5;
920 let b0: f64 = 1.0;
921
922 let gauss = Gaussian::new(1.0, 1.5).unwrap();
924
925 let prior_var = InvGamma::new(a0, b0).unwrap();
927
928 let xs: Vec<f64> = gauss.sample(n, &mut rng);
930 let sum_x = xs.iter().sum::<f64>();
931 let sum_x_sq = xs.iter().map(|&x| x * x).sum::<f64>();
932
933 println!("Mean(x): {}", sum_x / n as f64);
934
935 let score_fn = |mu_var: &[f64]| {
937 let mu = mu_var[0];
938 let var = mu_var[1];
939 let sigma = var.sqrt();
940 let g = Gaussian::new_unchecked(mu, sigma);
941 let fx: f64 = xs.iter().map(|x| g.ln_f(x)).sum();
942 let prior_mu = Gaussian::new(m0, v0.sqrt() * sigma).unwrap();
943 fx + prior_mu.ln_f(&mu) + prior_var.ln_f(&var)
944 };
945
946 let posterior_samples: Vec<(f64, f64)> = {
949 let nf = n as f64;
950
951 let v0_inv = v0.recip();
952 let vn_inv = v0_inv + nf;
953 let mn_over_vn = v0_inv.mul_add(m0, sum_x);
954 let mn = mn_over_vn * vn_inv.recip();
955 let an = a0 + nf / 2.0;
956 let bn = 0.5_f64.mul_add(
957 (mn * mn).mul_add(-vn_inv, (m0 * m0).mul_add(v0_inv, sum_x_sq)),
958 b0,
959 );
960 let vn_sqrt = vn_inv.recip().sqrt();
961
962 let post_var = InvGamma::new(an, bn).unwrap();
963 (0..n_samples)
964 .map(|_| {
965 let var: f64 = post_var.draw(&mut rng);
966 let mu: f64 = Gaussian::new(mn, vn_sqrt * var.sqrt())
967 .unwrap()
968 .draw(&mut rng);
969 (mu, var)
970 })
971 .collect()
972 };
973
974 let (mean_mu, var_mu) = {
975 use lace_utils::{mean, var};
976 let mus: Vec<f64> =
977 posterior_samples.iter().map(|xy| xy.0).collect();
978 (mean(&mus), var(&mus))
979 };
980 println!("Posterior Mean/Var: {}/{}", mean_mu, var_mu);
981
982 let n_passes = (0..N_FLAKY_TEST).fold(0, |acc, _| {
983 let mcmc_samples: Vec<(f64, f64)> = mh_chain(
984 (1.0, 1.0),
985 |x, mut rng| {
986 let x = mh_symrw_adaptive_mv(
987 Vector2([x.0, x.1]),
988 Vector2([0.1, 1.0]),
989 Matrix2x2::from_diag([1.0, 1.0]),
990 100,
991 score_fn,
992 &[(NEG_INFINITY, INFINITY), (0.0, INFINITY)],
993 &mut rng,
994 )
995 .x;
996 (x[0], x[1])
997 },
998 n_samples,
999 &mut rng,
1000 );
1001
1002 let (mean_mu_mh, var_mu_mh) = {
1003 use lace_utils::{mean, var};
1004 let mus: Vec<f64> =
1005 mcmc_samples.iter().map(|xy| xy.0).collect();
1006 (mean(&mus), var(&mus))
1007 };
1008
1009 let p = gauss_perm_test(
1010 posterior_samples.clone(),
1011 mcmc_samples,
1012 500,
1013 &mut rng,
1014 );
1015
1016 println!("p: {}", p);
1017 println!("MCMC Mean/Var: {}/{}", mean_mu_mh, var_mu_mh);
1018
1019 if p > KS_PVAL {
1020 acc + 1
1021 } else {
1022 acc
1023 }
1024 });
1025
1026 assert!(n_passes > 0);
1027 }
1028}