use crate::symbolic::core::Expr;
use crate::symbolic::simplify_dag::simplify;
#[must_use]
pub fn shannon_entropy(probs: &[Expr]) -> Expr {
let log2 = Expr::new_log(Expr::Constant(2.0));
let sum = probs
.iter()
.map(|p| {
let log2_p = Expr::new_div(Expr::new_log(p.clone()), log2.clone());
Expr::new_mul(p.clone(), log2_p)
})
.reduce(|acc, e| simplify(&Expr::new_add(acc, e)))
.unwrap_or(Expr::Constant(0.0));
simplify(&Expr::new_neg(sum))
}
pub fn kl_divergence(
p_dist: &[Expr],
q_dist: &[Expr],
) -> Result<Expr, String> {
if p_dist.len() != q_dist.len() {
return Err("Distributions \
must have the \
same length"
.to_string());
}
let log2 = Expr::new_log(Expr::Constant(2.0));
let sum = p_dist
.iter()
.zip(q_dist.iter())
.map(|(p, q)| {
let ratio = Expr::new_div(p.clone(), q.clone());
let log2_ratio = Expr::new_div(Expr::new_log(ratio), log2.clone());
Expr::new_mul(p.clone(), log2_ratio)
})
.reduce(|acc, e| simplify(&Expr::new_add(acc, e)))
.unwrap_or(Expr::Constant(0.0));
Ok(simplify(&sum))
}
pub fn cross_entropy(
p_dist: &[Expr],
q_dist: &[Expr],
) -> Result<Expr, String> {
if p_dist.len() != q_dist.len() {
return Err("Distributions \
must have the \
same length"
.to_string());
}
let log2 = Expr::new_log(Expr::Constant(2.0));
let sum = p_dist
.iter()
.zip(q_dist.iter())
.map(|(p, q)| {
let log2_q = Expr::new_div(Expr::new_log(q.clone()), log2.clone());
Expr::new_mul(p.clone(), log2_q)
})
.reduce(|acc, e| simplify(&Expr::new_add(acc, e)))
.unwrap_or(Expr::Constant(0.0));
Ok(simplify(&Expr::new_neg(sum)))
}
pub fn joint_entropy(joint_probs: &Expr) -> Result<Expr, String> {
if let Expr::Matrix(rows) = joint_probs {
let flat_probs: Vec<Expr> = rows.iter().flatten().cloned().collect();
Ok(shannon_entropy(&flat_probs))
} else {
Err("Input must be a matrix \
of joint probabilities."
.to_string())
}
}
pub fn conditional_entropy(joint_probs: &Expr) -> Result<Expr, String> {
if let Expr::Matrix(rows) = joint_probs {
let p_x: Vec<Expr> = rows
.iter()
.map(|row| {
row.iter()
.cloned()
.reduce(|a, b| simplify(&Expr::new_add(a, b)))
.unwrap_or(Expr::Constant(0.0))
})
.collect();
let h_x = shannon_entropy(&p_x);
let h_xy = joint_entropy(joint_probs)?;
Ok(simplify(&Expr::new_sub(h_xy, h_x)))
} else {
Err("Input must be a matrix \
of joint probabilities."
.to_string())
}
}
pub fn mutual_information(joint_probs: &Expr) -> Result<Expr, String> {
if let Expr::Matrix(rows) = joint_probs {
let p_x: Vec<Expr> = rows
.iter()
.map(|row| {
row.iter()
.cloned()
.reduce(|a, b| simplify(&Expr::new_add(a, b)))
.unwrap_or(Expr::Constant(0.0))
})
.collect();
let num_cols = rows.first().map_or(0, std::vec::Vec::len);
let mut p_y = vec![Expr::Constant(0.0); num_cols];
for row in rows {
for (j, p_ij) in row.iter().enumerate() {
p_y[j] = simplify(&Expr::new_add(p_y[j].clone(), p_ij.clone()));
}
}
let h_x = shannon_entropy(&p_x);
let h_y = shannon_entropy(&p_y);
let h_xy = joint_entropy(joint_probs)?;
Ok(simplify(&Expr::new_sub(Expr::new_add(h_x, h_y), h_xy)))
} else {
Err("Input must be a matrix \
of joint probabilities."
.to_string())
}
}
#[must_use]
pub fn gini_impurity(probs: &[Expr]) -> Expr {
let sum_of_squares = probs
.iter()
.map(|p| Expr::new_pow(p.clone(), Expr::Constant(2.0)))
.reduce(|acc, e| simplify(&Expr::new_add(acc, e)))
.unwrap_or(Expr::Constant(0.0));
simplify(&Expr::new_sub(Expr::Constant(1.0), sum_of_squares))
}