lace/interface/oracle/
utils.rs

1use std::borrow::Borrow;
2use std::collections::BTreeMap;
3use std::convert::{TryFrom, TryInto};
4use std::f64::{INFINITY, NEG_INFINITY};
5use std::fs::File;
6use std::hash::Hash;
7use std::io::Read;
8use std::path::Path;
9
10use crate::cc::feature::{ColModel, FType, Feature};
11use crate::cc::state::State;
12use crate::codebook::Codebook;
13use crate::stats::rv::dist::{
14    Bernoulli, Categorical, Gaussian, Mixture, Poisson,
15};
16use crate::stats::rv::traits::{Entropy, Mode, QuadBounds, Rv, Variance};
17use crate::stats::MixtureType;
18use lace_consts::rv::misc::logsumexp;
19use lace_data::{Category, Datum};
20use lace_utils::{argmax, transpose};
21
22use super::error::IndexError;
23use crate::interface::Given;
24use crate::optimize::{fmin_bounded, fmin_brute};
25
26pub(crate) fn u8_to_category(
27    x: u8,
28    col_ix: usize,
29    codebook: &Codebook,
30) -> Option<Category> {
31    codebook.col_metadata[col_ix]
32        .coltype
33        .value_map()
34        .map(|vm| vm.category(x as usize))
35}
36
37pub(crate) fn pre_process_datum(
38    x: Datum,
39    col_ix: usize,
40    codebook: &Codebook,
41) -> Result<Datum, IndexError> {
42    let n_cols = codebook.col_metadata.len();
43    if col_ix >= n_cols {
44        return Err(IndexError::ColumnIndexOutOfBounds { n_cols, col_ix });
45    }
46
47    if let Datum::Categorical(cat) = x {
48        let value_map = codebook.col_metadata[col_ix]
49            .coltype
50            .value_map()
51            .ok_or_else(|| IndexError::InvalidDatumForColumn {
52                col_ix,
53                ftype_req: FType::Categorical,
54                ftype: FType::from_coltype(
55                    &codebook.col_metadata[col_ix].coltype,
56                ),
57            })?;
58
59        value_map
60            .ix(&cat)
61            .map(|u| Datum::Categorical(Category::U8(u as u8)))
62            .ok_or(IndexError::CategoryIndexNotFound { col_ix, cat })
63    } else {
64        Ok(x)
65    }
66}
67
68pub(crate) fn pre_process_row(
69    row: &[Datum],
70    col_ixs: &[usize],
71    codebook: &Codebook,
72) -> Vec<Datum> {
73    row.iter()
74        .zip(col_ixs.iter())
75        .map(|(x, col_ix)| {
76            pre_process_datum(x.clone(), *col_ix, codebook).unwrap()
77        })
78        .collect()
79}
80
81pub(crate) fn post_process_datum(
82    x: Datum,
83    col_ix: usize,
84    codebook: &Codebook,
85) -> Datum {
86    if let Datum::Categorical(Category::U8(x_u8)) = x {
87        codebook
88            .value_map(col_ix)
89            .map(|map| map.category(x_u8 as usize))
90            .map(Datum::Categorical)
91            .unwrap_or(x)
92    } else {
93        x
94    }
95}
96
97pub(crate) fn post_process_row(
98    mut row: Vec<Datum>,
99    col_ixs: &[usize],
100    codebook: &Codebook,
101) -> Vec<Datum> {
102    row.drain(..)
103        .zip(col_ixs.iter())
104        .map(|(x, col_ix)| post_process_datum(x, *col_ix, codebook))
105        .collect()
106}
107
108pub(crate) fn select_states<'s>(
109    states: &'s [State],
110    states_ixs_opt: Option<&[usize]>,
111) -> Vec<&'s State> {
112    match states_ixs_opt {
113        Some(state_ixs) => state_ixs.iter().map(|&ix| &states[ix]).collect(),
114        None => states.iter().collect(),
115    }
116}
117
118/// Generates samples
119pub struct Simulator<'s, R: rand::Rng> {
120    rng: &'s mut R,
121    /// A list of the states
122    states: &'s [&'s State],
123    /// The view weights for each state
124    weights: &'s [BTreeMap<usize, Vec<f64>>],
125    /// Draws state indices at uniform
126    state_ixer: Categorical,
127    /// List of state indices from which to simulate
128    state_ixs: Vec<usize>,
129    /// List of state indices from which to simulate
130    col_ixs: &'s [usize],
131    component_ixers: BTreeMap<usize, Vec<Categorical>>,
132}
133
134impl<'s, R: rand::Rng> Simulator<'s, R> {
135    pub fn new(
136        states: &'s [&'s State],
137        weights: &'s [BTreeMap<usize, Vec<f64>>],
138        state_ixs: Option<Vec<usize>>,
139        col_ixs: &'s [usize],
140        rng: &'s mut R,
141    ) -> Self {
142        Simulator {
143            rng,
144            weights,
145            state_ixer: match state_ixs {
146                Some(ref ixs) => Categorical::uniform(ixs.len()),
147                None => Categorical::uniform(states.len()),
148            },
149            state_ixs: match state_ixs {
150                Some(ixs) => ixs,
151                None => (0..states.len()).collect(),
152            },
153            states,
154            col_ixs,
155            component_ixers: BTreeMap::new(),
156        }
157    }
158}
159
160impl<'s, R: rand::Rng> Iterator for Simulator<'s, R> {
161    type Item = Vec<Datum>;
162
163    fn next(&mut self) -> Option<Self::Item> {
164        let mut rng = &mut self.rng;
165
166        // choose a random state
167        let draw_ix: usize = self.state_ixer.draw(&mut rng);
168        let state_ix: usize = self.state_ixs[draw_ix];
169        let state = &self.states[draw_ix];
170
171        let weights = &self.weights;
172
173        // for each view
174        //   choose a random component from the weights
175        self.component_ixers.entry(state_ix).or_insert_with(|| {
176            // TODO: use Categorical::new_unchecked when rv 0.9.3 drops.
177            // from_ln_weights checks that the input logsumexp's to 0
178            weights[draw_ix]
179                .values()
180                .map(|view_weights| {
181                    Categorical::from_ln_weights(view_weights.clone()).unwrap()
182                })
183                .collect()
184        });
185
186        let cpnt_ixs: BTreeMap<usize, usize> = self.weights[draw_ix]
187            .keys()
188            .zip(self.component_ixers[&state_ix].iter())
189            .map(|(&view_ix, cpnt_ixer)| (view_ix, cpnt_ixer.draw(&mut rng)))
190            .collect();
191
192        let xs: Vec<_> = self
193            .col_ixs
194            .iter()
195            .map(|col_ix| {
196                let view_ix = state.asgn().asgn[*col_ix];
197                let k = cpnt_ixs[&view_ix];
198                state.views[view_ix].ftrs[col_ix].draw(k, &mut rng)
199            })
200            .collect();
201
202        Some(xs)
203    }
204}
205
206/// Computes probabilities from streams of data
207pub struct Calculator<'s, Xs>
208where
209    Xs: Iterator,
210    Xs::Item: Borrow<Vec<Datum>>,
211{
212    /// A list of the states
213    states: &'s [&'s State],
214    /// A codebook
215    codebook: Option<&'s Codebook>,
216    /// The view weights for each state
217    weights: &'s [BTreeMap<usize, Vec<f64>>],
218    /// List of state indices from which to simulate
219    col_ixs: &'s [usize],
220    values: &'s mut Xs,
221    /// Holds the values of logp under each state. Prevents reallocations of
222    /// vectors for every logp computation.
223    state_logps: Vec<f64>,
224    /// Whether to scale the logps to [0, 1]
225    scaled: bool,
226}
227
228impl<'s, Xs> Calculator<'s, Xs>
229where
230    Xs: Iterator,
231    Xs::Item: Borrow<Vec<Datum>>,
232{
233    pub fn new(
234        values: &'s mut Xs,
235        states: &'s [&'s State],
236        codebook: Option<&'s Codebook>,
237        weights: &'s [BTreeMap<usize, Vec<f64>>],
238        col_ixs: &'s [usize],
239    ) -> Self {
240        Self {
241            values,
242            weights,
243            states,
244            codebook,
245            col_ixs,
246            state_logps: vec![0.0; states.len()],
247            scaled: false,
248        }
249    }
250
251    pub fn new_scaled(
252        values: &'s mut Xs,
253        states: &'s [&'s State],
254        codebook: Option<&'s Codebook>,
255        weights: &'s [BTreeMap<usize, Vec<f64>>],
256        col_ixs: &'s [usize],
257    ) -> Self {
258        Self {
259            values,
260            weights,
261            states,
262            codebook,
263            col_ixs,
264            state_logps: vec![0.0; states.len()],
265            scaled: true,
266        }
267    }
268
269    fn calculate<X: Borrow<Vec<Datum>>>(&mut self, xs: X) -> Option<f64> {
270        let ln_n = (self.states.len() as f64).ln();
271        let col_ixs = self.col_ixs;
272        self.states
273            .iter()
274            .zip(self.weights.iter())
275            .enumerate()
276            .for_each(|(i, (state, weights))| {
277                let logp = single_val_logp(
278                    state,
279                    col_ixs,
280                    xs.borrow(),
281                    weights.clone(),
282                    self.scaled,
283                );
284                self.state_logps[i] = logp;
285            });
286        let logp = logsumexp(&self.state_logps) - ln_n;
287        if self.scaled {
288            // Geometric mean
289            Some(logp / self.col_ixs.len() as f64)
290        } else {
291            Some(logp)
292        }
293    }
294}
295
296impl<'s, Xs> Iterator for Calculator<'s, Xs>
297where
298    Xs: Iterator,
299    Xs::Item: Borrow<Vec<Datum>>,
300{
301    type Item = f64;
302
303    fn next(&mut self) -> Option<f64> {
304        match self.values.next() {
305            Some(xs) => {
306                if let Some(codebook) = self.codebook {
307                    let row =
308                        pre_process_row(xs.borrow(), self.col_ixs, codebook);
309                    self.calculate(row)
310                } else {
311                    self.calculate(xs)
312                }
313            }
314            None => None,
315        }
316    }
317}
318
319pub fn load_states<P: AsRef<Path>>(filenames: Vec<P>) -> Vec<State> {
320    filenames
321        .iter()
322        .map(|path| {
323            let mut file = File::open(path).unwrap();
324            let mut yaml = String::new();
325            let res = file.read_to_string(&mut yaml);
326            match res {
327                Ok(_) => serde_yaml::from_str(&yaml).unwrap(),
328                Err(err) => panic!("Error: {:?}", err),
329            }
330        })
331        .collect()
332}
333
334/// Generate uniformly `n` distributed data for specific columns and compute
335/// the reciprocal of the importance function.
336pub fn gen_sobol_samples(
337    col_ixs: &[usize],
338    state: &State,
339    n: usize,
340) -> (Vec<Vec<Datum>>, f64) {
341    use lace_stats::seq::SobolSeq;
342    use lace_stats::QmcEntropy;
343
344    let features: Vec<_> =
345        col_ixs.iter().map(|&ix| state.feature(ix)).collect();
346    let us_needed: usize = features.iter().map(|ftr| ftr.us_needed()).sum();
347    let sobol = SobolSeq::new(us_needed);
348
349    let samples: Vec<Vec<Datum>> = sobol
350        .take(n)
351        .map(|mut us| {
352            let mut drain = us.drain(..);
353            features
354                .iter()
355                .map(|ftr| ftr.us_to_datum(&mut drain))
356                .collect()
357        })
358        .collect();
359
360    let q_recip: f64 = features
361        .iter()
362        .fold(1_f64, |prod, ftr| prod * ftr.q_recip());
363
364    (samples, q_recip)
365}
366
367// Weight Calculation
368// ------------------
369#[inline]
370pub fn given_weights(
371    states: &[&State],
372    col_ixs: &[usize],
373    given: &Given<usize>,
374) -> Vec<BTreeMap<usize, Vec<f64>>> {
375    states
376        .iter()
377        .map(|state| single_state_weights(state, col_ixs, given))
378        .collect()
379}
380
381#[inline]
382pub fn given_exp_weights(
383    states: &[&State],
384    col_ixs: &[usize],
385    given: &Given<usize>,
386) -> Vec<BTreeMap<usize, Vec<f64>>> {
387    states
388        .iter()
389        .map(|state| single_state_exp_weights(state, col_ixs, given))
390        .collect()
391}
392
393#[inline]
394pub fn state_weights(
395    states: &[&State],
396    col_ixs: &[usize],
397    given: &Given<usize>,
398) -> Vec<BTreeMap<usize, Vec<f64>>> {
399    states
400        .iter()
401        .map(|state| single_state_weights(state, col_ixs, given))
402        .collect()
403}
404
405#[inline]
406pub fn state_exp_weights(
407    states: &[State],
408    col_ixs: &[usize],
409    given: &Given<usize>,
410) -> Vec<BTreeMap<usize, Vec<f64>>> {
411    states
412        .iter()
413        .map(|state| single_state_exp_weights(state, col_ixs, given))
414        .collect()
415}
416
417#[inline]
418pub fn single_state_weights(
419    state: &State,
420    col_ixs: &[usize],
421    given: &Given<usize>,
422) -> BTreeMap<usize, Vec<f64>> {
423    let mut view_weights: BTreeMap<usize, Vec<f64>> = BTreeMap::new();
424    col_ixs
425        .iter()
426        .map(|&ix| state.asgn().asgn[ix])
427        .for_each(|view_ix| {
428            view_weights
429                .entry(view_ix)
430                .or_insert_with(|| single_view_weights(state, view_ix, given));
431        });
432
433    view_weights
434}
435
436#[inline]
437pub fn single_state_exp_weights(
438    state: &State,
439    col_ixs: &[usize],
440    given: &Given<usize>,
441) -> BTreeMap<usize, Vec<f64>> {
442    let mut view_weights: BTreeMap<usize, Vec<f64>> = BTreeMap::new();
443    col_ixs
444        .iter()
445        .map(|&ix| state.asgn().asgn[ix])
446        .for_each(|view_ix| {
447            view_weights.entry(view_ix).or_insert_with(|| {
448                single_view_exp_weights(state, view_ix, given)
449            });
450        });
451
452    view_weights
453}
454
455#[inline]
456fn single_view_weights(
457    state: &State,
458    target_view_ix: usize,
459    given: &Given<usize>,
460) -> Vec<f64> {
461    let view = &state.views[target_view_ix];
462    let mut weights: Vec<_> = view.weights.iter().map(|w| w.ln()).collect();
463
464    match given {
465        Given::Conditions(ref conditions) => {
466            for &(col_ix, ref datum) in conditions {
467                let in_target_view =
468                    state.asgn().asgn[col_ix] == target_view_ix;
469                if in_target_view {
470                    view.ftrs[&col_ix].accum_weights(
471                        datum,
472                        &mut weights,
473                        false,
474                    );
475                }
476            }
477            let z = logsumexp(&weights);
478            weights.iter_mut().for_each(|w| *w -= z);
479        }
480        Given::Nothing => (),
481    }
482    weights
483}
484
485#[inline]
486fn single_view_exp_weights(
487    state: &State,
488    target_view_ix: usize,
489    given: &Given<usize>,
490) -> Vec<f64> {
491    let view = &state.views[target_view_ix];
492    let mut weights = view.weights.clone();
493
494    match given {
495        Given::Conditions(ref conditions) => {
496            conditions.iter().for_each(|(ix, datum)| {
497                let in_target_view = state.asgn().asgn[*ix] == target_view_ix;
498                if in_target_view {
499                    view.ftrs[ix].accum_exp_weights(datum, &mut weights);
500                }
501            });
502            let z = weights.iter().sum::<f64>();
503            weights.iter_mut().for_each(|w| *w /= z);
504        }
505        Given::Nothing => (),
506    }
507    weights
508}
509
510// Probability calculation
511// -----------------------
512
513/// Compute the probability of values under the state
514///
515/// # Notes
516///
517/// The mixture likelihood is
518///
519///  f(x) = Σ πᵢ f(x | θᵢ)
520///
521/// The scaled likelihood is
522///
523///  f(x) = Σ πᵢ f(x | θᵢ) / f(mode(θᵢ))
524///
525/// # Arguments
526///
527/// - state: The state
528/// - col_ixs: The column indices that each entry in each vector in `vals`
529///   comes from
530/// - vals: A vector of value rows. `vals[i][j]` is a datum from the column
531///   with index `col_ixs[j]`. The function returns a vector with an entry for
532///   each row in `vals`.
533/// - given: An optional set of conditions on the targets for p(vals | given).
534/// - view_weights_opt: Optional precomputed weights.
535/// - scaled: If supplied, the logp component contributed by each column
536///   will be normalized to [0, 1].
537pub fn state_logp(
538    state: &State,
539    col_ixs: &[usize],
540    vals: &[Vec<Datum>],
541    given: &Given<usize>,
542    view_weights_opt: Option<&BTreeMap<usize, Vec<f64>>>,
543    scaled: bool,
544) -> Vec<f64> {
545    match view_weights_opt {
546        Some(view_weights) => vals
547            .iter()
548            .map(|val| {
549                single_val_logp(
550                    state,
551                    col_ixs,
552                    val,
553                    view_weights.clone(),
554                    scaled,
555                )
556            })
557            .collect(),
558        None => {
559            let mut view_weights = single_state_weights(state, col_ixs, given);
560
561            // normalize view weights
562            for weights in view_weights.values_mut() {
563                let logz = logsumexp(weights);
564                weights.iter_mut().for_each(|w| *w -= logz);
565            }
566            vals.iter()
567                .map(|val| {
568                    single_val_logp(
569                        state,
570                        col_ixs,
571                        val,
572                        view_weights.clone(),
573                        scaled,
574                    )
575                })
576                .collect()
577        }
578    }
579}
580
581fn single_val_logp(
582    state: &State,
583    col_ixs: &[usize],
584    val: &[Datum],
585    mut view_weights: BTreeMap<usize, Vec<f64>>,
586    scaled: bool,
587) -> f64 {
588    // TODO: is there a way to do this without cloning the view_weights?
589    col_ixs
590        .iter()
591        .zip(val)
592        .map(|(col_ix, datum)| (col_ix, state.asgn().asgn[*col_ix], datum))
593        .for_each(|(col_ix, view_ix, datum)| {
594            state.views[view_ix].ftrs[col_ix].accum_weights(
595                datum,
596                view_weights.get_mut(&view_ix).unwrap(),
597                scaled,
598            );
599        });
600
601    view_weights.values().map(|logps| logsumexp(logps)).sum()
602}
603pub fn state_likelihood(
604    state: &State,
605    col_ixs: &[usize],
606    vals: &[Vec<Datum>],
607    given: &Given<usize>,
608    view_exp_weights_opt: Option<&BTreeMap<usize, Vec<f64>>>,
609) -> Vec<f64> {
610    match view_exp_weights_opt {
611        Some(view_exp_weights) => vals
612            .iter()
613            .map(|val| {
614                single_val_likelihood(state, col_ixs, val, view_exp_weights)
615            })
616            .collect(),
617        None => {
618            let mut view_exp_weights =
619                single_state_exp_weights(state, col_ixs, given);
620
621            // normalize view weights
622            for weights in view_exp_weights.values_mut() {
623                let z = weights.iter().sum::<f64>();
624                weights.iter_mut().for_each(|w| *w /= z);
625            }
626
627            vals.iter()
628                .map(|val| {
629                    single_val_likelihood(
630                        state,
631                        col_ixs,
632                        val,
633                        &view_exp_weights,
634                    )
635                })
636                .collect()
637        }
638    }
639}
640
641fn single_val_likelihood(
642    state: &State,
643    col_ixs: &[usize],
644    val: &[Datum],
645    view_exp_weights: &BTreeMap<usize, Vec<f64>>,
646) -> f64 {
647    view_exp_weights
648        .iter()
649        .fold(1.0, |prod, (&view_ix, weights)| {
650            let view = &state.views[view_ix];
651            // lookup for column indices and data assigned to the view
652            let view_data: Vec<(usize, Datum)> = col_ixs
653                .iter()
654                .zip(val.iter())
655                .filter(|(ix, _)| view.ftrs.contains_key(ix))
656                .map(|(ix, val)| (*ix, val.clone()))
657                .collect();
658
659            prod * weights
660                .iter()
661                .enumerate()
662                .map(|(k, &w)| {
663                    view_data.iter().fold(w, |acc, (col_ix, val)| {
664                        acc * view.ftrs[col_ix].cpnt_likelihood(val, k)
665                    })
666                })
667                .sum::<f64>()
668        })
669}
670
671// Imputation
672// ----------
673fn impute_bounds(states: &[&State], col_ix: usize) -> (f64, f64) {
674    states
675        .iter()
676        .map(|state| state.impute_bounds(col_ix).unwrap())
677        .fold((INFINITY, NEG_INFINITY), |(min, max), (lower, upper)| {
678            (min.min(lower), max.max(upper))
679        })
680}
681
682pub fn continuous_impute(
683    states: &[&State],
684    row_ix: usize,
685    col_ix: usize,
686) -> f64 {
687    let cpnts: Vec<Gaussian> = states
688        .iter()
689        .map(|state| {
690            state
691                .component(row_ix, col_ix)
692                .try_into()
693                .expect("Unexpected column type")
694        })
695        .collect();
696
697    if cpnts.len() == 1 {
698        cpnts[0].mu()
699    } else {
700        let f = |x: f64| {
701            let logfs: Vec<f64> =
702                cpnts.iter().map(|cpnt| cpnt.ln_f(&x)).collect();
703            -logsumexp(&logfs)
704        };
705
706        let bounds = impute_bounds(states, col_ix);
707        let n_grid = 100;
708        let step_size = (bounds.1 - bounds.0) / (n_grid as f64);
709        let x0 = fmin_brute(&f, bounds, n_grid);
710        fmin_bounded(f, (x0 - step_size, x0 + step_size), None, None)
711    }
712}
713
714pub fn categorical_impute(
715    states: &[&State],
716    row_ix: usize,
717    col_ix: usize,
718) -> u8 {
719    let cpnts: Vec<Categorical> = states
720        .iter()
721        .map(|state| {
722            state
723                .component(row_ix, col_ix)
724                .try_into()
725                .expect("Unexpected column type")
726        })
727        .collect();
728
729    let k = cpnts[0].k();
730    let fs: Vec<f64> = (0..k)
731        .map(|x| {
732            let logfs: Vec<f64> =
733                cpnts.iter().map(|cpnt| cpnt.ln_f(&x)).collect();
734            logsumexp(&logfs)
735        })
736        .collect();
737    argmax(&fs) as u8
738}
739
740pub fn count_impute(states: &[&State], row_ix: usize, col_ix: usize) -> u32 {
741    use lace_stats::rv::traits::Mean;
742    use lace_utils::MinMax;
743
744    let cpnts: Vec<Poisson> = states
745        .iter()
746        .map(|state| {
747            state
748                .component(row_ix, col_ix)
749                .try_into()
750                .expect("Unexpected column type")
751        })
752        .collect();
753
754    let (lower, upper) = {
755        let (lower, upper) = cpnts
756            .iter()
757            .map(|cpnt| {
758                let mean: f64 = cpnt.mean().expect("Poisson always has a mean");
759                mean
760            })
761            .minmax()
762            .unwrap();
763        ((lower.ceil() - 1.0) as u32, upper.floor() as u32)
764    };
765
766    // use fx instead of x so we can sum in place and not worry about
767    // allocating a vector. Since there is inly one number in the likelihood,
768    // we shouldn't have numerical issues.
769    let fx = |x: u32| cpnts.iter().map(|cpnt| cpnt.f(&x)).sum::<f64>();
770
771    (lower..=upper)
772        .skip(1)
773        .fold((lower, fx(lower)), |(argmax, max), xi| {
774            let fxi = fx(xi);
775            if fxi > max {
776                (xi, fxi)
777            } else {
778                (argmax, max)
779            }
780        })
781        .0
782}
783
784pub fn entropy_single(col_ix: usize, states: &[State]) -> f64 {
785    let mixtures = states
786        .iter()
787        .map(|state| state.feature_as_mixture(col_ix))
788        .collect();
789    let mixture = MixtureType::combine(mixtures);
790    mixture.entropy()
791}
792
793fn sort_mixture_by_mode<Fx>(mm: Mixture<Fx>) -> Mixture<Fx>
794where
795    Fx: Mode<f64>,
796{
797    let mut components: Vec<(f64, Fx)> = mm.into();
798    components.sort_by(|a, b| {
799        a.1.mode()
800            .partial_cmp(&b.1.mode())
801            .unwrap_or(std::cmp::Ordering::Less)
802    });
803    Mixture::<Fx>::try_from(components).unwrap()
804}
805
806fn continuous_mixture_quad_points<Fx>(mm: &Mixture<Fx>) -> Vec<f64>
807where
808    Fx: Mode<f64> + Variance<f64>,
809{
810    let mut state: (Option<f64>, Option<f64>) = (None, None);
811    let m = 2.0;
812    mm.components()
813        .iter()
814        .filter_map(|cpnt| {
815            let mode = cpnt.mode();
816            let std = cpnt.variance().map(|v| v.sqrt());
817            match (&state, (mode, std)) {
818                ((Some(m1), s1), (Some(m2), s2)) => {
819                    if (m2 - *m1)
820                        > (m * s1.unwrap_or(INFINITY))
821                            .min(m * s2.unwrap_or(INFINITY))
822                    {
823                        state = (mode, std);
824                        Some(m2)
825                    } else {
826                        None
827                    }
828                }
829                ((None, _), (Some(m2), _)) => {
830                    state = (mode, std);
831                    Some(m2)
832                }
833                _ => None,
834            }
835        })
836        .collect()
837}
838
839macro_rules! dep_ind_col_mixtures {
840    ($states: ident, $col_a: ident, $col_b: ident, $fx: ident) => {{
841        // Mixtures of col_a for which col_a and col_b are in the same view
842        // (dependent).
843        let mut mms_dep = Vec::new();
844        // Mixtures of col_a for which col_a and col_b are in different views
845        // (independent).
846        let mut mms_ind = Vec::new();
847        // The proportion of times the columns are in the same view (same as
848        // dependence probability).
849        let mut weight = 0.0;
850        $states.iter().for_each(|state| {
851            let mm = match state.feature_as_mixture($col_a) {
852                MixtureType::$fx(mm) => mm,
853                _ => panic!("Unexpected MixtureType"),
854            };
855
856            if state.asgn().asgn[$col_a] == state.asgn().asgn[$col_b] {
857                weight += 1.0;
858                mms_dep.push(mm);
859            } else {
860                mms_ind.push(mm);
861            }
862        });
863
864        weight /= $states.len() as f64;
865
866        // Combine the mixtures within each type into one big mixture for each
867        // type.
868        (weight, Mixture::combine(mms_dep), Mixture::combine(mms_ind))
869    }};
870}
871
872/// Joint entropy H(X, Y) where X is Categorical and Y is Gaussian
873pub fn categorical_gaussian_entropy_dual(
874    col_cat: usize,
875    col_gauss: usize,
876    states: &[State],
877) -> f64 {
878    use crate::cc::feature::MissingNotAtRandom;
879    use lace_stats::rv::misc::{
880        gauss_legendre_quadrature_cached, gauss_legendre_table,
881    };
882    use std::cell::RefCell;
883    use std::collections::HashMap;
884
885    // get a mixture model of the Gaussian component to compute the quad points
886    let (dep_weight, gm_dep, gm_ind) =
887        dep_ind_col_mixtures!(states, col_gauss, col_cat, Gaussian);
888    let (_, cm_dep, cm_ind) =
889        dep_ind_col_mixtures!(states, col_cat, col_gauss, Categorical);
890
891    // Get the number of values the categorical column support. Can never exceed
892    // u8::MAX (255).
893    let cat_k = match states[0].feature(col_cat) {
894        ColModel::Categorical(cm) => u8::try_from(cm.prior.k())
895            .expect("Categorical k exceeded u8 max value"),
896        ColModel::MissingNotAtRandom(MissingNotAtRandom { fx, .. }) => {
897            if let ColModel::Categorical(cm) = &**fx {
898                u8::try_from(cm.prior.k())
899                    .expect("Categorical k exceeded u8 max value")
900            } else {
901                panic!("Expected MissingNotAtRandom Categorical")
902            }
903        }
904        _ => panic!("Expected ColModel::Categorical"),
905    };
906
907    // Divide the function into nicely behaved intervals
908    let (points, lower, upper) = {
909        let gmm = Mixture::combine(vec![gm_ind.clone(), gm_dep.clone()]);
910        let gmm = sort_mixture_by_mode(gmm);
911        let points = continuous_mixture_quad_points(&gmm);
912        let (lower, upper) = gmm.quad_bounds();
913        (points, lower, upper)
914    };
915
916    // Make sure the dependent state mixtures line up
917    assert_eq!(cm_dep.k(), gm_dep.k());
918    cm_dep
919        .weights()
920        .iter()
921        .zip(gm_dep.weights().iter())
922        .for_each(|(wc, wg)| assert!((wc - wg).abs() < 1e-12));
923
924    // If the columns are either always in the same view or never in the same
925    // view across states, we may run into empty container errors, so we keep
926    // track here so we don't compute things we don't need to and potentially
927    // pass empty containers where they're not expected.
928    let has_dep_states = gm_dep.k() > 0;
929    let has_ind_states = gm_ind.k() > 0;
930    let has_ind_and_dep_states = has_ind_states && has_dep_states;
931
932    // order of the polynomial for gauss-legendre quadrature
933    let quad_level = 16;
934
935    // Super aggressive caching. You can't hash a f64, so we create a structure
936    // to transmute it to a u64 so we can use it as an index in our cache
937    #[derive(Hash, Clone, Copy, PartialEq, Eq)]
938    struct F64(u64);
939
940    impl F64 {
941        fn new(x: f64) -> Self {
942            // The quadrature points should be exactly the same each time. If
943            // that doesn't turn out the be the case, we can round x to like 14
944            // decimals or something.
945            Self(x.to_bits())
946        }
947    }
948
949    let ind_cache: RefCell<HashMap<F64, f64>> = RefCell::new(HashMap::new());
950    let dep_cache: RefCell<HashMap<F64, Vec<f64>>> =
951        RefCell::new(HashMap::new());
952
953    // Pre-generate the weights and roots for the quadrate since it never
954    // changes and requires allocating a couple of vecs each time the quadrature
955    // is run.
956    let gl_cache = gauss_legendre_table(quad_level);
957
958    // NOTE: this will take a really long time when k is large
959    -(0..cat_k)
960        .map(|k| {
961            // NOTE: I've chosen to use the logp instead of vanilla 'p'. It
962            // doesn't really change the runtime.
963            let ind_cat_f = if has_ind_states {
964                // TODO: can cache this
965                cm_ind.ln_f(&k)
966            } else {
967                // Note, it shouldn't matter what we return here because the
968                // weight for the independent mixture will be 0
969                1.0 // ln(0)
970            };
971
972            let dep_cat_fs: Vec<f64> = cm_dep
973                .weights()
974                .iter()
975                .zip(cm_dep.components().iter())
976                .map(|(w, cpnt)| w.ln() + cpnt.ln_f(&k))
977                .collect();
978
979            let quad_fn = |y: f64| {
980                // We have to compute things differently for states in which the
981                // two columns are dependent and independent. The dependent
982                // computation is a bit more complicated.
983                let dep_cpnt = if has_dep_states {
984                    let mut m = dep_cache.borrow_mut();
985                    let ln_fys = m.entry(F64::new(y)).or_insert_with(|| {
986                        gm_dep
987                            .components()
988                            .iter()
989                            .map(|cpnt| cpnt.ln_f(&y))
990                            .collect()
991                    });
992                    // This does manually what state_logp does, but this is
993                    // faster because it's less general
994                    let cpnts: Vec<f64> = dep_cat_fs
995                        .iter()
996                        .zip(ln_fys)
997                        .map(|(w, g)| w + *g)
998                        .collect();
999
1000                    let ln_f = logsumexp(&cpnts);
1001
1002                    // If we do not have independent components, we do not have
1003                    // to compute the weighted sum of the two scenarios, so we
1004                    // can return early
1005                    if !has_ind_and_dep_states {
1006                        return ln_f * ln_f.exp();
1007                    } else {
1008                        ln_f
1009                    }
1010                } else {
1011                    // Note, it shouldn't matter what we return here because the
1012                    // weight for the dependent mixture will be 0
1013                    1.0 // ln(0)
1014                };
1015
1016                // We can basically cache the entire independent computation, so
1017                // things will be faster the fewer states that have the columns
1018                // in the same view
1019                let ind_cpnt = if has_ind_states {
1020                    let mut m = ind_cache.borrow_mut();
1021                    let ln_fy =
1022                        m.entry(F64::new(y)).or_insert_with(|| gm_ind.ln_f(&y));
1023                    let ln_f = ind_cat_f + *ln_fy;
1024
1025                    // If we do not have dependent components, we do not have to
1026                    // compute the weighted sum of the two scenarios, so we can
1027                    // return early
1028                    if !has_ind_and_dep_states {
1029                        return ln_f * ln_f.exp();
1030                    } else {
1031                        ln_f
1032                    }
1033                } else {
1034                    assert_eq!(dep_weight, 1.0);
1035                    1.0 // ln(0) = 1
1036                };
1037
1038                // add the weighted sums of the independent-columns mixture and
1039                // the dependent-columns mixture
1040                let ln_f = logsumexp(&[
1041                    dep_weight.ln() + dep_cpnt,
1042                    (1.0 - dep_weight).ln() + ind_cpnt,
1043                ]);
1044
1045                ln_f * ln_f.exp()
1046            };
1047
1048            let last_ix = points.len() - 1;
1049
1050            // right tail of integral
1051            let q_a = gauss_legendre_quadrature_cached(
1052                quad_fn,
1053                (lower, points[0]),
1054                &gl_cache.0,
1055                &gl_cache.1,
1056            );
1057
1058            // left tail of integral
1059            let q_b = gauss_legendre_quadrature_cached(
1060                quad_fn,
1061                (points[last_ix], upper),
1062                &gl_cache.0,
1063                &gl_cache.1,
1064            );
1065
1066            // interior integral points
1067            let q_m = if points.len() == 1 {
1068                0.0
1069            } else {
1070                let mut left = points[0];
1071                points
1072                    .iter()
1073                    .skip(1)
1074                    .map(|&x| {
1075                        let q = gauss_legendre_quadrature_cached(
1076                            quad_fn,
1077                            (left, x),
1078                            &gl_cache.0,
1079                            &gl_cache.1,
1080                        );
1081
1082                        left = x;
1083                        q
1084                    })
1085                    .sum::<f64>()
1086            };
1087
1088            q_a + q_m + q_b
1089        })
1090        .sum::<f64>()
1091}
1092
1093/// Computes entropy among categorical columns exactly via enumeration
1094pub fn categorical_joint_entropy(col_ixs: &[usize], states: &[State]) -> f64 {
1095    let ranges = col_ixs
1096        .iter()
1097        .map(|&ix| {
1098            let cpnt: Categorical = states[0]
1099                .component(0, ix)
1100                .try_into()
1101                .expect("Unexpected column type");
1102            cpnt.k() as u8
1103        })
1104        .collect();
1105
1106    let vals: Vec<_> = lace_utils::CategoricalCartProd::new(ranges)
1107        .map(|mut xs| {
1108            let vals: Vec<_> = xs
1109                .drain(..)
1110                .map(|x| Datum::Categorical(Category::U8(x)))
1111                .collect();
1112            vals
1113        })
1114        .collect();
1115
1116    // TODO: this is a pattern that appears a lot. I should DRY it.
1117    let logps: Vec<Vec<f64>> = states
1118        .iter()
1119        .map(|state| {
1120            state_logp(state, col_ixs, &vals, &Given::Nothing, None, false)
1121        })
1122        .collect();
1123
1124    let ln_n_states = (states.len() as f64).ln();
1125
1126    transpose(&logps)
1127        .iter()
1128        .map(|lps| logsumexp(lps) - ln_n_states)
1129        .fold(0.0, |acc, lp| lp.mul_add(-lp.exp(), acc))
1130}
1131
1132/// Joint entropy H(X, Y) where both X and Y are Categorical
1133pub fn categorical_entropy_dual(
1134    col_a: usize,
1135    col_b: usize,
1136    states: &[State],
1137) -> f64 {
1138    use crate::cc::feature::MissingNotAtRandom;
1139    // TODO: We could probably do a lot of pre-computation and caching like we
1140    // do in categorical_gaussian_entropy_dual, but this function is really fast
1141    // as it is, so it's probably not a good candidate for optimization
1142    if col_a == col_b {
1143        return entropy_single(col_a, states);
1144    }
1145
1146    let k_a = match states[0].feature(col_a) {
1147        ColModel::Categorical(cm) => cm.prior.k(),
1148        ColModel::MissingNotAtRandom(MissingNotAtRandom { fx, .. }) => {
1149            if let ColModel::Categorical(cm) = &**fx {
1150                cm.prior.k()
1151            } else {
1152                panic!("Expected MissingNotAtRandom Categorical")
1153            }
1154        }
1155        _ => panic!("Expected ColModel::Categorical"),
1156    };
1157
1158    let k_b = match states[0].feature(col_b) {
1159        ColModel::Categorical(cm) => cm.prior.k(),
1160        ColModel::MissingNotAtRandom(MissingNotAtRandom { fx, .. }) => {
1161            if let ColModel::Categorical(cm) = &**fx {
1162                cm.prior.k()
1163            } else {
1164                panic!("Expected MissingNotAtRandom Categorical")
1165            }
1166        }
1167        _ => panic!("Expected ColModel::Categorical"),
1168    };
1169
1170    let mut vals: Vec<Vec<Datum>> = Vec::with_capacity(k_a * k_b);
1171    for i in 0..k_a {
1172        for j in 0..k_b {
1173            vals.push(vec![
1174                Datum::Categorical(Category::U8(i as u8)),
1175                Datum::Categorical(Category::U8(j as u8)),
1176            ]);
1177        }
1178    }
1179
1180    let view_weights =
1181        state_exp_weights(states, &[col_a, col_b], &Given::Nothing);
1182
1183    let ps = {
1184        let mut ps = vec![0_f64; vals.len()];
1185        states
1186            .iter()
1187            .zip(view_weights.iter())
1188            .for_each(|(state, weights)| {
1189                state_likelihood(
1190                    state,
1191                    &[col_a, col_b],
1192                    &vals,
1193                    &Given::Nothing,
1194                    Some(weights),
1195                )
1196                .drain(..)
1197                .enumerate()
1198                .for_each(|(ix, p)| {
1199                    ps[ix] += p;
1200                });
1201            });
1202
1203        let sf = states.len() as f64;
1204        ps.iter_mut().for_each(|p| *p /= sf);
1205        ps
1206    };
1207
1208    ps.iter().map(|p| -p * p.ln()).sum::<f64>()
1209}
1210
1211// Finds the first x such that
1212fn count_pr_limit(col: usize, mass: f64, states: &[State]) -> (u32, u32) {
1213    use lace_stats::rv::traits::{Cdf, Mean};
1214
1215    let lower_threshold = (1.0 - mass) / 2.0;
1216    let upper_threshold = mass - (1.0 - mass) / 2.0;
1217
1218    let mixtures = states
1219        .iter()
1220        .map(|state| {
1221            if let MixtureType::Poisson(mm) = state.feature_as_mixture(col) {
1222                mm
1223            } else {
1224                panic!("expected count type for column {}", col);
1225            }
1226        })
1227        .collect::<Vec<_>>();
1228
1229    let mm = Mixture::combine(mixtures);
1230    let max_mean = mm
1231        .components()
1232        .iter()
1233        .map(|cpnt| {
1234            let mean: u32 = cpnt.mean().unwrap().round() as u32;
1235            mean
1236        })
1237        .max()
1238        .unwrap();
1239
1240    let lower = (0_u32..)
1241        .find_map(|x| {
1242            if mm.cdf(&x) > lower_threshold {
1243                // make sure the lower bound is >= 0
1244                Some(x.saturating_sub(1))
1245            } else {
1246                None
1247            }
1248        })
1249        .unwrap();
1250
1251    #[allow(clippy::unnecessary_find_map)]
1252    let upper = (max_mean..)
1253        .find_map(|x| {
1254            if mm.cdf(&x) > upper_threshold {
1255                Some(x)
1256            } else {
1257                None
1258            }
1259        })
1260        .unwrap();
1261
1262    assert!(lower < upper);
1263
1264    (lower, upper)
1265}
1266
1267/// Joint entropy H(X, Y) where both X and Y are Categorical
1268pub fn count_entropy_dual(col_a: usize, col_b: usize, states: &[State]) -> f64 {
1269    if col_a == col_b {
1270        return entropy_single(col_a, states);
1271    }
1272
1273    let mass: f64 = 1_f64 - 1E-16;
1274    let (a_lower, a_upper) = count_pr_limit(col_a, mass, states);
1275    let (b_lower, b_upper) = count_pr_limit(col_b, mass, states);
1276
1277    let nx = (a_upper - a_lower) * (b_upper - b_lower);
1278    let mut vals: Vec<Vec<Datum>> = Vec::with_capacity(nx as usize);
1279
1280    // TODO: make this into an iterator
1281    for a in a_lower..a_upper {
1282        for b in b_lower..b_upper {
1283            vals.push(vec![Datum::Count(a), Datum::Count(b)]);
1284        }
1285    }
1286
1287    let logps: Vec<Vec<f64>> = states
1288        .iter()
1289        .map(|state| {
1290            state_logp(
1291                state,
1292                &[col_a, col_b],
1293                &vals,
1294                &Given::Nothing,
1295                None,
1296                false,
1297            )
1298        })
1299        .collect();
1300
1301    let ln_n_states = (states.len() as f64).ln();
1302
1303    transpose(&logps)
1304        .iter()
1305        .map(|lps| logsumexp(lps) - ln_n_states)
1306        .fold(0.0, |acc, lp| lp.mul_add(-lp.exp(), acc))
1307}
1308
1309// Prediction
1310// ----------
1311// pub(crate) fn predict(
1312//     col_ix: usize,
1313//     ftype: FType,
1314//     given: &Given<usize>,
1315//     states: &[&State],
1316// ) -> Datum {
1317//     match ftype {
1318//         FType::Continuous => {
1319//             let x = continuous_predict(states, col_ix, given);
1320//             Datum::Continuous(x)
1321//         }
1322//         FType::Categorical => {
1323//             let x = categorical_predict(states, col_ix, given);
1324//             Datum::Categorical(x)
1325//         }
1326//         FType::Count => {
1327//             let x = count_predict(states, col_ix, given);
1328//             Datum::Count(x)
1329//         }
1330//         _ => unimplemented!(),
1331//     }
1332// }
1333
1334pub fn continuous_predict(
1335    states: &[&State],
1336    col_ix: usize,
1337    given: &Given<usize>,
1338) -> f64 {
1339    let mm = {
1340        let mixtures = states
1341            .iter()
1342            .map(|state| {
1343                let view_ix = state.asgn().asgn[col_ix];
1344                // NOTE: There is a slight speedup from using given_exp_weights,
1345                // but at the cost of panics when there is a large number of
1346                // conditions in the given: underflow causes all the weights to
1347                // be zero, which causes a constructor error in Mixture::new
1348                let weights = &given_weights(&[state], &[col_ix], given)[0];
1349                let mut mm_weights: Vec<f64> = state.views[view_ix]
1350                    .weights
1351                    .iter()
1352                    .zip(weights[&view_ix].iter())
1353                    .map(|(&w1, &w2)| w1 + w2)
1354                    .collect();
1355
1356                let z: f64 = logsumexp(&mm_weights);
1357                mm_weights.iter_mut().for_each(|w| *w = (*w - z).exp());
1358
1359                match state.views[view_ix].ftrs[&col_ix].to_mixture(mm_weights)
1360                {
1361                    MixtureType::Gaussian(m) => m,
1362                    _ => panic!("invalid MixtureType for continuous predict"),
1363                }
1364            })
1365            .collect();
1366
1367        let mm = Mixture::combine(mixtures);
1368
1369        // sorts the mixture components in ascending order by their means/modes
1370        sort_mixture_by_mode(mm)
1371    };
1372
1373    let f = |x: f64| -mm.f(&x);
1374
1375    // We find the mode in the mixture model with the highest likelihood then
1376    // build everything around that mode
1377    let eval_points = continuous_mixture_quad_points(&mm);
1378    let n_eval_points = eval_points.len();
1379
1380    if n_eval_points == 1 {
1381        return eval_points[0];
1382    }
1383
1384    let min_ix = eval_points
1385        .iter()
1386        .enumerate()
1387        .map(|(ix, &x)| (ix, f(x)))
1388        .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
1389        .unwrap()
1390        .0;
1391
1392    // Check whether the first or last modes are the highest likelihood
1393    let (ix_left, ix_right) = if min_ix == 0 {
1394        (0, 1)
1395    } else if min_ix == n_eval_points - 1 {
1396        (n_eval_points - 2, n_eval_points - 1)
1397    } else {
1398        (min_ix - 1, min_ix + 1)
1399    };
1400
1401    let left = eval_points[ix_left];
1402    let right = eval_points[ix_right];
1403    let n_steps = 20;
1404    let step_size = (right - left) / n_steps as f64;
1405
1406    // Use a grid search to narrow down the range
1407    let x0 = fmin_brute(&f, (left, right), n_steps);
1408    fmin_bounded(f, (x0 - step_size, x0 + step_size), None, None)
1409}
1410
1411pub fn categorical_predict(
1412    states: &[&State],
1413    col_ix: usize,
1414    given: &Given<usize>,
1415) -> u8 {
1416    use crate::cc::feature::MissingNotAtRandom;
1417    let col_ixs: Vec<usize> = vec![col_ix];
1418
1419    let state_weights = state_weights(states, &col_ixs, given);
1420
1421    let f = |x: u8| {
1422        let y: Vec<Vec<Datum>> =
1423            vec![vec![Datum::Categorical(Category::U8(x))]];
1424        let scores: Vec<f64> = states
1425            .iter()
1426            .zip(state_weights.iter())
1427            .map(|(state, view_weights)| {
1428                state_logp(
1429                    state,
1430                    &col_ixs,
1431                    &y,
1432                    given,
1433                    Some(view_weights),
1434                    false,
1435                )[0]
1436            })
1437            .collect();
1438        logsumexp(&scores)
1439    };
1440
1441    let k: u8 = match states[0].feature(col_ix) {
1442        ColModel::Categorical(ftr) => ftr.prior.k() as u8,
1443        ColModel::MissingNotAtRandom(MissingNotAtRandom { fx, .. }) => {
1444            if let ColModel::Categorical(ref ftr) = **fx {
1445                ftr.prior.k() as u8
1446            } else {
1447                panic!("FType mismatch for categorical MNAR prediction")
1448            }
1449        }
1450        _ => panic!("FType mismatch for categorical prediction"),
1451    };
1452
1453    let fs: Vec<f64> = (0..k).map(f).collect();
1454    argmax(&fs) as u8
1455}
1456
1457pub fn count_predict(
1458    states: &[&State],
1459    col_ix: usize,
1460    given: &Given<usize>,
1461) -> u32 {
1462    let col_ixs: Vec<usize> = vec![col_ix];
1463
1464    let state_weights = state_weights(states, &col_ixs, given);
1465
1466    let ln_fx = |x: u32| {
1467        let y: Vec<Vec<Datum>> = vec![vec![Datum::Count(x)]];
1468        let scores: Vec<f64> = states
1469            .iter()
1470            .zip(state_weights.iter())
1471            .map(|(state, view_weights)| {
1472                state_logp(
1473                    state,
1474                    &col_ixs,
1475                    &y,
1476                    given,
1477                    Some(view_weights),
1478                    false,
1479                )[0]
1480            })
1481            .collect();
1482        logsumexp(&scores)
1483    };
1484
1485    let (lower, upper) = {
1486        let (lower, upper) = impute_bounds(states, col_ix);
1487        ((lower + 0.5) as u32, (upper + 0.5) as u32)
1488    };
1489
1490    (lower..=upper)
1491        .skip(1)
1492        .fold((lower, ln_fx(lower)), |(argmax, max), xi| {
1493            let ln_fxi = ln_fx(xi);
1494            if ln_fxi > max {
1495                (xi, ln_fxi)
1496            } else {
1497                (argmax, max)
1498            }
1499        })
1500        .0
1501}
1502
1503// Predictive uncertainty helpers
1504// ------------------------------
1505macro_rules! predunc_arm {
1506    ($states: expr, $col_ix: expr, $given_opt: expr, $cpnt_type: ty) => {{
1507        let mix_models: Vec<Mixture<$cpnt_type>> = $states
1508            .iter()
1509            .map(|state| {
1510                let view_ix = state.asgn().asgn[$col_ix];
1511                let weights = single_view_weights(&state, view_ix, $given_opt);
1512
1513                let mut mixture: Mixture<$cpnt_type> =
1514                    state.feature_as_mixture($col_ix).into();
1515
1516                let z = logsumexp(&weights);
1517
1518                let new_weights =
1519                    weights.iter().map(|w| (w - z).exp()).collect();
1520                mixture.set_weights_unchecked(new_weights);
1521
1522                mixture
1523            })
1524            .collect();
1525
1526        $crate::stats::uncertainty::mixture_normed_tvd(&mix_models)
1527        // jsd_mixture(mix_models)
1528    }};
1529}
1530
1531pub fn predict_uncertainty(
1532    states: &[State],
1533    col_ix: usize,
1534    given: &Given<usize>,
1535    states_ixs_opt: Option<&[usize]>,
1536) -> f64 {
1537    let ftype = {
1538        let view_ix = states[0].asgn().asgn[col_ix];
1539        states[0].views[view_ix].ftrs[&col_ix].ftype()
1540    };
1541    let states = select_states(states, states_ixs_opt);
1542    match ftype {
1543        FType::Continuous => predunc_arm!(states, col_ix, given, Gaussian),
1544        FType::Categorical => predunc_arm!(states, col_ix, given, Categorical),
1545        FType::Count => predunc_arm!(states, col_ix, given, Poisson),
1546        FType::Binary => unimplemented!(),
1547    }
1548}
1549
1550pub(crate) fn mnar_uncertainty(
1551    states: &[&State],
1552    col_ix: usize,
1553    given: &Given<usize>,
1554) -> f64 {
1555    use crate::cc::feature::MissingNotAtRandom;
1556
1557    // Extract the state-level missingness distributions
1558    let components = states
1559        .iter()
1560        .map(|state| match state.feature(col_ix) {
1561            ColModel::MissingNotAtRandom(MissingNotAtRandom {
1562                present,
1563                ..
1564            }) => {
1565                // get the index of the view to which this column is assigned
1566                let view_ix = state.asgn().asgn[col_ix];
1567                // Get the weights from the view using the given
1568                let weights = {
1569                    let mut weights =
1570                        single_view_weights(state, view_ix, given);
1571                    // exp and normalize the weights so they sum to 1
1572                    let z = logsumexp(&weights);
1573                    weights
1574                        .drain(..)
1575                        .map(|ln_w| (ln_w - z).exp())
1576                        .collect::<Vec<f64>>()
1577                };
1578                // get a mixture model from the column using the weights above
1579                let mixture = if let MixtureType::Bernoulli(m) =
1580                    present.to_mixture(weights)
1581                {
1582                    m
1583                } else {
1584                    panic!("invalid mixture type for MNAR")
1585                };
1586
1587                // p(true)
1588                let p = mixture.f(&true);
1589                // We can collapse the whole mixture into a single Bernoulli
1590                // distribution. This works for categorical as well. Save
1591                // compute time.
1592                Bernoulli::new(p).unwrap()
1593            }
1594            _ => panic!("Expected MNAR ColModel in MNAR uncertainty fn"),
1595        })
1596        .collect::<Vec<Bernoulli>>();
1597
1598    // Normally a mixture of mixtures, but since we've collapse the state-level
1599    // mixtures, this is a mixture of Bernoulli. Each component represents the
1600    // state-level distribution for missingness.
1601    let mixture = Mixture::uniform(components).unwrap();
1602
1603    // The entropy of the who distribution. Again, we can collapse the mixture
1604    // of Bernoulli's into a single Bernoulli. Her we compute the entropy of the
1605    // grand mixture.
1606    let h_mix = Bernoulli::new_unchecked(mixture.f(&true)).entropy();
1607
1608    // The sum of component entropies
1609    let h_cpnt = mixture
1610        .components()
1611        .iter()
1612        .map(|cpnt| cpnt.entropy())
1613        .sum::<f64>();
1614
1615    // To uniformly weight each component
1616    let kf = mixture.k() as f64;
1617
1618    // Jensen-Shannon divergence
1619    h_mix - h_cpnt / kf
1620}
1621
1622macro_rules! impunc_arm {
1623    ($row_ix: ident, $col_ix: ident, $states: ident, $variant: ident) => {{
1624        let n_states = $states.len();
1625        let mixtures = (0..n_states)
1626            .map(|state_ix| {
1627                let view_ix = $states[state_ix].asgn().asgn[$col_ix];
1628                let view = &$states[state_ix].views[view_ix];
1629                let k = view.asgn().asgn[$row_ix];
1630                match &view.ftrs[&$col_ix] {
1631                    ColModel::$variant(ref ftr) => ftr.components[k].fx.clone(),
1632                    ColModel::MissingNotAtRandom(
1633                        $crate::cc::feature::MissingNotAtRandom { fx, .. },
1634                    ) => match &**fx {
1635                        ColModel::$variant(ref ftr) => {
1636                            ftr.components[k].fx.clone()
1637                        }
1638                        cm => {
1639                            panic!(
1640                                "Mismatched MNAR feature type: {}",
1641                                cm.ftype()
1642                            )
1643                        }
1644                    },
1645                    cm => panic!("Mismatched feature type: {}", cm.ftype()),
1646                }
1647            })
1648            .map(|cpnt| Mixture::uniform(vec![cpnt]).unwrap())
1649            .collect::<Vec<_>>();
1650
1651        $crate::stats::uncertainty::mixture_normed_tvd(&mixtures)
1652    }};
1653}
1654
1655pub fn impute_uncertainty(
1656    states: &[State],
1657    row_ix: usize,
1658    col_ix: usize,
1659) -> f64 {
1660    let ftype = states[0].ftype(col_ix);
1661    match ftype {
1662        FType::Continuous => {
1663            impunc_arm!(row_ix, col_ix, states, Continuous)
1664        }
1665        FType::Categorical => {
1666            impunc_arm!(row_ix, col_ix, states, Categorical)
1667        }
1668        FType::Count => {
1669            impunc_arm!(row_ix, col_ix, states, Count)
1670        }
1671        f => {
1672            panic!("Unsupported ftype: {:?}", f)
1673        }
1674    }
1675}
1676
1677#[cfg(test)]
1678mod tests {
1679    use super::*;
1680    use approx::*;
1681
1682    const TOL: f64 = 1E-8;
1683
1684    fn get_single_continuous_state_from_yaml() -> State {
1685        let filenames = vec!["resources/test/single-continuous.yaml"];
1686        load_states(filenames).remove(0)
1687    }
1688
1689    fn get_single_categorical_state_from_yaml() -> State {
1690        let filenames = vec!["resources/test/single-categorical.yaml"];
1691        load_states(filenames).remove(0)
1692    }
1693
1694    fn get_single_count_state_from_yaml() -> State {
1695        let filenames = vec!["resources/test/single-count.yaml"];
1696        load_states(filenames).remove(0)
1697    }
1698
1699    fn get_states_from_yaml() -> Vec<State> {
1700        let filenames = vec![
1701            "resources/test/small/small-state-1.yaml",
1702            "resources/test/small/small-state-2.yaml",
1703            "resources/test/small/small-state-3.yaml",
1704        ];
1705        load_states(filenames)
1706    }
1707
1708    fn get_entropy_states_from_yaml() -> Vec<State> {
1709        let filenames = vec![
1710            "resources/test/entropy/entropy-state-1.yaml",
1711            "resources/test/entropy/entropy-state-2.yaml",
1712        ];
1713        load_states(filenames)
1714    }
1715
1716    pub fn old_categorical_entropy_single(
1717        col_ix: usize,
1718        states: &[State],
1719    ) -> f64 {
1720        let cpnt: Categorical =
1721            states[0].component(0, col_ix).try_into().unwrap();
1722        let k = cpnt.k();
1723
1724        let mut vals: Vec<Vec<Datum>> = Vec::with_capacity(k);
1725        for i in 0..k {
1726            vals.push(vec![Datum::Categorical((i as u8).into())]);
1727        }
1728
1729        let logps: Vec<Vec<f64>> = states
1730            .iter()
1731            .map(|state| {
1732                state_logp(
1733                    state,
1734                    &[col_ix],
1735                    &vals,
1736                    &Given::Nothing,
1737                    None,
1738                    false,
1739                )
1740            })
1741            .collect();
1742
1743        let ln_n_states = (states.len() as f64).ln();
1744
1745        transpose(&logps)
1746            .iter()
1747            .map(|lps| logsumexp(lps) - ln_n_states)
1748            .fold(0.0, |acc, lp| lp.mul_add(-lp.exp(), acc))
1749    }
1750
1751    #[test]
1752    fn single_continuous_column_weights_no_given() {
1753        let state = get_single_continuous_state_from_yaml();
1754
1755        let weights = single_view_weights(&state, 0, &Given::Nothing);
1756
1757        assert_relative_eq!(weights[0], -std::f64::consts::LN_2, epsilon = TOL);
1758        assert_relative_eq!(weights[1], -std::f64::consts::LN_2, epsilon = TOL);
1759    }
1760
1761    #[test]
1762    fn single_continuous_column_weights_given() {
1763        let state = get_single_continuous_state_from_yaml();
1764        let given = Given::Conditions(vec![(0, Datum::Continuous(0.5))]);
1765
1766        let weights = single_view_weights(&state, 0, &given);
1767        let target = {
1768            let mut unnormed_targets =
1769                vec![-2.857_054_917_013_031_5, -16.598_938_533_204_67];
1770            let z = logsumexp(&unnormed_targets);
1771            unnormed_targets.iter_mut().for_each(|w| *w -= z);
1772            unnormed_targets
1773        };
1774
1775        assert_relative_eq!(weights[0], target[0], epsilon = TOL);
1776        assert_relative_eq!(weights[1], target[1], epsilon = TOL);
1777    }
1778
1779    #[test]
1780    fn continuous_predict_with_spread_out_modes() {
1781        let states = {
1782            let filenames =
1783                vec!["resources/test/spread-out-continuous-modes.yaml"];
1784            load_states(filenames)
1785        };
1786        let states: Vec<&State> = states.iter().collect();
1787
1788        let x = continuous_predict(&states, 0, &Given::Nothing);
1789        assert_relative_eq!(x, -0.12, epsilon = 1E-5);
1790    }
1791
1792    #[test]
1793    fn single_view_weights_state_0_no_given() {
1794        let states = get_states_from_yaml();
1795
1796        let weights_0 = single_view_weights(&states[0], 0, &Given::Nothing);
1797
1798        assert_relative_eq!(
1799            weights_0[0],
1800            -std::f64::consts::LN_2,
1801            epsilon = TOL
1802        );
1803        assert_relative_eq!(
1804            weights_0[1],
1805            -std::f64::consts::LN_2,
1806            epsilon = TOL
1807        );
1808
1809        let weights_1 = single_view_weights(&states[0], 1, &Given::Nothing);
1810
1811        assert_relative_eq!(
1812            weights_1[0],
1813            -1.386_294_361_119_890_6,
1814            epsilon = TOL
1815        );
1816        assert_relative_eq!(
1817            weights_1[1],
1818            -0.287_682_072_451_780_9,
1819            epsilon = TOL
1820        );
1821    }
1822
1823    #[test]
1824    fn single_view_weights_vs_exp() {
1825        let states = get_states_from_yaml();
1826        let weights_0 = single_view_weights(&states[0], 0, &Given::Nothing);
1827        let weights_1 = single_view_weights(&states[0], 1, &Given::Nothing);
1828        let exp_weights_0 =
1829            single_view_exp_weights(&states[0], 0, &Given::Nothing);
1830        let exp_weights_1 =
1831            single_view_exp_weights(&states[0], 1, &Given::Nothing);
1832
1833        weights_0
1834            .iter()
1835            .zip(exp_weights_0.iter())
1836            .for_each(|(&w, &e)| assert_relative_eq!(w, e.ln(), epsilon = TOL));
1837
1838        weights_1
1839            .iter()
1840            .zip(exp_weights_1.iter())
1841            .for_each(|(&w, &e)| assert_relative_eq!(w, e.ln(), epsilon = TOL));
1842    }
1843
1844    #[test]
1845    fn single_view_weights_state_0_with_one_given() {
1846        let states = get_states_from_yaml();
1847
1848        // column 1 should not affect view 0 weights because it is assigned to
1849        // view 1
1850        let given = Given::Conditions(vec![
1851            (0, Datum::Continuous(0.0)),
1852            (1, Datum::Continuous(-1.0)),
1853        ]);
1854
1855        let weights_0 = single_view_weights(&states[0], 0, &given);
1856        let weights_1 = single_view_weights(&states[0], 1, &given);
1857        {
1858            let unnormed_targets =
1859                vec![-3.158_958_368_120_129, -1.926_578_447_516_985];
1860            let z = logsumexp(&unnormed_targets);
1861            let targets: Vec<_> =
1862                unnormed_targets.iter().map(|&w| w - z).collect();
1863            assert_relative_eq!(weights_0[0], targets[0], epsilon = TOL);
1864            assert_relative_eq!(weights_0[1], targets[1], epsilon = TOL);
1865        }
1866
1867        {
1868            let unnormed_targets =
1869                vec![-4.095_863_302_766_923, -0.417_781_136_933_142_9];
1870            let z = logsumexp(&unnormed_targets);
1871            let targets: Vec<_> =
1872                unnormed_targets.iter().map(|&w| w - z).collect();
1873            assert_relative_eq!(weights_1[0], targets[0], epsilon = TOL);
1874            assert_relative_eq!(weights_1[1], targets[1], epsilon = TOL);
1875        }
1876    }
1877
1878    #[test]
1879    fn single_view_weights_vs_exp_one_given() {
1880        let given = Given::Conditions(vec![
1881            (0, Datum::Continuous(0.0)),
1882            (1, Datum::Continuous(-1.0)),
1883        ]);
1884
1885        let states = get_states_from_yaml();
1886        let weights_0 = single_view_weights(&states[0], 0, &given);
1887        let weights_1 = single_view_weights(&states[0], 1, &given);
1888        let exp_weights_0 = single_view_exp_weights(&states[0], 0, &given);
1889        let exp_weights_1 = single_view_exp_weights(&states[0], 1, &given);
1890
1891        weights_0
1892            .iter()
1893            .zip(exp_weights_0.iter())
1894            .for_each(|(&w, &e)| assert_relative_eq!(w, e.ln(), epsilon = TOL));
1895
1896        weights_1
1897            .iter()
1898            .zip(exp_weights_1.iter())
1899            .for_each(|(&w, &e)| assert_relative_eq!(w, e.ln(), epsilon = TOL));
1900    }
1901
1902    #[test]
1903    fn single_view_weights_state_0_with_added_given() {
1904        let states = get_states_from_yaml();
1905
1906        let given = Given::Conditions(vec![
1907            (0, Datum::Continuous(0.0)),
1908            (2, Datum::Continuous(-1.0)),
1909        ]);
1910
1911        let weights_0 = single_view_weights(&states[0], 0, &given);
1912
1913        {
1914            let unnormed_targets =
1915                vec![-5.669_175_767_690_254, -9.304_554_786_193_446];
1916            let z = logsumexp(&unnormed_targets);
1917            let targets: Vec<_> =
1918                unnormed_targets.iter().map(|&w| w - z).collect();
1919            assert_relative_eq!(weights_0[0], targets[0], epsilon = TOL);
1920            assert_relative_eq!(weights_0[1], targets[1], epsilon = TOL);
1921        }
1922    }
1923
1924    #[test]
1925    fn single_state_weights_value_check() {
1926        let states = get_states_from_yaml();
1927
1928        let col_ixs = vec![0, 1];
1929        let given = Given::Conditions(vec![
1930            (0, Datum::Continuous(0.0)),
1931            (1, Datum::Continuous(-1.0)),
1932            (2, Datum::Continuous(-1.0)),
1933        ]);
1934
1935        let weights = single_state_weights(&states[0], &col_ixs, &given);
1936
1937        assert_eq!(weights.len(), 2);
1938        assert_eq!(weights[&0].len(), 2);
1939        assert_eq!(weights[&1].len(), 2);
1940
1941        {
1942            let unnormed_targets =
1943                vec![-5.669_175_767_690_254, -9.304_554_786_193_446];
1944            let z = logsumexp(&unnormed_targets);
1945            let targets: Vec<_> =
1946                unnormed_targets.iter().map(|&w| w - z).collect();
1947            assert_relative_eq!(weights[&0][0], targets[0], epsilon = TOL);
1948            assert_relative_eq!(weights[&0][1], targets[1], epsilon = TOL);
1949        }
1950
1951        {
1952            let unnormed_targets =
1953                vec![-4.095_863_302_766_923, -0.417_781_136_933_142_9];
1954            let z = logsumexp(&unnormed_targets);
1955            let targets: Vec<_> =
1956                unnormed_targets.iter().map(|&w| w - z).collect();
1957            assert_relative_eq!(weights[&1][0], targets[0], epsilon = TOL);
1958            assert_relative_eq!(weights[&1][1], targets[1], epsilon = TOL);
1959        }
1960    }
1961
1962    #[test]
1963    fn give_weights_size_check_single_target_column() {
1964        let states = get_states_from_yaml();
1965
1966        let col_ixs = vec![0];
1967        let state_weights = given_weights(
1968            states.iter().collect::<Vec<_>>().as_slice(),
1969            &col_ixs,
1970            &Given::Nothing,
1971        );
1972
1973        assert_eq!(state_weights.len(), 3);
1974
1975        assert_eq!(state_weights[0].len(), 1);
1976        assert_eq!(state_weights[1].len(), 1);
1977        assert_eq!(state_weights[2].len(), 1);
1978
1979        assert_eq!(state_weights[0][&0].len(), 2);
1980        assert_eq!(state_weights[1][&0].len(), 3);
1981        assert_eq!(state_weights[2][&0].len(), 2);
1982    }
1983
1984    macro_rules! state_logp_vs_exp {
1985        ($precomp: expr, $state: expr, $col_ixs: expr, $vals: expr, $given: expr) => {{
1986            let state_weights = single_state_weights($state, $col_ixs, $given);
1987            let logp = state_logp(
1988                $state,
1989                $col_ixs,
1990                $vals,
1991                $given,
1992                if $precomp { Some(&state_weights) } else { None },
1993                false,
1994            );
1995
1996            let state_exp_weights =
1997                single_state_exp_weights($state, $col_ixs, $given);
1998            let likeihood = state_likelihood(
1999                $state,
2000                $col_ixs,
2001                $vals,
2002                $given,
2003                if $precomp {
2004                    Some(&state_exp_weights)
2005                } else {
2006                    None
2007                },
2008            );
2009
2010            for (&ln_f, &f) in logp.iter().zip(likeihood.iter()) {
2011                assert_relative_eq!(ln_f, f.ln(), epsilon = TOL)
2012            }
2013        }};
2014    }
2015
2016    #[test]
2017    fn state_logp_values_single_col_single_view() {
2018        let states = get_states_from_yaml();
2019
2020        let col_ixs = vec![0];
2021        let vals = vec![vec![Datum::Continuous(1.2)]];
2022        let logp = state_logp(
2023            &states[0],
2024            &col_ixs,
2025            &vals,
2026            &Given::Nothing,
2027            None,
2028            false,
2029        );
2030
2031        assert_relative_eq!(logp[0], -2.939_618_577_673_343_7, epsilon = TOL);
2032    }
2033
2034    #[test]
2035    fn state_logp_vs_exp_values_single_col_single_view() {
2036        let states = get_states_from_yaml();
2037
2038        let col_ixs = vec![0];
2039        let vals = vec![vec![Datum::Continuous(1.2)]];
2040        state_logp_vs_exp!(false, &states[0], &col_ixs, &vals, &Given::Nothing);
2041    }
2042
2043    #[test]
2044    fn state_logp_values_multi_col_single_view() {
2045        let states = get_states_from_yaml();
2046
2047        let col_ixs = vec![0, 2];
2048        let vals = vec![vec![Datum::Continuous(1.2), Datum::Continuous(-0.3)]];
2049        let logp = state_logp(
2050            &states[0],
2051            &col_ixs,
2052            &vals,
2053            &Given::Nothing,
2054            None,
2055            false,
2056        );
2057
2058        assert_relative_eq!(logp[0], -4.277_889_544_469_348, epsilon = TOL);
2059    }
2060
2061    #[test]
2062    fn state_logp_vs_exp_values_multi_col_single_view() {
2063        let states = get_states_from_yaml();
2064
2065        let col_ixs = vec![0, 2];
2066        let vals = vec![vec![Datum::Continuous(1.2), Datum::Continuous(-0.3)]];
2067        state_logp_vs_exp!(false, &states[0], &col_ixs, &vals, &Given::Nothing)
2068    }
2069
2070    #[test]
2071    fn state_logp_values_multi_col_single_view_precomp() {
2072        let states = get_states_from_yaml();
2073
2074        let col_ixs = vec![0, 2];
2075        let vals = vec![vec![Datum::Continuous(1.2), Datum::Continuous(-0.3)]];
2076        let view_weights =
2077            single_state_weights(&states[0], &col_ixs, &Given::Nothing);
2078
2079        let logp = state_logp(
2080            &states[0],
2081            &col_ixs,
2082            &vals,
2083            &Given::Nothing,
2084            Some(&view_weights),
2085            false,
2086        );
2087
2088        assert_relative_eq!(logp[0], -4.277_889_544_469_348, epsilon = TOL);
2089    }
2090
2091    #[test]
2092    fn state_logp_vs_exp_values_multi_col_single_view_precomp() {
2093        let states = get_states_from_yaml();
2094
2095        let col_ixs = vec![0, 2];
2096        let vals = vec![vec![Datum::Continuous(1.2), Datum::Continuous(-0.3)]];
2097
2098        state_logp_vs_exp!(true, &states[0], &col_ixs, &vals, &Given::Nothing);
2099    }
2100
2101    #[test]
2102    fn state_logp_values_multi_col_multi_view() {
2103        let states = get_states_from_yaml();
2104
2105        let col_ixs = vec![0, 1];
2106        let vals = vec![vec![Datum::Continuous(1.2), Datum::Continuous(0.2)]];
2107        let logp = state_logp(
2108            &states[0],
2109            &col_ixs,
2110            &vals,
2111            &Given::Nothing,
2112            None,
2113            false,
2114        );
2115
2116        assert_relative_eq!(logp[0], -4.718_619_899_900_069, epsilon = TOL);
2117    }
2118
2119    #[test]
2120    fn state_logp_vs_exp_values_multi_col_multi_view() {
2121        let states = get_states_from_yaml();
2122
2123        let col_ixs = vec![0, 1];
2124        let vals = vec![vec![Datum::Continuous(1.2), Datum::Continuous(0.2)]];
2125        state_logp_vs_exp!(false, &states[0], &col_ixs, &vals, &Given::Nothing);
2126    }
2127
2128    #[test]
2129    fn state_logp_values_multi_col_multi_view_precomp() {
2130        let states = get_states_from_yaml();
2131
2132        let col_ixs = vec![0, 1];
2133        let vals = vec![vec![Datum::Continuous(1.2), Datum::Continuous(0.2)]];
2134        let view_weights =
2135            single_state_weights(&states[0], &col_ixs, &Given::Nothing);
2136        let logp = state_logp(
2137            &states[0],
2138            &col_ixs,
2139            &vals,
2140            &Given::Nothing,
2141            Some(&view_weights),
2142            false,
2143        );
2144
2145        assert_relative_eq!(logp[0], -4.718_619_899_900_069, epsilon = TOL);
2146    }
2147
2148    #[test]
2149    fn state_logp_vs_exp_values_multi_col_multi_view_precomp() {
2150        let states = get_states_from_yaml();
2151
2152        let col_ixs = vec![0, 1];
2153        let vals = vec![vec![Datum::Continuous(1.2), Datum::Continuous(0.2)]];
2154        state_logp_vs_exp!(true, &states[0], &col_ixs, &vals, &Given::Nothing);
2155    }
2156
2157    #[test]
2158    fn single_state_continuous_impute_1() {
2159        let mut all_states = get_states_from_yaml();
2160        let states = [&all_states.remove(0)];
2161        let x: f64 = continuous_impute(&states, 1, 0);
2162        assert_relative_eq!(x, 1.683_113_796_266_261_7, epsilon = 10E-6);
2163    }
2164
2165    #[test]
2166    fn single_state_continuous_impute_2() {
2167        let mut all_states = get_states_from_yaml();
2168        let states = [&all_states.remove(0)];
2169        let x: f64 = continuous_impute(&states, 3, 0);
2170        assert_relative_eq!(x, -0.824_416_188_399_796_6, epsilon = 10E-6);
2171    }
2172
2173    #[test]
2174    fn multi_state_continuous_impute_1() {
2175        let mut all_states = get_states_from_yaml();
2176        let states = [&all_states.remove(0), &all_states.remove(0)];
2177        let x: f64 = continuous_impute(&states, 1, 2);
2178        assert_relative_eq!(x, 0.554_604_492_187_499_9, epsilon = 10E-6);
2179    }
2180
2181    #[test]
2182    fn multi_state_continuous_impute_2() {
2183        let states = get_states_from_yaml();
2184        let states: Vec<&State> = states.iter().collect();
2185        let x: f64 = continuous_impute(&states, 1, 2);
2186        assert_relative_eq!(x, -0.250_584_379_015_657_5, epsilon = 10E-6);
2187    }
2188
2189    #[test]
2190    fn single_state_categorical_impute_1() {
2191        let state: State = get_single_categorical_state_from_yaml();
2192        let x: u8 = categorical_impute(&[&state], 0, 0);
2193        assert_eq!(x, 2);
2194    }
2195
2196    #[test]
2197    fn single_state_categorical_impute_2() {
2198        let state: State = get_single_categorical_state_from_yaml();
2199        let x: u8 = categorical_impute(&[&state], 2, 0);
2200        assert_eq!(x, 0);
2201    }
2202
2203    #[test]
2204    fn single_state_categorical_predict_1() {
2205        let state: State = get_single_categorical_state_from_yaml();
2206        let x: u8 = categorical_predict(&[&state], 0, &Given::Nothing);
2207        assert_eq!(x, 2);
2208    }
2209
2210    #[test]
2211    fn single_state_categorical_entropy() {
2212        let state: State = get_single_categorical_state_from_yaml();
2213        let h = entropy_single(0, &vec![state]);
2214        assert_relative_eq!(h, 1.368_541_708_152_32, epsilon = 10E-6);
2215    }
2216
2217    #[test]
2218    fn single_state_categorical_self_entropy() {
2219        let state: State = get_single_categorical_state_from_yaml();
2220        let states = vec![state];
2221        let h_x = entropy_single(0, &states);
2222        let h_xx = categorical_entropy_dual(0, 0, &states);
2223        assert_relative_eq!(h_xx, h_x, epsilon = 1E-12);
2224    }
2225
2226    #[test]
2227    fn multi_state_categorical_self_entropy() {
2228        let state: State = get_single_categorical_state_from_yaml();
2229        let states = vec![state];
2230        let h_x = entropy_single(0, &states);
2231        let h_xx = categorical_entropy_dual(0, 0, &states);
2232        assert_relative_eq!(h_xx, h_x, epsilon = 1E-12);
2233    }
2234
2235    #[test]
2236    fn multi_state_categorical_single_entropy() {
2237        let states = get_entropy_states_from_yaml();
2238        let h_x = entropy_single(2, &states);
2239        assert_relative_eq!(h_x, 1.368_715_500_467_195_1, epsilon = 1E-12);
2240    }
2241
2242    #[cfg(feature = "examples")]
2243    #[test]
2244    fn multi_state_categorical_single_entropy_vs_old() {
2245        use crate::examples::Example;
2246        use crate::HasStates;
2247        let oracle = Example::Animals.oracle().unwrap();
2248
2249        for col_ix in 0..oracle.n_cols() {
2250            let h_x_new = entropy_single(col_ix, &oracle.states);
2251            let h_x_old =
2252                old_categorical_entropy_single(col_ix, &oracle.states);
2253            assert_relative_eq!(h_x_new, h_x_old, epsilon = 1E-12);
2254        }
2255    }
2256
2257    #[test]
2258    fn single_state_count_impute_1() {
2259        let states = [&get_single_count_state_from_yaml()];
2260        let x: u32 = count_impute(&states, 1, 0);
2261        assert_eq!(x, 1);
2262    }
2263
2264    #[test]
2265    fn single_state_count_impute_2() {
2266        let states = [&get_single_count_state_from_yaml()];
2267        let x: u32 = count_impute(&states, 1, 0);
2268        assert_eq!(x, 1);
2269    }
2270
2271    #[test]
2272    fn single_state_count_predict() {
2273        let states = [&get_single_count_state_from_yaml()];
2274        let x: u32 = count_predict(&states, 0, &Given::<usize>::Nothing);
2275        assert_eq!(x, 1);
2276    }
2277
2278    #[test]
2279    fn single_state_dual_categorical_entropy_0() {
2280        let mut states = get_entropy_states_from_yaml();
2281        let state = states.drain(..).next().unwrap();
2282        let hxy = categorical_entropy_dual(2, 3, &vec![state]);
2283        assert_relative_eq!(hxy, 2.050_396_319_359_273_4, epsilon = 1E-14);
2284    }
2285
2286    #[test]
2287    fn single_state_dual_categorical_entropy_1() {
2288        let mut states = get_entropy_states_from_yaml();
2289        let state = states.pop().unwrap();
2290        let hxy = categorical_entropy_dual(2, 3, &vec![state]);
2291        assert_relative_eq!(hxy, 2.035_433_971_709_626, epsilon = 1E-14);
2292    }
2293
2294    #[test]
2295    fn single_state_dual_categorical_entropy_vs_joint_equiv() {
2296        let states = {
2297            let mut states = get_entropy_states_from_yaml();
2298            let state = states.pop().unwrap();
2299            vec![state]
2300        };
2301        let hxy_dual = categorical_entropy_dual(2, 3, &states);
2302        let hxy_joint = categorical_joint_entropy(&[2, 3], &states);
2303
2304        assert_relative_eq!(hxy_dual, hxy_joint, epsilon = 1E-14);
2305    }
2306
2307    #[test]
2308    fn multi_state_dual_categorical_entropy_1() {
2309        let states = get_entropy_states_from_yaml();
2310        let hxy = categorical_entropy_dual(2, 3, &states);
2311        assert_relative_eq!(hxy, 2.050_402_245_628_641_5, epsilon = 1E-14);
2312    }
2313
2314    #[test]
2315    fn multi_state_dual_categorical_entropy_vs_joint_equiv() {
2316        let states = get_entropy_states_from_yaml();
2317        let hxy_dual = categorical_entropy_dual(2, 3, &states);
2318        let hxy_joint = categorical_joint_entropy(&[2, 3], &states);
2319        assert_relative_eq!(hxy_dual, hxy_joint, epsilon = 1E-14);
2320    }
2321
2322    #[test]
2323    fn single_state_categorical_gaussian_entropy_0() {
2324        let mut states = get_entropy_states_from_yaml();
2325        // first state
2326        let state = states.drain(..).next().unwrap();
2327        let hxy = categorical_gaussian_entropy_dual(2, 0, &vec![state]);
2328        assert_relative_eq!(hxy, 2.726_163_712_601_034, epsilon = 1E-7);
2329    }
2330
2331    #[test]
2332    fn single_state_categorical_gaussian_entropy_1() {
2333        let mut states = get_entropy_states_from_yaml();
2334        // second (last) state
2335        let state = states.pop().unwrap();
2336        let hxy = categorical_gaussian_entropy_dual(2, 0, &vec![state]);
2337        assert_relative_eq!(hxy, 2.735_457_532_371_074_6, epsilon = 1E-7);
2338    }
2339
2340    #[test]
2341    fn multi_state_categorical_gaussian_entropy_0() {
2342        let states = get_entropy_states_from_yaml();
2343        let hxy = categorical_gaussian_entropy_dual(2, 0, &states);
2344        assert_relative_eq!(hxy, 2.744_356_173_055_859, epsilon = 1E-7);
2345    }
2346
2347    #[test]
2348    fn sobol_samples() {
2349        let mut states = get_entropy_states_from_yaml();
2350        let state = states.pop().unwrap();
2351        let (samples, _) = gen_sobol_samples(&[0, 2, 3], &state, 102);
2352
2353        assert_eq!(samples.len(), 102);
2354
2355        for vals in samples {
2356            assert_eq!(vals.len(), 3);
2357            assert!(vals[0].is_continuous());
2358            assert!(vals[1].is_categorical());
2359            assert!(vals[2].is_categorical());
2360        }
2361    }
2362
2363    fn sobolo_vs_exact_entropy(col_ix: usize, n: usize) -> (f64, f64) {
2364        let mut states = get_entropy_states_from_yaml();
2365        let state = states.pop().unwrap();
2366
2367        let h_sobol = {
2368            let (samples, q_recip) = gen_sobol_samples(&[col_ix], &state, n);
2369
2370            let logps = state_logp(
2371                &state,
2372                &[col_ix],
2373                &samples,
2374                &Given::Nothing,
2375                None,
2376                false,
2377            );
2378
2379            let h: f64 = logps.iter().map(|logp| -logp * logp.exp()).sum();
2380
2381            h * q_recip / (n as f64)
2382        };
2383
2384        let h_exact = entropy_single(col_ix, &vec![state]);
2385
2386        (h_exact, h_sobol)
2387    }
2388
2389    #[test]
2390    fn sobol_single_categorical_entropy_vs_exact() {
2391        let (h_exact, h_sobol) = sobolo_vs_exact_entropy(2, 10_000);
2392        assert_relative_eq!(h_exact, h_sobol, epsilon = 1E-12);
2393    }
2394
2395    #[test]
2396    fn sobol_single_gaussian_entropy_vs_exact() {
2397        let (h_exact, h_sobol) = sobolo_vs_exact_entropy(0, 10_000);
2398        assert_relative_eq!(h_exact, h_sobol, epsilon = 1E-7);
2399    }
2400}