Skip to main content

optirs_core/regularizers/
group_lasso.rs

1// Group Lasso regularization with structured sparsity
2//
3// This module provides Group Lasso regularization which encourages entire groups
4// of parameters to be zeroed out, enabling structured sparsity. It also provides
5// structured sparsity patterns (column, row, block) for matrix parameters.
6
7use scirs2_core::ndarray::{Array, Dimension, ScalarOperand, Zip};
8use scirs2_core::numeric::Float;
9use std::fmt::Debug;
10
11use crate::error::{OptimError, Result};
12use crate::regularizers::Regularizer;
13
14/// Group Lasso regularizer for structured sparsity
15///
16/// The Group Lasso penalty encourages entire groups of parameters to be zero,
17/// rather than individual parameters (as in standard L1/Lasso). This is useful
18/// when parameters have a natural grouping structure (e.g., features belonging
19/// to the same category, filters in a convolutional layer).
20///
21/// Penalty: `lambda * sum_g(w_g * ||params[group_g]||_2)`
22///
23/// where `w_g` is the weight for group `g` and `||.||_2` is the L2 norm.
24///
25/// # Examples
26///
27/// ```
28/// use scirs2_core::ndarray::Array1;
29/// use optirs_core::regularizers::{GroupLasso, Regularizer};
30///
31/// // Create a Group Lasso regularizer with two groups
32/// let regularizer = GroupLasso::new(0.1_f64)
33///     .with_groups(vec![vec![0, 1, 2], vec![3, 4, 5]]);
34///
35/// let params = Array1::from_vec(vec![1.0, 2.0, 3.0, 0.0, 0.0, 0.0]);
36/// let penalty = regularizer.penalty(&params).expect("penalty computation failed");
37/// ```
38#[derive(Debug, Clone)]
39pub struct GroupLasso<A: Float + ScalarOperand + Debug> {
40    /// Regularization strength
41    lambda: A,
42    /// Index sets per group - each inner Vec contains parameter indices for that group
43    groups: Vec<Vec<usize>>,
44    /// Per-group weights (default: all 1.0)
45    group_weights: Option<Vec<A>>,
46}
47
48impl<A: Float + ScalarOperand + Debug> GroupLasso<A> {
49    /// Create a new Group Lasso regularizer
50    ///
51    /// # Arguments
52    ///
53    /// * `lambda` - Regularization strength (must be non-negative)
54    pub fn new(lambda: A) -> Self {
55        Self {
56            lambda,
57            groups: Vec::new(),
58            group_weights: None,
59        }
60    }
61
62    /// Set the groups for the regularizer (builder pattern)
63    ///
64    /// # Arguments
65    ///
66    /// * `groups` - A vector of index sets, where each inner vector contains the
67    ///   parameter indices belonging to that group
68    pub fn with_groups(mut self, groups: Vec<Vec<usize>>) -> Self {
69        self.groups = groups;
70        self
71    }
72
73    /// Set per-group weights (builder pattern)
74    ///
75    /// # Arguments
76    ///
77    /// * `weights` - A vector of weights, one per group. Must have the same
78    ///   length as the number of groups.
79    pub fn with_group_weights(mut self, weights: Vec<A>) -> Self {
80        self.group_weights = Some(weights);
81        self
82    }
83
84    /// Automatically create equal-sized groups (builder pattern)
85    ///
86    /// Partitions parameter indices `[0, param_size)` into groups of `group_size`.
87    /// The last group may be smaller if `param_size` is not evenly divisible.
88    ///
89    /// # Arguments
90    ///
91    /// * `param_size` - Total number of parameters
92    /// * `group_size` - Number of parameters per group
93    pub fn auto_groups(mut self, param_size: usize, group_size: usize) -> Self {
94        let mut groups = Vec::new();
95        let mut start = 0;
96        while start < param_size {
97            let end = (start + group_size).min(param_size);
98            groups.push((start..end).collect());
99            start = end;
100        }
101        self.groups = groups;
102        self
103    }
104
105    /// Get the regularization strength
106    pub fn lambda(&self) -> A {
107        self.lambda
108    }
109
110    /// Get the groups
111    pub fn groups(&self) -> &[Vec<usize>] {
112        &self.groups
113    }
114
115    /// Get the number of groups
116    pub fn num_groups(&self) -> usize {
117        self.groups.len()
118    }
119
120    /// Get the weight for a specific group
121    fn group_weight(&self, group_idx: usize) -> A {
122        self.group_weights
123            .as_ref()
124            .and_then(|w| w.get(group_idx).copied())
125            .unwrap_or_else(A::one)
126    }
127
128    /// Compute the L2 norm of parameters at the given indices
129    ///
130    /// Flattens the array and accesses elements by their linear index.
131    fn group_l2_norm(&self, params: &Array<A, impl Dimension>, indices: &[usize]) -> A {
132        let flat = params.as_slice_memory_order();
133        let sum_sq = indices.iter().fold(A::zero(), |acc, &idx| {
134            if let Some(slice) = flat {
135                if idx < slice.len() {
136                    acc + slice[idx] * slice[idx]
137                } else {
138                    acc
139                }
140            } else {
141                // Fallback for non-contiguous arrays
142                let mut iter = params.iter();
143                if let Some(&val) = iter.nth(idx) {
144                    acc + val * val
145                } else {
146                    acc
147                }
148            }
149        });
150        sum_sq.sqrt()
151    }
152
153    /// Validate that group indices are within bounds for the given parameter array
154    fn validate_groups(&self, param_len: usize) -> Result<()> {
155        for (g_idx, group) in self.groups.iter().enumerate() {
156            for &idx in group {
157                if idx >= param_len {
158                    return Err(OptimError::InvalidParameter(format!(
159                        "Group {} contains index {} which exceeds parameter size {}",
160                        g_idx, idx, param_len
161                    )));
162                }
163            }
164        }
165        if let Some(ref weights) = self.group_weights {
166            if weights.len() != self.groups.len() {
167                return Err(OptimError::InvalidConfig(format!(
168                    "Number of group weights ({}) does not match number of groups ({})",
169                    weights.len(),
170                    self.groups.len()
171                )));
172            }
173        }
174        Ok(())
175    }
176}
177
178impl<A, D> Regularizer<A, D> for GroupLasso<A>
179where
180    A: Float + ScalarOperand + Debug,
181    D: Dimension,
182{
183    fn apply(&self, params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
184        let param_len = params.len();
185        self.validate_groups(param_len)?;
186
187        let epsilon = A::from(1e-8).unwrap_or_else(|| A::epsilon());
188
189        // Get mutable slice for gradients
190        let grad_slice = gradients.as_slice_memory_order_mut().ok_or_else(|| {
191            OptimError::InvalidParameter("Gradients array is not contiguous in memory".to_string())
192        })?;
193
194        let param_slice = params.as_slice_memory_order().ok_or_else(|| {
195            OptimError::InvalidParameter("Parameters array is not contiguous in memory".to_string())
196        })?;
197
198        for (g_idx, group) in self.groups.iter().enumerate() {
199            let w_g = self.group_weight(g_idx);
200
201            // Compute L2 norm for this group
202            let sum_sq = group.iter().fold(A::zero(), |acc, &idx| {
203                if idx < param_len {
204                    acc + param_slice[idx] * param_slice[idx]
205                } else {
206                    acc
207                }
208            });
209            let norm = sum_sq.sqrt();
210
211            // Gradient: lambda * w_g * param[i] / (||group||_2 + epsilon)
212            let scale = self.lambda * w_g / (norm + epsilon);
213
214            for &idx in group {
215                if idx < param_len {
216                    grad_slice[idx] = grad_slice[idx] + scale * param_slice[idx];
217                }
218            }
219        }
220
221        self.penalty(params)
222    }
223
224    fn penalty(&self, params: &Array<A, D>) -> Result<A> {
225        let param_len = params.len();
226        self.validate_groups(param_len)?;
227
228        let mut total = A::zero();
229
230        for (g_idx, group) in self.groups.iter().enumerate() {
231            let w_g = self.group_weight(g_idx);
232            let norm = self.group_l2_norm(params, group);
233            total = total + w_g * norm;
234        }
235
236        Ok(self.lambda * total)
237    }
238}
239
240/// Sparsity pattern for structured sparsity regularization
241///
242/// Defines how parameters are grouped for structured sparsity.
243/// Different patterns encourage different structural zeros in
244/// the parameter matrix.
245#[derive(Debug, Clone)]
246pub enum SparsityPattern {
247    /// Column-wise sparsity: entire columns of the parameter matrix are zeroed
248    ///
249    /// Groups parameters by column index, encouraging entire columns to be sparse.
250    Column {
251        /// Number of columns in the parameter matrix
252        num_columns: usize,
253    },
254    /// Row-wise sparsity: entire rows of the parameter matrix are zeroed
255    ///
256    /// Groups parameters by row index, encouraging entire rows to be sparse.
257    Row {
258        /// Number of rows in the parameter matrix
259        num_rows: usize,
260    },
261    /// Block-wise sparsity: rectangular blocks of parameters are zeroed
262    ///
263    /// Groups parameters into rectangular blocks, encouraging entire blocks to be sparse.
264    Block {
265        /// Height of each block (number of rows)
266        block_height: usize,
267        /// Width of each block (number of columns)
268        block_width: usize,
269    },
270}
271
272/// Structured sparsity regularizer
273///
274/// Applies Group Lasso regularization with groups defined by a structural
275/// pattern (columns, rows, or blocks). This is useful for encouraging
276/// structured sparsity in weight matrices, e.g., pruning entire neurons
277/// (row sparsity) or features (column sparsity).
278///
279/// # Examples
280///
281/// ```
282/// use scirs2_core::ndarray::Array1;
283/// use optirs_core::regularizers::{StructuredSparsity, SparsityPattern, Regularizer};
284///
285/// // Create column-wise structured sparsity for a 3x4 matrix (stored as 12-element vector)
286/// let regularizer = StructuredSparsity::new(0.1_f64, SparsityPattern::Column { num_columns: 4 });
287///
288/// let params = Array1::from_vec(vec![1.0, 0.0, 3.0, 0.0, 5.0, 0.0, 7.0, 0.0, 9.0, 0.0, 11.0, 0.0]);
289/// let penalty = regularizer.penalty(&params).expect("penalty computation failed");
290/// ```
291#[derive(Debug, Clone)]
292pub struct StructuredSparsity<A: Float + ScalarOperand + Debug> {
293    /// Regularization strength
294    lambda: A,
295    /// The structural sparsity pattern
296    pattern: SparsityPattern,
297}
298
299impl<A: Float + ScalarOperand + Debug> StructuredSparsity<A> {
300    /// Create a new structured sparsity regularizer
301    ///
302    /// # Arguments
303    ///
304    /// * `lambda` - Regularization strength
305    /// * `pattern` - The structural pattern defining how parameters are grouped
306    pub fn new(lambda: A, pattern: SparsityPattern) -> Self {
307        Self { lambda, pattern }
308    }
309
310    /// Get the regularization strength
311    pub fn lambda(&self) -> A {
312        self.lambda
313    }
314
315    /// Get the sparsity pattern
316    pub fn pattern(&self) -> &SparsityPattern {
317        &self.pattern
318    }
319
320    /// Build groups from the sparsity pattern for a given parameter count
321    ///
322    /// For a matrix with `total_params` elements stored in row-major order:
323    /// - Column pattern: groups parameters sharing the same column index
324    /// - Row pattern: groups parameters sharing the same row index
325    /// - Block pattern: groups parameters into rectangular blocks
326    fn build_groups(&self, total_params: usize) -> Result<Vec<Vec<usize>>> {
327        match &self.pattern {
328            SparsityPattern::Column { num_columns } => {
329                if *num_columns == 0 {
330                    return Err(OptimError::InvalidConfig(
331                        "Number of columns must be greater than 0".to_string(),
332                    ));
333                }
334                let num_rows = total_params / num_columns;
335                if num_rows * num_columns != total_params {
336                    return Err(OptimError::InvalidConfig(format!(
337                        "Total parameters ({}) is not evenly divisible by num_columns ({})",
338                        total_params, num_columns
339                    )));
340                }
341
342                let mut groups = Vec::with_capacity(*num_columns);
343                for col in 0..*num_columns {
344                    let group: Vec<usize> =
345                        (0..num_rows).map(|row| row * num_columns + col).collect();
346                    groups.push(group);
347                }
348                Ok(groups)
349            }
350            SparsityPattern::Row { num_rows } => {
351                if *num_rows == 0 {
352                    return Err(OptimError::InvalidConfig(
353                        "Number of rows must be greater than 0".to_string(),
354                    ));
355                }
356                let num_columns = total_params / num_rows;
357                if num_rows * num_columns != total_params {
358                    return Err(OptimError::InvalidConfig(format!(
359                        "Total parameters ({}) is not evenly divisible by num_rows ({})",
360                        total_params, num_rows
361                    )));
362                }
363
364                let mut groups = Vec::with_capacity(*num_rows);
365                for row in 0..*num_rows {
366                    let start = row * num_columns;
367                    let group: Vec<usize> = (start..start + num_columns).collect();
368                    groups.push(group);
369                }
370                Ok(groups)
371            }
372            SparsityPattern::Block {
373                block_height,
374                block_width,
375            } => {
376                if *block_height == 0 || *block_width == 0 {
377                    return Err(OptimError::InvalidConfig(
378                        "Block dimensions must be greater than 0".to_string(),
379                    ));
380                }
381
382                // Infer the total number of columns from block_width
383                // We need to figure out the matrix dimensions. We assume the matrix
384                // width is a multiple of block_width. We try to find a valid decomposition.
385                // For block sparsity, we need the user to provide a compatible total_params.
386                // We compute the number of columns as the smallest multiple of block_width
387                // such that total_params / num_cols is a multiple of block_height.
388                let num_cols =
389                    self.infer_matrix_columns(total_params, *block_height, *block_width)?;
390                let num_rows = total_params / num_cols;
391
392                let blocks_per_row = num_cols / block_width;
393                let blocks_per_col = num_rows / block_height;
394
395                let mut groups = Vec::with_capacity(blocks_per_row * blocks_per_col);
396                for block_row in 0..blocks_per_col {
397                    for block_col in 0..blocks_per_row {
398                        let mut group = Vec::with_capacity(block_height * block_width);
399                        for r in 0..*block_height {
400                            for c in 0..*block_width {
401                                let row = block_row * block_height + r;
402                                let col = block_col * block_width + c;
403                                group.push(row * num_cols + col);
404                            }
405                        }
406                        groups.push(group);
407                    }
408                }
409                Ok(groups)
410            }
411        }
412    }
413
414    /// Infer the number of matrix columns for block sparsity
415    ///
416    /// Tries to find the candidate number of columns (a multiple of block_width)
417    /// that produces the most square-like matrix decomposition.
418    fn infer_matrix_columns(
419        &self,
420        total_params: usize,
421        block_height: usize,
422        block_width: usize,
423    ) -> Result<usize> {
424        let target = (total_params as f64).sqrt();
425        let mut best_candidate: Option<usize> = None;
426        let mut best_distance = f64::MAX;
427
428        let mut candidate = block_width;
429        while candidate <= total_params {
430            if total_params.is_multiple_of(candidate) {
431                let rows = total_params / candidate;
432                if rows.is_multiple_of(block_height) {
433                    let distance = (candidate as f64 - target).abs();
434                    if distance < best_distance {
435                        best_distance = distance;
436                        best_candidate = Some(candidate);
437                    }
438                }
439            }
440            candidate += block_width;
441        }
442
443        best_candidate.ok_or_else(|| {
444            OptimError::InvalidConfig(format!(
445                "Cannot decompose {} parameters into blocks of {}x{}",
446                total_params, block_height, block_width
447            ))
448        })
449    }
450}
451
452impl<A, D> Regularizer<A, D> for StructuredSparsity<A>
453where
454    A: Float + ScalarOperand + Debug,
455    D: Dimension,
456{
457    fn apply(&self, params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
458        let total_params = params.len();
459        let groups = self.build_groups(total_params)?;
460
461        let epsilon = A::from(1e-8).unwrap_or_else(|| A::epsilon());
462
463        let grad_slice = gradients.as_slice_memory_order_mut().ok_or_else(|| {
464            OptimError::InvalidParameter("Gradients array is not contiguous in memory".to_string())
465        })?;
466
467        let param_slice = params.as_slice_memory_order().ok_or_else(|| {
468            OptimError::InvalidParameter("Parameters array is not contiguous in memory".to_string())
469        })?;
470
471        for group in &groups {
472            // Compute L2 norm for this group
473            let sum_sq = group.iter().fold(A::zero(), |acc, &idx| {
474                if idx < total_params {
475                    acc + param_slice[idx] * param_slice[idx]
476                } else {
477                    acc
478                }
479            });
480            let norm = sum_sq.sqrt();
481
482            let scale = self.lambda / (norm + epsilon);
483
484            for &idx in group {
485                if idx < total_params {
486                    grad_slice[idx] = grad_slice[idx] + scale * param_slice[idx];
487                }
488            }
489        }
490
491        self.penalty(params)
492    }
493
494    fn penalty(&self, params: &Array<A, D>) -> Result<A> {
495        let total_params = params.len();
496        let groups = self.build_groups(total_params)?;
497
498        let param_slice = params.as_slice_memory_order().ok_or_else(|| {
499            OptimError::InvalidParameter("Parameters array is not contiguous in memory".to_string())
500        })?;
501
502        let mut total = A::zero();
503
504        for group in &groups {
505            let sum_sq = group.iter().fold(A::zero(), |acc, &idx| {
506                if idx < total_params {
507                    acc + param_slice[idx] * param_slice[idx]
508                } else {
509                    acc
510                }
511            });
512            total = total + sum_sq.sqrt();
513        }
514
515        Ok(self.lambda * total)
516    }
517}
518
519#[cfg(test)]
520mod tests {
521    use super::*;
522    use approx::assert_abs_diff_eq;
523    use scirs2_core::ndarray::Array1;
524
525    #[test]
526    fn test_group_lasso_basic_penalty() {
527        // Two groups: [0,1,2] and [3,4,5]
528        // Group 0: params = [1, 2, 3], ||group||_2 = sqrt(1+4+9) = sqrt(14)
529        // Group 1: params = [0, 0, 0], ||group||_2 = 0
530        // Penalty = 0.1 * (sqrt(14) + 0) = 0.1 * sqrt(14)
531        let regularizer = GroupLasso::new(0.1_f64).with_groups(vec![vec![0, 1, 2], vec![3, 4, 5]]);
532
533        let params = Array1::from_vec(vec![1.0, 2.0, 3.0, 0.0, 0.0, 0.0]);
534        let penalty = regularizer
535            .penalty(&params)
536            .expect("penalty computation failed");
537
538        let expected = 0.1 * (14.0_f64).sqrt();
539        assert_abs_diff_eq!(penalty, expected, epsilon = 1e-10);
540    }
541
542    #[test]
543    fn test_group_lasso_with_weights() {
544        // Two groups with different weights
545        let regularizer = GroupLasso::new(0.5_f64)
546            .with_groups(vec![vec![0, 1], vec![2, 3]])
547            .with_group_weights(vec![2.0, 0.5]);
548
549        let params = Array1::from_vec(vec![3.0, 4.0, 1.0, 0.0]);
550        let penalty = regularizer
551            .penalty(&params)
552            .expect("penalty computation failed");
553
554        // Group 0: w=2.0, norm=sqrt(9+16)=5.0 => 2.0*5.0=10.0
555        // Group 1: w=0.5, norm=sqrt(1+0)=1.0 => 0.5*1.0=0.5
556        // Total: 0.5 * (10.0 + 0.5) = 5.25
557        let expected = 0.5 * (2.0 * 5.0 + 0.5 * 1.0);
558        assert_abs_diff_eq!(penalty, expected, epsilon = 1e-10);
559    }
560
561    #[test]
562    fn test_group_lasso_auto_groups() {
563        let regularizer = GroupLasso::new(0.1_f64).auto_groups(9, 3);
564
565        // Should create 3 groups: [0,1,2], [3,4,5], [6,7,8]
566        assert_eq!(regularizer.num_groups(), 3);
567        assert_eq!(regularizer.groups()[0], vec![0, 1, 2]);
568        assert_eq!(regularizer.groups()[1], vec![3, 4, 5]);
569        assert_eq!(regularizer.groups()[2], vec![6, 7, 8]);
570
571        // Test with non-evenly-divisible size
572        let regularizer2 = GroupLasso::new(0.1_f64).auto_groups(7, 3);
573        assert_eq!(regularizer2.num_groups(), 3);
574        assert_eq!(regularizer2.groups()[0], vec![0, 1, 2]);
575        assert_eq!(regularizer2.groups()[1], vec![3, 4, 5]);
576        assert_eq!(regularizer2.groups()[2], vec![6]); // Remainder group
577    }
578
579    #[test]
580    fn test_group_lasso_gradient_application() {
581        let regularizer = GroupLasso::new(1.0_f64).with_groups(vec![vec![0, 1], vec![2, 3]]);
582
583        let params = Array1::from_vec(vec![3.0, 4.0, 0.0, 0.0]);
584        let mut gradients = Array1::zeros(4);
585
586        let penalty = regularizer
587            .apply(&params, &mut gradients)
588            .expect("apply failed");
589
590        // Group 0: norm = sqrt(9+16) = 5.0
591        // Gradient for idx 0: 1.0 * 1.0 * 3.0 / (5.0 + 1e-8) ~ 0.6
592        // Gradient for idx 1: 1.0 * 1.0 * 4.0 / (5.0 + 1e-8) ~ 0.8
593        let epsilon = 1e-8_f64;
594        let norm0 = 5.0_f64;
595        assert_abs_diff_eq!(gradients[0], 3.0 / (norm0 + epsilon), epsilon = 1e-6);
596        assert_abs_diff_eq!(gradients[1], 4.0 / (norm0 + epsilon), epsilon = 1e-6);
597
598        // Group 1: norm = 0.0, so gradient ~ lambda * 0 / epsilon ~ 0
599        assert_abs_diff_eq!(gradients[2], 0.0, epsilon = 1e-6);
600        assert_abs_diff_eq!(gradients[3], 0.0, epsilon = 1e-6);
601
602        // Penalty: 1.0 * (5.0 + 0.0) = 5.0
603        assert_abs_diff_eq!(penalty, 5.0, epsilon = 1e-10);
604    }
605
606    #[test]
607    fn test_structured_sparsity_column() {
608        // 3x4 matrix stored as 12 elements (row-major)
609        // Columns: col0=[0,4,8], col1=[1,5,9], col2=[2,6,10], col3=[3,7,11]
610        let regularizer =
611            StructuredSparsity::new(0.1_f64, SparsityPattern::Column { num_columns: 4 });
612
613        // Matrix (row-major):
614        // [1, 0, 3, 0]
615        // [5, 0, 7, 0]
616        // [9, 0, 11, 0]
617        let params = Array1::from_vec(vec![
618            1.0, 0.0, 3.0, 0.0, 5.0, 0.0, 7.0, 0.0, 9.0, 0.0, 11.0, 0.0,
619        ]);
620
621        let penalty = regularizer
622            .penalty(&params)
623            .expect("penalty computation failed");
624
625        // Col 0: [1,5,9] => norm = sqrt(1+25+81) = sqrt(107)
626        // Col 1: [0,0,0] => norm = 0
627        // Col 2: [3,7,11] => norm = sqrt(9+49+121) = sqrt(179)
628        // Col 3: [0,0,0] => norm = 0
629        let expected = 0.1 * (107.0_f64.sqrt() + 179.0_f64.sqrt());
630        assert_abs_diff_eq!(penalty, expected, epsilon = 1e-10);
631    }
632
633    #[test]
634    fn test_structured_sparsity_row() {
635        // 3x2 matrix stored as 6 elements
636        let regularizer = StructuredSparsity::new(0.5_f64, SparsityPattern::Row { num_rows: 3 });
637
638        // [1, 2]
639        // [0, 0]
640        // [3, 4]
641        let params = Array1::from_vec(vec![1.0, 2.0, 0.0, 0.0, 3.0, 4.0]);
642
643        let penalty = regularizer
644            .penalty(&params)
645            .expect("penalty computation failed");
646
647        // Row 0: [1,2] => norm = sqrt(5)
648        // Row 1: [0,0] => norm = 0
649        // Row 2: [3,4] => norm = sqrt(25) = 5
650        let expected = 0.5 * (5.0_f64.sqrt() + 5.0);
651        assert_abs_diff_eq!(penalty, expected, epsilon = 1e-10);
652    }
653
654    #[test]
655    fn test_structured_sparsity_block() {
656        // 4x4 matrix with 2x2 blocks => 4 blocks
657        let regularizer = StructuredSparsity::new(
658            0.2_f64,
659            SparsityPattern::Block {
660                block_height: 2,
661                block_width: 2,
662            },
663        );
664
665        // [1, 1, 0, 0]
666        // [1, 1, 0, 0]
667        // [0, 0, 2, 2]
668        // [0, 0, 2, 2]
669        let params = Array1::from_vec(vec![
670            1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 2.0, 2.0, 0.0, 0.0, 2.0, 2.0,
671        ]);
672
673        let penalty = regularizer
674            .penalty(&params)
675            .expect("penalty computation failed");
676
677        // Block (0,0): [1,1,1,1] => norm = 2.0
678        // Block (0,1): [0,0,0,0] => norm = 0.0
679        // Block (1,0): [0,0,0,0] => norm = 0.0
680        // Block (1,1): [2,2,2,2] => norm = 4.0
681        let expected = 0.2 * (2.0 + 4.0);
682        assert_abs_diff_eq!(penalty, expected, epsilon = 1e-10);
683    }
684
685    #[test]
686    fn test_structured_sparsity_gradient_application() {
687        let regularizer = StructuredSparsity::new(1.0_f64, SparsityPattern::Row { num_rows: 2 });
688
689        // [3, 4]
690        // [0, 0]
691        let params = Array1::from_vec(vec![3.0, 4.0, 0.0, 0.0]);
692        let mut gradients = Array1::zeros(4);
693
694        let _penalty = regularizer
695            .apply(&params, &mut gradients)
696            .expect("apply failed");
697
698        // Row 0: norm = 5.0
699        // grad[0] = 1.0 * 3.0 / (5.0 + eps) ~ 0.6
700        // grad[1] = 1.0 * 4.0 / (5.0 + eps) ~ 0.8
701        let epsilon = 1e-8_f64;
702        assert_abs_diff_eq!(gradients[0], 3.0 / (5.0 + epsilon), epsilon = 1e-6);
703        assert_abs_diff_eq!(gradients[1], 4.0 / (5.0 + epsilon), epsilon = 1e-6);
704
705        // Row 1: norm ~ 0, so gradient ~ 0 / eps ~ 0
706        assert_abs_diff_eq!(gradients[2], 0.0, epsilon = 1e-6);
707        assert_abs_diff_eq!(gradients[3], 0.0, epsilon = 1e-6);
708    }
709
710    #[test]
711    fn test_group_lasso_empty_groups() {
712        // No groups => penalty should be zero
713        let regularizer = GroupLasso::<f64>::new(0.1);
714
715        let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
716        let penalty = regularizer
717            .penalty(&params)
718            .expect("penalty computation failed");
719
720        assert_abs_diff_eq!(penalty, 0.0, epsilon = 1e-10);
721    }
722
723    #[test]
724    fn test_group_lasso_out_of_bounds_index() {
725        let regularizer = GroupLasso::new(0.1_f64).with_groups(vec![vec![0, 1, 100]]); // index 100 is out of bounds
726
727        let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
728        let result = regularizer.penalty(&params);
729
730        assert!(result.is_err());
731    }
732
733    #[test]
734    fn test_group_lasso_weight_mismatch() {
735        let regularizer = GroupLasso::new(0.1_f64)
736            .with_groups(vec![vec![0, 1], vec![2, 3]])
737            .with_group_weights(vec![1.0]); // Only 1 weight for 2 groups
738
739        let params = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
740        let result = regularizer.penalty(&params);
741
742        assert!(result.is_err());
743    }
744
745    #[test]
746    fn test_structured_sparsity_invalid_dimensions() {
747        // 7 params cannot be divided into columns of 3
748        let regularizer =
749            StructuredSparsity::new(0.1_f64, SparsityPattern::Column { num_columns: 3 });
750
751        let params = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]);
752        let result = regularizer.penalty(&params);
753
754        assert!(result.is_err());
755    }
756
757    #[test]
758    fn test_structured_sparsity_zero_columns() {
759        let regularizer =
760            StructuredSparsity::new(0.1_f64, SparsityPattern::Column { num_columns: 0 });
761
762        let params = Array1::from_vec(vec![1.0, 2.0]);
763        let result = regularizer.penalty(&params);
764
765        assert!(result.is_err());
766    }
767
768    #[test]
769    fn test_group_lasso_builder_pattern() {
770        let regularizer = GroupLasso::new(0.5_f64)
771            .with_groups(vec![vec![0, 1], vec![2, 3]])
772            .with_group_weights(vec![1.0, 2.0]);
773
774        assert_eq!(regularizer.lambda(), 0.5);
775        assert_eq!(regularizer.num_groups(), 2);
776        assert_eq!(regularizer.groups()[0], vec![0, 1]);
777        assert_eq!(regularizer.groups()[1], vec![2, 3]);
778    }
779}