optirs_core/regularizers/
entropy.rs

1use scirs2_core::ndarray::{Array, ArrayBase, Data, Dimension, ScalarOperand};
2use scirs2_core::numeric::{Float, FromPrimitive};
3use std::fmt::Debug;
4
5use crate::error::Result;
6use crate::regularizers::Regularizer;
7
8/// Entropy regularization
9///
10/// Entropy regularization encourages a model to produce more confident outputs (lower entropy),
11/// or more uncertain outputs (higher entropy), depending on the settings. This is often used
12/// in reinforcement learning, semi-supervised learning, and some classifier applications.
13///
14/// # Types
15///
16/// * `MaximizeEntropy`: Encourages high entropy (uniform, uncertain predictions)
17/// * `MinimizeEntropy`: Encourages low entropy (confident, peaked predictions)
18///
19/// # Parameters
20///
21/// * `lambda`: Regularization strength coefficient, controls the amount of regularization applied
22/// * `epsilon`: Small value for numerical stability (prevents log(0))
23///
24#[derive(Debug, Clone, Copy)]
25pub enum EntropyRegularizerType {
26    /// Maximize entropy (encourages uniform distributions)
27    MaximizeEntropy,
28    /// Minimize entropy (encourages confident predictions)
29    MinimizeEntropy,
30}
31
32/// Entropy regularization for probability distributions
33///
34/// This regularizer can either encourage high entropy (more uniform distributions) or
35/// low entropy (more peaked distributions) depending on the selected regularizer type.
36///
37/// It's commonly used in reinforcement learning algorithms, semi-supervised learning,
38/// and some classification tasks where controlling the certainty of outputs is desired.
39#[derive(Debug, Clone, Copy)]
40pub struct EntropyRegularization<A: Float + FromPrimitive + Debug> {
41    /// Regularization strength
42    pub lambda: A,
43    /// Small value for numerical stability
44    pub epsilon: A,
45    /// Type of entropy regularization
46    pub reg_type: EntropyRegularizerType,
47}
48
49impl<A: Float + FromPrimitive + Debug + Send + Sync> EntropyRegularization<A> {
50    /// Create a new entropy regularization
51    ///
52    /// # Arguments
53    ///
54    /// * `lambda` - Regularization strength coefficient
55    /// * `reg_type` - Type of entropy regularization (maximize or minimize)
56    ///
57    /// # Returns
58    ///
59    /// An entropy regularization with default epsilon
60    pub fn new(lambda: A, regtype: EntropyRegularizerType) -> Self {
61        let epsilon = A::from_f64(1e-8).unwrap();
62        Self {
63            lambda,
64            epsilon,
65            reg_type: regtype,
66        }
67    }
68
69    /// Create a new entropy regularization with custom epsilon
70    ///
71    /// # Arguments
72    ///
73    /// * `lambda` - Regularization strength coefficient
74    /// * `epsilon` - Small value for numerical stability
75    /// * `reg_type` - Type of entropy regularization (maximize or minimize)
76    ///
77    /// # Returns
78    ///
79    /// An entropy regularization with custom epsilon
80    pub fn new_with_epsilon(lambda: A, epsilon: A, regtype: EntropyRegularizerType) -> Self {
81        Self {
82            lambda,
83            epsilon,
84            reg_type: regtype,
85        }
86    }
87
88    /// Calculate the entropy of a probability distribution
89    ///
90    /// # Arguments
91    ///
92    /// * `probs` - Probability distribution (should sum to 1 along the appropriate axis)
93    ///
94    /// # Returns
95    ///
96    /// The entropy value
97    pub fn calculate_entropy<S, D>(&self, probs: &ArrayBase<S, D>) -> A
98    where
99        S: Data<Elem = A>,
100        D: Dimension,
101    {
102        // Clip probabilities to avoid log(0)
103        let safe_probs = probs.mapv(|p| {
104            if p < self.epsilon {
105                self.epsilon
106            } else if p > (A::one() - self.epsilon) {
107                A::one() - self.epsilon
108            } else {
109                p
110            }
111        });
112
113        // Calculate entropy: -sum(p * log(p))
114        let neg_entropy = safe_probs.mapv(|p| p * p.ln()).sum();
115        -neg_entropy
116    }
117
118    /// Calculate gradient of entropy with respect to input probabilities
119    ///
120    /// # Arguments
121    ///
122    /// * `probs` - Probability distribution
123    ///
124    /// # Returns
125    ///
126    /// The gradient of entropy with respect to probabilities
127    fn entropy_gradient<S, D>(&self, probs: &ArrayBase<S, D>) -> Array<A, D>
128    where
129        S: Data<Elem = A>,
130        D: Dimension,
131    {
132        // Clip probabilities to avoid log(0)
133        let safe_probs = probs.mapv(|p| {
134            if p < self.epsilon {
135                self.epsilon
136            } else if p > (A::one() - self.epsilon) {
137                A::one() - self.epsilon
138            } else {
139                p
140            }
141        });
142
143        // Gradient of entropy: -(1 + log(p))
144        let gradient = safe_probs.mapv(|p| -(A::one() + p.ln()));
145
146        // For minimizing entropy, we negate the gradient
147        match self.reg_type {
148            EntropyRegularizerType::MaximizeEntropy => gradient,
149            EntropyRegularizerType::MinimizeEntropy => gradient.mapv(|g| -g),
150        }
151    }
152}
153
154impl<A, D> Regularizer<A, D> for EntropyRegularization<A>
155where
156    A: Float + ScalarOperand + Debug + FromPrimitive + Send + Sync,
157    D: Dimension,
158{
159    fn apply(&self, params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
160        // Calculate entropy penalty
161        let entropy = self.calculate_entropy(params);
162
163        // Calculate entropy gradients
164        let entropy_grads = self.entropy_gradient(params);
165
166        // Scale gradients by lambda and add to input gradients
167        gradients.zip_mut_with(&entropy_grads, |g, &e| *g = *g + self.lambda * e);
168
169        // Return the regularization term to be added to the loss:
170        // For maximizing entropy, we return -lambda * entropy (to minimize -entropy)
171        // For minimizing entropy, we return lambda * entropy (to minimize entropy)
172        let penalty = match self.reg_type {
173            EntropyRegularizerType::MaximizeEntropy => -self.lambda * entropy,
174            EntropyRegularizerType::MinimizeEntropy => self.lambda * entropy,
175        };
176
177        Ok(penalty)
178    }
179
180    fn penalty(&self, params: &Array<A, D>) -> Result<A> {
181        // Calculate entropy penalty
182        let entropy = self.calculate_entropy(params);
183
184        // For maximizing entropy, we return -lambda * entropy (to minimize -entropy)
185        // For minimizing entropy, we return lambda * entropy (to minimize entropy)
186        let penalty = match self.reg_type {
187            EntropyRegularizerType::MaximizeEntropy => -self.lambda * entropy,
188            EntropyRegularizerType::MinimizeEntropy => self.lambda * entropy,
189        };
190
191        Ok(penalty)
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198    use approx::assert_abs_diff_eq;
199    use scirs2_core::ndarray::Array1;
200
201    #[test]
202    fn test_entropy_regularization_creation() {
203        let er = EntropyRegularization::new(0.1f64, EntropyRegularizerType::MaximizeEntropy);
204        assert_eq!(er.lambda, 0.1);
205        assert_eq!(er.epsilon, 1e-8);
206        match er.reg_type {
207            EntropyRegularizerType::MaximizeEntropy => (),
208            _ => panic!("Wrong regularizer type"),
209        }
210
211        let er = EntropyRegularization::new_with_epsilon(
212            0.2f64,
213            1e-10,
214            EntropyRegularizerType::MinimizeEntropy,
215        );
216        assert_eq!(er.lambda, 0.2);
217        assert_eq!(er.epsilon, 1e-10);
218        match er.reg_type {
219            EntropyRegularizerType::MinimizeEntropy => (),
220            _ => panic!("Wrong regularizer type"),
221        }
222    }
223
224    #[test]
225    fn test_calculate_entropy() {
226        // Uniform distribution (maximum entropy)
227        let uniform = Array1::from_vec(vec![0.25f64, 0.25, 0.25, 0.25]);
228        let er = EntropyRegularization::new(1.0f64, EntropyRegularizerType::MaximizeEntropy);
229        let entropy = er.calculate_entropy(&uniform);
230
231        // Entropy of uniform distribution should be ln(n)
232        let expected = (4.0f64).ln();
233        assert_abs_diff_eq!(entropy, expected, epsilon = 1e-6);
234
235        // Peaked distribution (low entropy)
236        let peaked = Array1::from_vec(vec![0.01f64, 0.01, 0.97, 0.01]);
237        let entropy = er.calculate_entropy(&peaked);
238        assert!(entropy < expected); // Should be less than uniform entropy
239    }
240
241    #[test]
242    fn test_entropy_gradient() {
243        let er = EntropyRegularization::new(1.0f64, EntropyRegularizerType::MaximizeEntropy);
244
245        // For uniform distribution, gradients should be approximately equal
246        let uniform = Array1::from_vec(vec![0.25f64, 0.25, 0.25, 0.25]);
247        let grads = er.entropy_gradient(&uniform);
248
249        // Expected gradient: -(1 + ln(0.25))
250        let expected = -(1.0 + 0.25f64.ln());
251        for &g in grads.iter() {
252            assert_abs_diff_eq!(g, expected, epsilon = 1e-6);
253        }
254
255        // For peaked distribution, gradients should be different for different probabilities
256        let peaked = Array1::from_vec(vec![0.1f64, 0.1, 0.7, 0.1]);
257        let grads = er.entropy_gradient(&peaked);
258
259        // The gradient for larger probability should have a smaller absolute value
260        // because ln(0.7) is greater (less negative) than ln(0.1)
261        // So -(1 + ln(0.7)) has smaller magnitude than -(1 + ln(0.1))
262        assert!(grads[2].abs() < grads[0].abs());
263    }
264
265    #[test]
266    fn test_maximize_entropy_penalty() {
267        // For maximizing entropy, we want to minimize -entropy
268        let er = EntropyRegularization::new(1.0f64, EntropyRegularizerType::MaximizeEntropy);
269
270        // Uniform distribution (high entropy)
271        let uniform = Array1::from_vec(vec![0.25f64, 0.25, 0.25, 0.25]);
272        let penalty = er.penalty(&uniform).unwrap();
273
274        // Peaked distribution (low entropy)
275        let peaked = Array1::from_vec(vec![0.01f64, 0.01, 0.97, 0.01]);
276        let peaked_penalty = er.penalty(&peaked).unwrap();
277
278        // The penalty for peaked should be greater than for uniform
279        // because we're trying to maximize entropy
280        assert!(peaked_penalty > penalty);
281    }
282
283    #[test]
284    fn test_minimize_entropy_penalty() {
285        // For minimizing entropy, we want to minimize entropy
286        let er = EntropyRegularization::new(1.0f64, EntropyRegularizerType::MinimizeEntropy);
287
288        // Uniform distribution (high entropy)
289        let uniform = Array1::from_vec(vec![0.25f64, 0.25, 0.25, 0.25]);
290        let penalty = er.penalty(&uniform).unwrap();
291
292        // Peaked distribution (low entropy)
293        let peaked = Array1::from_vec(vec![0.01f64, 0.01, 0.97, 0.01]);
294        let peaked_penalty = er.penalty(&peaked).unwrap();
295
296        // The penalty for uniform should be greater than for peaked
297        // because we're trying to minimize entropy
298        assert!(penalty > peaked_penalty);
299    }
300
301    #[test]
302    fn test_apply_gradients() {
303        let lambda = 0.5f64;
304        let er = EntropyRegularization::new(lambda, EntropyRegularizerType::MaximizeEntropy);
305
306        let probs = Array1::from_vec(vec![0.25f64, 0.25, 0.25, 0.25]);
307        let mut gradients = Array1::zeros(4);
308
309        let penalty = er.apply(&probs, &mut gradients).unwrap();
310
311        // Check that gradients have been modified
312        assert!(gradients.iter().all(|&g| g != 0.0));
313
314        // For uniform distribution, all gradients should be equal
315        let first = gradients[0];
316        assert!(gradients.iter().all(|&g| (g - first).abs() < 1e-6));
317
318        // Expected gradient: -lambda * (1 + ln(0.25))
319        let expected_grad = -lambda * (1.0 + 0.25f64.ln());
320        assert_abs_diff_eq!(gradients[0], expected_grad, epsilon = 1e-6);
321
322        // Check penalty matches expected value
323        let entropy = (4.0f64).ln(); // Entropy of uniform distribution
324        let expected_penalty = -lambda * entropy; // For maximizing entropy
325        assert_abs_diff_eq!(penalty, expected_penalty, epsilon = 1e-6);
326    }
327
328    #[test]
329    fn test_regularizer_trait() {
330        // Test that EntropyRegularization implements Regularizer trait correctly
331        let er = EntropyRegularization::new(0.1f64, EntropyRegularizerType::MaximizeEntropy);
332
333        let probs = Array1::from_vec(vec![0.25f64, 0.25, 0.25, 0.25]);
334        let mut gradients = Array1::zeros(4);
335
336        // Both methods should return the same penalty for the same input
337        let penalty1 = er.apply(&probs, &mut gradients).unwrap();
338        let penalty2 = er.penalty(&probs).unwrap();
339
340        assert_abs_diff_eq!(penalty1, penalty2, epsilon = 1e-10);
341    }
342}