optirs_core/regularizers/
entropy.rs1use scirs2_core::ndarray::{Array, ArrayBase, Data, Dimension, ScalarOperand};
2use scirs2_core::numeric::{Float, FromPrimitive};
3use std::fmt::Debug;
4
5use crate::error::Result;
6use crate::regularizers::Regularizer;
7
8#[derive(Debug, Clone, Copy)]
25pub enum EntropyRegularizerType {
26 MaximizeEntropy,
28 MinimizeEntropy,
30}
31
32#[derive(Debug, Clone, Copy)]
40pub struct EntropyRegularization<A: Float + FromPrimitive + Debug> {
41 pub lambda: A,
43 pub epsilon: A,
45 pub reg_type: EntropyRegularizerType,
47}
48
49impl<A: Float + FromPrimitive + Debug + Send + Sync> EntropyRegularization<A> {
50 pub fn new(lambda: A, regtype: EntropyRegularizerType) -> Self {
61 let epsilon = A::from_f64(1e-8).unwrap();
62 Self {
63 lambda,
64 epsilon,
65 reg_type: regtype,
66 }
67 }
68
69 pub fn new_with_epsilon(lambda: A, epsilon: A, regtype: EntropyRegularizerType) -> Self {
81 Self {
82 lambda,
83 epsilon,
84 reg_type: regtype,
85 }
86 }
87
88 pub fn calculate_entropy<S, D>(&self, probs: &ArrayBase<S, D>) -> A
98 where
99 S: Data<Elem = A>,
100 D: Dimension,
101 {
102 let safe_probs = probs.mapv(|p| {
104 if p < self.epsilon {
105 self.epsilon
106 } else if p > (A::one() - self.epsilon) {
107 A::one() - self.epsilon
108 } else {
109 p
110 }
111 });
112
113 let neg_entropy = safe_probs.mapv(|p| p * p.ln()).sum();
115 -neg_entropy
116 }
117
118 fn entropy_gradient<S, D>(&self, probs: &ArrayBase<S, D>) -> Array<A, D>
128 where
129 S: Data<Elem = A>,
130 D: Dimension,
131 {
132 let safe_probs = probs.mapv(|p| {
134 if p < self.epsilon {
135 self.epsilon
136 } else if p > (A::one() - self.epsilon) {
137 A::one() - self.epsilon
138 } else {
139 p
140 }
141 });
142
143 let gradient = safe_probs.mapv(|p| -(A::one() + p.ln()));
145
146 match self.reg_type {
148 EntropyRegularizerType::MaximizeEntropy => gradient,
149 EntropyRegularizerType::MinimizeEntropy => gradient.mapv(|g| -g),
150 }
151 }
152}
153
154impl<A, D> Regularizer<A, D> for EntropyRegularization<A>
155where
156 A: Float + ScalarOperand + Debug + FromPrimitive + Send + Sync,
157 D: Dimension,
158{
159 fn apply(&self, params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
160 let entropy = self.calculate_entropy(params);
162
163 let entropy_grads = self.entropy_gradient(params);
165
166 gradients.zip_mut_with(&entropy_grads, |g, &e| *g = *g + self.lambda * e);
168
169 let penalty = match self.reg_type {
173 EntropyRegularizerType::MaximizeEntropy => -self.lambda * entropy,
174 EntropyRegularizerType::MinimizeEntropy => self.lambda * entropy,
175 };
176
177 Ok(penalty)
178 }
179
180 fn penalty(&self, params: &Array<A, D>) -> Result<A> {
181 let entropy = self.calculate_entropy(params);
183
184 let penalty = match self.reg_type {
187 EntropyRegularizerType::MaximizeEntropy => -self.lambda * entropy,
188 EntropyRegularizerType::MinimizeEntropy => self.lambda * entropy,
189 };
190
191 Ok(penalty)
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198 use approx::assert_abs_diff_eq;
199 use scirs2_core::ndarray::Array1;
200
201 #[test]
202 fn test_entropy_regularization_creation() {
203 let er = EntropyRegularization::new(0.1f64, EntropyRegularizerType::MaximizeEntropy);
204 assert_eq!(er.lambda, 0.1);
205 assert_eq!(er.epsilon, 1e-8);
206 match er.reg_type {
207 EntropyRegularizerType::MaximizeEntropy => (),
208 _ => panic!("Wrong regularizer type"),
209 }
210
211 let er = EntropyRegularization::new_with_epsilon(
212 0.2f64,
213 1e-10,
214 EntropyRegularizerType::MinimizeEntropy,
215 );
216 assert_eq!(er.lambda, 0.2);
217 assert_eq!(er.epsilon, 1e-10);
218 match er.reg_type {
219 EntropyRegularizerType::MinimizeEntropy => (),
220 _ => panic!("Wrong regularizer type"),
221 }
222 }
223
224 #[test]
225 fn test_calculate_entropy() {
226 let uniform = Array1::from_vec(vec![0.25f64, 0.25, 0.25, 0.25]);
228 let er = EntropyRegularization::new(1.0f64, EntropyRegularizerType::MaximizeEntropy);
229 let entropy = er.calculate_entropy(&uniform);
230
231 let expected = (4.0f64).ln();
233 assert_abs_diff_eq!(entropy, expected, epsilon = 1e-6);
234
235 let peaked = Array1::from_vec(vec![0.01f64, 0.01, 0.97, 0.01]);
237 let entropy = er.calculate_entropy(&peaked);
238 assert!(entropy < expected); }
240
241 #[test]
242 fn test_entropy_gradient() {
243 let er = EntropyRegularization::new(1.0f64, EntropyRegularizerType::MaximizeEntropy);
244
245 let uniform = Array1::from_vec(vec![0.25f64, 0.25, 0.25, 0.25]);
247 let grads = er.entropy_gradient(&uniform);
248
249 let expected = -(1.0 + 0.25f64.ln());
251 for &g in grads.iter() {
252 assert_abs_diff_eq!(g, expected, epsilon = 1e-6);
253 }
254
255 let peaked = Array1::from_vec(vec![0.1f64, 0.1, 0.7, 0.1]);
257 let grads = er.entropy_gradient(&peaked);
258
259 assert!(grads[2].abs() < grads[0].abs());
263 }
264
265 #[test]
266 fn test_maximize_entropy_penalty() {
267 let er = EntropyRegularization::new(1.0f64, EntropyRegularizerType::MaximizeEntropy);
269
270 let uniform = Array1::from_vec(vec![0.25f64, 0.25, 0.25, 0.25]);
272 let penalty = er.penalty(&uniform).unwrap();
273
274 let peaked = Array1::from_vec(vec![0.01f64, 0.01, 0.97, 0.01]);
276 let peaked_penalty = er.penalty(&peaked).unwrap();
277
278 assert!(peaked_penalty > penalty);
281 }
282
283 #[test]
284 fn test_minimize_entropy_penalty() {
285 let er = EntropyRegularization::new(1.0f64, EntropyRegularizerType::MinimizeEntropy);
287
288 let uniform = Array1::from_vec(vec![0.25f64, 0.25, 0.25, 0.25]);
290 let penalty = er.penalty(&uniform).unwrap();
291
292 let peaked = Array1::from_vec(vec![0.01f64, 0.01, 0.97, 0.01]);
294 let peaked_penalty = er.penalty(&peaked).unwrap();
295
296 assert!(penalty > peaked_penalty);
299 }
300
301 #[test]
302 fn test_apply_gradients() {
303 let lambda = 0.5f64;
304 let er = EntropyRegularization::new(lambda, EntropyRegularizerType::MaximizeEntropy);
305
306 let probs = Array1::from_vec(vec![0.25f64, 0.25, 0.25, 0.25]);
307 let mut gradients = Array1::zeros(4);
308
309 let penalty = er.apply(&probs, &mut gradients).unwrap();
310
311 assert!(gradients.iter().all(|&g| g != 0.0));
313
314 let first = gradients[0];
316 assert!(gradients.iter().all(|&g| (g - first).abs() < 1e-6));
317
318 let expected_grad = -lambda * (1.0 + 0.25f64.ln());
320 assert_abs_diff_eq!(gradients[0], expected_grad, epsilon = 1e-6);
321
322 let entropy = (4.0f64).ln(); let expected_penalty = -lambda * entropy; assert_abs_diff_eq!(penalty, expected_penalty, epsilon = 1e-6);
326 }
327
328 #[test]
329 fn test_regularizer_trait() {
330 let er = EntropyRegularization::new(0.1f64, EntropyRegularizerType::MaximizeEntropy);
332
333 let probs = Array1::from_vec(vec![0.25f64, 0.25, 0.25, 0.25]);
334 let mut gradients = Array1::zeros(4);
335
336 let penalty1 = er.apply(&probs, &mut gradients).unwrap();
338 let penalty2 = er.penalty(&probs).unwrap();
339
340 assert_abs_diff_eq!(penalty1, penalty2, epsilon = 1e-10);
341 }
342}