1use crate::modular_framework::Regularization;
8use scirs2_core::ndarray::Array1;
9use sklears_core::{
10 error::{Result, SklearsError},
11 types::Float,
12};
13
14#[derive(Debug, Clone)]
16pub struct L2Regularization {
17 pub alpha: Float,
19}
20
21impl L2Regularization {
22 pub fn new(alpha: Float) -> Result<Self> {
24 if alpha < 0.0 {
25 return Err(SklearsError::InvalidParameter {
26 name: "alpha".to_string(),
27 reason: format!(
28 "Regularization strength must be non-negative, got {}",
29 alpha
30 ),
31 });
32 }
33 Ok(Self { alpha })
34 }
35}
36
37impl Regularization for L2Regularization {
38 fn penalty(&self, coefficients: &Array1<Float>) -> Result<Float> {
39 let norm_squared = coefficients.mapv(|x| x * x).sum();
40 Ok(0.5 * self.alpha * norm_squared)
41 }
42
43 fn penalty_gradient(&self, coefficients: &Array1<Float>) -> Result<Array1<Float>> {
44 Ok(self.alpha * coefficients)
45 }
46
47 fn proximal_operator(
48 &self,
49 coefficients: &Array1<Float>,
50 step_size: Float,
51 ) -> Result<Array1<Float>> {
52 let shrinkage_factor = 1.0 / (1.0 + self.alpha * step_size);
54 Ok(coefficients * shrinkage_factor)
55 }
56
57 fn strength(&self) -> Float {
58 self.alpha
59 }
60
61 fn name(&self) -> &'static str {
62 "L2Regularization"
63 }
64}
65
66#[derive(Debug, Clone)]
68pub struct L1Regularization {
69 pub alpha: Float,
71}
72
73impl L1Regularization {
74 pub fn new(alpha: Float) -> Result<Self> {
76 if alpha < 0.0 {
77 return Err(SklearsError::InvalidParameter {
78 name: "alpha".to_string(),
79 reason: format!(
80 "Regularization strength must be non-negative, got {}",
81 alpha
82 ),
83 });
84 }
85 Ok(Self { alpha })
86 }
87}
88
89impl Regularization for L1Regularization {
90 fn penalty(&self, coefficients: &Array1<Float>) -> Result<Float> {
91 let l1_norm = coefficients.mapv(|x| x.abs()).sum();
92 Ok(self.alpha * l1_norm)
93 }
94
95 fn penalty_gradient(&self, coefficients: &Array1<Float>) -> Result<Array1<Float>> {
96 let subgradient = coefficients.mapv(|x| {
98 if x > 0.0 {
99 self.alpha
100 } else if x < 0.0 {
101 -self.alpha
102 } else {
103 0.0 }
105 });
106 Ok(subgradient)
107 }
108
109 fn proximal_operator(
110 &self,
111 coefficients: &Array1<Float>,
112 step_size: Float,
113 ) -> Result<Array1<Float>> {
114 let threshold = self.alpha * step_size;
116 let result = coefficients.mapv(|x| {
117 if x > threshold {
118 x - threshold
119 } else if x < -threshold {
120 x + threshold
121 } else {
122 0.0
123 }
124 });
125 Ok(result)
126 }
127
128 fn is_non_smooth(&self) -> bool {
129 true
130 }
131
132 fn strength(&self) -> Float {
133 self.alpha
134 }
135
136 fn name(&self) -> &'static str {
137 "L1Regularization"
138 }
139}
140
141#[derive(Debug, Clone)]
143pub struct ElasticNetRegularization {
144 pub alpha: Float,
146 pub l1_ratio: Float,
148}
149
150impl ElasticNetRegularization {
151 pub fn new(alpha: Float, l1_ratio: Float) -> Result<Self> {
153 if alpha < 0.0 {
154 return Err(SklearsError::InvalidParameter {
155 name: "alpha".to_string(),
156 reason: format!(
157 "Regularization strength must be non-negative, got {}",
158 alpha
159 ),
160 });
161 }
162 if !(0.0..=1.0).contains(&l1_ratio) {
163 return Err(SklearsError::InvalidParameter {
164 name: "l1_ratio".to_string(),
165 reason: format!("L1 ratio must be between 0 and 1, got {}", l1_ratio),
166 });
167 }
168 Ok(Self { alpha, l1_ratio })
169 }
170
171 pub fn l1_strength(&self) -> Float {
173 self.alpha * self.l1_ratio
174 }
175
176 pub fn l2_strength(&self) -> Float {
178 self.alpha * (1.0 - self.l1_ratio)
179 }
180}
181
182impl Regularization for ElasticNetRegularization {
183 fn penalty(&self, coefficients: &Array1<Float>) -> Result<Float> {
184 let l1_norm = coefficients.mapv(|x| x.abs()).sum();
185 let l2_norm_squared = coefficients.mapv(|x| x * x).sum();
186
187 let l1_penalty = self.l1_strength() * l1_norm;
188 let l2_penalty = 0.5 * self.l2_strength() * l2_norm_squared;
189
190 Ok(l1_penalty + l2_penalty)
191 }
192
193 fn penalty_gradient(&self, coefficients: &Array1<Float>) -> Result<Array1<Float>> {
194 let l1_strength = self.l1_strength();
195 let l2_strength = self.l2_strength();
196
197 let gradient = coefficients.mapv(|x| {
198 let l1_subgrad = if x > 0.0 {
199 l1_strength
200 } else if x < 0.0 {
201 -l1_strength
202 } else {
203 0.0
204 };
205 let l2_grad = l2_strength * x;
206 l1_subgrad + l2_grad
207 });
208
209 Ok(gradient)
210 }
211
212 fn proximal_operator(
213 &self,
214 coefficients: &Array1<Float>,
215 step_size: Float,
216 ) -> Result<Array1<Float>> {
217 let l1_strength = self.l1_strength();
218 let l2_strength = self.l2_strength();
219
220 let threshold = l1_strength * step_size;
222 let shrinkage_factor = 1.0 / (1.0 + l2_strength * step_size);
223
224 let result = coefficients.mapv(|x| {
225 let soft_thresholded = if x > threshold {
226 x - threshold
227 } else if x < -threshold {
228 x + threshold
229 } else {
230 0.0
231 };
232 soft_thresholded * shrinkage_factor
233 });
234
235 Ok(result)
236 }
237
238 fn is_non_smooth(&self) -> bool {
239 self.l1_ratio > 0.0
240 }
241
242 fn strength(&self) -> Float {
243 self.alpha
244 }
245
246 fn name(&self) -> &'static str {
247 "ElasticNetRegularization"
248 }
249}
250
251#[derive(Debug, Clone)]
253pub struct GroupLassoRegularization {
254 pub alpha: Float,
256 pub groups: Vec<usize>,
258}
259
260impl GroupLassoRegularization {
261 pub fn new(alpha: Float, groups: Vec<usize>) -> Result<Self> {
263 if alpha < 0.0 {
264 return Err(SklearsError::InvalidParameter {
265 name: "alpha".to_string(),
266 reason: format!(
267 "Regularization strength must be non-negative, got {}",
268 alpha
269 ),
270 });
271 }
272 Ok(Self { alpha, groups })
273 }
274}
275
276impl Regularization for GroupLassoRegularization {
277 fn penalty(&self, coefficients: &Array1<Float>) -> Result<Float> {
278 if coefficients.len() != self.groups.len() {
279 return Err(SklearsError::DimensionMismatch {
280 expected: self.groups.len(),
281 actual: coefficients.len(),
282 });
283 }
284
285 let max_group = *self.groups.iter().max().unwrap_or(&0);
287 let mut group_norms = vec![0.0; max_group + 1];
288
289 for (i, &group_id) in self.groups.iter().enumerate() {
290 group_norms[group_id] += coefficients[i] * coefficients[i];
291 }
292
293 let penalty = group_norms
294 .iter()
295 .map(|&norm_sq| norm_sq.sqrt())
296 .sum::<Float>();
297 Ok(self.alpha * penalty)
298 }
299
300 fn penalty_gradient(&self, coefficients: &Array1<Float>) -> Result<Array1<Float>> {
301 if coefficients.len() != self.groups.len() {
302 return Err(SklearsError::DimensionMismatch {
303 expected: self.groups.len(),
304 actual: coefficients.len(),
305 });
306 }
307
308 let max_group = *self.groups.iter().max().unwrap_or(&0);
309 let mut group_norms = vec![0.0; max_group + 1];
310
311 for (i, &group_id) in self.groups.iter().enumerate() {
313 group_norms[group_id] += coefficients[i] * coefficients[i];
314 }
315
316 for norm_sq in &mut group_norms {
318 *norm_sq = norm_sq.sqrt();
319 }
320
321 let mut gradient = Array1::zeros(coefficients.len());
323 for (i, &group_id) in self.groups.iter().enumerate() {
324 if group_norms[group_id] > 0.0 {
325 gradient[i] = self.alpha * coefficients[i] / group_norms[group_id];
326 } else {
327 gradient[i] = 0.0; }
329 }
330
331 Ok(gradient)
332 }
333
334 fn proximal_operator(
335 &self,
336 coefficients: &Array1<Float>,
337 step_size: Float,
338 ) -> Result<Array1<Float>> {
339 if coefficients.len() != self.groups.len() {
340 return Err(SklearsError::DimensionMismatch {
341 expected: self.groups.len(),
342 actual: coefficients.len(),
343 });
344 }
345
346 let max_group = *self.groups.iter().max().unwrap_or(&0);
347 let mut group_norms = vec![0.0; max_group + 1];
348
349 for (i, &group_id) in self.groups.iter().enumerate() {
351 group_norms[group_id] += coefficients[i] * coefficients[i];
352 }
353
354 for norm_sq in &mut group_norms {
355 *norm_sq = norm_sq.sqrt();
356 }
357
358 let threshold = self.alpha * step_size;
360 let mut result = coefficients.clone();
361
362 for (i, &group_id) in self.groups.iter().enumerate() {
363 let group_norm = group_norms[group_id];
364 if group_norm > threshold {
365 let shrinkage_factor = (group_norm - threshold) / group_norm;
366 result[i] *= shrinkage_factor;
367 } else {
368 result[i] = 0.0;
369 }
370 }
371
372 Ok(result)
373 }
374
375 fn is_non_smooth(&self) -> bool {
376 true
377 }
378
379 fn strength(&self) -> Float {
380 self.alpha
381 }
382
383 fn name(&self) -> &'static str {
384 "GroupLassoRegularization"
385 }
386}
387
388#[derive(Debug)]
390pub struct CompositeRegularization {
391 regularizations: Vec<(Float, Box<dyn Regularization>)>,
393}
394
395impl Default for CompositeRegularization {
396 fn default() -> Self {
397 Self::new()
398 }
399}
400
401impl CompositeRegularization {
402 pub fn new() -> Self {
404 Self {
405 regularizations: Vec::new(),
406 }
407 }
408
409 pub fn add_regularization(
411 mut self,
412 weight: Float,
413 regularization: Box<dyn Regularization>,
414 ) -> Self {
415 self.regularizations.push((weight, regularization));
416 self
417 }
418
419 pub fn add_l1(self, alpha: Float) -> Result<Self> {
421 Ok(self.add_regularization(1.0, Box::new(L1Regularization::new(alpha)?)))
422 }
423
424 pub fn add_l2(self, alpha: Float) -> Result<Self> {
426 Ok(self.add_regularization(1.0, Box::new(L2Regularization::new(alpha)?)))
427 }
428
429 pub fn add_group_lasso(self, alpha: Float, groups: Vec<usize>) -> Result<Self> {
431 Ok(self.add_regularization(1.0, Box::new(GroupLassoRegularization::new(alpha, groups)?)))
432 }
433
434 pub fn is_any_non_smooth(&self) -> bool {
436 self.regularizations
437 .iter()
438 .any(|(_, reg)| reg.is_non_smooth())
439 }
440}
441
442impl Regularization for CompositeRegularization {
443 fn penalty(&self, coefficients: &Array1<Float>) -> Result<Float> {
444 let mut total_penalty = 0.0;
445 for (weight, regularization) in &self.regularizations {
446 total_penalty += weight * regularization.penalty(coefficients)?;
447 }
448 Ok(total_penalty)
449 }
450
451 fn penalty_gradient(&self, coefficients: &Array1<Float>) -> Result<Array1<Float>> {
452 let mut total_gradient = Array1::zeros(coefficients.len());
453 for (weight, regularization) in &self.regularizations {
454 let grad = regularization.penalty_gradient(coefficients)?;
455 total_gradient = total_gradient + *weight * grad;
456 }
457 Ok(total_gradient)
458 }
459
460 fn proximal_operator(
461 &self,
462 coefficients: &Array1<Float>,
463 step_size: Float,
464 ) -> Result<Array1<Float>> {
465 let mut result = coefficients.clone();
468 for (weight, regularization) in &self.regularizations {
469 result = regularization.proximal_operator(&result, weight * step_size)?;
470 }
471 Ok(result)
472 }
473
474 fn is_non_smooth(&self) -> bool {
475 self.is_any_non_smooth()
476 }
477
478 fn strength(&self) -> Float {
479 self.regularizations
481 .iter()
482 .map(|(weight, reg)| weight * reg.strength())
483 .sum()
484 }
485
486 fn name(&self) -> &'static str {
487 "CompositeRegularization"
488 }
489}
490
491pub struct RegularizationFactory;
493
494impl RegularizationFactory {
495 pub fn l1(alpha: Float) -> Result<Box<dyn Regularization>> {
497 Ok(Box::new(L1Regularization::new(alpha)?))
498 }
499
500 pub fn l2(alpha: Float) -> Result<Box<dyn Regularization>> {
502 Ok(Box::new(L2Regularization::new(alpha)?))
503 }
504
505 pub fn elastic_net(alpha: Float, l1_ratio: Float) -> Result<Box<dyn Regularization>> {
507 Ok(Box::new(ElasticNetRegularization::new(alpha, l1_ratio)?))
508 }
509
510 pub fn group_lasso(alpha: Float, groups: Vec<usize>) -> Result<Box<dyn Regularization>> {
512 Ok(Box::new(GroupLassoRegularization::new(alpha, groups)?))
513 }
514
515 pub fn composite() -> CompositeRegularization {
517 CompositeRegularization::new()
518 }
519}
520
521#[allow(non_snake_case)]
522#[cfg(test)]
523mod tests {
524 use super::*;
525 use scirs2_core::ndarray::Array;
526
527 #[test]
528 fn test_l2_regularization() {
529 let reg = L2Regularization::new(0.5).unwrap();
530 let coefficients = Array::from_vec(vec![1.0, -2.0, 3.0]);
531
532 let penalty = reg.penalty(&coefficients).unwrap();
533 let expected = 0.5 * 0.5 * (1.0 + 4.0 + 9.0); assert!((penalty - expected).abs() < 1e-10);
535
536 let gradient = reg.penalty_gradient(&coefficients).unwrap();
537 let expected_grad = Array::from_vec(vec![0.5, -1.0, 1.5]); for (actual, expected) in gradient.iter().zip(expected_grad.iter()) {
539 assert!((actual - expected).abs() < 1e-10);
540 }
541 }
542
543 #[test]
544 fn test_l1_regularization() {
545 let reg = L1Regularization::new(0.3).unwrap();
546 let coefficients = Array::from_vec(vec![1.0, -2.0, 3.0]);
547
548 let penalty = reg.penalty(&coefficients).unwrap();
549 let expected = 0.3 * (1.0 + 2.0 + 3.0); assert!((penalty - expected).abs() < 1e-10);
551
552 assert!(reg.is_non_smooth());
553 }
554
555 #[test]
556 fn test_l1_proximal_operator() {
557 let reg = L1Regularization::new(1.0).unwrap();
558 let coefficients = Array::from_vec(vec![2.0, -1.0, 0.5]);
559 let step_size = 1.0;
560
561 let result = reg.proximal_operator(&coefficients, step_size).unwrap();
562 let expected = Array::from_vec(vec![1.0, 0.0, 0.0]);
564 for (actual, expected) in result.iter().zip(expected.iter()) {
565 assert!((actual - expected).abs() < 1e-10);
566 }
567 }
568
569 #[test]
570 fn test_elastic_net_regularization() {
571 let reg = ElasticNetRegularization::new(1.0, 0.7).unwrap();
572 let coefficients = Array::from_vec(vec![1.0, -1.0]);
573
574 let penalty = reg.penalty(&coefficients).unwrap();
575 let l1_penalty = 0.7 * (1.0 + 1.0); let l2_penalty = 0.5 * 0.3 * (1.0 + 1.0); let expected = l1_penalty + l2_penalty;
578 assert!((penalty - expected).abs() < 1e-10);
579
580 assert!(reg.is_non_smooth()); }
582
583 #[test]
584 fn test_group_lasso_regularization() {
585 let groups = vec![0, 0, 1, 1]; let reg = GroupLassoRegularization::new(1.0, groups).unwrap();
587 let coefficients = Array::from_vec(vec![3.0, 4.0, 0.0, 0.0]); let penalty = reg.penalty(&coefficients).unwrap();
590 let expected = 5.0 + 0.0; assert!((penalty - expected).abs() < 1e-10);
592
593 assert!(reg.is_non_smooth());
594 }
595
596 #[test]
597 fn test_composite_regularization() {
598 let composite = CompositeRegularization::new()
599 .add_l1(0.1)
600 .unwrap()
601 .add_l2(0.2)
602 .unwrap();
603
604 let coefficients = Array::from_vec(vec![1.0, -2.0]);
605
606 let penalty = composite.penalty(&coefficients).unwrap();
607 let l1_penalty = 0.1 * (1.0 + 2.0);
608 let l2_penalty = 0.5 * 0.2 * (1.0 + 4.0);
609 let expected = l1_penalty + l2_penalty;
610 assert!((penalty - expected).abs() < 1e-10);
611
612 assert!(composite.is_non_smooth()); }
614
615 #[test]
616 fn test_regularization_factory() {
617 let l1 = RegularizationFactory::l1(0.5).unwrap();
618 assert_eq!(l1.name(), "L1Regularization");
619
620 let l2 = RegularizationFactory::l2(0.3).unwrap();
621 assert_eq!(l2.name(), "L2Regularization");
622
623 let elastic_net = RegularizationFactory::elastic_net(1.0, 0.8).unwrap();
624 assert_eq!(elastic_net.name(), "ElasticNetRegularization");
625 }
626
627 #[test]
628 fn test_invalid_parameters() {
629 assert!(L1Regularization::new(-1.0).is_err());
631 assert!(L2Regularization::new(-0.1).is_err());
632
633 assert!(ElasticNetRegularization::new(1.0, -0.1).is_err());
635 assert!(ElasticNetRegularization::new(1.0, 1.5).is_err());
636 }
637}