Skip to main content

lace/geweke/
tester.rs

1use std::collections::BTreeMap;
2use std::fmt;
3use std::fs::File;
4use std::io::prelude::Write;
5use std::path::Path;
6
7use indicatif::ProgressBar;
8use rand::Rng;
9use serde::Serialize;
10
11use crate::geweke::traits::*;
12use crate::stats::EmpiricalCdf;
13use crate::utils::transpose_mapvec;
14
15/// Verifies the correctness of MCMC algorithms by way of the "joint
16/// distribution test
17pub struct GewekeTester<G>
18where
19    G: GewekeModel + GewekeResampleData + GewekeSummarize,
20    G::Summary: Into<BTreeMap<String, f64>> + Clone,
21{
22    settings: G::Settings,
23    pub verbose: bool,
24    pub f_chain_out: Vec<G::Summary>,
25    pub p_chain_out: Vec<G::Summary>,
26}
27
28#[derive(Serialize, Debug, Clone)]
29pub struct GewekeResult {
30    pub forward: BTreeMap<String, Vec<f64>>,
31    pub posterior: BTreeMap<String, Vec<f64>>,
32}
33
34impl fmt::Display for GewekeResult {
35    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36        writeln!(f, "Geweke Errors")?;
37        write!(f, "━━━━━━━━━━━━━")?;
38        let errs: BTreeMap<String, f64> = self.aucs().collect();
39        let width = errs.keys().fold(0_usize, |len, k| len.max(k.len()));
40        write!(f, "\n{:width$}  Value", "Stat", width = width)?;
41        write!(f, "\n{:width$}  ━━━━━", "━━━━", width = width)?;
42        errs.iter()
43            .try_for_each(|(k, auc)| write!(f, "\n{k:width$}  {auc}"))
44    }
45}
46
47impl GewekeResult {
48    pub fn aucs<'a>(&'a self) -> Box<dyn Iterator<Item = (String, f64)> + 'a> {
49        let iter = self.forward.keys().map(move |k| {
50            let cdf_f = EmpiricalCdf::new(self.forward.get(k).unwrap());
51            let cdf_p = EmpiricalCdf::new(self.posterior.get(k).unwrap());
52            (String::from(k), cdf_f.auc(&cdf_p))
53        });
54
55        Box::new(iter)
56    }
57
58    pub fn ks(&self) -> BTreeMap<String, f64> {
59        use rv::misc::ks_two_sample;
60        use rv::misc::KsAlternative;
61        use rv::misc::KsMode;
62
63        self.forward
64            .keys()
65            .map(|k| {
66                let (_, p) = ks_two_sample(
67                    self.forward.get(k).unwrap(),
68                    self.posterior.get(k).unwrap(),
69                    KsMode::Auto,
70                    KsAlternative::TwoSided,
71                )
72                .unwrap();
73                // TODO: return p value instead
74                (k.clone(), p)
75            })
76            .collect()
77    }
78
79    pub fn report(&self) {
80        println!("{self}")
81    }
82}
83
84impl<G> GewekeTester<G>
85where
86    G: GewekeModel + GewekeResampleData + GewekeSummarize,
87    G::Summary: Into<BTreeMap<String, f64>> + Clone,
88{
89    pub fn new(settings: G::Settings) -> Self {
90        GewekeTester {
91            settings,
92            f_chain_out: vec![],
93            p_chain_out: vec![],
94            verbose: false,
95        }
96    }
97
98    pub fn set_verbose(&mut self, verbose: bool) {
99        self.verbose = verbose;
100    }
101
102    pub fn result(&self) -> GewekeResult {
103        // TODO: would be nice if we didn't have to clone the summaries here
104        let forward = transpose_mapvec(
105            &self
106                .f_chain_out
107                .iter()
108                .map(|val| val.to_owned().into())
109                .collect::<Vec<_>>(),
110        );
111
112        let posterior = transpose_mapvec(
113            &self
114                .p_chain_out
115                .iter()
116                .map(|val| val.to_owned().into())
117                .collect::<Vec<_>>(),
118        );
119
120        GewekeResult { forward, posterior }
121    }
122
123    /// Output results as json
124    pub fn save(&self, path: &Path) {
125        let res = self.result();
126        let j = serde_yaml::to_string(&res).unwrap();
127        let mut file = File::create(path).unwrap();
128        let _nbytes = file.write(j.as_bytes()).unwrap();
129    }
130
131    pub fn run<R: Rng>(
132        &mut self,
133        n_iter: usize,
134        lag: Option<usize>,
135        mut rng: &mut R,
136    ) {
137        self.run_forward_chain(n_iter, &mut rng);
138        self.run_posterior_chain(n_iter, lag.unwrap_or(1), &mut rng);
139        if self.verbose {
140            self.result().report()
141        }
142    }
143
144    fn run_forward_chain<R: Rng>(&mut self, n_iter: usize, mut rng: &mut R) {
145        let pb = ProgressBar::new(n_iter as u64);
146        self.f_chain_out.reserve(n_iter);
147
148        for _ in 0..n_iter {
149            let mut model = G::geweke_from_prior(&self.settings, &mut rng);
150            model.geweke_resample_data(Some(&self.settings), &mut rng);
151            self.f_chain_out
152                .push(model.geweke_summarize(&self.settings));
153            pb.inc(1);
154        }
155        pb.finish_and_clear();
156    }
157
158    fn run_posterior_chain<R: Rng>(
159        &mut self,
160        n_iter: usize,
161        lag: usize,
162        mut rng: &mut R,
163    ) {
164        let pb = ProgressBar::new(n_iter as u64);
165        self.p_chain_out.reserve(n_iter);
166
167        let mut model = G::geweke_from_prior(&self.settings, &mut rng);
168        model.geweke_resample_data(Some(&self.settings), &mut rng);
169        for _ in 0..n_iter {
170            for _ in 0..lag {
171                model.geweke_step(&self.settings, &mut rng);
172                model.geweke_resample_data(Some(&self.settings), &mut rng);
173            }
174            self.p_chain_out
175                .push(model.geweke_summarize(&self.settings));
176            pb.inc(1);
177        }
178        pb.finish_and_clear();
179    }
180}