optirs_core/regularizers/
activity.rs1use scirs2_core::ndarray::{Array, ArrayBase, Data, Dimension, ScalarOperand};
2use scirs2_core::numeric::{Float, FromPrimitive};
3use std::fmt::Debug;
4
5use crate::error::Result;
6use crate::regularizers::Regularizer;
7
8#[derive(Debug, Clone, Copy, PartialEq)]
10pub enum ActivityNorm {
11 L1,
13 L2,
15 L2Squared,
17}
18
19#[derive(Debug, Clone, Copy)]
36pub struct ActivityRegularization<A: Float + FromPrimitive + Debug> {
37 pub lambda: A,
39 pub norm: ActivityNorm,
41}
42
43impl<A: Float + FromPrimitive + Debug + Send + Sync> ActivityRegularization<A> {
44 pub fn l1(lambda: A) -> Self {
54 Self {
55 lambda,
56 norm: ActivityNorm::L1,
57 }
58 }
59
60 pub fn l2(lambda: A) -> Self {
70 Self {
71 lambda,
72 norm: ActivityNorm::L2,
73 }
74 }
75
76 pub fn l2_squared(lambda: A) -> Self {
86 Self {
87 lambda,
88 norm: ActivityNorm::L2Squared,
89 }
90 }
91
92 pub fn new(lambda: A, norm: ActivityNorm) -> Self {
103 Self { lambda, norm }
104 }
105
106 fn calculate_penalty<S, D>(&self, activations: &ArrayBase<S, D>) -> A
116 where
117 S: Data<Elem = A>,
118 D: Dimension,
119 {
120 match self.norm {
121 ActivityNorm::L1 => {
122 let sum_abs = activations.mapv(|x| x.abs()).sum();
124 self.lambda * sum_abs
125 }
126 ActivityNorm::L2 => {
127 let sum_squared = activations.mapv(|x| x * x).sum();
129 self.lambda * sum_squared.sqrt()
130 }
131 ActivityNorm::L2Squared => {
132 let sum_squared = activations.mapv(|x| x * x).sum();
134 self.lambda * sum_squared
135 }
136 }
137 }
138
139 fn calculate_gradients<S, D>(&self, activations: &ArrayBase<S, D>) -> Array<A, D>
149 where
150 S: Data<Elem = A>,
151 D: Dimension,
152 {
153 match self.norm {
154 ActivityNorm::L1 => {
155 activations.mapv(|x| {
157 if x > A::zero() {
158 self.lambda
159 } else if x < A::zero() {
160 -self.lambda
161 } else {
162 A::zero()
163 }
164 })
165 }
166 ActivityNorm::L2 => {
167 let sum_squared = activations.mapv(|x| x * x).sum();
169
170 if sum_squared <= A::epsilon() {
172 return Array::zeros(activations.raw_dim());
173 }
174
175 let norm = sum_squared.sqrt();
176 activations.mapv(|x| self.lambda * x / norm)
177 }
178 ActivityNorm::L2Squared => {
179 let two = A::one() + A::one();
181 activations.mapv(|x| self.lambda * two * x)
182 }
183 }
184 }
185}
186
187impl<A, D> Regularizer<A, D> for ActivityRegularization<A>
188where
189 A: Float + ScalarOperand + Debug + FromPrimitive + Send + Sync,
190 D: Dimension,
191{
192 fn apply(&self, params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
193 let penalty = self.calculate_penalty(params);
195
196 let activity_grads = self.calculate_gradients(params);
198 gradients.zip_mut_with(&activity_grads, |g, &a| *g = *g + a);
199
200 Ok(penalty)
201 }
202
203 fn penalty(&self, params: &Array<A, D>) -> Result<A> {
204 Ok(self.calculate_penalty(params))
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211 use approx::assert_abs_diff_eq;
212 use scirs2_core::ndarray::array;
213 use scirs2_core::ndarray::{Array1, Array2};
214
215 #[test]
216 fn test_activity_regularization_creation() {
217 let ar = ActivityRegularization::l1(0.1f64);
218 assert_eq!(ar.lambda, 0.1);
219 assert_eq!(ar.norm, ActivityNorm::L1);
220
221 let ar = ActivityRegularization::l2(0.2f64);
222 assert_eq!(ar.lambda, 0.2);
223 assert_eq!(ar.norm, ActivityNorm::L2);
224
225 let ar = ActivityRegularization::l2_squared(0.3f64);
226 assert_eq!(ar.lambda, 0.3);
227 assert_eq!(ar.norm, ActivityNorm::L2Squared);
228
229 let ar = ActivityRegularization::new(0.4f64, ActivityNorm::L1);
230 assert_eq!(ar.lambda, 0.4);
231 assert_eq!(ar.norm, ActivityNorm::L1);
232 }
233
234 #[test]
235 fn test_l1_penalty() {
236 let lambda = 0.1f64;
237 let ar = ActivityRegularization::l1(lambda);
238
239 let activations = Array1::from_vec(vec![1.0f64, -2.0, 3.0]);
240 let penalty = ar.penalty(&activations).expect("unwrap failed");
241
242 assert_abs_diff_eq!(penalty, lambda * 6.0, epsilon = 1e-10);
244 }
245
246 #[test]
247 fn test_l2_penalty() {
248 let lambda = 0.1f64;
249 let ar = ActivityRegularization::l2(lambda);
250
251 let activations = Array1::from_vec(vec![3.0f64, 4.0]);
252 let penalty = ar.penalty(&activations).expect("unwrap failed");
253
254 assert_abs_diff_eq!(penalty, lambda * 5.0, epsilon = 1e-10);
256 }
257
258 #[test]
259 fn test_l2_squared_penalty() {
260 let lambda = 0.1f64;
261 let ar = ActivityRegularization::l2_squared(lambda);
262
263 let activations = Array1::from_vec(vec![1.0f64, 2.0, 3.0]);
264 let penalty = ar.penalty(&activations).expect("unwrap failed");
265
266 assert_abs_diff_eq!(penalty, lambda * 14.0, epsilon = 1e-10);
268 }
269
270 #[test]
271 fn test_l1_gradients() {
272 let lambda = 0.1f64;
273 let ar = ActivityRegularization::l1(lambda);
274
275 let activations = Array1::from_vec(vec![1.0f64, -2.0, 0.0]);
276 let mut gradients = Array1::zeros(3);
277
278 let penalty = ar
279 .apply(&activations, &mut gradients)
280 .expect("unwrap failed");
281
282 assert_abs_diff_eq!(gradients[0], lambda, epsilon = 1e-10); assert_abs_diff_eq!(gradients[1], -lambda, epsilon = 1e-10); assert_abs_diff_eq!(gradients[2], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(penalty, lambda * 3.0, epsilon = 1e-10);
289 }
290
291 #[test]
292 fn test_l2_gradients() {
293 let lambda = 0.1f64;
294 let ar = ActivityRegularization::l2(lambda);
295
296 let activations = Array1::from_vec(vec![3.0f64, 4.0]);
297 let mut gradients = Array1::zeros(2);
298
299 let penalty = ar
300 .apply(&activations, &mut gradients)
301 .expect("unwrap failed");
302
303 assert_abs_diff_eq!(gradients[0], lambda * 3.0 / 5.0, epsilon = 1e-10);
306 assert_abs_diff_eq!(gradients[1], lambda * 4.0 / 5.0, epsilon = 1e-10);
307
308 assert_abs_diff_eq!(penalty, lambda * 5.0, epsilon = 1e-10);
310 }
311
312 #[test]
313 fn test_l2_gradients_zero_activations() {
314 let lambda = 0.1f64;
315 let ar = ActivityRegularization::l2(lambda);
316
317 let activations = Array1::from_vec(vec![0.0f64, 0.0]);
318 let mut gradients = Array1::zeros(2);
319
320 let penalty = ar
321 .apply(&activations, &mut gradients)
322 .expect("unwrap failed");
323
324 assert_abs_diff_eq!(gradients[0], 0.0, epsilon = 1e-10);
326 assert_abs_diff_eq!(gradients[1], 0.0, epsilon = 1e-10);
327
328 assert_abs_diff_eq!(penalty, 0.0, epsilon = 1e-10);
330 }
331
332 #[test]
333 fn test_l2_squared_gradients() {
334 let lambda = 0.1f64;
335 let ar = ActivityRegularization::l2_squared(lambda);
336
337 let activations = Array1::from_vec(vec![2.0f64, 3.0]);
338 let mut gradients = Array1::zeros(2);
339
340 let penalty = ar
341 .apply(&activations, &mut gradients)
342 .expect("unwrap failed");
343
344 assert_abs_diff_eq!(gradients[0], lambda * 2.0 * 2.0, epsilon = 1e-10);
346 assert_abs_diff_eq!(gradients[1], lambda * 2.0 * 3.0, epsilon = 1e-10);
347
348 assert_abs_diff_eq!(penalty, lambda * 13.0, epsilon = 1e-10);
350 }
351
352 #[test]
353 fn test_2d_activations() {
354 let lambda = 0.1f64;
355 let ar = ActivityRegularization::l1(lambda);
356
357 let activations =
358 Array2::from_shape_vec((2, 2), vec![1.0f64, 2.0, -3.0, 4.0]).expect("unwrap failed");
359 let penalty = ar.penalty(&activations).expect("unwrap failed");
360
361 assert_abs_diff_eq!(penalty, lambda * 10.0, epsilon = 1e-10);
363 }
364
365 #[test]
366 fn test_regularizer_trait() {
367 let lambda = 0.1f64;
368 let ar = ActivityRegularization::l1(lambda);
369
370 let activations = array![1.0f64, 2.0, 3.0];
371 let mut gradients = Array1::zeros(3);
372
373 let penalty1 = ar.penalty(&activations).expect("unwrap failed");
375 let penalty2 = ar
376 .apply(&activations, &mut gradients)
377 .expect("unwrap failed");
378
379 assert_abs_diff_eq!(penalty1, penalty2, epsilon = 1e-10);
380
381 assert_abs_diff_eq!(gradients[0], lambda, epsilon = 1e-10);
383 assert_abs_diff_eq!(gradients[1], lambda, epsilon = 1e-10);
384 assert_abs_diff_eq!(gradients[2], lambda, epsilon = 1e-10);
385 }
386}