use scirs2_core::ndarray::{Array, ArrayBase, Data, Dimension, ScalarOperand};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::Debug;
use crate::error::Result;
use crate::regularizers::Regularizer;
#[derive(Debug, Clone, Copy)]
pub enum EntropyRegularizerType {
MaximizeEntropy,
MinimizeEntropy,
}
#[derive(Debug, Clone, Copy)]
pub struct EntropyRegularization<A: Float + FromPrimitive + Debug> {
pub lambda: A,
pub epsilon: A,
pub reg_type: EntropyRegularizerType,
}
impl<A: Float + FromPrimitive + Debug + Send + Sync> EntropyRegularization<A> {
pub fn new(lambda: A, regtype: EntropyRegularizerType) -> Self {
let epsilon = A::from_f64(1e-8).expect("unwrap failed");
Self {
lambda,
epsilon,
reg_type: regtype,
}
}
pub fn new_with_epsilon(lambda: A, epsilon: A, regtype: EntropyRegularizerType) -> Self {
Self {
lambda,
epsilon,
reg_type: regtype,
}
}
pub fn calculate_entropy<S, D>(&self, probs: &ArrayBase<S, D>) -> A
where
S: Data<Elem = A>,
D: Dimension,
{
let safe_probs = probs.mapv(|p| {
if p < self.epsilon {
self.epsilon
} else if p > (A::one() - self.epsilon) {
A::one() - self.epsilon
} else {
p
}
});
let neg_entropy = safe_probs.mapv(|p| p * p.ln()).sum();
-neg_entropy
}
fn entropy_gradient<S, D>(&self, probs: &ArrayBase<S, D>) -> Array<A, D>
where
S: Data<Elem = A>,
D: Dimension,
{
let safe_probs = probs.mapv(|p| {
if p < self.epsilon {
self.epsilon
} else if p > (A::one() - self.epsilon) {
A::one() - self.epsilon
} else {
p
}
});
let gradient = safe_probs.mapv(|p| -(A::one() + p.ln()));
match self.reg_type {
EntropyRegularizerType::MaximizeEntropy => gradient,
EntropyRegularizerType::MinimizeEntropy => gradient.mapv(|g| -g),
}
}
}
impl<A, D> Regularizer<A, D> for EntropyRegularization<A>
where
A: Float + ScalarOperand + Debug + FromPrimitive + Send + Sync,
D: Dimension,
{
fn apply(&self, params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
let entropy = self.calculate_entropy(params);
let entropy_grads = self.entropy_gradient(params);
gradients.zip_mut_with(&entropy_grads, |g, &e| *g = *g + self.lambda * e);
let penalty = match self.reg_type {
EntropyRegularizerType::MaximizeEntropy => -self.lambda * entropy,
EntropyRegularizerType::MinimizeEntropy => self.lambda * entropy,
};
Ok(penalty)
}
fn penalty(&self, params: &Array<A, D>) -> Result<A> {
let entropy = self.calculate_entropy(params);
let penalty = match self.reg_type {
EntropyRegularizerType::MaximizeEntropy => -self.lambda * entropy,
EntropyRegularizerType::MinimizeEntropy => self.lambda * entropy,
};
Ok(penalty)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::Array1;
#[test]
fn test_entropy_regularization_creation() {
let er = EntropyRegularization::new(0.1f64, EntropyRegularizerType::MaximizeEntropy);
assert_eq!(er.lambda, 0.1);
assert_eq!(er.epsilon, 1e-8);
match er.reg_type {
EntropyRegularizerType::MaximizeEntropy => (),
_ => panic!("Wrong regularizer type"),
}
let er = EntropyRegularization::new_with_epsilon(
0.2f64,
1e-10,
EntropyRegularizerType::MinimizeEntropy,
);
assert_eq!(er.lambda, 0.2);
assert_eq!(er.epsilon, 1e-10);
match er.reg_type {
EntropyRegularizerType::MinimizeEntropy => (),
_ => panic!("Wrong regularizer type"),
}
}
#[test]
fn test_calculate_entropy() {
let uniform = Array1::from_vec(vec![0.25f64, 0.25, 0.25, 0.25]);
let er = EntropyRegularization::new(1.0f64, EntropyRegularizerType::MaximizeEntropy);
let entropy = er.calculate_entropy(&uniform);
let expected = (4.0f64).ln();
assert_abs_diff_eq!(entropy, expected, epsilon = 1e-6);
let peaked = Array1::from_vec(vec![0.01f64, 0.01, 0.97, 0.01]);
let entropy = er.calculate_entropy(&peaked);
assert!(entropy < expected); }
#[test]
fn test_entropy_gradient() {
let er = EntropyRegularization::new(1.0f64, EntropyRegularizerType::MaximizeEntropy);
let uniform = Array1::from_vec(vec![0.25f64, 0.25, 0.25, 0.25]);
let grads = er.entropy_gradient(&uniform);
let expected = -(1.0 + 0.25f64.ln());
for &g in grads.iter() {
assert_abs_diff_eq!(g, expected, epsilon = 1e-6);
}
let peaked = Array1::from_vec(vec![0.1f64, 0.1, 0.7, 0.1]);
let grads = er.entropy_gradient(&peaked);
assert!(grads[2].abs() < grads[0].abs());
}
#[test]
fn test_maximize_entropy_penalty() {
let er = EntropyRegularization::new(1.0f64, EntropyRegularizerType::MaximizeEntropy);
let uniform = Array1::from_vec(vec![0.25f64, 0.25, 0.25, 0.25]);
let penalty = er.penalty(&uniform).expect("unwrap failed");
let peaked = Array1::from_vec(vec![0.01f64, 0.01, 0.97, 0.01]);
let peaked_penalty = er.penalty(&peaked).expect("unwrap failed");
assert!(peaked_penalty > penalty);
}
#[test]
fn test_minimize_entropy_penalty() {
let er = EntropyRegularization::new(1.0f64, EntropyRegularizerType::MinimizeEntropy);
let uniform = Array1::from_vec(vec![0.25f64, 0.25, 0.25, 0.25]);
let penalty = er.penalty(&uniform).expect("unwrap failed");
let peaked = Array1::from_vec(vec![0.01f64, 0.01, 0.97, 0.01]);
let peaked_penalty = er.penalty(&peaked).expect("unwrap failed");
assert!(penalty > peaked_penalty);
}
#[test]
fn test_apply_gradients() {
let lambda = 0.5f64;
let er = EntropyRegularization::new(lambda, EntropyRegularizerType::MaximizeEntropy);
let probs = Array1::from_vec(vec![0.25f64, 0.25, 0.25, 0.25]);
let mut gradients = Array1::zeros(4);
let penalty = er.apply(&probs, &mut gradients).expect("unwrap failed");
assert!(gradients.iter().all(|&g| g != 0.0));
let first = gradients[0];
assert!(gradients.iter().all(|&g| (g - first).abs() < 1e-6));
let expected_grad = -lambda * (1.0 + 0.25f64.ln());
assert_abs_diff_eq!(gradients[0], expected_grad, epsilon = 1e-6);
let entropy = (4.0f64).ln(); let expected_penalty = -lambda * entropy; assert_abs_diff_eq!(penalty, expected_penalty, epsilon = 1e-6);
}
#[test]
fn test_regularizer_trait() {
let er = EntropyRegularization::new(0.1f64, EntropyRegularizerType::MaximizeEntropy);
let probs = Array1::from_vec(vec![0.25f64, 0.25, 0.25, 0.25]);
let mut gradients = Array1::zeros(4);
let penalty1 = er.apply(&probs, &mut gradients).expect("unwrap failed");
let penalty2 = er.penalty(&probs).expect("unwrap failed");
assert_abs_diff_eq!(penalty1, penalty2, epsilon = 1e-10);
}
}