1use std::borrow::Borrow;
2use std::collections::BTreeMap;
3use std::convert::{TryFrom, TryInto};
4use std::f64::{INFINITY, NEG_INFINITY};
5use std::fs::File;
6use std::hash::Hash;
7use std::io::Read;
8use std::path::Path;
9
10use crate::cc::feature::{ColModel, FType, Feature};
11use crate::cc::state::State;
12use crate::codebook::Codebook;
13use crate::stats::rv::dist::{
14 Bernoulli, Categorical, Gaussian, Mixture, Poisson,
15};
16use crate::stats::rv::traits::{Entropy, Mode, QuadBounds, Rv, Variance};
17use crate::stats::MixtureType;
18use lace_consts::rv::misc::logsumexp;
19use lace_data::{Category, Datum};
20use lace_utils::{argmax, transpose};
21
22use super::error::IndexError;
23use crate::interface::Given;
24use crate::optimize::{fmin_bounded, fmin_brute};
25
26pub(crate) fn u8_to_category(
27 x: u8,
28 col_ix: usize,
29 codebook: &Codebook,
30) -> Option<Category> {
31 codebook.col_metadata[col_ix]
32 .coltype
33 .value_map()
34 .map(|vm| vm.category(x as usize))
35}
36
37pub(crate) fn pre_process_datum(
38 x: Datum,
39 col_ix: usize,
40 codebook: &Codebook,
41) -> Result<Datum, IndexError> {
42 let n_cols = codebook.col_metadata.len();
43 if col_ix >= n_cols {
44 return Err(IndexError::ColumnIndexOutOfBounds { n_cols, col_ix });
45 }
46
47 if let Datum::Categorical(cat) = x {
48 let value_map = codebook.col_metadata[col_ix]
49 .coltype
50 .value_map()
51 .ok_or_else(|| IndexError::InvalidDatumForColumn {
52 col_ix,
53 ftype_req: FType::Categorical,
54 ftype: FType::from_coltype(
55 &codebook.col_metadata[col_ix].coltype,
56 ),
57 })?;
58
59 value_map
60 .ix(&cat)
61 .map(|u| Datum::Categorical(Category::U8(u as u8)))
62 .ok_or(IndexError::CategoryIndexNotFound { col_ix, cat })
63 } else {
64 Ok(x)
65 }
66}
67
68pub(crate) fn pre_process_row(
69 row: &[Datum],
70 col_ixs: &[usize],
71 codebook: &Codebook,
72) -> Vec<Datum> {
73 row.iter()
74 .zip(col_ixs.iter())
75 .map(|(x, col_ix)| {
76 pre_process_datum(x.clone(), *col_ix, codebook).unwrap()
77 })
78 .collect()
79}
80
81pub(crate) fn post_process_datum(
82 x: Datum,
83 col_ix: usize,
84 codebook: &Codebook,
85) -> Datum {
86 if let Datum::Categorical(Category::U8(x_u8)) = x {
87 codebook
88 .value_map(col_ix)
89 .map(|map| map.category(x_u8 as usize))
90 .map(Datum::Categorical)
91 .unwrap_or(x)
92 } else {
93 x
94 }
95}
96
97pub(crate) fn post_process_row(
98 mut row: Vec<Datum>,
99 col_ixs: &[usize],
100 codebook: &Codebook,
101) -> Vec<Datum> {
102 row.drain(..)
103 .zip(col_ixs.iter())
104 .map(|(x, col_ix)| post_process_datum(x, *col_ix, codebook))
105 .collect()
106}
107
108pub(crate) fn select_states<'s>(
109 states: &'s [State],
110 states_ixs_opt: Option<&[usize]>,
111) -> Vec<&'s State> {
112 match states_ixs_opt {
113 Some(state_ixs) => state_ixs.iter().map(|&ix| &states[ix]).collect(),
114 None => states.iter().collect(),
115 }
116}
117
118pub struct Simulator<'s, R: rand::Rng> {
120 rng: &'s mut R,
121 states: &'s [&'s State],
123 weights: &'s [BTreeMap<usize, Vec<f64>>],
125 state_ixer: Categorical,
127 state_ixs: Vec<usize>,
129 col_ixs: &'s [usize],
131 component_ixers: BTreeMap<usize, Vec<Categorical>>,
132}
133
134impl<'s, R: rand::Rng> Simulator<'s, R> {
135 pub fn new(
136 states: &'s [&'s State],
137 weights: &'s [BTreeMap<usize, Vec<f64>>],
138 state_ixs: Option<Vec<usize>>,
139 col_ixs: &'s [usize],
140 rng: &'s mut R,
141 ) -> Self {
142 Simulator {
143 rng,
144 weights,
145 state_ixer: match state_ixs {
146 Some(ref ixs) => Categorical::uniform(ixs.len()),
147 None => Categorical::uniform(states.len()),
148 },
149 state_ixs: match state_ixs {
150 Some(ixs) => ixs,
151 None => (0..states.len()).collect(),
152 },
153 states,
154 col_ixs,
155 component_ixers: BTreeMap::new(),
156 }
157 }
158}
159
160impl<'s, R: rand::Rng> Iterator for Simulator<'s, R> {
161 type Item = Vec<Datum>;
162
163 fn next(&mut self) -> Option<Self::Item> {
164 let mut rng = &mut self.rng;
165
166 let draw_ix: usize = self.state_ixer.draw(&mut rng);
168 let state_ix: usize = self.state_ixs[draw_ix];
169 let state = &self.states[draw_ix];
170
171 let weights = &self.weights;
172
173 self.component_ixers.entry(state_ix).or_insert_with(|| {
176 weights[draw_ix]
179 .values()
180 .map(|view_weights| {
181 Categorical::from_ln_weights(view_weights.clone()).unwrap()
182 })
183 .collect()
184 });
185
186 let cpnt_ixs: BTreeMap<usize, usize> = self.weights[draw_ix]
187 .keys()
188 .zip(self.component_ixers[&state_ix].iter())
189 .map(|(&view_ix, cpnt_ixer)| (view_ix, cpnt_ixer.draw(&mut rng)))
190 .collect();
191
192 let xs: Vec<_> = self
193 .col_ixs
194 .iter()
195 .map(|col_ix| {
196 let view_ix = state.asgn().asgn[*col_ix];
197 let k = cpnt_ixs[&view_ix];
198 state.views[view_ix].ftrs[col_ix].draw(k, &mut rng)
199 })
200 .collect();
201
202 Some(xs)
203 }
204}
205
206pub struct Calculator<'s, Xs>
208where
209 Xs: Iterator,
210 Xs::Item: Borrow<Vec<Datum>>,
211{
212 states: &'s [&'s State],
214 codebook: Option<&'s Codebook>,
216 weights: &'s [BTreeMap<usize, Vec<f64>>],
218 col_ixs: &'s [usize],
220 values: &'s mut Xs,
221 state_logps: Vec<f64>,
224 scaled: bool,
226}
227
228impl<'s, Xs> Calculator<'s, Xs>
229where
230 Xs: Iterator,
231 Xs::Item: Borrow<Vec<Datum>>,
232{
233 pub fn new(
234 values: &'s mut Xs,
235 states: &'s [&'s State],
236 codebook: Option<&'s Codebook>,
237 weights: &'s [BTreeMap<usize, Vec<f64>>],
238 col_ixs: &'s [usize],
239 ) -> Self {
240 Self {
241 values,
242 weights,
243 states,
244 codebook,
245 col_ixs,
246 state_logps: vec![0.0; states.len()],
247 scaled: false,
248 }
249 }
250
251 pub fn new_scaled(
252 values: &'s mut Xs,
253 states: &'s [&'s State],
254 codebook: Option<&'s Codebook>,
255 weights: &'s [BTreeMap<usize, Vec<f64>>],
256 col_ixs: &'s [usize],
257 ) -> Self {
258 Self {
259 values,
260 weights,
261 states,
262 codebook,
263 col_ixs,
264 state_logps: vec![0.0; states.len()],
265 scaled: true,
266 }
267 }
268
269 fn calculate<X: Borrow<Vec<Datum>>>(&mut self, xs: X) -> Option<f64> {
270 let ln_n = (self.states.len() as f64).ln();
271 let col_ixs = self.col_ixs;
272 self.states
273 .iter()
274 .zip(self.weights.iter())
275 .enumerate()
276 .for_each(|(i, (state, weights))| {
277 let logp = single_val_logp(
278 state,
279 col_ixs,
280 xs.borrow(),
281 weights.clone(),
282 self.scaled,
283 );
284 self.state_logps[i] = logp;
285 });
286 let logp = logsumexp(&self.state_logps) - ln_n;
287 if self.scaled {
288 Some(logp / self.col_ixs.len() as f64)
290 } else {
291 Some(logp)
292 }
293 }
294}
295
296impl<'s, Xs> Iterator for Calculator<'s, Xs>
297where
298 Xs: Iterator,
299 Xs::Item: Borrow<Vec<Datum>>,
300{
301 type Item = f64;
302
303 fn next(&mut self) -> Option<f64> {
304 match self.values.next() {
305 Some(xs) => {
306 if let Some(codebook) = self.codebook {
307 let row =
308 pre_process_row(xs.borrow(), self.col_ixs, codebook);
309 self.calculate(row)
310 } else {
311 self.calculate(xs)
312 }
313 }
314 None => None,
315 }
316 }
317}
318
319pub fn load_states<P: AsRef<Path>>(filenames: Vec<P>) -> Vec<State> {
320 filenames
321 .iter()
322 .map(|path| {
323 let mut file = File::open(path).unwrap();
324 let mut yaml = String::new();
325 let res = file.read_to_string(&mut yaml);
326 match res {
327 Ok(_) => serde_yaml::from_str(&yaml).unwrap(),
328 Err(err) => panic!("Error: {:?}", err),
329 }
330 })
331 .collect()
332}
333
334pub fn gen_sobol_samples(
337 col_ixs: &[usize],
338 state: &State,
339 n: usize,
340) -> (Vec<Vec<Datum>>, f64) {
341 use lace_stats::seq::SobolSeq;
342 use lace_stats::QmcEntropy;
343
344 let features: Vec<_> =
345 col_ixs.iter().map(|&ix| state.feature(ix)).collect();
346 let us_needed: usize = features.iter().map(|ftr| ftr.us_needed()).sum();
347 let sobol = SobolSeq::new(us_needed);
348
349 let samples: Vec<Vec<Datum>> = sobol
350 .take(n)
351 .map(|mut us| {
352 let mut drain = us.drain(..);
353 features
354 .iter()
355 .map(|ftr| ftr.us_to_datum(&mut drain))
356 .collect()
357 })
358 .collect();
359
360 let q_recip: f64 = features
361 .iter()
362 .fold(1_f64, |prod, ftr| prod * ftr.q_recip());
363
364 (samples, q_recip)
365}
366
367#[inline]
370pub fn given_weights(
371 states: &[&State],
372 col_ixs: &[usize],
373 given: &Given<usize>,
374) -> Vec<BTreeMap<usize, Vec<f64>>> {
375 states
376 .iter()
377 .map(|state| single_state_weights(state, col_ixs, given))
378 .collect()
379}
380
381#[inline]
382pub fn given_exp_weights(
383 states: &[&State],
384 col_ixs: &[usize],
385 given: &Given<usize>,
386) -> Vec<BTreeMap<usize, Vec<f64>>> {
387 states
388 .iter()
389 .map(|state| single_state_exp_weights(state, col_ixs, given))
390 .collect()
391}
392
393#[inline]
394pub fn state_weights(
395 states: &[&State],
396 col_ixs: &[usize],
397 given: &Given<usize>,
398) -> Vec<BTreeMap<usize, Vec<f64>>> {
399 states
400 .iter()
401 .map(|state| single_state_weights(state, col_ixs, given))
402 .collect()
403}
404
405#[inline]
406pub fn state_exp_weights(
407 states: &[State],
408 col_ixs: &[usize],
409 given: &Given<usize>,
410) -> Vec<BTreeMap<usize, Vec<f64>>> {
411 states
412 .iter()
413 .map(|state| single_state_exp_weights(state, col_ixs, given))
414 .collect()
415}
416
417#[inline]
418pub fn single_state_weights(
419 state: &State,
420 col_ixs: &[usize],
421 given: &Given<usize>,
422) -> BTreeMap<usize, Vec<f64>> {
423 let mut view_weights: BTreeMap<usize, Vec<f64>> = BTreeMap::new();
424 col_ixs
425 .iter()
426 .map(|&ix| state.asgn().asgn[ix])
427 .for_each(|view_ix| {
428 view_weights
429 .entry(view_ix)
430 .or_insert_with(|| single_view_weights(state, view_ix, given));
431 });
432
433 view_weights
434}
435
436#[inline]
437pub fn single_state_exp_weights(
438 state: &State,
439 col_ixs: &[usize],
440 given: &Given<usize>,
441) -> BTreeMap<usize, Vec<f64>> {
442 let mut view_weights: BTreeMap<usize, Vec<f64>> = BTreeMap::new();
443 col_ixs
444 .iter()
445 .map(|&ix| state.asgn().asgn[ix])
446 .for_each(|view_ix| {
447 view_weights.entry(view_ix).or_insert_with(|| {
448 single_view_exp_weights(state, view_ix, given)
449 });
450 });
451
452 view_weights
453}
454
455#[inline]
456fn single_view_weights(
457 state: &State,
458 target_view_ix: usize,
459 given: &Given<usize>,
460) -> Vec<f64> {
461 let view = &state.views[target_view_ix];
462 let mut weights: Vec<_> = view.weights.iter().map(|w| w.ln()).collect();
463
464 match given {
465 Given::Conditions(ref conditions) => {
466 for &(col_ix, ref datum) in conditions {
467 let in_target_view =
468 state.asgn().asgn[col_ix] == target_view_ix;
469 if in_target_view {
470 view.ftrs[&col_ix].accum_weights(
471 datum,
472 &mut weights,
473 false,
474 );
475 }
476 }
477 let z = logsumexp(&weights);
478 weights.iter_mut().for_each(|w| *w -= z);
479 }
480 Given::Nothing => (),
481 }
482 weights
483}
484
485#[inline]
486fn single_view_exp_weights(
487 state: &State,
488 target_view_ix: usize,
489 given: &Given<usize>,
490) -> Vec<f64> {
491 let view = &state.views[target_view_ix];
492 let mut weights = view.weights.clone();
493
494 match given {
495 Given::Conditions(ref conditions) => {
496 conditions.iter().for_each(|(ix, datum)| {
497 let in_target_view = state.asgn().asgn[*ix] == target_view_ix;
498 if in_target_view {
499 view.ftrs[ix].accum_exp_weights(datum, &mut weights);
500 }
501 });
502 let z = weights.iter().sum::<f64>();
503 weights.iter_mut().for_each(|w| *w /= z);
504 }
505 Given::Nothing => (),
506 }
507 weights
508}
509
510pub fn state_logp(
538 state: &State,
539 col_ixs: &[usize],
540 vals: &[Vec<Datum>],
541 given: &Given<usize>,
542 view_weights_opt: Option<&BTreeMap<usize, Vec<f64>>>,
543 scaled: bool,
544) -> Vec<f64> {
545 match view_weights_opt {
546 Some(view_weights) => vals
547 .iter()
548 .map(|val| {
549 single_val_logp(
550 state,
551 col_ixs,
552 val,
553 view_weights.clone(),
554 scaled,
555 )
556 })
557 .collect(),
558 None => {
559 let mut view_weights = single_state_weights(state, col_ixs, given);
560
561 for weights in view_weights.values_mut() {
563 let logz = logsumexp(weights);
564 weights.iter_mut().for_each(|w| *w -= logz);
565 }
566 vals.iter()
567 .map(|val| {
568 single_val_logp(
569 state,
570 col_ixs,
571 val,
572 view_weights.clone(),
573 scaled,
574 )
575 })
576 .collect()
577 }
578 }
579}
580
581fn single_val_logp(
582 state: &State,
583 col_ixs: &[usize],
584 val: &[Datum],
585 mut view_weights: BTreeMap<usize, Vec<f64>>,
586 scaled: bool,
587) -> f64 {
588 col_ixs
590 .iter()
591 .zip(val)
592 .map(|(col_ix, datum)| (col_ix, state.asgn().asgn[*col_ix], datum))
593 .for_each(|(col_ix, view_ix, datum)| {
594 state.views[view_ix].ftrs[col_ix].accum_weights(
595 datum,
596 view_weights.get_mut(&view_ix).unwrap(),
597 scaled,
598 );
599 });
600
601 view_weights.values().map(|logps| logsumexp(logps)).sum()
602}
603pub fn state_likelihood(
604 state: &State,
605 col_ixs: &[usize],
606 vals: &[Vec<Datum>],
607 given: &Given<usize>,
608 view_exp_weights_opt: Option<&BTreeMap<usize, Vec<f64>>>,
609) -> Vec<f64> {
610 match view_exp_weights_opt {
611 Some(view_exp_weights) => vals
612 .iter()
613 .map(|val| {
614 single_val_likelihood(state, col_ixs, val, view_exp_weights)
615 })
616 .collect(),
617 None => {
618 let mut view_exp_weights =
619 single_state_exp_weights(state, col_ixs, given);
620
621 for weights in view_exp_weights.values_mut() {
623 let z = weights.iter().sum::<f64>();
624 weights.iter_mut().for_each(|w| *w /= z);
625 }
626
627 vals.iter()
628 .map(|val| {
629 single_val_likelihood(
630 state,
631 col_ixs,
632 val,
633 &view_exp_weights,
634 )
635 })
636 .collect()
637 }
638 }
639}
640
641fn single_val_likelihood(
642 state: &State,
643 col_ixs: &[usize],
644 val: &[Datum],
645 view_exp_weights: &BTreeMap<usize, Vec<f64>>,
646) -> f64 {
647 view_exp_weights
648 .iter()
649 .fold(1.0, |prod, (&view_ix, weights)| {
650 let view = &state.views[view_ix];
651 let view_data: Vec<(usize, Datum)> = col_ixs
653 .iter()
654 .zip(val.iter())
655 .filter(|(ix, _)| view.ftrs.contains_key(ix))
656 .map(|(ix, val)| (*ix, val.clone()))
657 .collect();
658
659 prod * weights
660 .iter()
661 .enumerate()
662 .map(|(k, &w)| {
663 view_data.iter().fold(w, |acc, (col_ix, val)| {
664 acc * view.ftrs[col_ix].cpnt_likelihood(val, k)
665 })
666 })
667 .sum::<f64>()
668 })
669}
670
671fn impute_bounds(states: &[&State], col_ix: usize) -> (f64, f64) {
674 states
675 .iter()
676 .map(|state| state.impute_bounds(col_ix).unwrap())
677 .fold((INFINITY, NEG_INFINITY), |(min, max), (lower, upper)| {
678 (min.min(lower), max.max(upper))
679 })
680}
681
682pub fn continuous_impute(
683 states: &[&State],
684 row_ix: usize,
685 col_ix: usize,
686) -> f64 {
687 let cpnts: Vec<Gaussian> = states
688 .iter()
689 .map(|state| {
690 state
691 .component(row_ix, col_ix)
692 .try_into()
693 .expect("Unexpected column type")
694 })
695 .collect();
696
697 if cpnts.len() == 1 {
698 cpnts[0].mu()
699 } else {
700 let f = |x: f64| {
701 let logfs: Vec<f64> =
702 cpnts.iter().map(|cpnt| cpnt.ln_f(&x)).collect();
703 -logsumexp(&logfs)
704 };
705
706 let bounds = impute_bounds(states, col_ix);
707 let n_grid = 100;
708 let step_size = (bounds.1 - bounds.0) / (n_grid as f64);
709 let x0 = fmin_brute(&f, bounds, n_grid);
710 fmin_bounded(f, (x0 - step_size, x0 + step_size), None, None)
711 }
712}
713
714pub fn categorical_impute(
715 states: &[&State],
716 row_ix: usize,
717 col_ix: usize,
718) -> u8 {
719 let cpnts: Vec<Categorical> = states
720 .iter()
721 .map(|state| {
722 state
723 .component(row_ix, col_ix)
724 .try_into()
725 .expect("Unexpected column type")
726 })
727 .collect();
728
729 let k = cpnts[0].k();
730 let fs: Vec<f64> = (0..k)
731 .map(|x| {
732 let logfs: Vec<f64> =
733 cpnts.iter().map(|cpnt| cpnt.ln_f(&x)).collect();
734 logsumexp(&logfs)
735 })
736 .collect();
737 argmax(&fs) as u8
738}
739
740pub fn count_impute(states: &[&State], row_ix: usize, col_ix: usize) -> u32 {
741 use lace_stats::rv::traits::Mean;
742 use lace_utils::MinMax;
743
744 let cpnts: Vec<Poisson> = states
745 .iter()
746 .map(|state| {
747 state
748 .component(row_ix, col_ix)
749 .try_into()
750 .expect("Unexpected column type")
751 })
752 .collect();
753
754 let (lower, upper) = {
755 let (lower, upper) = cpnts
756 .iter()
757 .map(|cpnt| {
758 let mean: f64 = cpnt.mean().expect("Poisson always has a mean");
759 mean
760 })
761 .minmax()
762 .unwrap();
763 ((lower.ceil() - 1.0) as u32, upper.floor() as u32)
764 };
765
766 let fx = |x: u32| cpnts.iter().map(|cpnt| cpnt.f(&x)).sum::<f64>();
770
771 (lower..=upper)
772 .skip(1)
773 .fold((lower, fx(lower)), |(argmax, max), xi| {
774 let fxi = fx(xi);
775 if fxi > max {
776 (xi, fxi)
777 } else {
778 (argmax, max)
779 }
780 })
781 .0
782}
783
784pub fn entropy_single(col_ix: usize, states: &[State]) -> f64 {
785 let mixtures = states
786 .iter()
787 .map(|state| state.feature_as_mixture(col_ix))
788 .collect();
789 let mixture = MixtureType::combine(mixtures);
790 mixture.entropy()
791}
792
793fn sort_mixture_by_mode<Fx>(mm: Mixture<Fx>) -> Mixture<Fx>
794where
795 Fx: Mode<f64>,
796{
797 let mut components: Vec<(f64, Fx)> = mm.into();
798 components.sort_by(|a, b| {
799 a.1.mode()
800 .partial_cmp(&b.1.mode())
801 .unwrap_or(std::cmp::Ordering::Less)
802 });
803 Mixture::<Fx>::try_from(components).unwrap()
804}
805
806fn continuous_mixture_quad_points<Fx>(mm: &Mixture<Fx>) -> Vec<f64>
807where
808 Fx: Mode<f64> + Variance<f64>,
809{
810 let mut state: (Option<f64>, Option<f64>) = (None, None);
811 let m = 2.0;
812 mm.components()
813 .iter()
814 .filter_map(|cpnt| {
815 let mode = cpnt.mode();
816 let std = cpnt.variance().map(|v| v.sqrt());
817 match (&state, (mode, std)) {
818 ((Some(m1), s1), (Some(m2), s2)) => {
819 if (m2 - *m1)
820 > (m * s1.unwrap_or(INFINITY))
821 .min(m * s2.unwrap_or(INFINITY))
822 {
823 state = (mode, std);
824 Some(m2)
825 } else {
826 None
827 }
828 }
829 ((None, _), (Some(m2), _)) => {
830 state = (mode, std);
831 Some(m2)
832 }
833 _ => None,
834 }
835 })
836 .collect()
837}
838
839macro_rules! dep_ind_col_mixtures {
840 ($states: ident, $col_a: ident, $col_b: ident, $fx: ident) => {{
841 let mut mms_dep = Vec::new();
844 let mut mms_ind = Vec::new();
847 let mut weight = 0.0;
850 $states.iter().for_each(|state| {
851 let mm = match state.feature_as_mixture($col_a) {
852 MixtureType::$fx(mm) => mm,
853 _ => panic!("Unexpected MixtureType"),
854 };
855
856 if state.asgn().asgn[$col_a] == state.asgn().asgn[$col_b] {
857 weight += 1.0;
858 mms_dep.push(mm);
859 } else {
860 mms_ind.push(mm);
861 }
862 });
863
864 weight /= $states.len() as f64;
865
866 (weight, Mixture::combine(mms_dep), Mixture::combine(mms_ind))
869 }};
870}
871
872pub fn categorical_gaussian_entropy_dual(
874 col_cat: usize,
875 col_gauss: usize,
876 states: &[State],
877) -> f64 {
878 use crate::cc::feature::MissingNotAtRandom;
879 use lace_stats::rv::misc::{
880 gauss_legendre_quadrature_cached, gauss_legendre_table,
881 };
882 use std::cell::RefCell;
883 use std::collections::HashMap;
884
885 let (dep_weight, gm_dep, gm_ind) =
887 dep_ind_col_mixtures!(states, col_gauss, col_cat, Gaussian);
888 let (_, cm_dep, cm_ind) =
889 dep_ind_col_mixtures!(states, col_cat, col_gauss, Categorical);
890
891 let cat_k = match states[0].feature(col_cat) {
894 ColModel::Categorical(cm) => u8::try_from(cm.prior.k())
895 .expect("Categorical k exceeded u8 max value"),
896 ColModel::MissingNotAtRandom(MissingNotAtRandom { fx, .. }) => {
897 if let ColModel::Categorical(cm) = &**fx {
898 u8::try_from(cm.prior.k())
899 .expect("Categorical k exceeded u8 max value")
900 } else {
901 panic!("Expected MissingNotAtRandom Categorical")
902 }
903 }
904 _ => panic!("Expected ColModel::Categorical"),
905 };
906
907 let (points, lower, upper) = {
909 let gmm = Mixture::combine(vec![gm_ind.clone(), gm_dep.clone()]);
910 let gmm = sort_mixture_by_mode(gmm);
911 let points = continuous_mixture_quad_points(&gmm);
912 let (lower, upper) = gmm.quad_bounds();
913 (points, lower, upper)
914 };
915
916 assert_eq!(cm_dep.k(), gm_dep.k());
918 cm_dep
919 .weights()
920 .iter()
921 .zip(gm_dep.weights().iter())
922 .for_each(|(wc, wg)| assert!((wc - wg).abs() < 1e-12));
923
924 let has_dep_states = gm_dep.k() > 0;
929 let has_ind_states = gm_ind.k() > 0;
930 let has_ind_and_dep_states = has_ind_states && has_dep_states;
931
932 let quad_level = 16;
934
935 #[derive(Hash, Clone, Copy, PartialEq, Eq)]
938 struct F64(u64);
939
940 impl F64 {
941 fn new(x: f64) -> Self {
942 Self(x.to_bits())
946 }
947 }
948
949 let ind_cache: RefCell<HashMap<F64, f64>> = RefCell::new(HashMap::new());
950 let dep_cache: RefCell<HashMap<F64, Vec<f64>>> =
951 RefCell::new(HashMap::new());
952
953 let gl_cache = gauss_legendre_table(quad_level);
957
958 -(0..cat_k)
960 .map(|k| {
961 let ind_cat_f = if has_ind_states {
964 cm_ind.ln_f(&k)
966 } else {
967 1.0 };
971
972 let dep_cat_fs: Vec<f64> = cm_dep
973 .weights()
974 .iter()
975 .zip(cm_dep.components().iter())
976 .map(|(w, cpnt)| w.ln() + cpnt.ln_f(&k))
977 .collect();
978
979 let quad_fn = |y: f64| {
980 let dep_cpnt = if has_dep_states {
984 let mut m = dep_cache.borrow_mut();
985 let ln_fys = m.entry(F64::new(y)).or_insert_with(|| {
986 gm_dep
987 .components()
988 .iter()
989 .map(|cpnt| cpnt.ln_f(&y))
990 .collect()
991 });
992 let cpnts: Vec<f64> = dep_cat_fs
995 .iter()
996 .zip(ln_fys)
997 .map(|(w, g)| w + *g)
998 .collect();
999
1000 let ln_f = logsumexp(&cpnts);
1001
1002 if !has_ind_and_dep_states {
1006 return ln_f * ln_f.exp();
1007 } else {
1008 ln_f
1009 }
1010 } else {
1011 1.0 };
1015
1016 let ind_cpnt = if has_ind_states {
1020 let mut m = ind_cache.borrow_mut();
1021 let ln_fy =
1022 m.entry(F64::new(y)).or_insert_with(|| gm_ind.ln_f(&y));
1023 let ln_f = ind_cat_f + *ln_fy;
1024
1025 if !has_ind_and_dep_states {
1029 return ln_f * ln_f.exp();
1030 } else {
1031 ln_f
1032 }
1033 } else {
1034 assert_eq!(dep_weight, 1.0);
1035 1.0 };
1037
1038 let ln_f = logsumexp(&[
1041 dep_weight.ln() + dep_cpnt,
1042 (1.0 - dep_weight).ln() + ind_cpnt,
1043 ]);
1044
1045 ln_f * ln_f.exp()
1046 };
1047
1048 let last_ix = points.len() - 1;
1049
1050 let q_a = gauss_legendre_quadrature_cached(
1052 quad_fn,
1053 (lower, points[0]),
1054 &gl_cache.0,
1055 &gl_cache.1,
1056 );
1057
1058 let q_b = gauss_legendre_quadrature_cached(
1060 quad_fn,
1061 (points[last_ix], upper),
1062 &gl_cache.0,
1063 &gl_cache.1,
1064 );
1065
1066 let q_m = if points.len() == 1 {
1068 0.0
1069 } else {
1070 let mut left = points[0];
1071 points
1072 .iter()
1073 .skip(1)
1074 .map(|&x| {
1075 let q = gauss_legendre_quadrature_cached(
1076 quad_fn,
1077 (left, x),
1078 &gl_cache.0,
1079 &gl_cache.1,
1080 );
1081
1082 left = x;
1083 q
1084 })
1085 .sum::<f64>()
1086 };
1087
1088 q_a + q_m + q_b
1089 })
1090 .sum::<f64>()
1091}
1092
1093pub fn categorical_joint_entropy(col_ixs: &[usize], states: &[State]) -> f64 {
1095 let ranges = col_ixs
1096 .iter()
1097 .map(|&ix| {
1098 let cpnt: Categorical = states[0]
1099 .component(0, ix)
1100 .try_into()
1101 .expect("Unexpected column type");
1102 cpnt.k() as u8
1103 })
1104 .collect();
1105
1106 let vals: Vec<_> = lace_utils::CategoricalCartProd::new(ranges)
1107 .map(|mut xs| {
1108 let vals: Vec<_> = xs
1109 .drain(..)
1110 .map(|x| Datum::Categorical(Category::U8(x)))
1111 .collect();
1112 vals
1113 })
1114 .collect();
1115
1116 let logps: Vec<Vec<f64>> = states
1118 .iter()
1119 .map(|state| {
1120 state_logp(state, col_ixs, &vals, &Given::Nothing, None, false)
1121 })
1122 .collect();
1123
1124 let ln_n_states = (states.len() as f64).ln();
1125
1126 transpose(&logps)
1127 .iter()
1128 .map(|lps| logsumexp(lps) - ln_n_states)
1129 .fold(0.0, |acc, lp| lp.mul_add(-lp.exp(), acc))
1130}
1131
1132pub fn categorical_entropy_dual(
1134 col_a: usize,
1135 col_b: usize,
1136 states: &[State],
1137) -> f64 {
1138 use crate::cc::feature::MissingNotAtRandom;
1139 if col_a == col_b {
1143 return entropy_single(col_a, states);
1144 }
1145
1146 let k_a = match states[0].feature(col_a) {
1147 ColModel::Categorical(cm) => cm.prior.k(),
1148 ColModel::MissingNotAtRandom(MissingNotAtRandom { fx, .. }) => {
1149 if let ColModel::Categorical(cm) = &**fx {
1150 cm.prior.k()
1151 } else {
1152 panic!("Expected MissingNotAtRandom Categorical")
1153 }
1154 }
1155 _ => panic!("Expected ColModel::Categorical"),
1156 };
1157
1158 let k_b = match states[0].feature(col_b) {
1159 ColModel::Categorical(cm) => cm.prior.k(),
1160 ColModel::MissingNotAtRandom(MissingNotAtRandom { fx, .. }) => {
1161 if let ColModel::Categorical(cm) = &**fx {
1162 cm.prior.k()
1163 } else {
1164 panic!("Expected MissingNotAtRandom Categorical")
1165 }
1166 }
1167 _ => panic!("Expected ColModel::Categorical"),
1168 };
1169
1170 let mut vals: Vec<Vec<Datum>> = Vec::with_capacity(k_a * k_b);
1171 for i in 0..k_a {
1172 for j in 0..k_b {
1173 vals.push(vec![
1174 Datum::Categorical(Category::U8(i as u8)),
1175 Datum::Categorical(Category::U8(j as u8)),
1176 ]);
1177 }
1178 }
1179
1180 let view_weights =
1181 state_exp_weights(states, &[col_a, col_b], &Given::Nothing);
1182
1183 let ps = {
1184 let mut ps = vec![0_f64; vals.len()];
1185 states
1186 .iter()
1187 .zip(view_weights.iter())
1188 .for_each(|(state, weights)| {
1189 state_likelihood(
1190 state,
1191 &[col_a, col_b],
1192 &vals,
1193 &Given::Nothing,
1194 Some(weights),
1195 )
1196 .drain(..)
1197 .enumerate()
1198 .for_each(|(ix, p)| {
1199 ps[ix] += p;
1200 });
1201 });
1202
1203 let sf = states.len() as f64;
1204 ps.iter_mut().for_each(|p| *p /= sf);
1205 ps
1206 };
1207
1208 ps.iter().map(|p| -p * p.ln()).sum::<f64>()
1209}
1210
1211fn count_pr_limit(col: usize, mass: f64, states: &[State]) -> (u32, u32) {
1213 use lace_stats::rv::traits::{Cdf, Mean};
1214
1215 let lower_threshold = (1.0 - mass) / 2.0;
1216 let upper_threshold = mass - (1.0 - mass) / 2.0;
1217
1218 let mixtures = states
1219 .iter()
1220 .map(|state| {
1221 if let MixtureType::Poisson(mm) = state.feature_as_mixture(col) {
1222 mm
1223 } else {
1224 panic!("expected count type for column {}", col);
1225 }
1226 })
1227 .collect::<Vec<_>>();
1228
1229 let mm = Mixture::combine(mixtures);
1230 let max_mean = mm
1231 .components()
1232 .iter()
1233 .map(|cpnt| {
1234 let mean: u32 = cpnt.mean().unwrap().round() as u32;
1235 mean
1236 })
1237 .max()
1238 .unwrap();
1239
1240 let lower = (0_u32..)
1241 .find_map(|x| {
1242 if mm.cdf(&x) > lower_threshold {
1243 Some(x.saturating_sub(1))
1245 } else {
1246 None
1247 }
1248 })
1249 .unwrap();
1250
1251 #[allow(clippy::unnecessary_find_map)]
1252 let upper = (max_mean..)
1253 .find_map(|x| {
1254 if mm.cdf(&x) > upper_threshold {
1255 Some(x)
1256 } else {
1257 None
1258 }
1259 })
1260 .unwrap();
1261
1262 assert!(lower < upper);
1263
1264 (lower, upper)
1265}
1266
1267pub fn count_entropy_dual(col_a: usize, col_b: usize, states: &[State]) -> f64 {
1269 if col_a == col_b {
1270 return entropy_single(col_a, states);
1271 }
1272
1273 let mass: f64 = 1_f64 - 1E-16;
1274 let (a_lower, a_upper) = count_pr_limit(col_a, mass, states);
1275 let (b_lower, b_upper) = count_pr_limit(col_b, mass, states);
1276
1277 let nx = (a_upper - a_lower) * (b_upper - b_lower);
1278 let mut vals: Vec<Vec<Datum>> = Vec::with_capacity(nx as usize);
1279
1280 for a in a_lower..a_upper {
1282 for b in b_lower..b_upper {
1283 vals.push(vec![Datum::Count(a), Datum::Count(b)]);
1284 }
1285 }
1286
1287 let logps: Vec<Vec<f64>> = states
1288 .iter()
1289 .map(|state| {
1290 state_logp(
1291 state,
1292 &[col_a, col_b],
1293 &vals,
1294 &Given::Nothing,
1295 None,
1296 false,
1297 )
1298 })
1299 .collect();
1300
1301 let ln_n_states = (states.len() as f64).ln();
1302
1303 transpose(&logps)
1304 .iter()
1305 .map(|lps| logsumexp(lps) - ln_n_states)
1306 .fold(0.0, |acc, lp| lp.mul_add(-lp.exp(), acc))
1307}
1308
1309pub fn continuous_predict(
1335 states: &[&State],
1336 col_ix: usize,
1337 given: &Given<usize>,
1338) -> f64 {
1339 let mm = {
1340 let mixtures = states
1341 .iter()
1342 .map(|state| {
1343 let view_ix = state.asgn().asgn[col_ix];
1344 let weights = &given_weights(&[state], &[col_ix], given)[0];
1349 let mut mm_weights: Vec<f64> = state.views[view_ix]
1350 .weights
1351 .iter()
1352 .zip(weights[&view_ix].iter())
1353 .map(|(&w1, &w2)| w1 + w2)
1354 .collect();
1355
1356 let z: f64 = logsumexp(&mm_weights);
1357 mm_weights.iter_mut().for_each(|w| *w = (*w - z).exp());
1358
1359 match state.views[view_ix].ftrs[&col_ix].to_mixture(mm_weights)
1360 {
1361 MixtureType::Gaussian(m) => m,
1362 _ => panic!("invalid MixtureType for continuous predict"),
1363 }
1364 })
1365 .collect();
1366
1367 let mm = Mixture::combine(mixtures);
1368
1369 sort_mixture_by_mode(mm)
1371 };
1372
1373 let f = |x: f64| -mm.f(&x);
1374
1375 let eval_points = continuous_mixture_quad_points(&mm);
1378 let n_eval_points = eval_points.len();
1379
1380 if n_eval_points == 1 {
1381 return eval_points[0];
1382 }
1383
1384 let min_ix = eval_points
1385 .iter()
1386 .enumerate()
1387 .map(|(ix, &x)| (ix, f(x)))
1388 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
1389 .unwrap()
1390 .0;
1391
1392 let (ix_left, ix_right) = if min_ix == 0 {
1394 (0, 1)
1395 } else if min_ix == n_eval_points - 1 {
1396 (n_eval_points - 2, n_eval_points - 1)
1397 } else {
1398 (min_ix - 1, min_ix + 1)
1399 };
1400
1401 let left = eval_points[ix_left];
1402 let right = eval_points[ix_right];
1403 let n_steps = 20;
1404 let step_size = (right - left) / n_steps as f64;
1405
1406 let x0 = fmin_brute(&f, (left, right), n_steps);
1408 fmin_bounded(f, (x0 - step_size, x0 + step_size), None, None)
1409}
1410
1411pub fn categorical_predict(
1412 states: &[&State],
1413 col_ix: usize,
1414 given: &Given<usize>,
1415) -> u8 {
1416 use crate::cc::feature::MissingNotAtRandom;
1417 let col_ixs: Vec<usize> = vec![col_ix];
1418
1419 let state_weights = state_weights(states, &col_ixs, given);
1420
1421 let f = |x: u8| {
1422 let y: Vec<Vec<Datum>> =
1423 vec![vec![Datum::Categorical(Category::U8(x))]];
1424 let scores: Vec<f64> = states
1425 .iter()
1426 .zip(state_weights.iter())
1427 .map(|(state, view_weights)| {
1428 state_logp(
1429 state,
1430 &col_ixs,
1431 &y,
1432 given,
1433 Some(view_weights),
1434 false,
1435 )[0]
1436 })
1437 .collect();
1438 logsumexp(&scores)
1439 };
1440
1441 let k: u8 = match states[0].feature(col_ix) {
1442 ColModel::Categorical(ftr) => ftr.prior.k() as u8,
1443 ColModel::MissingNotAtRandom(MissingNotAtRandom { fx, .. }) => {
1444 if let ColModel::Categorical(ref ftr) = **fx {
1445 ftr.prior.k() as u8
1446 } else {
1447 panic!("FType mismatch for categorical MNAR prediction")
1448 }
1449 }
1450 _ => panic!("FType mismatch for categorical prediction"),
1451 };
1452
1453 let fs: Vec<f64> = (0..k).map(f).collect();
1454 argmax(&fs) as u8
1455}
1456
1457pub fn count_predict(
1458 states: &[&State],
1459 col_ix: usize,
1460 given: &Given<usize>,
1461) -> u32 {
1462 let col_ixs: Vec<usize> = vec![col_ix];
1463
1464 let state_weights = state_weights(states, &col_ixs, given);
1465
1466 let ln_fx = |x: u32| {
1467 let y: Vec<Vec<Datum>> = vec![vec![Datum::Count(x)]];
1468 let scores: Vec<f64> = states
1469 .iter()
1470 .zip(state_weights.iter())
1471 .map(|(state, view_weights)| {
1472 state_logp(
1473 state,
1474 &col_ixs,
1475 &y,
1476 given,
1477 Some(view_weights),
1478 false,
1479 )[0]
1480 })
1481 .collect();
1482 logsumexp(&scores)
1483 };
1484
1485 let (lower, upper) = {
1486 let (lower, upper) = impute_bounds(states, col_ix);
1487 ((lower + 0.5) as u32, (upper + 0.5) as u32)
1488 };
1489
1490 (lower..=upper)
1491 .skip(1)
1492 .fold((lower, ln_fx(lower)), |(argmax, max), xi| {
1493 let ln_fxi = ln_fx(xi);
1494 if ln_fxi > max {
1495 (xi, ln_fxi)
1496 } else {
1497 (argmax, max)
1498 }
1499 })
1500 .0
1501}
1502
1503macro_rules! predunc_arm {
1506 ($states: expr, $col_ix: expr, $given_opt: expr, $cpnt_type: ty) => {{
1507 let mix_models: Vec<Mixture<$cpnt_type>> = $states
1508 .iter()
1509 .map(|state| {
1510 let view_ix = state.asgn().asgn[$col_ix];
1511 let weights = single_view_weights(&state, view_ix, $given_opt);
1512
1513 let mut mixture: Mixture<$cpnt_type> =
1514 state.feature_as_mixture($col_ix).into();
1515
1516 let z = logsumexp(&weights);
1517
1518 let new_weights =
1519 weights.iter().map(|w| (w - z).exp()).collect();
1520 mixture.set_weights_unchecked(new_weights);
1521
1522 mixture
1523 })
1524 .collect();
1525
1526 $crate::stats::uncertainty::mixture_normed_tvd(&mix_models)
1527 }};
1529}
1530
1531pub fn predict_uncertainty(
1532 states: &[State],
1533 col_ix: usize,
1534 given: &Given<usize>,
1535 states_ixs_opt: Option<&[usize]>,
1536) -> f64 {
1537 let ftype = {
1538 let view_ix = states[0].asgn().asgn[col_ix];
1539 states[0].views[view_ix].ftrs[&col_ix].ftype()
1540 };
1541 let states = select_states(states, states_ixs_opt);
1542 match ftype {
1543 FType::Continuous => predunc_arm!(states, col_ix, given, Gaussian),
1544 FType::Categorical => predunc_arm!(states, col_ix, given, Categorical),
1545 FType::Count => predunc_arm!(states, col_ix, given, Poisson),
1546 FType::Binary => unimplemented!(),
1547 }
1548}
1549
1550pub(crate) fn mnar_uncertainty(
1551 states: &[&State],
1552 col_ix: usize,
1553 given: &Given<usize>,
1554) -> f64 {
1555 use crate::cc::feature::MissingNotAtRandom;
1556
1557 let components = states
1559 .iter()
1560 .map(|state| match state.feature(col_ix) {
1561 ColModel::MissingNotAtRandom(MissingNotAtRandom {
1562 present,
1563 ..
1564 }) => {
1565 let view_ix = state.asgn().asgn[col_ix];
1567 let weights = {
1569 let mut weights =
1570 single_view_weights(state, view_ix, given);
1571 let z = logsumexp(&weights);
1573 weights
1574 .drain(..)
1575 .map(|ln_w| (ln_w - z).exp())
1576 .collect::<Vec<f64>>()
1577 };
1578 let mixture = if let MixtureType::Bernoulli(m) =
1580 present.to_mixture(weights)
1581 {
1582 m
1583 } else {
1584 panic!("invalid mixture type for MNAR")
1585 };
1586
1587 let p = mixture.f(&true);
1589 Bernoulli::new(p).unwrap()
1593 }
1594 _ => panic!("Expected MNAR ColModel in MNAR uncertainty fn"),
1595 })
1596 .collect::<Vec<Bernoulli>>();
1597
1598 let mixture = Mixture::uniform(components).unwrap();
1602
1603 let h_mix = Bernoulli::new_unchecked(mixture.f(&true)).entropy();
1607
1608 let h_cpnt = mixture
1610 .components()
1611 .iter()
1612 .map(|cpnt| cpnt.entropy())
1613 .sum::<f64>();
1614
1615 let kf = mixture.k() as f64;
1617
1618 h_mix - h_cpnt / kf
1620}
1621
1622macro_rules! impunc_arm {
1623 ($row_ix: ident, $col_ix: ident, $states: ident, $variant: ident) => {{
1624 let n_states = $states.len();
1625 let mixtures = (0..n_states)
1626 .map(|state_ix| {
1627 let view_ix = $states[state_ix].asgn().asgn[$col_ix];
1628 let view = &$states[state_ix].views[view_ix];
1629 let k = view.asgn().asgn[$row_ix];
1630 match &view.ftrs[&$col_ix] {
1631 ColModel::$variant(ref ftr) => ftr.components[k].fx.clone(),
1632 ColModel::MissingNotAtRandom(
1633 $crate::cc::feature::MissingNotAtRandom { fx, .. },
1634 ) => match &**fx {
1635 ColModel::$variant(ref ftr) => {
1636 ftr.components[k].fx.clone()
1637 }
1638 cm => {
1639 panic!(
1640 "Mismatched MNAR feature type: {}",
1641 cm.ftype()
1642 )
1643 }
1644 },
1645 cm => panic!("Mismatched feature type: {}", cm.ftype()),
1646 }
1647 })
1648 .map(|cpnt| Mixture::uniform(vec![cpnt]).unwrap())
1649 .collect::<Vec<_>>();
1650
1651 $crate::stats::uncertainty::mixture_normed_tvd(&mixtures)
1652 }};
1653}
1654
1655pub fn impute_uncertainty(
1656 states: &[State],
1657 row_ix: usize,
1658 col_ix: usize,
1659) -> f64 {
1660 let ftype = states[0].ftype(col_ix);
1661 match ftype {
1662 FType::Continuous => {
1663 impunc_arm!(row_ix, col_ix, states, Continuous)
1664 }
1665 FType::Categorical => {
1666 impunc_arm!(row_ix, col_ix, states, Categorical)
1667 }
1668 FType::Count => {
1669 impunc_arm!(row_ix, col_ix, states, Count)
1670 }
1671 f => {
1672 panic!("Unsupported ftype: {:?}", f)
1673 }
1674 }
1675}
1676
1677#[cfg(test)]
1678mod tests {
1679 use super::*;
1680 use approx::*;
1681
1682 const TOL: f64 = 1E-8;
1683
1684 fn get_single_continuous_state_from_yaml() -> State {
1685 let filenames = vec!["resources/test/single-continuous.yaml"];
1686 load_states(filenames).remove(0)
1687 }
1688
1689 fn get_single_categorical_state_from_yaml() -> State {
1690 let filenames = vec!["resources/test/single-categorical.yaml"];
1691 load_states(filenames).remove(0)
1692 }
1693
1694 fn get_single_count_state_from_yaml() -> State {
1695 let filenames = vec!["resources/test/single-count.yaml"];
1696 load_states(filenames).remove(0)
1697 }
1698
1699 fn get_states_from_yaml() -> Vec<State> {
1700 let filenames = vec![
1701 "resources/test/small/small-state-1.yaml",
1702 "resources/test/small/small-state-2.yaml",
1703 "resources/test/small/small-state-3.yaml",
1704 ];
1705 load_states(filenames)
1706 }
1707
1708 fn get_entropy_states_from_yaml() -> Vec<State> {
1709 let filenames = vec![
1710 "resources/test/entropy/entropy-state-1.yaml",
1711 "resources/test/entropy/entropy-state-2.yaml",
1712 ];
1713 load_states(filenames)
1714 }
1715
1716 pub fn old_categorical_entropy_single(
1717 col_ix: usize,
1718 states: &[State],
1719 ) -> f64 {
1720 let cpnt: Categorical =
1721 states[0].component(0, col_ix).try_into().unwrap();
1722 let k = cpnt.k();
1723
1724 let mut vals: Vec<Vec<Datum>> = Vec::with_capacity(k);
1725 for i in 0..k {
1726 vals.push(vec![Datum::Categorical((i as u8).into())]);
1727 }
1728
1729 let logps: Vec<Vec<f64>> = states
1730 .iter()
1731 .map(|state| {
1732 state_logp(
1733 state,
1734 &[col_ix],
1735 &vals,
1736 &Given::Nothing,
1737 None,
1738 false,
1739 )
1740 })
1741 .collect();
1742
1743 let ln_n_states = (states.len() as f64).ln();
1744
1745 transpose(&logps)
1746 .iter()
1747 .map(|lps| logsumexp(lps) - ln_n_states)
1748 .fold(0.0, |acc, lp| lp.mul_add(-lp.exp(), acc))
1749 }
1750
1751 #[test]
1752 fn single_continuous_column_weights_no_given() {
1753 let state = get_single_continuous_state_from_yaml();
1754
1755 let weights = single_view_weights(&state, 0, &Given::Nothing);
1756
1757 assert_relative_eq!(weights[0], -std::f64::consts::LN_2, epsilon = TOL);
1758 assert_relative_eq!(weights[1], -std::f64::consts::LN_2, epsilon = TOL);
1759 }
1760
1761 #[test]
1762 fn single_continuous_column_weights_given() {
1763 let state = get_single_continuous_state_from_yaml();
1764 let given = Given::Conditions(vec![(0, Datum::Continuous(0.5))]);
1765
1766 let weights = single_view_weights(&state, 0, &given);
1767 let target = {
1768 let mut unnormed_targets =
1769 vec![-2.857_054_917_013_031_5, -16.598_938_533_204_67];
1770 let z = logsumexp(&unnormed_targets);
1771 unnormed_targets.iter_mut().for_each(|w| *w -= z);
1772 unnormed_targets
1773 };
1774
1775 assert_relative_eq!(weights[0], target[0], epsilon = TOL);
1776 assert_relative_eq!(weights[1], target[1], epsilon = TOL);
1777 }
1778
1779 #[test]
1780 fn continuous_predict_with_spread_out_modes() {
1781 let states = {
1782 let filenames =
1783 vec!["resources/test/spread-out-continuous-modes.yaml"];
1784 load_states(filenames)
1785 };
1786 let states: Vec<&State> = states.iter().collect();
1787
1788 let x = continuous_predict(&states, 0, &Given::Nothing);
1789 assert_relative_eq!(x, -0.12, epsilon = 1E-5);
1790 }
1791
1792 #[test]
1793 fn single_view_weights_state_0_no_given() {
1794 let states = get_states_from_yaml();
1795
1796 let weights_0 = single_view_weights(&states[0], 0, &Given::Nothing);
1797
1798 assert_relative_eq!(
1799 weights_0[0],
1800 -std::f64::consts::LN_2,
1801 epsilon = TOL
1802 );
1803 assert_relative_eq!(
1804 weights_0[1],
1805 -std::f64::consts::LN_2,
1806 epsilon = TOL
1807 );
1808
1809 let weights_1 = single_view_weights(&states[0], 1, &Given::Nothing);
1810
1811 assert_relative_eq!(
1812 weights_1[0],
1813 -1.386_294_361_119_890_6,
1814 epsilon = TOL
1815 );
1816 assert_relative_eq!(
1817 weights_1[1],
1818 -0.287_682_072_451_780_9,
1819 epsilon = TOL
1820 );
1821 }
1822
1823 #[test]
1824 fn single_view_weights_vs_exp() {
1825 let states = get_states_from_yaml();
1826 let weights_0 = single_view_weights(&states[0], 0, &Given::Nothing);
1827 let weights_1 = single_view_weights(&states[0], 1, &Given::Nothing);
1828 let exp_weights_0 =
1829 single_view_exp_weights(&states[0], 0, &Given::Nothing);
1830 let exp_weights_1 =
1831 single_view_exp_weights(&states[0], 1, &Given::Nothing);
1832
1833 weights_0
1834 .iter()
1835 .zip(exp_weights_0.iter())
1836 .for_each(|(&w, &e)| assert_relative_eq!(w, e.ln(), epsilon = TOL));
1837
1838 weights_1
1839 .iter()
1840 .zip(exp_weights_1.iter())
1841 .for_each(|(&w, &e)| assert_relative_eq!(w, e.ln(), epsilon = TOL));
1842 }
1843
1844 #[test]
1845 fn single_view_weights_state_0_with_one_given() {
1846 let states = get_states_from_yaml();
1847
1848 let given = Given::Conditions(vec![
1851 (0, Datum::Continuous(0.0)),
1852 (1, Datum::Continuous(-1.0)),
1853 ]);
1854
1855 let weights_0 = single_view_weights(&states[0], 0, &given);
1856 let weights_1 = single_view_weights(&states[0], 1, &given);
1857 {
1858 let unnormed_targets =
1859 vec![-3.158_958_368_120_129, -1.926_578_447_516_985];
1860 let z = logsumexp(&unnormed_targets);
1861 let targets: Vec<_> =
1862 unnormed_targets.iter().map(|&w| w - z).collect();
1863 assert_relative_eq!(weights_0[0], targets[0], epsilon = TOL);
1864 assert_relative_eq!(weights_0[1], targets[1], epsilon = TOL);
1865 }
1866
1867 {
1868 let unnormed_targets =
1869 vec![-4.095_863_302_766_923, -0.417_781_136_933_142_9];
1870 let z = logsumexp(&unnormed_targets);
1871 let targets: Vec<_> =
1872 unnormed_targets.iter().map(|&w| w - z).collect();
1873 assert_relative_eq!(weights_1[0], targets[0], epsilon = TOL);
1874 assert_relative_eq!(weights_1[1], targets[1], epsilon = TOL);
1875 }
1876 }
1877
1878 #[test]
1879 fn single_view_weights_vs_exp_one_given() {
1880 let given = Given::Conditions(vec![
1881 (0, Datum::Continuous(0.0)),
1882 (1, Datum::Continuous(-1.0)),
1883 ]);
1884
1885 let states = get_states_from_yaml();
1886 let weights_0 = single_view_weights(&states[0], 0, &given);
1887 let weights_1 = single_view_weights(&states[0], 1, &given);
1888 let exp_weights_0 = single_view_exp_weights(&states[0], 0, &given);
1889 let exp_weights_1 = single_view_exp_weights(&states[0], 1, &given);
1890
1891 weights_0
1892 .iter()
1893 .zip(exp_weights_0.iter())
1894 .for_each(|(&w, &e)| assert_relative_eq!(w, e.ln(), epsilon = TOL));
1895
1896 weights_1
1897 .iter()
1898 .zip(exp_weights_1.iter())
1899 .for_each(|(&w, &e)| assert_relative_eq!(w, e.ln(), epsilon = TOL));
1900 }
1901
1902 #[test]
1903 fn single_view_weights_state_0_with_added_given() {
1904 let states = get_states_from_yaml();
1905
1906 let given = Given::Conditions(vec![
1907 (0, Datum::Continuous(0.0)),
1908 (2, Datum::Continuous(-1.0)),
1909 ]);
1910
1911 let weights_0 = single_view_weights(&states[0], 0, &given);
1912
1913 {
1914 let unnormed_targets =
1915 vec![-5.669_175_767_690_254, -9.304_554_786_193_446];
1916 let z = logsumexp(&unnormed_targets);
1917 let targets: Vec<_> =
1918 unnormed_targets.iter().map(|&w| w - z).collect();
1919 assert_relative_eq!(weights_0[0], targets[0], epsilon = TOL);
1920 assert_relative_eq!(weights_0[1], targets[1], epsilon = TOL);
1921 }
1922 }
1923
1924 #[test]
1925 fn single_state_weights_value_check() {
1926 let states = get_states_from_yaml();
1927
1928 let col_ixs = vec![0, 1];
1929 let given = Given::Conditions(vec![
1930 (0, Datum::Continuous(0.0)),
1931 (1, Datum::Continuous(-1.0)),
1932 (2, Datum::Continuous(-1.0)),
1933 ]);
1934
1935 let weights = single_state_weights(&states[0], &col_ixs, &given);
1936
1937 assert_eq!(weights.len(), 2);
1938 assert_eq!(weights[&0].len(), 2);
1939 assert_eq!(weights[&1].len(), 2);
1940
1941 {
1942 let unnormed_targets =
1943 vec![-5.669_175_767_690_254, -9.304_554_786_193_446];
1944 let z = logsumexp(&unnormed_targets);
1945 let targets: Vec<_> =
1946 unnormed_targets.iter().map(|&w| w - z).collect();
1947 assert_relative_eq!(weights[&0][0], targets[0], epsilon = TOL);
1948 assert_relative_eq!(weights[&0][1], targets[1], epsilon = TOL);
1949 }
1950
1951 {
1952 let unnormed_targets =
1953 vec![-4.095_863_302_766_923, -0.417_781_136_933_142_9];
1954 let z = logsumexp(&unnormed_targets);
1955 let targets: Vec<_> =
1956 unnormed_targets.iter().map(|&w| w - z).collect();
1957 assert_relative_eq!(weights[&1][0], targets[0], epsilon = TOL);
1958 assert_relative_eq!(weights[&1][1], targets[1], epsilon = TOL);
1959 }
1960 }
1961
1962 #[test]
1963 fn give_weights_size_check_single_target_column() {
1964 let states = get_states_from_yaml();
1965
1966 let col_ixs = vec![0];
1967 let state_weights = given_weights(
1968 states.iter().collect::<Vec<_>>().as_slice(),
1969 &col_ixs,
1970 &Given::Nothing,
1971 );
1972
1973 assert_eq!(state_weights.len(), 3);
1974
1975 assert_eq!(state_weights[0].len(), 1);
1976 assert_eq!(state_weights[1].len(), 1);
1977 assert_eq!(state_weights[2].len(), 1);
1978
1979 assert_eq!(state_weights[0][&0].len(), 2);
1980 assert_eq!(state_weights[1][&0].len(), 3);
1981 assert_eq!(state_weights[2][&0].len(), 2);
1982 }
1983
1984 macro_rules! state_logp_vs_exp {
1985 ($precomp: expr, $state: expr, $col_ixs: expr, $vals: expr, $given: expr) => {{
1986 let state_weights = single_state_weights($state, $col_ixs, $given);
1987 let logp = state_logp(
1988 $state,
1989 $col_ixs,
1990 $vals,
1991 $given,
1992 if $precomp { Some(&state_weights) } else { None },
1993 false,
1994 );
1995
1996 let state_exp_weights =
1997 single_state_exp_weights($state, $col_ixs, $given);
1998 let likeihood = state_likelihood(
1999 $state,
2000 $col_ixs,
2001 $vals,
2002 $given,
2003 if $precomp {
2004 Some(&state_exp_weights)
2005 } else {
2006 None
2007 },
2008 );
2009
2010 for (&ln_f, &f) in logp.iter().zip(likeihood.iter()) {
2011 assert_relative_eq!(ln_f, f.ln(), epsilon = TOL)
2012 }
2013 }};
2014 }
2015
2016 #[test]
2017 fn state_logp_values_single_col_single_view() {
2018 let states = get_states_from_yaml();
2019
2020 let col_ixs = vec![0];
2021 let vals = vec![vec![Datum::Continuous(1.2)]];
2022 let logp = state_logp(
2023 &states[0],
2024 &col_ixs,
2025 &vals,
2026 &Given::Nothing,
2027 None,
2028 false,
2029 );
2030
2031 assert_relative_eq!(logp[0], -2.939_618_577_673_343_7, epsilon = TOL);
2032 }
2033
2034 #[test]
2035 fn state_logp_vs_exp_values_single_col_single_view() {
2036 let states = get_states_from_yaml();
2037
2038 let col_ixs = vec![0];
2039 let vals = vec![vec![Datum::Continuous(1.2)]];
2040 state_logp_vs_exp!(false, &states[0], &col_ixs, &vals, &Given::Nothing);
2041 }
2042
2043 #[test]
2044 fn state_logp_values_multi_col_single_view() {
2045 let states = get_states_from_yaml();
2046
2047 let col_ixs = vec![0, 2];
2048 let vals = vec![vec![Datum::Continuous(1.2), Datum::Continuous(-0.3)]];
2049 let logp = state_logp(
2050 &states[0],
2051 &col_ixs,
2052 &vals,
2053 &Given::Nothing,
2054 None,
2055 false,
2056 );
2057
2058 assert_relative_eq!(logp[0], -4.277_889_544_469_348, epsilon = TOL);
2059 }
2060
2061 #[test]
2062 fn state_logp_vs_exp_values_multi_col_single_view() {
2063 let states = get_states_from_yaml();
2064
2065 let col_ixs = vec![0, 2];
2066 let vals = vec![vec![Datum::Continuous(1.2), Datum::Continuous(-0.3)]];
2067 state_logp_vs_exp!(false, &states[0], &col_ixs, &vals, &Given::Nothing)
2068 }
2069
2070 #[test]
2071 fn state_logp_values_multi_col_single_view_precomp() {
2072 let states = get_states_from_yaml();
2073
2074 let col_ixs = vec![0, 2];
2075 let vals = vec![vec![Datum::Continuous(1.2), Datum::Continuous(-0.3)]];
2076 let view_weights =
2077 single_state_weights(&states[0], &col_ixs, &Given::Nothing);
2078
2079 let logp = state_logp(
2080 &states[0],
2081 &col_ixs,
2082 &vals,
2083 &Given::Nothing,
2084 Some(&view_weights),
2085 false,
2086 );
2087
2088 assert_relative_eq!(logp[0], -4.277_889_544_469_348, epsilon = TOL);
2089 }
2090
2091 #[test]
2092 fn state_logp_vs_exp_values_multi_col_single_view_precomp() {
2093 let states = get_states_from_yaml();
2094
2095 let col_ixs = vec![0, 2];
2096 let vals = vec![vec![Datum::Continuous(1.2), Datum::Continuous(-0.3)]];
2097
2098 state_logp_vs_exp!(true, &states[0], &col_ixs, &vals, &Given::Nothing);
2099 }
2100
2101 #[test]
2102 fn state_logp_values_multi_col_multi_view() {
2103 let states = get_states_from_yaml();
2104
2105 let col_ixs = vec![0, 1];
2106 let vals = vec![vec![Datum::Continuous(1.2), Datum::Continuous(0.2)]];
2107 let logp = state_logp(
2108 &states[0],
2109 &col_ixs,
2110 &vals,
2111 &Given::Nothing,
2112 None,
2113 false,
2114 );
2115
2116 assert_relative_eq!(logp[0], -4.718_619_899_900_069, epsilon = TOL);
2117 }
2118
2119 #[test]
2120 fn state_logp_vs_exp_values_multi_col_multi_view() {
2121 let states = get_states_from_yaml();
2122
2123 let col_ixs = vec![0, 1];
2124 let vals = vec![vec![Datum::Continuous(1.2), Datum::Continuous(0.2)]];
2125 state_logp_vs_exp!(false, &states[0], &col_ixs, &vals, &Given::Nothing);
2126 }
2127
2128 #[test]
2129 fn state_logp_values_multi_col_multi_view_precomp() {
2130 let states = get_states_from_yaml();
2131
2132 let col_ixs = vec![0, 1];
2133 let vals = vec![vec![Datum::Continuous(1.2), Datum::Continuous(0.2)]];
2134 let view_weights =
2135 single_state_weights(&states[0], &col_ixs, &Given::Nothing);
2136 let logp = state_logp(
2137 &states[0],
2138 &col_ixs,
2139 &vals,
2140 &Given::Nothing,
2141 Some(&view_weights),
2142 false,
2143 );
2144
2145 assert_relative_eq!(logp[0], -4.718_619_899_900_069, epsilon = TOL);
2146 }
2147
2148 #[test]
2149 fn state_logp_vs_exp_values_multi_col_multi_view_precomp() {
2150 let states = get_states_from_yaml();
2151
2152 let col_ixs = vec![0, 1];
2153 let vals = vec![vec![Datum::Continuous(1.2), Datum::Continuous(0.2)]];
2154 state_logp_vs_exp!(true, &states[0], &col_ixs, &vals, &Given::Nothing);
2155 }
2156
2157 #[test]
2158 fn single_state_continuous_impute_1() {
2159 let mut all_states = get_states_from_yaml();
2160 let states = [&all_states.remove(0)];
2161 let x: f64 = continuous_impute(&states, 1, 0);
2162 assert_relative_eq!(x, 1.683_113_796_266_261_7, epsilon = 10E-6);
2163 }
2164
2165 #[test]
2166 fn single_state_continuous_impute_2() {
2167 let mut all_states = get_states_from_yaml();
2168 let states = [&all_states.remove(0)];
2169 let x: f64 = continuous_impute(&states, 3, 0);
2170 assert_relative_eq!(x, -0.824_416_188_399_796_6, epsilon = 10E-6);
2171 }
2172
2173 #[test]
2174 fn multi_state_continuous_impute_1() {
2175 let mut all_states = get_states_from_yaml();
2176 let states = [&all_states.remove(0), &all_states.remove(0)];
2177 let x: f64 = continuous_impute(&states, 1, 2);
2178 assert_relative_eq!(x, 0.554_604_492_187_499_9, epsilon = 10E-6);
2179 }
2180
2181 #[test]
2182 fn multi_state_continuous_impute_2() {
2183 let states = get_states_from_yaml();
2184 let states: Vec<&State> = states.iter().collect();
2185 let x: f64 = continuous_impute(&states, 1, 2);
2186 assert_relative_eq!(x, -0.250_584_379_015_657_5, epsilon = 10E-6);
2187 }
2188
2189 #[test]
2190 fn single_state_categorical_impute_1() {
2191 let state: State = get_single_categorical_state_from_yaml();
2192 let x: u8 = categorical_impute(&[&state], 0, 0);
2193 assert_eq!(x, 2);
2194 }
2195
2196 #[test]
2197 fn single_state_categorical_impute_2() {
2198 let state: State = get_single_categorical_state_from_yaml();
2199 let x: u8 = categorical_impute(&[&state], 2, 0);
2200 assert_eq!(x, 0);
2201 }
2202
2203 #[test]
2204 fn single_state_categorical_predict_1() {
2205 let state: State = get_single_categorical_state_from_yaml();
2206 let x: u8 = categorical_predict(&[&state], 0, &Given::Nothing);
2207 assert_eq!(x, 2);
2208 }
2209
2210 #[test]
2211 fn single_state_categorical_entropy() {
2212 let state: State = get_single_categorical_state_from_yaml();
2213 let h = entropy_single(0, &vec![state]);
2214 assert_relative_eq!(h, 1.368_541_708_152_32, epsilon = 10E-6);
2215 }
2216
2217 #[test]
2218 fn single_state_categorical_self_entropy() {
2219 let state: State = get_single_categorical_state_from_yaml();
2220 let states = vec![state];
2221 let h_x = entropy_single(0, &states);
2222 let h_xx = categorical_entropy_dual(0, 0, &states);
2223 assert_relative_eq!(h_xx, h_x, epsilon = 1E-12);
2224 }
2225
2226 #[test]
2227 fn multi_state_categorical_self_entropy() {
2228 let state: State = get_single_categorical_state_from_yaml();
2229 let states = vec![state];
2230 let h_x = entropy_single(0, &states);
2231 let h_xx = categorical_entropy_dual(0, 0, &states);
2232 assert_relative_eq!(h_xx, h_x, epsilon = 1E-12);
2233 }
2234
2235 #[test]
2236 fn multi_state_categorical_single_entropy() {
2237 let states = get_entropy_states_from_yaml();
2238 let h_x = entropy_single(2, &states);
2239 assert_relative_eq!(h_x, 1.368_715_500_467_195_1, epsilon = 1E-12);
2240 }
2241
2242 #[cfg(feature = "examples")]
2243 #[test]
2244 fn multi_state_categorical_single_entropy_vs_old() {
2245 use crate::examples::Example;
2246 use crate::HasStates;
2247 let oracle = Example::Animals.oracle().unwrap();
2248
2249 for col_ix in 0..oracle.n_cols() {
2250 let h_x_new = entropy_single(col_ix, &oracle.states);
2251 let h_x_old =
2252 old_categorical_entropy_single(col_ix, &oracle.states);
2253 assert_relative_eq!(h_x_new, h_x_old, epsilon = 1E-12);
2254 }
2255 }
2256
2257 #[test]
2258 fn single_state_count_impute_1() {
2259 let states = [&get_single_count_state_from_yaml()];
2260 let x: u32 = count_impute(&states, 1, 0);
2261 assert_eq!(x, 1);
2262 }
2263
2264 #[test]
2265 fn single_state_count_impute_2() {
2266 let states = [&get_single_count_state_from_yaml()];
2267 let x: u32 = count_impute(&states, 1, 0);
2268 assert_eq!(x, 1);
2269 }
2270
2271 #[test]
2272 fn single_state_count_predict() {
2273 let states = [&get_single_count_state_from_yaml()];
2274 let x: u32 = count_predict(&states, 0, &Given::<usize>::Nothing);
2275 assert_eq!(x, 1);
2276 }
2277
2278 #[test]
2279 fn single_state_dual_categorical_entropy_0() {
2280 let mut states = get_entropy_states_from_yaml();
2281 let state = states.drain(..).next().unwrap();
2282 let hxy = categorical_entropy_dual(2, 3, &vec![state]);
2283 assert_relative_eq!(hxy, 2.050_396_319_359_273_4, epsilon = 1E-14);
2284 }
2285
2286 #[test]
2287 fn single_state_dual_categorical_entropy_1() {
2288 let mut states = get_entropy_states_from_yaml();
2289 let state = states.pop().unwrap();
2290 let hxy = categorical_entropy_dual(2, 3, &vec![state]);
2291 assert_relative_eq!(hxy, 2.035_433_971_709_626, epsilon = 1E-14);
2292 }
2293
2294 #[test]
2295 fn single_state_dual_categorical_entropy_vs_joint_equiv() {
2296 let states = {
2297 let mut states = get_entropy_states_from_yaml();
2298 let state = states.pop().unwrap();
2299 vec![state]
2300 };
2301 let hxy_dual = categorical_entropy_dual(2, 3, &states);
2302 let hxy_joint = categorical_joint_entropy(&[2, 3], &states);
2303
2304 assert_relative_eq!(hxy_dual, hxy_joint, epsilon = 1E-14);
2305 }
2306
2307 #[test]
2308 fn multi_state_dual_categorical_entropy_1() {
2309 let states = get_entropy_states_from_yaml();
2310 let hxy = categorical_entropy_dual(2, 3, &states);
2311 assert_relative_eq!(hxy, 2.050_402_245_628_641_5, epsilon = 1E-14);
2312 }
2313
2314 #[test]
2315 fn multi_state_dual_categorical_entropy_vs_joint_equiv() {
2316 let states = get_entropy_states_from_yaml();
2317 let hxy_dual = categorical_entropy_dual(2, 3, &states);
2318 let hxy_joint = categorical_joint_entropy(&[2, 3], &states);
2319 assert_relative_eq!(hxy_dual, hxy_joint, epsilon = 1E-14);
2320 }
2321
2322 #[test]
2323 fn single_state_categorical_gaussian_entropy_0() {
2324 let mut states = get_entropy_states_from_yaml();
2325 let state = states.drain(..).next().unwrap();
2327 let hxy = categorical_gaussian_entropy_dual(2, 0, &vec![state]);
2328 assert_relative_eq!(hxy, 2.726_163_712_601_034, epsilon = 1E-7);
2329 }
2330
2331 #[test]
2332 fn single_state_categorical_gaussian_entropy_1() {
2333 let mut states = get_entropy_states_from_yaml();
2334 let state = states.pop().unwrap();
2336 let hxy = categorical_gaussian_entropy_dual(2, 0, &vec![state]);
2337 assert_relative_eq!(hxy, 2.735_457_532_371_074_6, epsilon = 1E-7);
2338 }
2339
2340 #[test]
2341 fn multi_state_categorical_gaussian_entropy_0() {
2342 let states = get_entropy_states_from_yaml();
2343 let hxy = categorical_gaussian_entropy_dual(2, 0, &states);
2344 assert_relative_eq!(hxy, 2.744_356_173_055_859, epsilon = 1E-7);
2345 }
2346
2347 #[test]
2348 fn sobol_samples() {
2349 let mut states = get_entropy_states_from_yaml();
2350 let state = states.pop().unwrap();
2351 let (samples, _) = gen_sobol_samples(&[0, 2, 3], &state, 102);
2352
2353 assert_eq!(samples.len(), 102);
2354
2355 for vals in samples {
2356 assert_eq!(vals.len(), 3);
2357 assert!(vals[0].is_continuous());
2358 assert!(vals[1].is_categorical());
2359 assert!(vals[2].is_categorical());
2360 }
2361 }
2362
2363 fn sobolo_vs_exact_entropy(col_ix: usize, n: usize) -> (f64, f64) {
2364 let mut states = get_entropy_states_from_yaml();
2365 let state = states.pop().unwrap();
2366
2367 let h_sobol = {
2368 let (samples, q_recip) = gen_sobol_samples(&[col_ix], &state, n);
2369
2370 let logps = state_logp(
2371 &state,
2372 &[col_ix],
2373 &samples,
2374 &Given::Nothing,
2375 None,
2376 false,
2377 );
2378
2379 let h: f64 = logps.iter().map(|logp| -logp * logp.exp()).sum();
2380
2381 h * q_recip / (n as f64)
2382 };
2383
2384 let h_exact = entropy_single(col_ix, &vec![state]);
2385
2386 (h_exact, h_sobol)
2387 }
2388
2389 #[test]
2390 fn sobol_single_categorical_entropy_vs_exact() {
2391 let (h_exact, h_sobol) = sobolo_vs_exact_entropy(2, 10_000);
2392 assert_relative_eq!(h_exact, h_sobol, epsilon = 1E-12);
2393 }
2394
2395 #[test]
2396 fn sobol_single_gaussian_entropy_vs_exact() {
2397 let (h_exact, h_sobol) = sobolo_vs_exact_entropy(0, 10_000);
2398 assert_relative_eq!(h_exact, h_sobol, epsilon = 1E-7);
2399 }
2400}