use scirs2_core::ndarray::{Array, Dimension, ScalarOperand, Zip};
use scirs2_core::numeric::Float;
use std::fmt::Debug;
use crate::error::{OptimError, Result};
use crate::regularizers::Regularizer;
#[derive(Debug, Clone)]
pub struct GroupLasso<A: Float + ScalarOperand + Debug> {
lambda: A,
groups: Vec<Vec<usize>>,
group_weights: Option<Vec<A>>,
}
impl<A: Float + ScalarOperand + Debug> GroupLasso<A> {
pub fn new(lambda: A) -> Self {
Self {
lambda,
groups: Vec::new(),
group_weights: None,
}
}
pub fn with_groups(mut self, groups: Vec<Vec<usize>>) -> Self {
self.groups = groups;
self
}
pub fn with_group_weights(mut self, weights: Vec<A>) -> Self {
self.group_weights = Some(weights);
self
}
pub fn auto_groups(mut self, param_size: usize, group_size: usize) -> Self {
let mut groups = Vec::new();
let mut start = 0;
while start < param_size {
let end = (start + group_size).min(param_size);
groups.push((start..end).collect());
start = end;
}
self.groups = groups;
self
}
pub fn lambda(&self) -> A {
self.lambda
}
pub fn groups(&self) -> &[Vec<usize>] {
&self.groups
}
pub fn num_groups(&self) -> usize {
self.groups.len()
}
fn group_weight(&self, group_idx: usize) -> A {
self.group_weights
.as_ref()
.and_then(|w| w.get(group_idx).copied())
.unwrap_or_else(A::one)
}
fn group_l2_norm(&self, params: &Array<A, impl Dimension>, indices: &[usize]) -> A {
let flat = params.as_slice_memory_order();
let sum_sq = indices.iter().fold(A::zero(), |acc, &idx| {
if let Some(slice) = flat {
if idx < slice.len() {
acc + slice[idx] * slice[idx]
} else {
acc
}
} else {
let mut iter = params.iter();
if let Some(&val) = iter.nth(idx) {
acc + val * val
} else {
acc
}
}
});
sum_sq.sqrt()
}
fn validate_groups(&self, param_len: usize) -> Result<()> {
for (g_idx, group) in self.groups.iter().enumerate() {
for &idx in group {
if idx >= param_len {
return Err(OptimError::InvalidParameter(format!(
"Group {} contains index {} which exceeds parameter size {}",
g_idx, idx, param_len
)));
}
}
}
if let Some(ref weights) = self.group_weights {
if weights.len() != self.groups.len() {
return Err(OptimError::InvalidConfig(format!(
"Number of group weights ({}) does not match number of groups ({})",
weights.len(),
self.groups.len()
)));
}
}
Ok(())
}
}
impl<A, D> Regularizer<A, D> for GroupLasso<A>
where
A: Float + ScalarOperand + Debug,
D: Dimension,
{
fn apply(&self, params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
let param_len = params.len();
self.validate_groups(param_len)?;
let epsilon = A::from(1e-8).unwrap_or_else(|| A::epsilon());
let grad_slice = gradients.as_slice_memory_order_mut().ok_or_else(|| {
OptimError::InvalidParameter("Gradients array is not contiguous in memory".to_string())
})?;
let param_slice = params.as_slice_memory_order().ok_or_else(|| {
OptimError::InvalidParameter("Parameters array is not contiguous in memory".to_string())
})?;
for (g_idx, group) in self.groups.iter().enumerate() {
let w_g = self.group_weight(g_idx);
let sum_sq = group.iter().fold(A::zero(), |acc, &idx| {
if idx < param_len {
acc + param_slice[idx] * param_slice[idx]
} else {
acc
}
});
let norm = sum_sq.sqrt();
let scale = self.lambda * w_g / (norm + epsilon);
for &idx in group {
if idx < param_len {
grad_slice[idx] = grad_slice[idx] + scale * param_slice[idx];
}
}
}
self.penalty(params)
}
fn penalty(&self, params: &Array<A, D>) -> Result<A> {
let param_len = params.len();
self.validate_groups(param_len)?;
let mut total = A::zero();
for (g_idx, group) in self.groups.iter().enumerate() {
let w_g = self.group_weight(g_idx);
let norm = self.group_l2_norm(params, group);
total = total + w_g * norm;
}
Ok(self.lambda * total)
}
}
#[derive(Debug, Clone)]
pub enum SparsityPattern {
Column {
num_columns: usize,
},
Row {
num_rows: usize,
},
Block {
block_height: usize,
block_width: usize,
},
}
#[derive(Debug, Clone)]
pub struct StructuredSparsity<A: Float + ScalarOperand + Debug> {
lambda: A,
pattern: SparsityPattern,
}
impl<A: Float + ScalarOperand + Debug> StructuredSparsity<A> {
pub fn new(lambda: A, pattern: SparsityPattern) -> Self {
Self { lambda, pattern }
}
pub fn lambda(&self) -> A {
self.lambda
}
pub fn pattern(&self) -> &SparsityPattern {
&self.pattern
}
fn build_groups(&self, total_params: usize) -> Result<Vec<Vec<usize>>> {
match &self.pattern {
SparsityPattern::Column { num_columns } => {
if *num_columns == 0 {
return Err(OptimError::InvalidConfig(
"Number of columns must be greater than 0".to_string(),
));
}
let num_rows = total_params / num_columns;
if num_rows * num_columns != total_params {
return Err(OptimError::InvalidConfig(format!(
"Total parameters ({}) is not evenly divisible by num_columns ({})",
total_params, num_columns
)));
}
let mut groups = Vec::with_capacity(*num_columns);
for col in 0..*num_columns {
let group: Vec<usize> =
(0..num_rows).map(|row| row * num_columns + col).collect();
groups.push(group);
}
Ok(groups)
}
SparsityPattern::Row { num_rows } => {
if *num_rows == 0 {
return Err(OptimError::InvalidConfig(
"Number of rows must be greater than 0".to_string(),
));
}
let num_columns = total_params / num_rows;
if num_rows * num_columns != total_params {
return Err(OptimError::InvalidConfig(format!(
"Total parameters ({}) is not evenly divisible by num_rows ({})",
total_params, num_rows
)));
}
let mut groups = Vec::with_capacity(*num_rows);
for row in 0..*num_rows {
let start = row * num_columns;
let group: Vec<usize> = (start..start + num_columns).collect();
groups.push(group);
}
Ok(groups)
}
SparsityPattern::Block {
block_height,
block_width,
} => {
if *block_height == 0 || *block_width == 0 {
return Err(OptimError::InvalidConfig(
"Block dimensions must be greater than 0".to_string(),
));
}
let num_cols =
self.infer_matrix_columns(total_params, *block_height, *block_width)?;
let num_rows = total_params / num_cols;
let blocks_per_row = num_cols / block_width;
let blocks_per_col = num_rows / block_height;
let mut groups = Vec::with_capacity(blocks_per_row * blocks_per_col);
for block_row in 0..blocks_per_col {
for block_col in 0..blocks_per_row {
let mut group = Vec::with_capacity(block_height * block_width);
for r in 0..*block_height {
for c in 0..*block_width {
let row = block_row * block_height + r;
let col = block_col * block_width + c;
group.push(row * num_cols + col);
}
}
groups.push(group);
}
}
Ok(groups)
}
}
}
fn infer_matrix_columns(
&self,
total_params: usize,
block_height: usize,
block_width: usize,
) -> Result<usize> {
let target = (total_params as f64).sqrt();
let mut best_candidate: Option<usize> = None;
let mut best_distance = f64::MAX;
let mut candidate = block_width;
while candidate <= total_params {
if total_params.is_multiple_of(candidate) {
let rows = total_params / candidate;
if rows.is_multiple_of(block_height) {
let distance = (candidate as f64 - target).abs();
if distance < best_distance {
best_distance = distance;
best_candidate = Some(candidate);
}
}
}
candidate += block_width;
}
best_candidate.ok_or_else(|| {
OptimError::InvalidConfig(format!(
"Cannot decompose {} parameters into blocks of {}x{}",
total_params, block_height, block_width
))
})
}
}
impl<A, D> Regularizer<A, D> for StructuredSparsity<A>
where
A: Float + ScalarOperand + Debug,
D: Dimension,
{
fn apply(&self, params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
let total_params = params.len();
let groups = self.build_groups(total_params)?;
let epsilon = A::from(1e-8).unwrap_or_else(|| A::epsilon());
let grad_slice = gradients.as_slice_memory_order_mut().ok_or_else(|| {
OptimError::InvalidParameter("Gradients array is not contiguous in memory".to_string())
})?;
let param_slice = params.as_slice_memory_order().ok_or_else(|| {
OptimError::InvalidParameter("Parameters array is not contiguous in memory".to_string())
})?;
for group in &groups {
let sum_sq = group.iter().fold(A::zero(), |acc, &idx| {
if idx < total_params {
acc + param_slice[idx] * param_slice[idx]
} else {
acc
}
});
let norm = sum_sq.sqrt();
let scale = self.lambda / (norm + epsilon);
for &idx in group {
if idx < total_params {
grad_slice[idx] = grad_slice[idx] + scale * param_slice[idx];
}
}
}
self.penalty(params)
}
fn penalty(&self, params: &Array<A, D>) -> Result<A> {
let total_params = params.len();
let groups = self.build_groups(total_params)?;
let param_slice = params.as_slice_memory_order().ok_or_else(|| {
OptimError::InvalidParameter("Parameters array is not contiguous in memory".to_string())
})?;
let mut total = A::zero();
for group in &groups {
let sum_sq = group.iter().fold(A::zero(), |acc, &idx| {
if idx < total_params {
acc + param_slice[idx] * param_slice[idx]
} else {
acc
}
});
total = total + sum_sq.sqrt();
}
Ok(self.lambda * total)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::Array1;
#[test]
fn test_group_lasso_basic_penalty() {
let regularizer = GroupLasso::new(0.1_f64).with_groups(vec![vec![0, 1, 2], vec![3, 4, 5]]);
let params = Array1::from_vec(vec![1.0, 2.0, 3.0, 0.0, 0.0, 0.0]);
let penalty = regularizer
.penalty(¶ms)
.expect("penalty computation failed");
let expected = 0.1 * (14.0_f64).sqrt();
assert_abs_diff_eq!(penalty, expected, epsilon = 1e-10);
}
#[test]
fn test_group_lasso_with_weights() {
let regularizer = GroupLasso::new(0.5_f64)
.with_groups(vec![vec![0, 1], vec![2, 3]])
.with_group_weights(vec![2.0, 0.5]);
let params = Array1::from_vec(vec![3.0, 4.0, 1.0, 0.0]);
let penalty = regularizer
.penalty(¶ms)
.expect("penalty computation failed");
let expected = 0.5 * (2.0 * 5.0 + 0.5 * 1.0);
assert_abs_diff_eq!(penalty, expected, epsilon = 1e-10);
}
#[test]
fn test_group_lasso_auto_groups() {
let regularizer = GroupLasso::new(0.1_f64).auto_groups(9, 3);
assert_eq!(regularizer.num_groups(), 3);
assert_eq!(regularizer.groups()[0], vec![0, 1, 2]);
assert_eq!(regularizer.groups()[1], vec![3, 4, 5]);
assert_eq!(regularizer.groups()[2], vec![6, 7, 8]);
let regularizer2 = GroupLasso::new(0.1_f64).auto_groups(7, 3);
assert_eq!(regularizer2.num_groups(), 3);
assert_eq!(regularizer2.groups()[0], vec![0, 1, 2]);
assert_eq!(regularizer2.groups()[1], vec![3, 4, 5]);
assert_eq!(regularizer2.groups()[2], vec![6]); }
#[test]
fn test_group_lasso_gradient_application() {
let regularizer = GroupLasso::new(1.0_f64).with_groups(vec![vec![0, 1], vec![2, 3]]);
let params = Array1::from_vec(vec![3.0, 4.0, 0.0, 0.0]);
let mut gradients = Array1::zeros(4);
let penalty = regularizer
.apply(¶ms, &mut gradients)
.expect("apply failed");
let epsilon = 1e-8_f64;
let norm0 = 5.0_f64;
assert_abs_diff_eq!(gradients[0], 3.0 / (norm0 + epsilon), epsilon = 1e-6);
assert_abs_diff_eq!(gradients[1], 4.0 / (norm0 + epsilon), epsilon = 1e-6);
assert_abs_diff_eq!(gradients[2], 0.0, epsilon = 1e-6);
assert_abs_diff_eq!(gradients[3], 0.0, epsilon = 1e-6);
assert_abs_diff_eq!(penalty, 5.0, epsilon = 1e-10);
}
#[test]
fn test_structured_sparsity_column() {
let regularizer =
StructuredSparsity::new(0.1_f64, SparsityPattern::Column { num_columns: 4 });
let params = Array1::from_vec(vec![
1.0, 0.0, 3.0, 0.0, 5.0, 0.0, 7.0, 0.0, 9.0, 0.0, 11.0, 0.0,
]);
let penalty = regularizer
.penalty(¶ms)
.expect("penalty computation failed");
let expected = 0.1 * (107.0_f64.sqrt() + 179.0_f64.sqrt());
assert_abs_diff_eq!(penalty, expected, epsilon = 1e-10);
}
#[test]
fn test_structured_sparsity_row() {
let regularizer = StructuredSparsity::new(0.5_f64, SparsityPattern::Row { num_rows: 3 });
let params = Array1::from_vec(vec![1.0, 2.0, 0.0, 0.0, 3.0, 4.0]);
let penalty = regularizer
.penalty(¶ms)
.expect("penalty computation failed");
let expected = 0.5 * (5.0_f64.sqrt() + 5.0);
assert_abs_diff_eq!(penalty, expected, epsilon = 1e-10);
}
#[test]
fn test_structured_sparsity_block() {
let regularizer = StructuredSparsity::new(
0.2_f64,
SparsityPattern::Block {
block_height: 2,
block_width: 2,
},
);
let params = Array1::from_vec(vec![
1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 2.0, 2.0, 0.0, 0.0, 2.0, 2.0,
]);
let penalty = regularizer
.penalty(¶ms)
.expect("penalty computation failed");
let expected = 0.2 * (2.0 + 4.0);
assert_abs_diff_eq!(penalty, expected, epsilon = 1e-10);
}
#[test]
fn test_structured_sparsity_gradient_application() {
let regularizer = StructuredSparsity::new(1.0_f64, SparsityPattern::Row { num_rows: 2 });
let params = Array1::from_vec(vec![3.0, 4.0, 0.0, 0.0]);
let mut gradients = Array1::zeros(4);
let _penalty = regularizer
.apply(¶ms, &mut gradients)
.expect("apply failed");
let epsilon = 1e-8_f64;
assert_abs_diff_eq!(gradients[0], 3.0 / (5.0 + epsilon), epsilon = 1e-6);
assert_abs_diff_eq!(gradients[1], 4.0 / (5.0 + epsilon), epsilon = 1e-6);
assert_abs_diff_eq!(gradients[2], 0.0, epsilon = 1e-6);
assert_abs_diff_eq!(gradients[3], 0.0, epsilon = 1e-6);
}
#[test]
fn test_group_lasso_empty_groups() {
let regularizer = GroupLasso::<f64>::new(0.1);
let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let penalty = regularizer
.penalty(¶ms)
.expect("penalty computation failed");
assert_abs_diff_eq!(penalty, 0.0, epsilon = 1e-10);
}
#[test]
fn test_group_lasso_out_of_bounds_index() {
let regularizer = GroupLasso::new(0.1_f64).with_groups(vec![vec![0, 1, 100]]);
let params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let result = regularizer.penalty(¶ms);
assert!(result.is_err());
}
#[test]
fn test_group_lasso_weight_mismatch() {
let regularizer = GroupLasso::new(0.1_f64)
.with_groups(vec![vec![0, 1], vec![2, 3]])
.with_group_weights(vec![1.0]);
let params = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let result = regularizer.penalty(¶ms);
assert!(result.is_err());
}
#[test]
fn test_structured_sparsity_invalid_dimensions() {
let regularizer =
StructuredSparsity::new(0.1_f64, SparsityPattern::Column { num_columns: 3 });
let params = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]);
let result = regularizer.penalty(¶ms);
assert!(result.is_err());
}
#[test]
fn test_structured_sparsity_zero_columns() {
let regularizer =
StructuredSparsity::new(0.1_f64, SparsityPattern::Column { num_columns: 0 });
let params = Array1::from_vec(vec![1.0, 2.0]);
let result = regularizer.penalty(¶ms);
assert!(result.is_err());
}
#[test]
fn test_group_lasso_builder_pattern() {
let regularizer = GroupLasso::new(0.5_f64)
.with_groups(vec![vec![0, 1], vec![2, 3]])
.with_group_weights(vec![1.0, 2.0]);
assert_eq!(regularizer.lambda(), 0.5);
assert_eq!(regularizer.num_groups(), 2);
assert_eq!(regularizer.groups()[0], vec![0, 1]);
assert_eq!(regularizer.groups()[1], vec![2, 3]);
}
}