optirs_core/regularizers/
activity.rs1use scirs2_core::ndarray::{Array, ArrayBase, Data, Dimension, ScalarOperand};
2use scirs2_core::numeric::{Float, FromPrimitive};
3use std::fmt::Debug;
4
5use crate::error::Result;
6use crate::regularizers::Regularizer;
7
8#[derive(Debug, Clone, Copy, PartialEq)]
10pub enum ActivityNorm {
11 L1,
13 L2,
15 L2Squared,
17}
18
19#[derive(Debug, Clone, Copy)]
36pub struct ActivityRegularization<A: Float + FromPrimitive + Debug> {
37 pub lambda: A,
39 pub norm: ActivityNorm,
41}
42
43impl<A: Float + FromPrimitive + Debug + Send + Sync> ActivityRegularization<A> {
44 pub fn l1(lambda: A) -> Self {
54 Self {
55 lambda,
56 norm: ActivityNorm::L1,
57 }
58 }
59
60 pub fn l2(lambda: A) -> Self {
70 Self {
71 lambda,
72 norm: ActivityNorm::L2,
73 }
74 }
75
76 pub fn l2_squared(lambda: A) -> Self {
86 Self {
87 lambda,
88 norm: ActivityNorm::L2Squared,
89 }
90 }
91
92 pub fn new(lambda: A, norm: ActivityNorm) -> Self {
103 Self { lambda, norm }
104 }
105
106 fn calculate_penalty<S, D>(&self, activations: &ArrayBase<S, D>) -> A
116 where
117 S: Data<Elem = A>,
118 D: Dimension,
119 {
120 match self.norm {
121 ActivityNorm::L1 => {
122 let sum_abs = activations.mapv(|x| x.abs()).sum();
124 self.lambda * sum_abs
125 }
126 ActivityNorm::L2 => {
127 let sum_squared = activations.mapv(|x| x * x).sum();
129 self.lambda * sum_squared.sqrt()
130 }
131 ActivityNorm::L2Squared => {
132 let sum_squared = activations.mapv(|x| x * x).sum();
134 self.lambda * sum_squared
135 }
136 }
137 }
138
139 fn calculate_gradients<S, D>(&self, activations: &ArrayBase<S, D>) -> Array<A, D>
149 where
150 S: Data<Elem = A>,
151 D: Dimension,
152 {
153 match self.norm {
154 ActivityNorm::L1 => {
155 activations.mapv(|x| {
157 if x > A::zero() {
158 self.lambda
159 } else if x < A::zero() {
160 -self.lambda
161 } else {
162 A::zero()
163 }
164 })
165 }
166 ActivityNorm::L2 => {
167 let sum_squared = activations.mapv(|x| x * x).sum();
169
170 if sum_squared <= A::epsilon() {
172 return Array::zeros(activations.raw_dim());
173 }
174
175 let norm = sum_squared.sqrt();
176 activations.mapv(|x| self.lambda * x / norm)
177 }
178 ActivityNorm::L2Squared => {
179 let two = A::one() + A::one();
181 activations.mapv(|x| self.lambda * two * x)
182 }
183 }
184 }
185}
186
187impl<A, D> Regularizer<A, D> for ActivityRegularization<A>
188where
189 A: Float + ScalarOperand + Debug + FromPrimitive + Send + Sync,
190 D: Dimension,
191{
192 fn apply(&self, params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
193 let penalty = self.calculate_penalty(params);
195
196 let activity_grads = self.calculate_gradients(params);
198 gradients.zip_mut_with(&activity_grads, |g, &a| *g = *g + a);
199
200 Ok(penalty)
201 }
202
203 fn penalty(&self, params: &Array<A, D>) -> Result<A> {
204 Ok(self.calculate_penalty(params))
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211 use approx::assert_abs_diff_eq;
212 use scirs2_core::ndarray::array;
213 use scirs2_core::ndarray::{Array1, Array2};
214
215 #[test]
216 fn test_activity_regularization_creation() {
217 let ar = ActivityRegularization::l1(0.1f64);
218 assert_eq!(ar.lambda, 0.1);
219 assert_eq!(ar.norm, ActivityNorm::L1);
220
221 let ar = ActivityRegularization::l2(0.2f64);
222 assert_eq!(ar.lambda, 0.2);
223 assert_eq!(ar.norm, ActivityNorm::L2);
224
225 let ar = ActivityRegularization::l2_squared(0.3f64);
226 assert_eq!(ar.lambda, 0.3);
227 assert_eq!(ar.norm, ActivityNorm::L2Squared);
228
229 let ar = ActivityRegularization::new(0.4f64, ActivityNorm::L1);
230 assert_eq!(ar.lambda, 0.4);
231 assert_eq!(ar.norm, ActivityNorm::L1);
232 }
233
234 #[test]
235 fn test_l1_penalty() {
236 let lambda = 0.1f64;
237 let ar = ActivityRegularization::l1(lambda);
238
239 let activations = Array1::from_vec(vec![1.0f64, -2.0, 3.0]);
240 let penalty = ar.penalty(&activations).unwrap();
241
242 assert_abs_diff_eq!(penalty, lambda * 6.0, epsilon = 1e-10);
244 }
245
246 #[test]
247 fn test_l2_penalty() {
248 let lambda = 0.1f64;
249 let ar = ActivityRegularization::l2(lambda);
250
251 let activations = Array1::from_vec(vec![3.0f64, 4.0]);
252 let penalty = ar.penalty(&activations).unwrap();
253
254 assert_abs_diff_eq!(penalty, lambda * 5.0, epsilon = 1e-10);
256 }
257
258 #[test]
259 fn test_l2_squared_penalty() {
260 let lambda = 0.1f64;
261 let ar = ActivityRegularization::l2_squared(lambda);
262
263 let activations = Array1::from_vec(vec![1.0f64, 2.0, 3.0]);
264 let penalty = ar.penalty(&activations).unwrap();
265
266 assert_abs_diff_eq!(penalty, lambda * 14.0, epsilon = 1e-10);
268 }
269
270 #[test]
271 fn test_l1_gradients() {
272 let lambda = 0.1f64;
273 let ar = ActivityRegularization::l1(lambda);
274
275 let activations = Array1::from_vec(vec![1.0f64, -2.0, 0.0]);
276 let mut gradients = Array1::zeros(3);
277
278 let penalty = ar.apply(&activations, &mut gradients).unwrap();
279
280 assert_abs_diff_eq!(gradients[0], lambda, epsilon = 1e-10); assert_abs_diff_eq!(gradients[1], -lambda, epsilon = 1e-10); assert_abs_diff_eq!(gradients[2], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(penalty, lambda * 3.0, epsilon = 1e-10);
287 }
288
289 #[test]
290 fn test_l2_gradients() {
291 let lambda = 0.1f64;
292 let ar = ActivityRegularization::l2(lambda);
293
294 let activations = Array1::from_vec(vec![3.0f64, 4.0]);
295 let mut gradients = Array1::zeros(2);
296
297 let penalty = ar.apply(&activations, &mut gradients).unwrap();
298
299 assert_abs_diff_eq!(gradients[0], lambda * 3.0 / 5.0, epsilon = 1e-10);
302 assert_abs_diff_eq!(gradients[1], lambda * 4.0 / 5.0, epsilon = 1e-10);
303
304 assert_abs_diff_eq!(penalty, lambda * 5.0, epsilon = 1e-10);
306 }
307
308 #[test]
309 fn test_l2_gradients_zero_activations() {
310 let lambda = 0.1f64;
311 let ar = ActivityRegularization::l2(lambda);
312
313 let activations = Array1::from_vec(vec![0.0f64, 0.0]);
314 let mut gradients = Array1::zeros(2);
315
316 let penalty = ar.apply(&activations, &mut gradients).unwrap();
317
318 assert_abs_diff_eq!(gradients[0], 0.0, epsilon = 1e-10);
320 assert_abs_diff_eq!(gradients[1], 0.0, epsilon = 1e-10);
321
322 assert_abs_diff_eq!(penalty, 0.0, epsilon = 1e-10);
324 }
325
326 #[test]
327 fn test_l2_squared_gradients() {
328 let lambda = 0.1f64;
329 let ar = ActivityRegularization::l2_squared(lambda);
330
331 let activations = Array1::from_vec(vec![2.0f64, 3.0]);
332 let mut gradients = Array1::zeros(2);
333
334 let penalty = ar.apply(&activations, &mut gradients).unwrap();
335
336 assert_abs_diff_eq!(gradients[0], lambda * 2.0 * 2.0, epsilon = 1e-10);
338 assert_abs_diff_eq!(gradients[1], lambda * 2.0 * 3.0, epsilon = 1e-10);
339
340 assert_abs_diff_eq!(penalty, lambda * 13.0, epsilon = 1e-10);
342 }
343
344 #[test]
345 fn test_2d_activations() {
346 let lambda = 0.1f64;
347 let ar = ActivityRegularization::l1(lambda);
348
349 let activations = Array2::from_shape_vec((2, 2), vec![1.0f64, 2.0, -3.0, 4.0]).unwrap();
350 let penalty = ar.penalty(&activations).unwrap();
351
352 assert_abs_diff_eq!(penalty, lambda * 10.0, epsilon = 1e-10);
354 }
355
356 #[test]
357 fn test_regularizer_trait() {
358 let lambda = 0.1f64;
359 let ar = ActivityRegularization::l1(lambda);
360
361 let activations = array![1.0f64, 2.0, 3.0];
362 let mut gradients = Array1::zeros(3);
363
364 let penalty1 = ar.penalty(&activations).unwrap();
366 let penalty2 = ar.apply(&activations, &mut gradients).unwrap();
367
368 assert_abs_diff_eq!(penalty1, penalty2, epsilon = 1e-10);
369
370 assert_abs_diff_eq!(gradients[0], lambda, epsilon = 1e-10);
372 assert_abs_diff_eq!(gradients[1], lambda, epsilon = 1e-10);
373 assert_abs_diff_eq!(gradients[2], lambda, epsilon = 1e-10);
374 }
375}