use crate::error::{PgmError, Result};
pub trait ExponentialFamily: Clone {
fn family_name(&self) -> &'static str;
fn natural_dim(&self) -> usize;
fn natural_params(&self) -> Vec<f64>;
fn to_natural(&self) -> Vec<f64> {
self.natural_params()
}
fn set_natural(&mut self, new_eta: &[f64]) -> Result<()>;
fn update_natural(&mut self, delta: &[f64]) -> Result<()> {
if delta.len() != self.natural_dim() {
return Err(PgmError::DimensionMismatch {
expected: vec![self.natural_dim()],
got: vec![delta.len()],
});
}
let mut eta = self.natural_params();
for (a, b) in eta.iter_mut().zip(delta.iter()) {
*a += *b;
}
self.set_natural(&eta)
}
fn sufficient_statistics(&self, value: f64) -> Vec<f64>;
fn log_partition(&self, natural_params: &[f64]) -> Result<f64>;
fn expected_sufficient_statistics(&self) -> Vec<f64>;
fn entropy(&self) -> Result<f64> {
let eta = self.natural_params();
let a = self.log_partition(&eta)?;
let ess = self.expected_sufficient_statistics();
debug_assert_eq!(eta.len(), ess.len());
let dot: f64 = eta.iter().zip(ess.iter()).map(|(e, s)| e * s).sum();
Ok(a - dot)
}
}