use crate::estimators::approaches::discrete::discrete_utils::reduce_joint_space_compact;
use crate::estimators::approaches::discrete::discrete_utils::{DiscreteDataset, rows_as_vec};
use crate::estimators::traits::{GlobalValue, JointEntropy, LocalValues, OptionalLocalValues};
use ndarray::{Array1, Array2};
pub struct ZhangEntropy {
dataset: DiscreteDataset,
}
impl ZhangEntropy {
pub fn new(data: Array1<i32>) -> Self {
let dataset = DiscreteDataset::from_data(data);
Self { dataset }
}
pub fn from_rows(data: Array2<i32>) -> Vec<Self> {
rows_as_vec(data).into_iter().map(Self::new).collect()
}
#[inline]
fn t2_for_count(n: usize, total_samples: usize) -> f64 {
if n == 0 || n >= total_samples {
return 0.0;
}
let nf = n as f64;
let n_total = total_samples as f64;
let mut h_hat = 0.0_f64;
let mut t1 = 1.0_f64;
for k in 1..=(total_samples - n) {
let kf = k as f64;
let factor = 1.0 - (nf - 1.0) / (n_total - kf);
t1 *= factor;
h_hat += t1 / kf;
}
h_hat
}
}
impl GlobalValue for ZhangEntropy {
fn global_value(&self) -> f64 {
let n = self.dataset.n;
let mut h = 0.0_f64;
let nf = n as f64;
for &cnt in self.dataset.counts.values() {
h += (cnt as f64 / nf) * Self::t2_for_count(cnt, n);
}
h
}
}
impl LocalValues for ZhangEntropy {
fn local_values(&self) -> Array1<f64> {
use std::collections::HashMap;
let mut contrib: HashMap<i32, f64> = HashMap::with_capacity(self.dataset.k);
let n = self.dataset.n;
for (&val, &cnt) in self.dataset.counts.iter() {
let t2 = Self::t2_for_count(cnt, n);
contrib.insert(val, t2);
}
self.dataset.data.mapv(|v| contrib[&v])
}
}
impl JointEntropy for ZhangEntropy {
type Source = Array1<i32>;
type Params = ();
fn joint_entropy(series: &[Self::Source], _params: Self::Params) -> f64 {
if series.is_empty() {
return 0.0;
}
let joint_codes = reduce_joint_space_compact(series);
let disc = ZhangEntropy::new(joint_codes);
GlobalValue::global_value(&disc)
}
}
impl OptionalLocalValues for ZhangEntropy {
fn supports_local(&self) -> bool {
true
}
fn local_values_opt(&self) -> Result<Array1<f64>, &'static str> {
Ok(self.local_values())
}
}