use ndarray::{Array1, Array2};
use statrs::function::gamma::digamma;
use crate::estimators::approaches::discrete::discrete_utils::reduce_joint_space_compact;
use crate::estimators::approaches::discrete::discrete_utils::{DiscreteDataset, rows_as_vec};
use crate::estimators::traits::{GlobalValue, JointEntropy, LocalValues, OptionalLocalValues};
pub struct AnsbEntropy {
dataset: DiscreteDataset,
k_override: Option<usize>,
undersampled_threshold: f64,
}
impl AnsbEntropy {
pub fn new(data: Array1<i32>, k_override: Option<usize>, undersampled_threshold: f64) -> Self {
let dataset = DiscreteDataset::from_data(data);
Self {
dataset,
k_override,
undersampled_threshold,
}
}
pub fn from_rows(
data: Array2<i32>,
k_override: Option<usize>,
undersampled_threshold: f64,
) -> Vec<Self> {
rows_as_vec(data)
.into_iter()
.map(|row| Self::new(row, k_override, undersampled_threshold))
.collect()
}
}
impl GlobalValue for AnsbEntropy {
fn global_value(&self) -> f64 {
let n = self.dataset.n;
if n == 0 {
return f64::NAN;
}
let k_obs = self.dataset.k;
let k = self.k_override.unwrap_or(k_obs);
if k == 0 {
return f64::NAN;
}
let ratio = n as f64 / k as f64;
if ratio > self.undersampled_threshold {
println!(
"Warning: Data is not sufficiently undersampled (N/K = {:.3} > {:.3}), so calculation may diverge...",
ratio, self.undersampled_threshold
);
}
let coincidences = (n as i64) - (k as i64);
if coincidences <= 0 {
return f64::NAN;
}
const EULER_GAMMA: f64 = 0.577_215_664_901_532_9;
(EULER_GAMMA - 2.0_f64.ln()) + 2.0 * (n as f64).ln() - digamma(coincidences as f64)
}
}
impl LocalValues for AnsbEntropy {
fn local_values(&self) -> Array1<f64> {
Array1::zeros(0)
}
}
impl JointEntropy for AnsbEntropy {
type Source = Array1<i32>;
type Params = (Option<usize>, f64);
fn joint_entropy(series: &[Self::Source], params: Self::Params) -> f64 {
if series.is_empty() {
return 0.0;
}
let joint_codes = reduce_joint_space_compact(series);
let disc = AnsbEntropy::new(joint_codes, params.0, params.1);
disc.global_value()
}
}
impl OptionalLocalValues for AnsbEntropy {
fn supports_local(&self) -> bool {
false
}
fn local_values_opt(&self) -> Result<Array1<f64>, &'static str> {
Err(
"Local values are not supported for ANSB estimator as it averages over Dirichlet priors.",
)
}
}