optirs_core/regularizers/
activity.rs

1use scirs2_core::ndarray::{Array, ArrayBase, Data, Dimension, ScalarOperand};
2use scirs2_core::numeric::{Float, FromPrimitive};
3use std::fmt::Debug;
4
5use crate::error::Result;
6use crate::regularizers::Regularizer;
7
8/// Different norms for Activity regularization
9#[derive(Debug, Clone, Copy, PartialEq)]
10pub enum ActivityNorm {
11    /// L1 norm (sum of absolute values)
12    L1,
13    /// L2 norm (square root of sum of squares)
14    L2,
15    /// Squared L2 norm (sum of squares)
16    L2Squared,
17}
18
19/// Activity regularization
20///
21/// Activity regularization penalizes high activation values in neural networks.
22/// It encourages sparse activations by adding a penalty based on the magnitude
23/// of activation values.
24///
25/// # Parameters
26///
27/// * `lambda`: Regularization strength parameter
28/// * `norm`: Type of norm to use for measuring activation magnitudes (L1, L2, or L2Squared)
29///
30/// # References
31///
32/// * Nowlan, S. J., & Hinton, G. E. (1992). Simplifying neural networks by soft
33///   weight-sharing. Neural Computation, 4(4), 473-493.
34///
35#[derive(Debug, Clone, Copy)]
36pub struct ActivityRegularization<A: Float + FromPrimitive + Debug> {
37    /// Regularization strength
38    pub lambda: A,
39    /// Norm to use for activity regularization
40    pub norm: ActivityNorm,
41}
42
43impl<A: Float + FromPrimitive + Debug + Send + Sync> ActivityRegularization<A> {
44    /// Create a new activity regularizer with L1 norm
45    ///
46    /// # Arguments
47    ///
48    /// * `lambda` - Regularization strength parameter
49    ///
50    /// # Returns
51    ///
52    /// A new activity regularizer with L1 norm
53    pub fn l1(lambda: A) -> Self {
54        Self {
55            lambda,
56            norm: ActivityNorm::L1,
57        }
58    }
59
60    /// Create a new activity regularizer with L2 norm
61    ///
62    /// # Arguments
63    ///
64    /// * `lambda` - Regularization strength parameter
65    ///
66    /// # Returns
67    ///
68    /// A new activity regularizer with L2 norm
69    pub fn l2(lambda: A) -> Self {
70        Self {
71            lambda,
72            norm: ActivityNorm::L2,
73        }
74    }
75
76    /// Create a new activity regularizer with squared L2 norm
77    ///
78    /// # Arguments
79    ///
80    /// * `lambda` - Regularization strength parameter
81    ///
82    /// # Returns
83    ///
84    /// A new activity regularizer with squared L2 norm
85    pub fn l2_squared(lambda: A) -> Self {
86        Self {
87            lambda,
88            norm: ActivityNorm::L2Squared,
89        }
90    }
91
92    /// Create a new activity regularizer with custom norm
93    ///
94    /// # Arguments
95    ///
96    /// * `lambda` - Regularization strength parameter
97    /// * `norm` - Norm to use (L1, L2, or L2Squared)
98    ///
99    /// # Returns
100    ///
101    /// A new activity regularizer with specified norm
102    pub fn new(lambda: A, norm: ActivityNorm) -> Self {
103        Self { lambda, norm }
104    }
105
106    /// Calculate the activity penalty
107    ///
108    /// # Arguments
109    ///
110    /// * `activations` - The activations to regularize
111    ///
112    /// # Returns
113    ///
114    /// The regularization penalty value
115    fn calculate_penalty<S, D>(&self, activations: &ArrayBase<S, D>) -> A
116    where
117        S: Data<Elem = A>,
118        D: Dimension,
119    {
120        match self.norm {
121            ActivityNorm::L1 => {
122                // L1 norm: sum of absolute values
123                let sum_abs = activations.mapv(|x| x.abs()).sum();
124                self.lambda * sum_abs
125            }
126            ActivityNorm::L2 => {
127                // L2 norm: sqrt of sum of squares
128                let sum_squared = activations.mapv(|x| x * x).sum();
129                self.lambda * sum_squared.sqrt()
130            }
131            ActivityNorm::L2Squared => {
132                // Squared L2 norm: sum of squares
133                let sum_squared = activations.mapv(|x| x * x).sum();
134                self.lambda * sum_squared
135            }
136        }
137    }
138
139    /// Calculate gradients for activity regularization
140    ///
141    /// # Arguments
142    ///
143    /// * `activations` - The activations to regularize
144    ///
145    /// # Returns
146    ///
147    /// The gradient array with respect to the activations
148    fn calculate_gradients<S, D>(&self, activations: &ArrayBase<S, D>) -> Array<A, D>
149    where
150        S: Data<Elem = A>,
151        D: Dimension,
152    {
153        match self.norm {
154            ActivityNorm::L1 => {
155                // Derivative of L1: sign of the value
156                activations.mapv(|x| {
157                    if x > A::zero() {
158                        self.lambda
159                    } else if x < A::zero() {
160                        -self.lambda
161                    } else {
162                        A::zero()
163                    }
164                })
165            }
166            ActivityNorm::L2 => {
167                // Derivative of L2: x / sqrt(sum(x^2))
168                let sum_squared = activations.mapv(|x| x * x).sum();
169
170                // Handle the case where sum_squared is zero to avoid division by zero
171                if sum_squared <= A::epsilon() {
172                    return Array::zeros(activations.raw_dim());
173                }
174
175                let norm = sum_squared.sqrt();
176                activations.mapv(|x| self.lambda * x / norm)
177            }
178            ActivityNorm::L2Squared => {
179                // Derivative of squared L2: 2 * x
180                let two = A::one() + A::one();
181                activations.mapv(|x| self.lambda * two * x)
182            }
183        }
184    }
185}
186
187impl<A, D> Regularizer<A, D> for ActivityRegularization<A>
188where
189    A: Float + ScalarOperand + Debug + FromPrimitive + Send + Sync,
190    D: Dimension,
191{
192    fn apply(&self, params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
193        // Calculate penalty
194        let penalty = self.calculate_penalty(params);
195
196        // Calculate and apply gradients
197        let activity_grads = self.calculate_gradients(params);
198        gradients.zip_mut_with(&activity_grads, |g, &a| *g = *g + a);
199
200        Ok(penalty)
201    }
202
203    fn penalty(&self, params: &Array<A, D>) -> Result<A> {
204        Ok(self.calculate_penalty(params))
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211    use approx::assert_abs_diff_eq;
212    use scirs2_core::ndarray::array;
213    use scirs2_core::ndarray::{Array1, Array2};
214
215    #[test]
216    fn test_activity_regularization_creation() {
217        let ar = ActivityRegularization::l1(0.1f64);
218        assert_eq!(ar.lambda, 0.1);
219        assert_eq!(ar.norm, ActivityNorm::L1);
220
221        let ar = ActivityRegularization::l2(0.2f64);
222        assert_eq!(ar.lambda, 0.2);
223        assert_eq!(ar.norm, ActivityNorm::L2);
224
225        let ar = ActivityRegularization::l2_squared(0.3f64);
226        assert_eq!(ar.lambda, 0.3);
227        assert_eq!(ar.norm, ActivityNorm::L2Squared);
228
229        let ar = ActivityRegularization::new(0.4f64, ActivityNorm::L1);
230        assert_eq!(ar.lambda, 0.4);
231        assert_eq!(ar.norm, ActivityNorm::L1);
232    }
233
234    #[test]
235    fn test_l1_penalty() {
236        let lambda = 0.1f64;
237        let ar = ActivityRegularization::l1(lambda);
238
239        let activations = Array1::from_vec(vec![1.0f64, -2.0, 3.0]);
240        let penalty = ar.penalty(&activations).unwrap();
241
242        // L1 penalty = lambda * sum(|x|) = 0.1 * (1 + 2 + 3) = 0.1 * 6 = 0.6
243        assert_abs_diff_eq!(penalty, lambda * 6.0, epsilon = 1e-10);
244    }
245
246    #[test]
247    fn test_l2_penalty() {
248        let lambda = 0.1f64;
249        let ar = ActivityRegularization::l2(lambda);
250
251        let activations = Array1::from_vec(vec![3.0f64, 4.0]);
252        let penalty = ar.penalty(&activations).unwrap();
253
254        // L2 penalty = lambda * sqrt(sum(x^2)) = 0.1 * sqrt(9 + 16) = 0.1 * 5 = 0.5
255        assert_abs_diff_eq!(penalty, lambda * 5.0, epsilon = 1e-10);
256    }
257
258    #[test]
259    fn test_l2_squared_penalty() {
260        let lambda = 0.1f64;
261        let ar = ActivityRegularization::l2_squared(lambda);
262
263        let activations = Array1::from_vec(vec![1.0f64, 2.0, 3.0]);
264        let penalty = ar.penalty(&activations).unwrap();
265
266        // L2 squared penalty = lambda * sum(x^2) = 0.1 * (1 + 4 + 9) = 0.1 * 14 = 1.4
267        assert_abs_diff_eq!(penalty, lambda * 14.0, epsilon = 1e-10);
268    }
269
270    #[test]
271    fn test_l1_gradients() {
272        let lambda = 0.1f64;
273        let ar = ActivityRegularization::l1(lambda);
274
275        let activations = Array1::from_vec(vec![1.0f64, -2.0, 0.0]);
276        let mut gradients = Array1::zeros(3);
277
278        let penalty = ar.apply(&activations, &mut gradients).unwrap();
279
280        // L1 gradients = lambda * sign(x)
281        assert_abs_diff_eq!(gradients[0], lambda, epsilon = 1e-10); // sign(1) = 1
282        assert_abs_diff_eq!(gradients[1], -lambda, epsilon = 1e-10); // sign(-2) = -1
283        assert_abs_diff_eq!(gradients[2], 0.0, epsilon = 1e-10); // sign(0) = 0
284
285        // L1 penalty = lambda * sum(|x|) = 0.1 * (1 + 2 + 0) = 0.1 * 3 = 0.3
286        assert_abs_diff_eq!(penalty, lambda * 3.0, epsilon = 1e-10);
287    }
288
289    #[test]
290    fn test_l2_gradients() {
291        let lambda = 0.1f64;
292        let ar = ActivityRegularization::l2(lambda);
293
294        let activations = Array1::from_vec(vec![3.0f64, 4.0]);
295        let mut gradients = Array1::zeros(2);
296
297        let penalty = ar.apply(&activations, &mut gradients).unwrap();
298
299        // Norm = sqrt(9 + 16) = 5
300        // L2 gradients = lambda * x / norm
301        assert_abs_diff_eq!(gradients[0], lambda * 3.0 / 5.0, epsilon = 1e-10);
302        assert_abs_diff_eq!(gradients[1], lambda * 4.0 / 5.0, epsilon = 1e-10);
303
304        // L2 penalty = lambda * sqrt(sum(x^2)) = 0.1 * sqrt(9 + 16) = 0.1 * 5 = 0.5
305        assert_abs_diff_eq!(penalty, lambda * 5.0, epsilon = 1e-10);
306    }
307
308    #[test]
309    fn test_l2_gradients_zero_activations() {
310        let lambda = 0.1f64;
311        let ar = ActivityRegularization::l2(lambda);
312
313        let activations = Array1::from_vec(vec![0.0f64, 0.0]);
314        let mut gradients = Array1::zeros(2);
315
316        let penalty = ar.apply(&activations, &mut gradients).unwrap();
317
318        // When all activations are zero, gradients should be zero to avoid division by zero
319        assert_abs_diff_eq!(gradients[0], 0.0, epsilon = 1e-10);
320        assert_abs_diff_eq!(gradients[1], 0.0, epsilon = 1e-10);
321
322        // L2 penalty = lambda * sqrt(sum(x^2)) = 0.1 * sqrt(0) = 0
323        assert_abs_diff_eq!(penalty, 0.0, epsilon = 1e-10);
324    }
325
326    #[test]
327    fn test_l2_squared_gradients() {
328        let lambda = 0.1f64;
329        let ar = ActivityRegularization::l2_squared(lambda);
330
331        let activations = Array1::from_vec(vec![2.0f64, 3.0]);
332        let mut gradients = Array1::zeros(2);
333
334        let penalty = ar.apply(&activations, &mut gradients).unwrap();
335
336        // L2 squared gradients = lambda * 2 * x
337        assert_abs_diff_eq!(gradients[0], lambda * 2.0 * 2.0, epsilon = 1e-10);
338        assert_abs_diff_eq!(gradients[1], lambda * 2.0 * 3.0, epsilon = 1e-10);
339
340        // L2 squared penalty = lambda * sum(x^2) = 0.1 * (4 + 9) = 0.1 * 13 = 1.3
341        assert_abs_diff_eq!(penalty, lambda * 13.0, epsilon = 1e-10);
342    }
343
344    #[test]
345    fn test_2d_activations() {
346        let lambda = 0.1f64;
347        let ar = ActivityRegularization::l1(lambda);
348
349        let activations = Array2::from_shape_vec((2, 2), vec![1.0f64, 2.0, -3.0, 4.0]).unwrap();
350        let penalty = ar.penalty(&activations).unwrap();
351
352        // L1 penalty = lambda * sum(|x|) = 0.1 * (1 + 2 + 3 + 4) = 0.1 * 10 = 1.0
353        assert_abs_diff_eq!(penalty, lambda * 10.0, epsilon = 1e-10);
354    }
355
356    #[test]
357    fn test_regularizer_trait() {
358        let lambda = 0.1f64;
359        let ar = ActivityRegularization::l1(lambda);
360
361        let activations = array![1.0f64, 2.0, 3.0];
362        let mut gradients = Array1::zeros(3);
363
364        // Both penalty() and apply() should return the same penalty value
365        let penalty1 = ar.penalty(&activations).unwrap();
366        let penalty2 = ar.apply(&activations, &mut gradients).unwrap();
367
368        assert_abs_diff_eq!(penalty1, penalty2, epsilon = 1e-10);
369
370        // Check that gradients have been modified correctly
371        assert_abs_diff_eq!(gradients[0], lambda, epsilon = 1e-10);
372        assert_abs_diff_eq!(gradients[1], lambda, epsilon = 1e-10);
373        assert_abs_diff_eq!(gradients[2], lambda, epsilon = 1e-10);
374    }
375}