1use std::collections::BTreeMap;
2use std::fmt;
3use std::fs::File;
4use std::io::prelude::Write;
5use std::path::Path;
6
7use indicatif::ProgressBar;
8use rand::Rng;
9use serde::Serialize;
10
11use crate::geweke::traits::*;
12use crate::stats::EmpiricalCdf;
13use crate::utils::transpose_mapvec;
14
15pub struct GewekeTester<G>
18where
19 G: GewekeModel + GewekeResampleData + GewekeSummarize,
20 G::Summary: Into<BTreeMap<String, f64>> + Clone,
21{
22 settings: G::Settings,
23 pub verbose: bool,
24 pub f_chain_out: Vec<G::Summary>,
25 pub p_chain_out: Vec<G::Summary>,
26}
27
28#[derive(Serialize, Debug, Clone)]
29pub struct GewekeResult {
30 pub forward: BTreeMap<String, Vec<f64>>,
31 pub posterior: BTreeMap<String, Vec<f64>>,
32}
33
34impl fmt::Display for GewekeResult {
35 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36 writeln!(f, "Geweke Errors")?;
37 write!(f, "━━━━━━━━━━━━━")?;
38 let errs: BTreeMap<String, f64> = self.aucs().collect();
39 let width = errs.keys().fold(0_usize, |len, k| len.max(k.len()));
40 write!(f, "\n{:width$} Value", "Stat", width = width)?;
41 write!(f, "\n{:width$} ━━━━━", "━━━━", width = width)?;
42 errs.iter()
43 .try_for_each(|(k, auc)| write!(f, "\n{k:width$} {auc}"))
44 }
45}
46
47impl GewekeResult {
48 pub fn aucs<'a>(&'a self) -> Box<dyn Iterator<Item = (String, f64)> + 'a> {
49 let iter = self.forward.keys().map(move |k| {
50 let cdf_f = EmpiricalCdf::new(self.forward.get(k).unwrap());
51 let cdf_p = EmpiricalCdf::new(self.posterior.get(k).unwrap());
52 (String::from(k), cdf_f.auc(&cdf_p))
53 });
54
55 Box::new(iter)
56 }
57
58 pub fn ks(&self) -> BTreeMap<String, f64> {
59 use rv::misc::ks_two_sample;
60 use rv::misc::KsAlternative;
61 use rv::misc::KsMode;
62
63 self.forward
64 .keys()
65 .map(|k| {
66 let (_, p) = ks_two_sample(
67 self.forward.get(k).unwrap(),
68 self.posterior.get(k).unwrap(),
69 KsMode::Auto,
70 KsAlternative::TwoSided,
71 )
72 .unwrap();
73 (k.clone(), p)
75 })
76 .collect()
77 }
78
79 pub fn report(&self) {
80 println!("{self}")
81 }
82}
83
84impl<G> GewekeTester<G>
85where
86 G: GewekeModel + GewekeResampleData + GewekeSummarize,
87 G::Summary: Into<BTreeMap<String, f64>> + Clone,
88{
89 pub fn new(settings: G::Settings) -> Self {
90 GewekeTester {
91 settings,
92 f_chain_out: vec![],
93 p_chain_out: vec![],
94 verbose: false,
95 }
96 }
97
98 pub fn set_verbose(&mut self, verbose: bool) {
99 self.verbose = verbose;
100 }
101
102 pub fn result(&self) -> GewekeResult {
103 let forward = transpose_mapvec(
105 &self
106 .f_chain_out
107 .iter()
108 .map(|val| val.to_owned().into())
109 .collect::<Vec<_>>(),
110 );
111
112 let posterior = transpose_mapvec(
113 &self
114 .p_chain_out
115 .iter()
116 .map(|val| val.to_owned().into())
117 .collect::<Vec<_>>(),
118 );
119
120 GewekeResult { forward, posterior }
121 }
122
123 pub fn save(&self, path: &Path) {
125 let res = self.result();
126 let j = serde_yaml::to_string(&res).unwrap();
127 let mut file = File::create(path).unwrap();
128 let _nbytes = file.write(j.as_bytes()).unwrap();
129 }
130
131 pub fn run<R: Rng>(
132 &mut self,
133 n_iter: usize,
134 lag: Option<usize>,
135 mut rng: &mut R,
136 ) {
137 self.run_forward_chain(n_iter, &mut rng);
138 self.run_posterior_chain(n_iter, lag.unwrap_or(1), &mut rng);
139 if self.verbose {
140 self.result().report()
141 }
142 }
143
144 fn run_forward_chain<R: Rng>(&mut self, n_iter: usize, mut rng: &mut R) {
145 let pb = ProgressBar::new(n_iter as u64);
146 self.f_chain_out.reserve(n_iter);
147
148 for _ in 0..n_iter {
149 let mut model = G::geweke_from_prior(&self.settings, &mut rng);
150 model.geweke_resample_data(Some(&self.settings), &mut rng);
151 self.f_chain_out
152 .push(model.geweke_summarize(&self.settings));
153 pb.inc(1);
154 }
155 pb.finish_and_clear();
156 }
157
158 fn run_posterior_chain<R: Rng>(
159 &mut self,
160 n_iter: usize,
161 lag: usize,
162 mut rng: &mut R,
163 ) {
164 let pb = ProgressBar::new(n_iter as u64);
165 self.p_chain_out.reserve(n_iter);
166
167 let mut model = G::geweke_from_prior(&self.settings, &mut rng);
168 model.geweke_resample_data(Some(&self.settings), &mut rng);
169 for _ in 0..n_iter {
170 for _ in 0..lag {
171 model.geweke_step(&self.settings, &mut rng);
172 model.geweke_resample_data(Some(&self.settings), &mut rng);
173 }
174 self.p_chain_out
175 .push(model.geweke_summarize(&self.settings));
176 pb.inc(1);
177 }
178 pb.finish_and_clear();
179 }
180}