use super::*;
use std::collections::HashMap;
use std::fmt::Debug;
mod diag_gaussian;
pub use diag_gaussian::*;
mod categorical;
pub use categorical::*;
mod dirichlet;
pub use dirichlet::*;
pub trait DiscreteDistribution: Clone + 'static {
fn ln_prob(&self, child: Option<&NodeAddress>) -> Option<f64>;
fn kl_divergence(&self, other: &Self) -> Option<f64>;
}
pub trait ContinousDistribution: Clone + 'static {
fn ln_prob(&self, point: &PointRef) -> Option<f64>;
fn kl_divergence(&self, other: &Self) -> Option<f64>;
}
pub trait DiscreteBayesianDistribution: DiscreteDistribution + Clone + 'static {
fn add_observation(&mut self, loc: Option<NodeAddress>);
}
pub trait ContinousBayesianDistribution: ContinousDistribution + Clone + 'static {
fn add_observation(&mut self, point: &PointRef);
}
pub trait DiscreteBayesianSequenceTracker<D: PointCloud>: Debug {
type Distribution: DiscreteBayesianDistribution + NodePlugin<D> + 'static;
fn add_dry_insert(&mut self, trace: Vec<(f32, NodeAddress)>);
fn running_distributions(&self) -> &HashMap<NodeAddress, Self::Distribution>;
fn tree_reader(&self) -> &CoverTreeReader<D>;
fn sequence_len(&self) -> usize;
fn current_stats(&self) -> KLDivergenceStats {
let mut max = f64::MIN;
let mut min = f64::MAX;
let mut nz_count = 0;
let mut moment1_nz = 0.0;
let mut moment2_nz = 0.0;
let mut layer_totals: Vec<u64> = vec![0; self.tree_reader().len()];
let mut layer_node_counts = vec![Vec::<usize>::new(); self.tree_reader().len()];
let parameters = self.tree_reader().parameters();
self.running_distributions()
.iter()
.for_each(|(address, sequence_pdf)| {
let kl = self
.tree_reader()
.get_node_plugin_and::<Self::Distribution, _, _>(*address, |p| {
p.kl_divergence(sequence_pdf).unwrap()
})
.unwrap();
if kl > 1.0e-10 {
layer_totals[parameters.internal_index(address.0)] += 1;
layer_node_counts[parameters.internal_index(address.0)].push(
self.tree_reader()
.get_node_and(*address, |n| n.cover_count())
.unwrap(),
);
moment1_nz += kl;
moment2_nz += kl * kl;
if max < kl {
max = kl;
}
if kl < min {
min = kl;
}
nz_count += 1;
}
});
let weighted_layer_totals: Vec<f32> = layer_node_counts.iter().map(|counts| {
let max: f32 = *counts.iter().max().unwrap_or(&1) as f32;
counts.iter().fold(0.0, |a,c| a + (*c as f32)/max)
}).collect();
KLDivergenceStats {
max,
min,
nz_count,
moment1_nz,
moment2_nz,
sequence_len: self.sequence_len() as u64,
layer_totals,
weighted_layer_totals,
}
}
fn all_node_kl(&self) -> Vec<(f64, NodeAddress)> {
self.running_distributions()
.iter()
.map(|(address, sequence_pdf)| {
let kl = self
.tree_reader()
.get_node_plugin_and::<Self::Distribution, _, _>(*address, |p| {
p.kl_divergence(sequence_pdf).unwrap()
})
.unwrap();
(kl, *address)
})
.collect()
}
}
#[derive(Debug)]
pub struct KLDivergenceStats {
pub max: f64,
pub min: f64,
pub nz_count: u64,
pub moment1_nz: f64,
pub moment2_nz: f64,
pub sequence_len: u64,
pub layer_totals: Vec<u64>,
pub weighted_layer_totals: Vec<f32>,
}