lace_stats/
mh.rs

1use rand::Rng;
2use std::f64;
3
4/// Information from the last step of a Metropolis-Hastings (MH) update
5pub struct MhResult<T> {
6    /// The final value of the Markov chain
7    pub x: T,
8    /// The final score value of x. This function will depend on what type of
9    /// sampler is being used.
10    pub score_x: f64,
11}
12
13impl<T> From<(T, f64)> for MhResult<T> {
14    fn from(tuple: (T, f64)) -> MhResult<T> {
15        MhResult {
16            x: tuple.0,
17            score_x: tuple.1,
18        }
19    }
20}
21
22/// Draw posterior samples from f(x|y)π(x) by taking proposals from the prior
23///
24/// # Arguments
25/// - x_start: the starting value
26/// - loglike: the liklihood function, f(y|x)
27/// - prior_draw: the draw function of the prior on `x`
28/// - n_iters: the number of MH steps
29/// - rng: The random number generator
30pub fn mh_prior<T, F, D, R: Rng>(
31    x_start: T,
32    loglike: F,
33    prior_draw: D,
34    n_iters: usize,
35    rng: &mut R,
36) -> MhResult<T>
37where
38    F: Fn(&T) -> f64,
39    D: Fn(&mut R) -> T,
40{
41    let x = x_start;
42    let fx = loglike(&x);
43    (0..n_iters)
44        .fold((x, fx), |(x, fx), _| {
45            let y = prior_draw(rng);
46            let fy = loglike(&y);
47
48            assert!(fy.is_finite(), "Non finite proposal likelihood");
49
50            let r: f64 = rng.gen::<f64>();
51            if r.ln() < fy - fx {
52                (y, fy)
53            } else {
54                (x, fx)
55            }
56        })
57        .into()
58}
59
60// TODO: rename this to mc_importance, because importance in Monte Carlo, not
61// Metropolis-Hastings
62/// Draw posterior samples from f(x|y)π(x) by taking proposals from a static
63/// importance distribution, Q.
64///
65/// # Arguments
66/// - x_start: the starting value
67/// - ln_f: The log proportional posterior
68/// - q_draw: Function that takes a Rng and draws from Q.
69/// - q_ln_f: Function that evaluates the log likelihood of Q at x
70/// - n_iters: the number of MH iterations
71/// - rng: The random number generator
72pub fn mh_importance<T, Fx, Dq, Fq, R: Rng>(
73    x_start: T,
74    ln_f: Fx,
75    q_draw: Dq,
76    q_ln_f: Fq,
77    n_iters: usize,
78    rng: &mut R,
79) -> MhResult<T>
80where
81    Fx: Fn(&T) -> f64,
82    Dq: Fn(&mut R) -> T,
83    Fq: Fn(&T) -> f64,
84{
85    let x = x_start;
86    let fx = ln_f(&x) - q_ln_f(&x);
87    (0..n_iters)
88        .fold((x, fx), |(x, fx), _| {
89            let y = q_draw(rng);
90            let fy = ln_f(&y) - q_ln_f(&y);
91
92            assert!(fy.is_finite(), "Non finite proposal likelihood");
93
94            let r: f64 = rng.gen::<f64>();
95            if r.ln() < fy - fx {
96                (y, fy)
97            } else {
98                (x, fx)
99            }
100        })
101        .into()
102}
103
104/// Symmetric random walk MCMC
105///
106/// # Arguments
107/// - x_start: the starting value
108/// - score_fn: the score function. For Bayesian inference: f(x|θ)π(θ)
109/// - walk_fn: a symmetric transition function q(x -> x') = q(x' -> x). Should
110///   enforce the domain bounds.
111/// - n_iters: the number of MH steps
112/// - rng: The random number generator
113pub fn mh_symrw<T, F, Q, R>(
114    x_start: T,
115    score_fn: F,
116    walk_fn: Q,
117    n_iters: usize,
118    rng: &mut R,
119) -> MhResult<T>
120where
121    F: Fn(&T) -> f64,
122    Q: Fn(&T, &mut R) -> T,
123    R: Rng,
124{
125    let score_x = score_fn(&x_start);
126    let x = x_start;
127    (0..n_iters)
128        .fold((x, score_x), |(x, fx), _| {
129            let y = walk_fn(&x, rng);
130            let fy = score_fn(&y);
131
132            assert!(fy.is_finite(), "Non finite proposal likelihood");
133
134            let r: f64 = rng.gen::<f64>();
135            if r.ln() < fy - fx {
136                (y, fy)
137            } else {
138                (x, fx)
139            }
140        })
141        .into()
142}
143
144fn slice_stepping_out<F>(
145    ln_height: f64,
146    x: f64,
147    step_size: f64,
148    score_fn: &F,
149    r: f64,
150    bounds: (f64, f64),
151) -> (f64, f64)
152where
153    F: Fn(f64) -> f64,
154{
155    let step_limit = 15_usize;
156
157    let x_left = {
158        let mut x_left = r.mul_add(-step_size, x);
159        let mut loop_counter: usize = 0;
160        let mut step = step_size;
161        loop {
162            let ln_fx_left = score_fn(x_left);
163            if x_left < bounds.0 {
164                break bounds.0;
165            } else if ln_fx_left < ln_height {
166                break x_left;
167            }
168
169            x_left -= step;
170            step *= 2.0;
171
172            if loop_counter == step_limit {
173                panic!(
174                    "x_left step ({}/{}) limit ({}) hit. x = {}, height = {}, fx = {}",
175                    step_size, step, step_limit, x, ln_height, ln_fx_left,
176                )
177            }
178            loop_counter += 1;
179        }
180    };
181
182    let x_right = {
183        let mut x_right = (1.0 - r).mul_add(step_size, x);
184        let mut loop_counter: usize = 0;
185        let mut step = step_size;
186        loop {
187            let ln_fx_right = score_fn(x_right);
188            if x_right > bounds.1 {
189                break bounds.1;
190            } else if ln_fx_right < ln_height {
191                break x_right;
192            }
193
194            x_right += step;
195            step *= 2.0;
196
197            if loop_counter == step_limit {
198                panic!("x_right step limit ({}) hit", step_limit)
199            }
200            loop_counter += 1;
201        }
202    };
203
204    (x_left, x_right)
205}
206
207fn mh_slice_step<F, R>(
208    x_start: f64,
209    step_size: f64,
210    score_fn: &F,
211    bounds: (f64, f64),
212    mut rng: &mut R,
213) -> MhResult<f64>
214where
215    F: Fn(f64) -> f64,
216    R: Rng,
217{
218    use crate::rv::dist::Uniform;
219    use crate::rv::traits::Rv;
220
221    let ln_fx = score_fn(x_start);
222    let ln_u = rng.gen::<f64>().ln() + ln_fx;
223    let (mut x_left, mut x_right) = slice_stepping_out(
224        ln_u,
225        x_start,
226        step_size,
227        &score_fn,
228        rng.gen::<f64>(),
229        bounds,
230    );
231
232    let step_limit = 50;
233    let mut loop_counter = 0;
234    loop {
235        let x: f64 = Uniform::new_unchecked(x_left, x_right).draw(&mut rng);
236        let ln_fx = score_fn(x);
237        // println!("{}: ({}, {}) - [{}, {}]", x, x_left, x_right, ln_u, ln_fx);
238        if ln_fx > ln_u {
239            break MhResult { x, score_x: ln_fx };
240        }
241
242        if loop_counter == step_limit {
243            panic!("Slice interval tuning limit ({}) hit", step_limit)
244        }
245
246        if x > x_start {
247            x_right = x;
248        } else {
249            x_left = x;
250        };
251
252        loop_counter += 1;
253    }
254}
255
256/// Uses a slice sampler w/ the stepping out method to draw from a univariate
257/// posterior distribution.
258///
259/// # Notes
260/// Under some circumstances, the stepping out will hit the max iterations and
261/// cause a panic. You might want to stay away from this sampler if you don't
262/// know that your posterior is well behaved.
263pub fn mh_slice<F, R>(
264    x_start: f64,
265    step_size: f64,
266    n_iters: usize,
267    score_fn: F,
268    bounds: (f64, f64),
269    mut rng: &mut R,
270) -> MhResult<f64>
271where
272    F: Fn(f64) -> f64,
273    R: Rng,
274{
275    (0..n_iters).fold(
276        mh_slice_step(x_start, step_size, &score_fn, bounds, &mut rng),
277        |acc, _| mh_slice_step(acc.x, step_size, &score_fn, bounds, &mut rng),
278    )
279}
280
281pub fn mh_symrw_adaptive<F, R>(
282    x_start: f64,
283    mut mu_guess: f64,
284    mut var_guess: f64,
285    n_steps: usize,
286    score_fn: F,
287    bounds: (f64, f64),
288    mut rng: &mut R,
289) -> MhResult<f64>
290where
291    F: Fn(f64) -> f64,
292    R: Rng,
293{
294    use crate::rv::dist::Gaussian;
295    use crate::rv::traits::Rv;
296
297    // FIXME: initialize this properly
298    let gamma_init = 0.9;
299
300    let mut x = x_start;
301    let mut fx = score_fn(x);
302    let mut x_sum = x;
303    let lambda: f64 = 2.38 * 2.38;
304
305    for n in 0..n_steps {
306        let y: f64 = Gaussian::new_unchecked(x, (lambda * var_guess).sqrt())
307            .draw(&mut rng);
308        if bounds.0 < x || x < bounds.1 {
309            let fy = score_fn(y);
310
311            assert!(fy.is_finite(), "Non finite proposal likelihood");
312
313            if rng.gen::<f64>().ln() < fy - fx {
314                x = y;
315                fx = fy;
316            }
317        }
318        x_sum += x;
319        let x_bar = x_sum / (n + 1) as f64;
320        let gamma = gamma_init / (n + 1) as f64;
321        let mu_next = (x_bar - mu_guess).mul_add(gamma, mu_guess);
322        var_guess = (x - mu_guess)
323            .mul_add(x - mu_guess, -var_guess)
324            .mul_add(gamma, var_guess);
325        mu_guess = mu_next;
326    }
327
328    // println!("[A: {}], (mu, sigma) = ({}, {})", acc / n_steps as f64, mu_guess, var_guess.sqrt());
329
330    MhResult { x, score_x: fx }
331}
332
333use crate::mat::{MeanVector, ScaleMatrix, SquareT};
334use std::ops::Mul;
335
336/// Multivariate adaptive Metropolis-Hastings sampler using globally adaptive
337/// symmetric random walk.
338///
339/// # Notes
340///
341/// This sampler is slow and unstable due to all the matrix math, and often does
342/// not achieve the correct stationary distribution, but most often achieves the
343/// correct posterior mean -- if you care about that.
344pub fn mh_symrw_adaptive_mv<F, R, M, S>(
345    x_start: M,
346    mut mu_guess: M,
347    mut var_guess: S,
348    n_steps: usize,
349    score_fn: F,
350    bounds: &[(f64, f64)],
351    mut rng: &mut R,
352) -> MhResult<Vec<f64>>
353where
354    F: Fn(&[f64]) -> f64,
355    R: Rng,
356    M: MeanVector + SquareT<Output = S> + Mul<f64, Output = M>,
357    S: ScaleMatrix + Mul<f64, Output = S>,
358{
359    use crate::rv::dist::MvGaussian;
360    use crate::rv::nalgebra::{DMatrix, DVector};
361    use crate::rv::traits::Rv;
362
363    // TODO: initialize this properly
364    // let gamma = (n_steps as f64).recip();
365    let gamma = 0.5;
366
367    let mut x = x_start;
368    let mut fx = score_fn(x.values());
369    let mut x_sum = M::zeros().mv_add(&x);
370    let mut ln_lambda: f64 = (2.38 * 2.38 / x.len() as f64).ln();
371
372    let n_rows = x.len();
373
374    for n in 0..n_steps {
375        var_guess.diagonalize();
376        let cov = DMatrix::from_row_slice(n_rows, n_rows, var_guess.values());
377        let mu = DVector::from_row_slice(x.values());
378
379        let y: DVector<f64> =
380            MvGaussian::new_unchecked(mu, ln_lambda.exp() * cov).draw(&mut rng);
381        let y = M::from_dvector(y);
382
383        let in_bounds = y
384            .values()
385            .iter()
386            .zip(bounds.iter())
387            .all(|(&y_i, bounds_i)| bounds_i.0 < y_i && y_i < bounds_i.1);
388
389        let alpha = if in_bounds {
390            let fy = score_fn(y.values());
391
392            assert!(fy.is_finite(), "Non finite proposal likelihood");
393
394            let ln_alpha = (fy - fx).min(0.0);
395            if rng.gen::<f64>().ln() < ln_alpha {
396                x = y;
397                fx = fy;
398            }
399            ln_alpha.exp()
400        } else {
401            0.0
402        };
403
404        x_sum = x_sum.mv_add(&x);
405        ln_lambda += gamma * (alpha - 0.234);
406
407        let x_bar = M::zeros().mv_add(&x_sum) * (n as f64 + 1.0).recip();
408        let mu_next = (x_bar.mv_sub(&mu_guess) * gamma).mv_add(&mu_guess);
409        var_guess = (x.clone().mv_sub(&mu_guess).square_t().mv_sub(&var_guess)
410            * gamma)
411            .mv_add(&var_guess);
412        mu_guess = mu_next;
413    }
414
415    MhResult {
416        x: Vec::from(x.values()),
417        score_x: fx,
418    }
419}
420
421#[cfg(test)]
422mod tests {
423    use super::*;
424    use crate::rv::dist::{Bernoulli, Beta, Gaussian};
425    use crate::rv::misc::ks_test;
426    use crate::rv::traits::{Cdf, Rv};
427    use rand_distr::Normal;
428
429    const KS_PVAL: f64 = 0.2;
430    const N_FLAKY_TEST: usize = 10;
431
432    fn mh_chain<F, X, R>(
433        x_start: X,
434        mh_fn: F,
435        n_steps: usize,
436        rng: &mut R,
437    ) -> Vec<X>
438    where
439        X: Clone,
440        F: Fn(&X, &mut R) -> X,
441        R: Rng,
442    {
443        let mut x = x_start;
444        let mut samples: Vec<X> = Vec::with_capacity(n_steps);
445        for _ in 0..n_steps {
446            let y = mh_fn(&x, rng);
447            samples.push(y.clone());
448            x = y
449        }
450
451        samples
452    }
453
454    #[test]
455    fn test_mh_prior_uniform() {
456        let loglike = |_x: &f64| 0.0;
457        fn prior_draw<R: Rng>(r: &mut R) -> f64 {
458            r.gen()
459        }
460
461        let mut rng = rand::thread_rng();
462        let n_passes = (0..N_FLAKY_TEST).fold(0, |acc, _| {
463            let xs = mh_chain(
464                0.5,
465                |&x, mut rng| mh_prior(x, loglike, prior_draw, 1, &mut rng).x,
466                500,
467                &mut rng,
468            );
469            let (_, p) = ks_test(&xs, |x| x);
470
471            if p > KS_PVAL {
472                acc + 1
473            } else {
474                acc
475            }
476        });
477
478        assert!(n_passes > 0);
479    }
480
481    #[test]
482    fn test_mh_prior_gaussian() {
483        let gauss = Gaussian::standard();
484        let loglike = |_x: &f64| 0.0;
485        fn prior_draw<R: Rng>(r: &mut R) -> f64 {
486            let norm = Normal::new(0.0, 1.0).unwrap();
487            r.sample(norm)
488        }
489
490        let mut rng = rand::thread_rng();
491        let n_passes = (0..N_FLAKY_TEST).fold(0, |acc, _| {
492            let xs = mh_chain(
493                0.5,
494                |&x, mut rng| mh_prior(x, loglike, prior_draw, 1, &mut rng).x,
495                500,
496                &mut rng,
497            );
498            let (_, p) = ks_test(&xs, |x| gauss.cdf(&x));
499
500            if p > KS_PVAL {
501                acc + 1
502            } else {
503                acc
504            }
505        });
506
507        assert!(n_passes > 0);
508    }
509
510    #[test]
511    fn test_mh_importance_beta() {
512        let xs: Vec<u8> = vec![0, 0, 1, 1, 1, 1];
513        let prior = Beta::new(2.0, 2.0).unwrap();
514
515        // Proportional to the posterior
516        let ln_fn = |theta: &f64| {
517            let likelihood = Bernoulli::new(*theta).unwrap();
518            let f: f64 = xs.iter().map(|x| likelihood.ln_f(x)).sum();
519            f + prior.ln_f(theta)
520        };
521
522        fn q_draw<R: Rng>(mut rng: &mut R) -> f64 {
523            let q = Beta::new(2.0, 1.0).unwrap();
524            q.draw(&mut rng)
525        }
526
527        fn q_ln_f(theta: &f64) -> f64 {
528            let q = Beta::new(2.0, 1.0).unwrap();
529            q.ln_f(theta)
530        }
531
532        let mut rng = rand::thread_rng();
533        let n_passes = (0..N_FLAKY_TEST).fold(0, |acc, _| {
534            let xs = mh_chain(
535                0.5,
536                |&x, mut rng| {
537                    mh_importance(x, ln_fn, q_draw, q_ln_f, 2, &mut rng).x
538                },
539                250,
540                &mut rng,
541            );
542
543            let true_posterior = Beta::new(2.0 + 4.0, 2.0 + 2.0).unwrap();
544
545            let (_, p) = ks_test(&xs, |x| true_posterior.cdf(&x));
546
547            if p > KS_PVAL {
548                acc + 1
549            } else {
550                acc
551            }
552        });
553
554        assert!(n_passes > 0);
555    }
556
557    #[test]
558    fn test_symrw_uniform() {
559        let score_fn = |_x: &f64| 0.0;
560        fn walk_fn<R: Rng>(x: &f64, r: &mut R) -> f64 {
561            let norm = Normal::new(*x, 0.2).unwrap();
562
563            r.sample(norm).rem_euclid(1.0)
564        }
565
566        let mut rng = rand::thread_rng();
567        let n_passes = (0..N_FLAKY_TEST).fold(0, |acc, _| {
568            let xs = mh_chain(
569                0.5,
570                |&x, mut rng| mh_symrw(x, score_fn, walk_fn, 1, &mut rng).x,
571                500,
572                &mut rng,
573            );
574            let (_, p) = ks_test(&xs, |x| x);
575
576            if p > KS_PVAL {
577                acc + 1
578            } else {
579                acc
580            }
581        });
582
583        assert!(n_passes > 0);
584    }
585
586    #[test]
587    fn test_symrw_gaussian() {
588        let gauss = Gaussian::new(1.0, 1.5).unwrap();
589
590        let score_fn = |x: &f64| gauss.ln_f(x);
591        fn walk_fn<R: Rng>(x: &f64, r: &mut R) -> f64 {
592            let norm = Normal::new(*x, 0.5).unwrap();
593            r.sample(norm)
594        }
595
596        let mut rng = rand::thread_rng();
597        let n_passes = (0..N_FLAKY_TEST).fold(0, |acc, _| {
598            let xs = mh_chain(
599                1.0,
600                |&x, mut rng| mh_symrw(x, score_fn, walk_fn, 10, &mut rng).x,
601                250,
602                &mut rng,
603            );
604            let (_, p) = ks_test(&xs, |x| gauss.cdf(&x));
605
606            if p > KS_PVAL {
607                acc + 1
608            } else {
609                acc
610            }
611        });
612
613        assert!(n_passes > 0);
614    }
615
616    #[test]
617    fn test_mh_slice_uniform() {
618        let loglike = |x: f64| {
619            if 0.0 < x && x < 1.0 {
620                0.0
621            } else {
622                std::f64::NEG_INFINITY
623            }
624        };
625
626        let mut rng = rand::thread_rng();
627        let n_passes = (0..N_FLAKY_TEST).fold(0, |acc, _| {
628            let xs = mh_chain(
629                0.5,
630                |&x, mut rng| {
631                    mh_slice(x, 0.2, 1, loglike, (0.0, 1.0), &mut rng).x
632                },
633                500,
634                &mut rng,
635            );
636            let (_, p) = ks_test(&xs, |x| x);
637
638            if p > KS_PVAL {
639                acc + 1
640            } else {
641                acc
642            }
643        });
644
645        assert!(n_passes > 0);
646    }
647
648    #[test]
649    fn test_mh_slice_gaussian() {
650        use std::f64::{INFINITY, NEG_INFINITY};
651
652        let gauss = Gaussian::new(1.0, 1.5).unwrap();
653
654        let score_fn = |x: f64| gauss.ln_f(&x);
655
656        let mut rng = rand::thread_rng();
657        let n_passes = (0..N_FLAKY_TEST).fold(0, |acc, _| {
658            let xs = mh_chain(
659                1.0,
660                |&x, mut rng| {
661                    mh_slice(
662                        x,
663                        1.0,
664                        1,
665                        score_fn,
666                        (NEG_INFINITY, INFINITY),
667                        &mut rng,
668                    )
669                    .x
670                },
671                250,
672                &mut rng,
673            );
674            let (_, p) = ks_test(&xs, |x| gauss.cdf(&x));
675
676            if p > KS_PVAL {
677                acc + 1
678            } else {
679                acc
680            }
681        });
682
683        assert!(n_passes > 0);
684    }
685
686    #[test]
687    fn test_mh_symrw_adaptive_gaussian() {
688        use std::f64::{INFINITY, NEG_INFINITY};
689
690        let gauss = Gaussian::new(1.0, 1.5).unwrap();
691
692        let score_fn = |x: f64| gauss.ln_f(&x);
693
694        let mut rng = rand::thread_rng();
695        let n_passes = (0..N_FLAKY_TEST).fold(0, |acc, _| {
696            let xs = mh_chain(
697                1.0,
698                |&x, mut rng| {
699                    mh_symrw_adaptive(
700                        x,
701                        0.1,
702                        0.1,
703                        100,
704                        score_fn,
705                        (NEG_INFINITY, INFINITY),
706                        &mut rng,
707                    )
708                    .x
709                },
710                250,
711                &mut rng,
712            );
713            let (_, p) = ks_test(&xs, |x| gauss.cdf(&x));
714
715            if p > KS_PVAL {
716                acc + 1
717            } else {
718                acc
719            }
720        });
721
722        assert!(n_passes > 0);
723    }
724
725    #[test]
726    fn test_mh_symrw_adaptive_normal_gamma() {
727        use std::f64::{INFINITY, NEG_INFINITY};
728
729        let mut rng = rand::thread_rng();
730        let sigma: f64 = 1.5;
731        let m0: f64 = 0.0;
732        let s0: f64 = 0.5;
733        let gauss = Gaussian::new(1.0, sigma).unwrap();
734        let prior = Gaussian::new(m0, s0).unwrap();
735
736        let xs: Vec<f64> = gauss.sample(20, &mut rng);
737        let sum_x = xs.iter().sum::<f64>();
738
739        let score_fn = |mu: f64| {
740            let g = Gaussian::new_unchecked(mu, sigma);
741            let fx: f64 = xs.iter().map(|x| g.ln_f(x)).sum();
742            fx + prior.ln_f(&mu)
743        };
744
745        let posterior = {
746            let nf = xs.len() as f64;
747            let s2 = sigma * sigma;
748            let s02 = s0 * s0;
749            let sn = ((nf / s2) + s02.recip()).recip();
750            let mn = sn * (m0 / s02 + sum_x / s2);
751            Gaussian::new(mn, sn.sqrt()).unwrap()
752        };
753
754        let n_passes = (0..N_FLAKY_TEST).fold(0, |acc, _| {
755            let ys = mh_chain(
756                1.0,
757                |&x, mut rng| {
758                    mh_symrw_adaptive(
759                        x,
760                        0.1,
761                        0.1,
762                        100,
763                        score_fn,
764                        (NEG_INFINITY, INFINITY),
765                        &mut rng,
766                    )
767                    .x
768                },
769                250,
770                &mut rng,
771            );
772            let (_, p) = ks_test(&ys, |y| posterior.cdf(&y));
773
774            if p > KS_PVAL {
775                acc + 1
776            } else {
777                acc
778            }
779        });
780
781        assert!(n_passes > 0);
782    }
783
784    #[test]
785    fn test_mh_symrw_mv_normal_gamma_known_var() {
786        let mut rng = rand::thread_rng();
787        let sigma: f64 = 1.5;
788        let m0: f64 = 0.0;
789        let s0: f64 = 0.5;
790        let gauss = Gaussian::new(1.0, sigma).unwrap();
791        let prior = Gaussian::new(m0, s0).unwrap();
792
793        let xs: Vec<f64> = gauss.sample(20, &mut rng);
794        let sum_x = xs.iter().sum::<f64>();
795
796        let score_fn = |mu: &f64| {
797            let g = Gaussian::new_unchecked(*mu, sigma);
798            let fx: f64 = xs.iter().map(|x| g.ln_f(x)).sum();
799            fx + prior.ln_f(mu)
800        };
801
802        fn walk_fn<R: Rng>(x: &f64, mut r: &mut R) -> f64 {
803            Gaussian::new_unchecked(*x, 0.2).draw(&mut r)
804        }
805
806        let posterior = {
807            let nf = xs.len() as f64;
808            let s2 = sigma * sigma;
809            let s02 = s0 * s0;
810
811            let sn = ((nf / s2) + s02.recip()).recip();
812
813            let mn = sn * (m0 / s02 + sum_x / s2);
814            println!("Posterior mean: {}", mn);
815
816            Gaussian::new(mn, sn.sqrt()).unwrap()
817        };
818
819        let n_passes = (0..N_FLAKY_TEST).fold(0, |acc, _| {
820            let ys = mh_chain(
821                1.0,
822                |x, mut rng| mh_symrw(*x, score_fn, walk_fn, 50, &mut rng).x,
823                250,
824                &mut rng,
825            );
826            let (_, p) = ks_test(&ys, |y| posterior.cdf(&y));
827            println!("p: {}, m: {}", p, lace_utils::mean(&ys));
828
829            if p > KS_PVAL {
830                acc + 1
831            } else {
832                acc
833            }
834        });
835
836        assert!(n_passes > 0);
837    }
838
839    #[test]
840    fn test_mh_symrw_adaptive_mv_normal_gamma_known_var() {
841        use crate::mat::Matrix1x1;
842        use std::f64::{INFINITY, NEG_INFINITY};
843
844        let mut rng = rand::thread_rng();
845        let sigma: f64 = 1.5;
846        let m0: f64 = 0.0;
847        let s0: f64 = 0.5;
848        let gauss = Gaussian::new(1.0, sigma).unwrap();
849        let prior = Gaussian::new(m0, s0).unwrap();
850
851        let xs: Vec<f64> = gauss.sample(20, &mut rng);
852        let sum_x = xs.iter().sum::<f64>();
853
854        let score_fn = |mu: &[f64]| {
855            let g = Gaussian::new_unchecked(mu[0], sigma);
856            let fx: f64 = xs.iter().map(|x| g.ln_f(x)).sum();
857            fx + prior.ln_f(&mu[0])
858        };
859
860        let posterior = {
861            let nf = xs.len() as f64;
862            let s2 = sigma * sigma;
863            let s02 = s0 * s0;
864
865            let sn = ((nf / s2) + s02.recip()).recip();
866
867            let mn = sn * (m0 / s02 + sum_x / s2);
868            println!("Posterior mean: {}", mn);
869
870            Gaussian::new(mn, sn.sqrt()).unwrap()
871        };
872
873        let n_passes = (0..N_FLAKY_TEST).fold(0, |acc, _| {
874            let ys = mh_chain(
875                1.0,
876                |x, mut rng| {
877                    mh_symrw_adaptive_mv(
878                        Matrix1x1([*x]),
879                        Matrix1x1([0.1]),
880                        Matrix1x1([0.1]),
881                        10,
882                        score_fn,
883                        &[(NEG_INFINITY, INFINITY)],
884                        &mut rng,
885                    )
886                    .x[0]
887                },
888                250,
889                &mut rng,
890            );
891            let (_, p) = ks_test(&ys, |y| posterior.cdf(&y));
892            println!("p: {}, m: {}", p, lace_utils::mean(&ys));
893
894            if p > KS_PVAL {
895                acc + 1
896            } else {
897                acc
898            }
899        });
900
901        assert!(n_passes > 0);
902    }
903
904    #[test]
905    fn test_mh_symrw_adaptive_mv_normal_gamma_unknown() {
906        use crate::mat::{Matrix2x2, Vector2};
907        use crate::rv::dist::InvGamma;
908        use crate::test::gauss_perm_test;
909        use std::f64::{INFINITY, NEG_INFINITY};
910
911        let n = 20;
912        let n_samples = 250;
913
914        let mut rng = rand::thread_rng();
915
916        // Prior parameters
917        let m0: f64 = 0.0;
918        let v0: f64 = 0.5;
919        let a0: f64 = 1.5;
920        let b0: f64 = 1.0;
921
922        // True distribution
923        let gauss = Gaussian::new(1.0, 1.5).unwrap();
924
925        // prior on sigma
926        let prior_var = InvGamma::new(a0, b0).unwrap();
927
928        // Generate data and get sufficient statistics
929        let xs: Vec<f64> = gauss.sample(n, &mut rng);
930        let sum_x = xs.iter().sum::<f64>();
931        let sum_x_sq = xs.iter().map(|&x| x * x).sum::<f64>();
932
933        println!("Mean(x): {}", sum_x / n as f64);
934
935        // The proportional posterior for MCMC
936        let score_fn = |mu_var: &[f64]| {
937            let mu = mu_var[0];
938            let var = mu_var[1];
939            let sigma = var.sqrt();
940            let g = Gaussian::new_unchecked(mu, sigma);
941            let fx: f64 = xs.iter().map(|x| g.ln_f(x)).sum();
942            let prior_mu = Gaussian::new(m0, v0.sqrt() * sigma).unwrap();
943            fx + prior_mu.ln_f(&mu) + prior_var.ln_f(&var)
944        };
945
946        // Compute the normal inverse-gamma posterior according kevin murphy's
947        // whitepaper
948        let posterior_samples: Vec<(f64, f64)> = {
949            let nf = n as f64;
950
951            let v0_inv = v0.recip();
952            let vn_inv = v0_inv + nf;
953            let mn_over_vn = v0_inv.mul_add(m0, sum_x);
954            let mn = mn_over_vn * vn_inv.recip();
955            let an = a0 + nf / 2.0;
956            let bn = 0.5_f64.mul_add(
957                (mn * mn).mul_add(-vn_inv, (m0 * m0).mul_add(v0_inv, sum_x_sq)),
958                b0,
959            );
960            let vn_sqrt = vn_inv.recip().sqrt();
961
962            let post_var = InvGamma::new(an, bn).unwrap();
963            (0..n_samples)
964                .map(|_| {
965                    let var: f64 = post_var.draw(&mut rng);
966                    let mu: f64 = Gaussian::new(mn, vn_sqrt * var.sqrt())
967                        .unwrap()
968                        .draw(&mut rng);
969                    (mu, var)
970                })
971                .collect()
972        };
973
974        let (mean_mu, var_mu) = {
975            use lace_utils::{mean, var};
976            let mus: Vec<f64> =
977                posterior_samples.iter().map(|xy| xy.0).collect();
978            (mean(&mus), var(&mus))
979        };
980        println!("Posterior Mean/Var: {}/{}", mean_mu, var_mu);
981
982        let n_passes = (0..N_FLAKY_TEST).fold(0, |acc, _| {
983            let mcmc_samples: Vec<(f64, f64)> = mh_chain(
984                (1.0, 1.0),
985                |x, mut rng| {
986                    let x = mh_symrw_adaptive_mv(
987                        Vector2([x.0, x.1]),
988                        Vector2([0.1, 1.0]),
989                        Matrix2x2::from_diag([1.0, 1.0]),
990                        100,
991                        score_fn,
992                        &[(NEG_INFINITY, INFINITY), (0.0, INFINITY)],
993                        &mut rng,
994                    )
995                    .x;
996                    (x[0], x[1])
997                },
998                n_samples,
999                &mut rng,
1000            );
1001
1002            let (mean_mu_mh, var_mu_mh) = {
1003                use lace_utils::{mean, var};
1004                let mus: Vec<f64> =
1005                    mcmc_samples.iter().map(|xy| xy.0).collect();
1006                (mean(&mus), var(&mus))
1007            };
1008
1009            let p = gauss_perm_test(
1010                posterior_samples.clone(),
1011                mcmc_samples,
1012                500,
1013                &mut rng,
1014            );
1015
1016            println!("p: {}", p);
1017            println!("MCMC Mean/Var: {}/{}", mean_mu_mh, var_mu_mh);
1018
1019            if p > KS_PVAL {
1020                acc + 1
1021            } else {
1022                acc
1023            }
1024        });
1025
1026        assert!(n_passes > 0);
1027    }
1028}