1use scirs2_core::ndarray::{Array, Dimension, ScalarOperand, Zip};
8use scirs2_core::numeric::Float;
9use std::fmt::Debug;
10
11use crate::error::{OptimError, Result};
12use crate::regularizers::Regularizer;
13
14#[derive(Debug, Clone)]
39pub struct GroupLasso<A: Float + ScalarOperand + Debug> {
40 lambda: A,
42 groups: Vec<Vec<usize>>,
44 group_weights: Option<Vec<A>>,
46}
47
48impl<A: Float + ScalarOperand + Debug> GroupLasso<A> {
49 pub fn new(lambda: A) -> Self {
55 Self {
56 lambda,
57 groups: Vec::new(),
58 group_weights: None,
59 }
60 }
61
62 pub fn with_groups(mut self, groups: Vec<Vec<usize>>) -> Self {
69 self.groups = groups;
70 self
71 }
72
73 pub fn with_group_weights(mut self, weights: Vec<A>) -> Self {
80 self.group_weights = Some(weights);
81 self
82 }
83
84 pub fn auto_groups(mut self, param_size: usize, group_size: usize) -> Self {
94 let mut groups = Vec::new();
95 let mut start = 0;
96 while start < param_size {
97 let end = (start + group_size).min(param_size);
98 groups.push((start..end).collect());
99 start = end;
100 }
101 self.groups = groups;
102 self
103 }
104
105 pub fn lambda(&self) -> A {
107 self.lambda
108 }
109
110 pub fn groups(&self) -> &[Vec<usize>] {
112 &self.groups
113 }
114
115 pub fn num_groups(&self) -> usize {
117 self.groups.len()
118 }
119
120 fn group_weight(&self, group_idx: usize) -> A {
122 self.group_weights
123 .as_ref()
124 .and_then(|w| w.get(group_idx).copied())
125 .unwrap_or_else(A::one)
126 }
127
128 fn group_l2_norm(&self, params: &Array<A, impl Dimension>, indices: &[usize]) -> A {
132 let flat = params.as_slice_memory_order();
133 let sum_sq = indices.iter().fold(A::zero(), |acc, &idx| {
134 if let Some(slice) = flat {
135 if idx < slice.len() {
136 acc + slice[idx] * slice[idx]
137 } else {
138 acc
139 }
140 } else {
141 let mut iter = params.iter();
143 if let Some(&val) = iter.nth(idx) {
144 acc + val * val
145 } else {
146 acc
147 }
148 }
149 });
150 sum_sq.sqrt()
151 }
152
153 fn validate_groups(&self, param_len: usize) -> Result<()> {
155 for (g_idx, group) in self.groups.iter().enumerate() {
156 for &idx in group {
157 if idx >= param_len {
158 return Err(OptimError::InvalidParameter(format!(
159 "Group {} contains index {} which exceeds parameter size {}",
160 g_idx, idx, param_len
161 )));
162 }
163 }
164 }
165 if let Some(ref weights) = self.group_weights {
166 if weights.len() != self.groups.len() {
167 return Err(OptimError::InvalidConfig(format!(
168 "Number of group weights ({}) does not match number of groups ({})",
169 weights.len(),
170 self.groups.len()
171 )));
172 }
173 }
174 Ok(())
175 }
176}
177
178impl<A, D> Regularizer<A, D> for GroupLasso<A>
179where
180 A: Float + ScalarOperand + Debug,
181 D: Dimension,
182{
183 fn apply(&self, params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
184 let param_len = params.len();
185 self.validate_groups(param_len)?;
186
187 let epsilon = A::from(1e-8).unwrap_or_else(|| A::epsilon());
188
189 let grad_slice = gradients.as_slice_memory_order_mut().ok_or_else(|| {
191 OptimError::InvalidParameter("Gradients array is not contiguous in memory".to_string())
192 })?;
193
194 let param_slice = params.as_slice_memory_order().ok_or_else(|| {
195 OptimError::InvalidParameter("Parameters array is not contiguous in memory".to_string())
196 })?;
197
198 for (g_idx, group) in self.groups.iter().enumerate() {
199 let w_g = self.group_weight(g_idx);
200
201 let sum_sq = group.iter().fold(A::zero(), |acc, &idx| {
203 if idx < param_len {
204 acc + param_slice[idx] * param_slice[idx]
205 } else {
206 acc
207 }
208 });
209 let norm = sum_sq.sqrt();
210
211 let scale = self.lambda * w_g / (norm + epsilon);
213
214 for &idx in group {
215 if idx < param_len {
216 grad_slice[idx] = grad_slice[idx] + scale * param_slice[idx];
217 }
218 }
219 }
220
221 self.penalty(params)
222 }
223
224 fn penalty(&self, params: &Array<A, D>) -> Result<A> {
225 let param_len = params.len();
226 self.validate_groups(param_len)?;
227
228 let mut total = A::zero();
229
230 for (g_idx, group) in self.groups.iter().enumerate() {
231 let w_g = self.group_weight(g_idx);
232 let norm = self.group_l2_norm(params, group);
233 total = total + w_g * norm;
234 }
235
236 Ok(self.lambda * total)
237 }
238}
239
240#[derive(Debug, Clone)]
246pub enum SparsityPattern {
247 Column {
251 num_columns: usize,
253 },
254 Row {
258 num_rows: usize,
260 },
261 Block {
265 block_height: usize,
267 block_width: usize,
269 },
270}
271
272#[derive(Debug, Clone)]
292pub struct StructuredSparsity<A: Float + ScalarOperand + Debug> {
293 lambda: A,
295 pattern: SparsityPattern,
297}
298
299impl<A: Float + ScalarOperand + Debug> StructuredSparsity<A> {
300 pub fn new(lambda: A, pattern: SparsityPattern) -> Self {
307 Self { lambda, pattern }
308 }
309
310 pub fn lambda(&self) -> A {
312 self.lambda
313 }
314
315 pub fn pattern(&self) -> &SparsityPattern {
317 &self.pattern
318 }
319
320 fn build_groups(&self, total_params: usize) -> Result<Vec<Vec<usize>>> {
327 match &self.pattern {
328 SparsityPattern::Column { num_columns } => {
329 if *num_columns == 0 {
330 return Err(OptimError::InvalidConfig(
331 "Number of columns must be greater than 0".to_string(),
332 ));
333 }
334 let num_rows = total_params / num_columns;
335 if num_rows * num_columns != total_params {
336 return Err(OptimError::InvalidConfig(format!(
337 "Total parameters ({}) is not evenly divisible by num_columns ({})",
338 total_params, num_columns
339 )));
340 }
341
342 let mut groups = Vec::with_capacity(*num_columns);
343 for col in 0..*num_columns {
344 let group: Vec<usize> =
345 (0..num_rows).map(|row| row * num_columns + col).collect();
346 groups.push(group);
347 }
348 Ok(groups)
349 }
350 SparsityPattern::Row { num_rows } => {
351 if *num_rows == 0 {
352 return Err(OptimError::InvalidConfig(
353 "Number of rows must be greater than 0".to_string(),
354 ));
355 }
356 let num_columns = total_params / num_rows;
357 if num_rows * num_columns != total_params {
358 return Err(OptimError::InvalidConfig(format!(
359 "Total parameters ({}) is not evenly divisible by num_rows ({})",
360 total_params, num_rows
361 )));
362 }
363
364 let mut groups = Vec::with_capacity(*num_rows);
365 for row in 0..*num_rows {
366 let start = row * num_columns;
367 let group: Vec<usize> = (start..start + num_columns).collect();
368 groups.push(group);
369 }
370 Ok(groups)
371 }
372 SparsityPattern::Block {
373 block_height,
374 block_width,
375 } => {
376 if *block_height == 0 || *block_width == 0 {
377 return Err(OptimError::InvalidConfig(
378 "Block dimensions must be greater than 0".to_string(),
379 ));
380 }
381
382 let num_cols =
389 self.infer_matrix_columns(total_params, *block_height, *block_width)?;
390 let num_rows = total_params / num_cols;
391
392 let blocks_per_row = num_cols / block_width;
393 let blocks_per_col = num_rows / block_height;
394
395 let mut groups = Vec::with_capacity(blocks_per_row * blocks_per_col);
396 for block_row in 0..blocks_per_col {
397 for block_col in 0..blocks_per_row {
398 let mut group = Vec::with_capacity(block_height * block_width);
399 for r in 0..*block_height {
400 for c in 0..*block_width {
401 let row = block_row * block_height + r;
402 let col = block_col * block_width + c;
403 group.push(row * num_cols + col);
404 }
405 }
406 groups.push(group);
407 }
408 }
409 Ok(groups)
410 }
411 }
412 }
413
414 fn infer_matrix_columns(
419 &self,
420 total_params: usize,
421 block_height: usize,
422 block_width: usize,
423 ) -> Result<usize> {
424 let target = (total_params as f64).sqrt();
425 let mut best_candidate: Option<usize> = None;
426 let mut best_distance = f64::MAX;
427
428 let mut candidate = block_width;
429 while candidate <= total_params {
430 if total_params.is_multiple_of(candidate) {
431 let rows = total_params / candidate;
432 if rows.is_multiple_of(block_height) {
433 let distance = (candidate as f64 - target).abs();
434 if distance < best_distance {
435 best_distance = distance;
436 best_candidate = Some(candidate);
437 }
438 }
439 }
440 candidate += block_width;
441 }
442
443 best_candidate.ok_or_else(|| {
444 OptimError::InvalidConfig(format!(
445 "Cannot decompose {} parameters into blocks of {}x{}",
446 total_params, block_height, block_width
447 ))
448 })
449 }
450}
451
452impl<A, D> Regularizer<A, D> for StructuredSparsity<A>
453where
454 A: Float + ScalarOperand + Debug,
455 D: Dimension,
456{
457 fn apply(&self, params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
458 let total_params = params.len();
459 let groups = self.build_groups(total_params)?;
460
461 let epsilon = A::from(1e-8).unwrap_or_else(|| A::epsilon());
462
463 let grad_slice = gradients.as_slice_memory_order_mut().ok_or_else(|| {
464 OptimError::InvalidParameter("Gradients array is not contiguous in memory".to_string())
465 })?;
466
467 let param_slice = params.as_slice_memory_order().ok_or_else(|| {
468 OptimError::InvalidParameter("Parameters array is not contiguous in memory".to_string())
469 })?;
470
471 for group in &groups {
472 let sum_sq = group.iter().fold(A::zero(), |acc, &idx| {
474 if idx < total_params {
475 acc + param_slice[idx] * param_slice[idx]
476 } else {
477 acc
478 }
479 });
480 let norm = sum_sq.sqrt();
481
482 let scale = self.lambda / (norm + epsilon);
483
484 for &idx in group {
485 if idx < total_params {
486 grad_slice[idx] = grad_slice[idx] + scale * param_slice[idx];
487 }
488 }
489 }
490
491 self.penalty(params)
492 }
493
494 fn penalty(&self, params: &Array<A, D>) -> Result<A> {
495 let total_params = params.len();
496 let groups = self.build_groups(total_params)?;
497
498 let param_slice = params.as_slice_memory_order().ok_or_else(|| {
499 OptimError::InvalidParameter("Parameters array is not contiguous in memory".to_string())
500 })?;
501
502 let mut total = A::zero();
503
504 for group in &groups {
505 let sum_sq = group.iter().fold(A::zero(), |acc, &idx| {
506 if idx < total_params {
507 acc + param_slice[idx] * param_slice[idx]
508 } else {
509 acc
510 }
511 });
512 total = total + sum_sq.sqrt();
513 }
514
515 Ok(self.lambda * total)
516 }
517}
518
519#[cfg(test)]
520mod tests {
521 use super::*;
522 use approx::assert_abs_diff_eq;
523 use scirs2_core::ndarray::Array1;
524
525 #[test]
526 fn test_group_lasso_basic_penalty() {
527 let regularizer = GroupLasso::new(0.1_f64).with_groups(vec![vec![0, 1, 2], vec![3, 4, 5]]);
532
533 let params = Array1::from_vec(vec![1.0, 2.0, 3.0, 0.0, 0.0, 0.0]);
534 let penalty = regularizer
535 .penalty(¶ms)
536 .expect("penalty computation failed");
537
538 let expected = 0.1 * (14.0_f64).sqrt();
539 assert_abs_diff_eq!(penalty, expected, epsilon = 1e-10);
540 }
541
542 #[test]
543 fn test_group_lasso_with_weights() {
544 let regularizer = GroupLasso::new(0.5_f64)
546 .with_groups(vec![vec![0, 1], vec![2, 3]])
547 .with_group_weights(vec![2.0, 0.5]);
548
549 let params = Array1::from_vec(vec![3.0, 4.0, 1.0, 0.0]);
550 let penalty = regularizer
551 .penalty(¶ms)
552 .expect("penalty computation failed");
553
554 let expected = 0.5 * (2.0 * 5.0 + 0.5 * 1.0);
558 assert_abs_diff_eq!(penalty, expected, epsilon = 1e-10);
559 }
560
561 #[test]
562 fn test_group_lasso_auto_groups() {
563 let regularizer = GroupLasso::new(0.1_f64).auto_groups(9, 3);
564
565 assert_eq!(regularizer.num_groups(), 3);
567 assert_eq!(regularizer.groups()[0], vec![0, 1, 2]);
568 assert_eq!(regularizer.groups()[1], vec![3, 4, 5]);
569 assert_eq!(regularizer.groups()[2], vec![6, 7, 8]);
570
571 let regularizer2 = GroupLasso::new(0.1_f64).auto_groups(7, 3);
573 assert_eq!(regularizer2.num_groups(), 3);
574 assert_eq!(regularizer2.groups()[0], vec![0, 1, 2]);
575 assert_eq!(regularizer2.groups()[1], vec![3, 4, 5]);
576 assert_eq!(regularizer2.groups()[2], vec![6]); }
578
579 #[test]
580 fn test_group_lasso_gradient_application() {
581 let regularizer = GroupLasso::new(1.0_f64).with_groups(vec![vec![0, 1], vec![2, 3]]);
582
583 let params = Array1::from_vec(vec![3.0, 4.0, 0.0, 0.0]);
584 let mut gradients = Array1::zeros(4);
585
586 let penalty = regularizer
587 .apply(¶ms, &mut gradients)
588 .expect("apply failed");
589
590 let epsilon = 1e-8_f64;
594 let norm0 = 5.0_f64;
595 assert_abs_diff_eq!(gradients[0], 3.0 / (norm0 + epsilon), epsilon = 1e-6);
596 assert_abs_diff_eq!(gradients[1], 4.0 / (norm0 + epsilon), epsilon = 1e-6);
597
598 assert_abs_diff_eq!(gradients[2], 0.0, epsilon = 1e-6);
600 assert_abs_diff_eq!(gradients[3], 0.0, epsilon = 1e-6);
601
602 assert_abs_diff_eq!(penalty, 5.0, epsilon = 1e-10);
604 }
605
606 #[test]
607 fn test_structured_sparsity_column() {
608 let regularizer =
611 StructuredSparsity::new(0.1_f64, SparsityPattern::Column { num_columns: 4 });
612
613 let params = Array1::from_vec(vec![
618 1.0, 0.0, 3.0, 0.0, 5.0, 0.0, 7.0, 0.0, 9.0, 0.0, 11.0, 0.0,
619 ]);
620
621 let penalty = regularizer
622 .penalty(¶ms)
623 .expect("penalty computation failed");
624
625 let expected = 0.1 * (107.0_f64.sqrt() + 179.0_f64.sqrt());
630 assert_abs_diff_eq!(penalty, expected, epsilon = 1e-10);
631 }
632
633 #[test]
634 fn test_structured_sparsity_row() {
635 let regularizer = StructuredSparsity::new(0.5_f64, SparsityPattern::Row { num_rows: 3 });
637
638 let params = Array1::from_vec(vec![1.0, 2.0, 0.0, 0.0, 3.0, 4.0]);
642
643 let penalty = regularizer
644 .penalty(¶ms)
645 .expect("penalty computation failed");
646
647 let expected = 0.5 * (5.0_f64.sqrt() + 5.0);
651 assert_abs_diff_eq!(penalty, expected, epsilon = 1e-10);
652 }
653
654 #[test]
655 fn test_structured_sparsity_block() {
656 let regularizer = StructuredSparsity::new(
658 0.2_f64,
659 SparsityPattern::Block {
660 block_height: 2,
661 block_width: 2,
662 },
663 );
664
665 let params = Array1::from_vec(vec![
670 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 2.0, 2.0, 0.0, 0.0, 2.0, 2.0,
671 ]);
672
673 let penalty = regularizer
674 .penalty(¶ms)
675 .expect("penalty computation failed");
676
677 let expected = 0.2 * (2.0 + 4.0);
682 assert_abs_diff_eq!(penalty, expected, epsilon = 1e-10);
683 }
684
685 #[test]
686 fn test_structured_sparsity_gradient_application() {
687 let regularizer = StructuredSparsity::new(1.0_f64, SparsityPattern::Row { num_rows: 2 });
688
689 let params = Array1::from_vec(vec![3.0, 4.0, 0.0, 0.0]);
692 let mut gradients = Array1::zeros(4);
693
694 let _penalty = regularizer
695 .apply(¶ms, &mut gradients)
696 .expect("apply failed");
697
698 let epsilon = 1e-8_f64;
702 assert_abs_diff_eq!(gradients[0], 3.0 / (5.0 + epsilon), epsilon = 1e-6);
703 assert_abs_diff_eq!(gradients[1], 4.0 / (5.0 + epsilon), epsilon = 1e-6);
704
705 assert_abs_diff_eq!(gradients[2], 0.0, epsilon = 1e-6);
707 assert_abs_diff_eq!(gradients[3], 0.0, epsilon = 1e-6);
708 }
709
710 #[test]
711 fn test_group_lasso_empty_groups() {
712 let regularizer = GroupLasso::<f64>::new(0.1);
714
715 let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
716 let penalty = regularizer
717 .penalty(¶ms)
718 .expect("penalty computation failed");
719
720 assert_abs_diff_eq!(penalty, 0.0, epsilon = 1e-10);
721 }
722
723 #[test]
724 fn test_group_lasso_out_of_bounds_index() {
725 let regularizer = GroupLasso::new(0.1_f64).with_groups(vec![vec![0, 1, 100]]); let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
728 let result = regularizer.penalty(¶ms);
729
730 assert!(result.is_err());
731 }
732
733 #[test]
734 fn test_group_lasso_weight_mismatch() {
735 let regularizer = GroupLasso::new(0.1_f64)
736 .with_groups(vec![vec![0, 1], vec![2, 3]])
737 .with_group_weights(vec![1.0]); let params = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
740 let result = regularizer.penalty(¶ms);
741
742 assert!(result.is_err());
743 }
744
745 #[test]
746 fn test_structured_sparsity_invalid_dimensions() {
747 let regularizer =
749 StructuredSparsity::new(0.1_f64, SparsityPattern::Column { num_columns: 3 });
750
751 let params = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]);
752 let result = regularizer.penalty(¶ms);
753
754 assert!(result.is_err());
755 }
756
757 #[test]
758 fn test_structured_sparsity_zero_columns() {
759 let regularizer =
760 StructuredSparsity::new(0.1_f64, SparsityPattern::Column { num_columns: 0 });
761
762 let params = Array1::from_vec(vec![1.0, 2.0]);
763 let result = regularizer.penalty(¶ms);
764
765 assert!(result.is_err());
766 }
767
768 #[test]
769 fn test_group_lasso_builder_pattern() {
770 let regularizer = GroupLasso::new(0.5_f64)
771 .with_groups(vec![vec![0, 1], vec![2, 3]])
772 .with_group_weights(vec![1.0, 2.0]);
773
774 assert_eq!(regularizer.lambda(), 0.5);
775 assert_eq!(regularizer.num_groups(), 2);
776 assert_eq!(regularizer.groups()[0], vec![0, 1]);
777 assert_eq!(regularizer.groups()[1], vec![2, 3]);
778 }
779}