use crate::estimators::approaches::discrete::discrete_utils::reduce_joint_space_compact;
use crate::estimators::approaches::discrete::discrete_utils::{DiscreteDataset, rows_as_vec};
use crate::estimators::traits::{GlobalValue, JointEntropy, LocalValues, OptionalLocalValues};
use ndarray::{Array1, Array2};
use statrs::function::gamma::{digamma, ln_gamma};
pub struct NsbEntropy {
dataset: DiscreteDataset,
k_override: Option<usize>,
tol: f64,
max_recursion: usize,
}
impl NsbEntropy {
pub fn new(data: Array1<i32>, k_override: Option<usize>) -> Self {
let dataset = DiscreteDataset::from_data(data);
Self {
dataset,
k_override,
tol: 1e-6,
max_recursion: 12,
}
}
fn counts_vec(&self) -> Vec<usize> {
self.dataset.counts.values().cloned().collect()
}
fn neg_log_rho(&self, beta: f64, k: usize, n: usize, counts: &[usize]) -> f64 {
let kappa = (k as f64) * beta;
let mut result = -(ln_gamma(kappa) - ln_gamma(n as f64 + kappa));
let ln_g_beta = ln_gamma(beta);
let mut sum_terms = 0.0_f64;
for &ci in counts {
sum_terms += (ci as f64) * (ln_gamma(ci as f64 + beta) - ln_g_beta);
}
result -= sum_terms;
result
}
fn dxi(&self, beta: f64, k: usize) -> f64 {
let kb = (k as f64) * beta;
(k as f64) * trigamma(1.0 + kb) - trigamma(1.0 + beta)
}
fn bayes_expectation(&self, beta: f64, counts: &[usize]) -> f64 {
let total_alpha = (self.dataset.n as f64) + (counts.len() as f64) * beta;
let mut sum_term = 0.0_f64;
for &ci in counts {
let ai = (ci as f64) + beta;
sum_term += ai * digamma(ai + 1.0);
}
digamma(total_alpha + 1.0) - (sum_term / total_alpha)
}
fn find_l0(&self, k: usize, n: usize) -> f64 {
let func = |k0: f64| -> f64 { (k as f64) / k0 - digamma(k0 + n as f64) + digamma(k0) };
let mut best_k0 = 0.1_f64;
let mut best_val = f64::INFINITY;
let steps = 200usize;
let upper = k as f64;
let mut t = 0.1_f64;
let step = (upper - 0.1_f64) / (steps as f64);
for _ in 0..steps {
let v = func(t).abs();
if v < best_val {
best_val = v;
best_k0 = t;
}
t += step;
}
let extremum_beta = best_k0 / (k as f64);
let counts = self.counts_vec();
self.neg_log_rho(extremum_beta, k, n, &counts)
}
fn simpson<F: Fn(f64) -> f64>(f: &F, a: f64, b: f64, tol: f64, depth: usize) -> f64 {
fn simp<F: Fn(f64) -> f64>(f: &F, a: f64, b: f64) -> f64 {
let c = 0.5 * (a + b);
let h = b - a;
(h / 6.0) * (f(a) + 4.0 * f(c) + f(b))
}
let c = 0.5 * (a + b);
let s_ab = simp(f, a, b);
let s_ac = simp(f, a, c);
let s_cb = simp(f, c, b);
if depth == 0 || (s_ac + s_cb - s_ab).abs() < 15.0 * tol {
return s_ac + s_cb + (s_ac + s_cb - s_ab) / 15.0;
}
Self::simpson(f, a, c, tol / 2.0, depth - 1) + Self::simpson(f, c, b, tol / 2.0, depth - 1)
}
}
impl GlobalValue for NsbEntropy {
fn global_value(&self) -> f64 {
let n = self.dataset.n;
let k_obs = self.dataset.k;
let k = self.k_override.unwrap_or(k_obs);
if n == 0 || k == 0 {
return f64::NAN;
}
let counts = self.counts_vec();
let coincidences = (n as i64) - (k as i64);
if coincidences == 0 {
return f64::NAN;
}
let upper = (k as f64).ln();
if !upper.is_finite() || upper <= 0.0 {
return f64::NAN;
}
let l0 = self.find_l0(k, n);
let neg_log_rho = |beta: f64| self.neg_log_rho(beta, k, n, &counts);
let dxi = |beta: f64| self.dxi(beta, k);
let bayes = |beta: f64| self.bayes_expectation(beta, &counts);
let f_num = |beta: f64| ((-(neg_log_rho(beta)) + l0).exp()) * dxi(beta) * bayes(beta);
let f_den = |beta: f64| ((-(neg_log_rho(beta)) + l0).exp()) * dxi(beta);
let a = 1e-8;
let num = Self::simpson(&f_num, a, upper, self.tol, self.max_recursion);
let den = Self::simpson(&f_den, a, upper, self.tol, self.max_recursion);
if den == 0.0 || !den.is_finite() {
return f64::NAN;
}
num / den
}
}
impl LocalValues for NsbEntropy {
fn local_values(&self) -> Array1<f64> {
Array1::zeros(0)
}
}
impl JointEntropy for NsbEntropy {
type Source = Array1<i32>;
type Params = Option<usize>;
fn joint_entropy(series: &[Self::Source], params: Self::Params) -> f64 {
if series.is_empty() {
return 0.0;
}
let joint_codes = reduce_joint_space_compact(series);
let disc = NsbEntropy::new(joint_codes, params);
disc.global_value()
}
}
fn trigamma(mut x: f64) -> f64 {
if !x.is_finite() {
return f64::NAN;
}
let mut acc = 0.0_f64;
while x < 8.0 {
acc += 1.0 / (x * x);
x += 1.0;
}
let z = 1.0 / x;
let z2 = z * z;
let mut series = z + 0.5 * z2 + (1.0 / 6.0) * z2 * z; let z5 = z2 * z2 * z; let z7 = z5 * z2; let z9 = z7 * z2; let z11 = z9 * z2; series += -(1.0 / 30.0) * z5 + (1.0 / 42.0) * z7 - (1.0 / 30.0) * z9 + (5.0 / 66.0) * z11;
acc + series
}
impl NsbEntropy {
pub fn from_rows(data: Array2<i32>, k_override: Option<usize>) -> Vec<Self> {
rows_as_vec(data)
.into_iter()
.map(|row| Self::new(row, k_override))
.collect()
}
}
impl OptionalLocalValues for NsbEntropy {
fn supports_local(&self) -> bool {
false
}
fn local_values_opt(&self) -> Result<Array1<f64>, &'static str> {
Err(
"Local values are not supported for NSB estimator as it averages over Dirichlet priors.",
)
}
}