use std::collections::{HashMap, HashSet};
use ndarray::{Array1, Array2, Axis, s};
use rayon::prelude::*;
use crate::lasso::fista::{FistaProblem, IterInfo, minimise as fista_minimise};
use crate::lasso::prox::l1_l2_prox;
use crate::lasso::singular_values::find_largest_singular_value;
use crate::lasso::subsampling::SubsamplingScheme;
#[derive(Debug, Clone, PartialEq, Default)]
pub enum ScaleReg {
#[default]
GroupSize,
None,
InverseGroupSize,
}
#[derive(Debug, Clone)]
pub struct GroupLassoParams {
pub groups: Vec<i64>,
pub group_reg: f64,
pub l1_reg: f64,
pub n_iter: usize,
pub tol: f64,
pub scale_reg: ScaleReg,
pub subsampling: SubsamplingScheme,
pub fit_intercept: bool,
pub frobenius_lipschitz: bool,
pub seed: u64,
pub warm_start: bool,
}
impl Default for GroupLassoParams {
fn default() -> Self {
Self {
groups: Vec::new(),
group_reg: 0.05,
l1_reg: 0.05,
n_iter: 100,
tol: 1e-5,
scale_reg: ScaleReg::GroupSize,
subsampling: SubsamplingScheme::None,
fit_intercept: true,
frobenius_lipschitz: false,
seed: 0,
warm_start: false,
}
}
}
#[derive(Debug)]
pub enum GroupLassoError {
NotFitted,
ShapeMismatch(String),
InvalidParam(String),
ConvergenceWarning,
}
impl std::fmt::Display for GroupLassoError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NotFitted => write!(f, "Model has not been fitted yet"),
Self::ShapeMismatch(s) => write!(f, "Shape mismatch: {}", s),
Self::InvalidParam(s) => write!(f, "Invalid parameter: {}", s),
Self::ConvergenceWarning => write!(
f,
"FISTA did not converge; try increasing n_iter or decreasing tol"
),
}
}
}
impl std::error::Error for GroupLassoError {}
#[derive(Debug, Clone)]
pub struct FittedCoefficients {
pub coef: Array2<f64>,
pub intercept: Array2<f64>,
}
#[derive(Debug, Clone)]
pub struct GroupLasso {
pub params: GroupLassoParams,
group_reg_vector: Option<Vec<f64>>,
groups_masks: Option<Vec<Vec<bool>>>,
group_ids: Option<Vec<i64>>,
feature_group_ids: Option<Vec<i64>>,
pub fitted: Option<FittedCoefficients>,
pub last_fista_iterations: usize,
x_means: Option<Array2<f64>>,
lipschitz: Option<f64>,
}
impl GroupLasso {
pub fn new(params: GroupLassoParams) -> Self {
Self {
params,
group_reg_vector: None,
groups_masks: None,
group_ids: None,
feature_group_ids: None,
fitted: None,
last_fista_iterations: 0,
x_means: None,
lipschitz: None,
}
}
pub fn new_with_regs(params: GroupLassoParams, group_reg_vec: Vec<f64>) -> Self {
let mut m = Self::new(params);
m.group_reg_vector = Some(group_reg_vec);
m
}
fn get_reg_strength(&self, group_size: usize, base_reg: f64) -> f64 {
let sz = group_size as f64;
match self.params.scale_reg {
ScaleReg::GroupSize => base_reg * sz.sqrt(),
ScaleReg::None => base_reg,
ScaleReg::InverseGroupSize => base_reg / sz.sqrt(),
}
}
fn build_reg_vector(&self, masks: &[Vec<bool>]) -> Vec<f64> {
if let Some(v) = &self.group_reg_vector {
return v.clone();
}
masks
.iter()
.map(|m| {
let size = m.iter().filter(|&&b| b).count();
self.get_reg_strength(size, self.params.group_reg)
})
.collect()
}
fn prepare_dataset(
&self,
x: &Array2<f64>,
y: &Array2<f64>,
) -> (Array2<f64>, Array2<f64>, Array2<f64>) {
let (n, p) = (x.nrows(), x.ncols());
let x_means = if self.params.fit_intercept {
x.mean_axis(Axis(0)).unwrap().insert_axis(Axis(0)) } else {
Array2::zeros((1, p))
};
let x_centred: Array2<f64> = if self.params.fit_intercept {
Array2::from_shape_fn((n, p), |(i, j)| x[[i, j]] - x_means[[0, j]])
} else {
x.clone()
};
let mut x_aug = Array2::zeros((n, p + 1));
x_aug.column_mut(0).fill(1.0);
x_aug.slice_mut(s![.., 1..]).assign(&x_centred);
(x_aug, x_means, y.clone())
}
fn build_groups(&mut self, num_features: usize) -> Result<(), GroupLassoError> {
let raw = if self.params.groups.is_empty() {
(0..num_features as i64).collect::<Vec<_>>()
} else {
if self.params.groups.len() != num_features {
return Err(GroupLassoError::ShapeMismatch(format!(
"groups has length {} but X has {} features",
self.params.groups.len(),
num_features
)));
}
self.params.groups.clone()
};
let mut unique: Vec<i64> = raw.iter().filter(|&&g| g >= 0).cloned().collect();
unique.sort_unstable();
unique.dedup();
let masks: Vec<Vec<bool>> = unique
.iter()
.map(|&uid| raw.iter().map(|&g| g == uid).collect())
.collect();
self.feature_group_ids = Some(raw);
self.group_ids = Some(unique);
self.groups_masks = Some(masks);
Ok(())
}
fn estimate_lipschitz(&self, x_aug: &Array2<f64>) -> f64 {
let n = x_aug.nrows() as f64;
if self.params.frobenius_lipschitz {
let frob: f64 = x_aug.iter().map(|v| v * v).sum::<f64>().sqrt();
return frob * frob / n;
}
let s_max = find_largest_singular_value(
x_aug,
self.params.seed,
&self.params.subsampling,
None,
None,
);
1.5 * s_max * s_max / n
}
pub fn regulariser(&self, coef: &Array2<f64>) -> f64 {
let masks = self.groups_masks.as_ref().unwrap();
let regs = self.build_reg_vector(masks);
let mut penalty = 0.0_f64;
for (mask, ®) in masks.iter().zip(regs.iter()) {
let num_targets = coef.ncols();
for col in 0..num_targets {
let norm: f64 = mask
.iter()
.enumerate()
.filter_map(|(i, &m)| {
if m {
Some(coef[[i, col]].powi(2))
} else {
None
}
})
.sum::<f64>()
.sqrt();
penalty += reg * norm;
}
}
let l1: f64 = coef.iter().map(|v| v.abs()).sum();
penalty += self.params.l1_reg * l1;
penalty
}
fn mse_loss(x_aug: &Array2<f64>, y: &Array2<f64>, w: &Array2<f64>) -> f64 {
let resid = x_aug.dot(w) - y; let n = x_aug.nrows() as f64;
0.5 * resid.iter().map(|v| v * v).sum::<f64>() / n
}
fn mse_grad(x_aug: &Array2<f64>, y: &Array2<f64>, w: &Array2<f64>) -> Array2<f64> {
let n = x_aug.nrows() as f64;
let resid = x_aug.dot(w) - y;
x_aug.t().dot(&resid) / n
}
fn join(intercept: &Array2<f64>, coef: &Array2<f64>) -> Array2<f64> {
ndarray::concatenate(Axis(0), &[intercept.view(), coef.view()]).unwrap()
}
fn split(w: &Array2<f64>) -> (Array2<f64>, Array2<f64>) {
let intercept = w.slice(s![0..1, ..]).to_owned();
let coef = w.slice(s![1.., ..]).to_owned();
(intercept, coef)
}
pub fn fit(
&mut self,
x: &Array2<f64>,
y: &Array2<f64>,
lipschitz: Option<f64>,
) -> Result<bool, GroupLassoError> {
let (n, p) = (x.nrows(), x.ncols());
if n != y.nrows() {
return Err(GroupLassoError::ShapeMismatch(
"X and y have different numbers of rows".into(),
));
}
let y2d = y.clone();
self.build_groups(p)?;
let masks = self.groups_masks.as_ref().unwrap();
let regs = self.build_reg_vector(masks);
let (x_aug, x_means, y_prep) = self.prepare_dataset(x, &y2d);
let l0 = lipschitz
.or(self.lipschitz)
.unwrap_or_else(|| self.estimate_lipschitz(&x_aug));
let num_targets = y_prep.ncols();
let (init_intercept, init_coef) = if self.params.warm_start {
if let Some(ref f) = self.fitted {
(f.intercept.clone(), f.coef.clone())
} else {
(
Array2::zeros((1, num_targets)),
Array2::zeros((p, num_targets)),
)
}
} else {
(
Array2::zeros((1, num_targets)),
Array2::zeros((p, num_targets)),
)
};
let w0 = Self::join(&init_intercept, &init_coef);
let x_aug_owned = x_aug.clone();
let y_owned = y_prep.clone();
let masks_owned = masks.clone();
let regs_owned = regs.clone();
let l1_reg = self.params.l1_reg;
let fit_intercept = self.params.fit_intercept;
let problem = GroupLassoProblem {
x_aug: x_aug_owned,
y: y_owned,
masks: masks_owned,
group_regs: regs_owned,
l1_reg,
fit_intercept,
};
let result = fista_minimise(
&problem,
w0,
l0,
self.params.n_iter,
self.params.tol,
None::<fn(&IterInfo)>,
);
self.last_fista_iterations = result.iterations;
self.lipschitz = Some(result.lipschitz);
let (mut intercept, coef) = Self::split(&result.coef);
let correction = x_means.dot(&coef); intercept = intercept - correction;
self.fitted = Some(FittedCoefficients { coef, intercept });
self.x_means = Some(x_means);
if result.converged {
Ok(true)
} else {
Err(GroupLassoError::ConvergenceWarning)
}
}
pub fn predict(&self, x: &Array2<f64>) -> Result<Array2<f64>, GroupLassoError> {
let fitted = self.fitted.as_ref().ok_or(GroupLassoError::NotFitted)?;
let w = Self::join(&fitted.intercept, &fitted.coef);
let n = x.nrows();
let p = fitted.coef.nrows();
if x.ncols() != p {
return Err(GroupLassoError::ShapeMismatch(format!(
"X has {} features but model was fitted with {}",
x.ncols(),
p
)));
}
let mut x_aug = Array2::zeros((n, p + 1));
x_aug.column_mut(0).fill(1.0);
x_aug.slice_mut(s![.., 1..]).assign(x);
Ok(x_aug.dot(&w))
}
pub fn sparsity_mask(&self) -> Result<Array1<bool>, GroupLassoError> {
let fitted = self.fitted.as_ref().ok_or(GroupLassoError::NotFitted)?;
let mean_abs: f64 =
fitted.coef.iter().map(|v| v.abs()).sum::<f64>() / fitted.coef.len() as f64;
let threshold = 1e-10 * mean_abs;
let coef_mean_across_targets: Array1<f64> = fitted.coef.mean_axis(Axis(1)).unwrap();
Ok(coef_mean_across_targets.mapv(|v| v.abs() > threshold))
}
pub fn chosen_groups(&self) -> Result<std::collections::HashSet<i64>, GroupLassoError> {
let mask = self.sparsity_mask()?;
let feature_ids = self
.feature_group_ids
.as_ref()
.ok_or(GroupLassoError::NotFitted)?;
let chosen = mask
.iter()
.zip(feature_ids.iter())
.filter_map(|(&m, &g)| if m && g >= 0 { Some(g) } else { None })
.collect();
Ok(chosen)
}
pub fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>, GroupLassoError> {
let mask = self.sparsity_mask()?;
let cols: Vec<usize> = mask
.iter()
.enumerate()
.filter_map(|(i, &m)| if m { Some(i) } else { None })
.collect();
let result = x.select(Axis(1), &cols);
Ok(result)
}
}
struct GroupLassoProblem {
x_aug: Array2<f64>,
y: Array2<f64>,
masks: Vec<Vec<bool>>,
group_regs: Vec<f64>,
l1_reg: f64,
fit_intercept: bool,
}
impl FistaProblem for GroupLassoProblem {
fn smooth_loss(&self, w: &Array2<f64>) -> f64 {
GroupLasso::mse_loss(&self.x_aug, &self.y, w)
}
fn smooth_grad(&self, w: &Array2<f64>) -> Array2<f64> {
let mut g = GroupLasso::mse_grad(&self.x_aug, &self.y, w);
if !self.fit_intercept {
g.row_mut(0).fill(0.0);
}
g
}
fn prox(&self, w: &Array2<f64>, lipschitz: f64) -> Array2<f64> {
let (intercept, coef) = GroupLasso::split(w);
let scaled_l1 = self.l1_reg / lipschitz;
let scaled_regs: Vec<f64> = self.group_regs.iter().map(|r| r / lipschitz).collect();
let new_coef = l1_l2_prox(&coef, scaled_l1, &scaled_regs, &self.masks);
GroupLasso::join(&intercept, &new_coef)
}
}
#[derive(Debug, Clone)]
pub struct ClusteredGroupLasso {
pub params: GroupLassoParams,
pub models: HashMap<i64, GroupLasso>,
}
impl ClusteredGroupLasso {
pub fn new(params: GroupLassoParams) -> Self {
Self {
params,
models: HashMap::new(),
}
}
pub fn fit(
&mut self,
x: &Array2<f64>,
y: &Array2<f64>,
clusters: &Array1<i64>,
) -> Result<bool, GroupLassoError> {
if x.nrows() != y.nrows() || x.nrows() != clusters.len() {
return Err(GroupLassoError::ShapeMismatch(
"X, y, and clusters must have the same number of rows".into(),
));
}
let unique_ids: Vec<i64> = clusters
.iter()
.cloned()
.collect::<HashSet<_>>()
.into_iter()
.collect();
let mut cluster_data = Vec::new();
for &id in &unique_ids {
let indices: Vec<usize> = clusters
.iter()
.enumerate()
.filter(|(_, c)| **c == id)
.map(|(i, _)| i)
.collect();
let x_cluster = x.select(Axis(0), &indices);
let y_cluster = y.select(Axis(0), &indices);
cluster_data.push((id, x_cluster, y_cluster));
}
let results: Vec<(i64, GroupLasso, Result<bool, GroupLassoError>)> = cluster_data
.into_par_iter()
.map(|(id, x_c, y_c)| {
let mut model = GroupLasso::new(self.params.clone());
let res = model.fit(&x_c, &y_c, None);
(id, model, res)
})
.collect();
let mut all_converged = true;
for (id, model, result) in results {
match result {
Ok(_) => {
self.models.insert(id, model);
}
Err(e) => {
if let GroupLassoError::ConvergenceWarning = e {
all_converged = false;
self.models.insert(id, model);
} else {
return Err(e);
}
}
}
}
if all_converged {
Ok(true)
} else {
Err(GroupLassoError::ConvergenceWarning)
}
}
pub fn predict(
&self,
x: &Array2<f64>,
clusters: &Array1<i64>,
) -> Result<Array2<f64>, GroupLassoError> {
if x.nrows() != clusters.len() {
return Err(GroupLassoError::ShapeMismatch(
"X and clusters must have the same number of rows".into(),
));
}
let unique_ids: Vec<i64> = clusters
.iter()
.cloned()
.collect::<HashSet<_>>()
.into_iter()
.collect();
let num_targets = self
.models
.values()
.next()
.and_then(|m| m.fitted.as_ref())
.map(|f| f.coef.ncols())
.unwrap_or(1);
let mut cluster_indices = Vec::new();
for &id in &unique_ids {
let indices: Vec<usize> = clusters
.iter()
.enumerate()
.filter(|(_, c)| **c == id)
.map(|(i, _)| i)
.collect();
cluster_indices.push((id, indices));
}
let results: Vec<Result<(Vec<usize>, Array2<f64>), GroupLassoError>> = cluster_indices
.into_par_iter()
.map(|(id, indices)| {
let model = self.models.get(&id).ok_or(GroupLassoError::NotFitted)?;
let x_cluster = x.select(Axis(0), &indices);
let p = model.predict(&x_cluster)?;
Ok((indices, p))
})
.collect();
let mut preds = Array2::zeros((x.nrows(), num_targets));
for res in results {
let (indices, p) = res?;
for (local_idx, &global_idx) in indices.iter().enumerate() {
preds.row_mut(global_idx).assign(&p.row(local_idx));
}
}
Ok(preds)
}
pub fn coefficients(&self) -> HashMap<i64, (Array2<f64>, Array2<f64>)> {
let mut result = HashMap::new();
for (&id, model) in &self.models {
if let Some(fitted) = &model.fitted {
result.insert(id, (fitted.coef.clone(), fitted.intercept.clone()));
}
}
result
}
pub fn sparsity_mask(&self) -> Result<HashMap<i64, Array1<bool>>, GroupLassoError> {
let mut masks = HashMap::new();
for (&id, model) in &self.models {
masks.insert(id, model.sparsity_mask()?);
}
Ok(masks)
}
pub fn chosen_groups(&self) -> Result<HashMap<i64, HashSet<i64>>, GroupLassoError> {
let mut groups = HashMap::new();
for (&id, model) in &self.models {
groups.insert(id, model.chosen_groups()?);
}
Ok(groups)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
fn simple_xy() -> (Array2<f64>, Array2<f64>) {
let n = 50_usize;
let x = Array2::from_shape_fn((n, 2), |(i, j)| {
if j == 0 {
i as f64 / n as f64
} else {
(n - i) as f64 / n as f64
}
});
let y = Array2::from_shape_fn((n, 1), |(i, _)| 2.0 * x[[i, 0]] + 3.0 * x[[i, 1]]);
(x, y)
}
fn xy_with_intercept() -> (Array2<f64>, Array2<f64>) {
let n = 80_usize;
let x = Array2::from_shape_fn((n, 3), |(i, j)| match j {
0 => (i as f64) / n as f64,
1 => ((n - i) as f64) / n as f64,
_ => 0.5 * (i as f64) / n as f64,
});
let y = Array2::from_shape_fn((n, 1), |(i, _)| {
5.0 + 2.0 * x[[i, 0]] + 3.0 * x[[i, 1]] + 1.0 * x[[i, 2]]
});
(x, y)
}
fn sparse_xy() -> (Array2<f64>, Array2<f64>) {
let n = 100_usize;
let x = Array2::from_shape_fn((n, 5), |(i, j)| ((i * 7 + j * 13) % 97) as f64 / 97.0);
let y = Array2::from_shape_fn((n, 1), |(i, _)| 2.0 * x[[i, 0]] - 1.5 * x[[i, 1]]);
(x, y)
}
#[test]
fn fit_predict_shape() {
let (x, y) = simple_xy();
let mut model = GroupLasso::new(GroupLassoParams {
groups: vec![0, 0],
group_reg: 0.001,
l1_reg: 0.001,
n_iter: 200,
..Default::default()
});
let _ = model.fit(&x, &y, None);
let pred = model.predict(&x).unwrap();
assert_eq!(pred.shape(), &[50, 1]);
}
#[test]
fn unregularised_group_recovers_coefficients() {
let (x, y) = simple_xy();
let mut model = GroupLasso::new(GroupLassoParams {
groups: vec![0, 1],
group_reg: 1e-6,
l1_reg: 1e-6,
n_iter: 500,
tol: 1e-8,
fit_intercept: false,
..Default::default()
});
let _ = model.fit(&x, &y, None);
let coef = &model.fitted.as_ref().unwrap().coef;
assert_abs_diff_eq!(coef[[0, 0]], 2.0, epsilon = 0.2);
assert_abs_diff_eq!(coef[[1, 0]], 3.0, epsilon = 0.2);
}
#[test]
fn high_group_reg_drives_coef_to_zero() {
let (x, y) = simple_xy();
let mut model = GroupLasso::new(GroupLassoParams {
groups: vec![0, 0],
group_reg: 100.0,
l1_reg: 0.0,
n_iter: 300,
..Default::default()
});
let _ = model.fit(&x, &y, None);
let coef = &model.fitted.as_ref().unwrap().coef;
for &v in coef.iter() {
assert_abs_diff_eq!(v, 0.0, epsilon = 1e-3);
}
}
#[test]
fn high_l1_reg_drives_coef_to_zero() {
let (x, y) = simple_xy();
let mut model = GroupLasso::new(GroupLassoParams {
groups: vec![0, 1],
group_reg: 0.0,
l1_reg: 100.0,
n_iter: 300,
..Default::default()
});
let _ = model.fit(&x, &y, None);
let coef = &model.fitted.as_ref().unwrap().coef;
for &v in coef.iter() {
assert_abs_diff_eq!(v, 0.0, epsilon = 1e-3);
}
}
#[test]
fn intercept_recovery() {
let (x, y) = xy_with_intercept();
let mut model = GroupLasso::new(GroupLassoParams {
groups: vec![0, 1, 2],
group_reg: 1e-8,
l1_reg: 1e-8,
n_iter: 5000,
tol: 1e-12,
fit_intercept: true,
..Default::default()
});
let _ = model.fit(&x, &y, None);
let pred = model.predict(&x).unwrap();
let y_mean = y.mean().unwrap();
let ss_tot: f64 = y.iter().map(|v| (v - y_mean).powi(2)).sum();
let ss_res: f64 = y
.iter()
.zip(pred.iter())
.map(|(a, b)| (a - b).powi(2))
.sum();
let r2 = 1.0 - ss_res / ss_tot;
assert!(
r2 > 0.99,
"R² should be very high with intercept, got {}",
r2
);
}
#[test]
fn no_intercept_zero() {
let (x, y) = simple_xy();
let mut model = GroupLasso::new(GroupLassoParams {
groups: vec![0, 1],
group_reg: 1e-6,
l1_reg: 1e-6,
n_iter: 500,
fit_intercept: false,
..Default::default()
});
let _ = model.fit(&x, &y, None);
let fitted = model.fitted.as_ref().unwrap();
assert_abs_diff_eq!(fitted.intercept[[0, 0]], 0.0, epsilon = 1e-6);
}
#[test]
fn sparsity_mask_identifies_active_features() {
let (x, y) = sparse_xy();
let mut model = GroupLasso::new(GroupLassoParams {
groups: vec![0, 1, 2, 3, 4],
group_reg: 0.01,
l1_reg: 0.01,
n_iter: 1000,
tol: 1e-10,
..Default::default()
});
let _ = model.fit(&x, &y, None);
let mask = model.sparsity_mask().unwrap();
let active_count = mask.iter().filter(|&&b| b).count();
assert!(active_count >= 1, "At least one feature should be active");
assert!(active_count <= 5, "Active count should be ≤ total features");
}
#[test]
fn chosen_groups_subset() {
let (x, y) = sparse_xy();
let mut model = GroupLasso::new(GroupLassoParams {
groups: vec![0, 0, 1, 1, 2],
group_reg: 0.01,
l1_reg: 0.01,
n_iter: 1000,
tol: 1e-10,
..Default::default()
});
let _ = model.fit(&x, &y, None);
let groups = model.chosen_groups().unwrap();
assert!(!groups.is_empty(), "At least one group should be chosen");
assert!(groups.len() <= 3, "At most 3 groups should be chosen");
}
#[test]
fn transform_reduces_columns() {
let (x, y) = sparse_xy();
let mut model = GroupLasso::new(GroupLassoParams {
groups: vec![0, 1, 2, 3, 4],
group_reg: 0.01,
l1_reg: 0.01,
n_iter: 1000,
tol: 1e-10,
..Default::default()
});
let _ = model.fit(&x, &y, None);
let x_t = model.transform(&x).unwrap();
assert!(
x_t.ncols() <= x.ncols(),
"Transform should reduce or maintain columns"
);
assert!(x_t.ncols() > 0, "At least one feature should survive");
assert_eq!(x_t.nrows(), x.nrows(), "Row count should be preserved");
}
#[test]
fn predict_before_fit_errors() {
let model = GroupLasso::new(GroupLassoParams::default());
let x = Array2::zeros((5, 3));
assert!(model.predict(&x).is_err());
}
#[test]
fn shape_mismatch_errors() {
let x = Array2::zeros((10, 3));
let y = Array2::zeros((5, 1));
let mut model = GroupLasso::new(GroupLassoParams {
groups: vec![0, 1, 2],
..Default::default()
});
assert!(model.fit(&x, &y, None).is_err());
}
#[test]
fn groups_length_mismatch_errors() {
let x = Array2::zeros((10, 3));
let y = Array2::zeros((10, 1));
let mut model = GroupLasso::new(GroupLassoParams {
groups: vec![0, 1], ..Default::default()
});
assert!(model.fit(&x, &y, None).is_err());
}
#[test]
fn warm_start_reuses_coefficients() {
let (x, y) = simple_xy();
let mut model = GroupLasso::new(GroupLassoParams {
groups: vec![0, 0],
group_reg: 0.001,
l1_reg: 0.001,
n_iter: 100,
warm_start: true,
..Default::default()
});
let _ = model.fit(&x, &y, None);
let coef1 = model.fitted.as_ref().unwrap().coef.clone();
let _ = model.fit(&x, &y, None);
let coef2 = model.fitted.as_ref().unwrap().coef.clone();
for i in 0..coef1.nrows() {
assert_abs_diff_eq!(coef1[[i, 0]], coef2[[i, 0]], epsilon = 0.1);
}
}
#[test]
fn scale_reg_group_size_increases_penalty() {
let (x, y) = simple_xy();
let mut m1 = GroupLasso::new(GroupLassoParams {
groups: vec![0, 0],
group_reg: 0.1,
l1_reg: 0.0,
scale_reg: ScaleReg::GroupSize,
n_iter: 300,
..Default::default()
});
let _ = m1.fit(&x, &y, None);
let mut m2 = GroupLasso::new(GroupLassoParams {
groups: vec![0, 0],
group_reg: 0.1,
l1_reg: 0.0,
scale_reg: ScaleReg::None,
n_iter: 300,
..Default::default()
});
let _ = m2.fit(&x, &y, None);
let norm1: f64 = m1.fitted.as_ref().unwrap().coef.iter().map(|v| v * v).sum();
let norm2: f64 = m2.fitted.as_ref().unwrap().coef.iter().map(|v| v * v).sum();
assert!(
norm1 <= norm2 + 1e-6,
"GroupSize scaling should produce sparser result"
);
}
#[test]
fn regulariser_nonnegative() {
let (x, y) = simple_xy();
let mut model = GroupLasso::new(GroupLassoParams {
groups: vec![0, 1],
group_reg: 0.1,
l1_reg: 0.1,
n_iter: 200,
..Default::default()
});
let _ = model.fit(&x, &y, None);
let coef = &model.fitted.as_ref().unwrap().coef;
let pen = model.regulariser(coef);
assert!(pen >= 0.0, "Regulariser must be non-negative");
}
#[test]
fn regulariser_zero_for_zero_coefs() {
let (x, y) = simple_xy();
let mut model = GroupLasso::new(GroupLassoParams {
groups: vec![0, 1],
group_reg: 100.0,
l1_reg: 100.0,
n_iter: 300,
..Default::default()
});
let _ = model.fit(&x, &y, None);
let zero_coef = Array2::zeros((2, 1));
let pen = model.regulariser(&zero_coef);
assert_abs_diff_eq!(pen, 0.0, epsilon = 1e-15);
}
#[test]
fn r2_positive_for_good_fit() {
let (x, y) = simple_xy();
let mut model = GroupLasso::new(GroupLassoParams {
groups: vec![0, 1],
group_reg: 1e-6,
l1_reg: 1e-6,
n_iter: 500,
tol: 1e-8,
..Default::default()
});
let _ = model.fit(&x, &y, None);
let pred = model.predict(&x).unwrap();
let y_mean = y.mean().unwrap();
let ss_tot: f64 = y.iter().map(|v| (v - y_mean).powi(2)).sum();
let ss_res: f64 = y
.iter()
.zip(pred.iter())
.map(|(yi, yhat)| (yi - yhat).powi(2))
.sum();
let r2 = 1.0 - ss_res / ss_tot;
assert!(r2 > 0.9, "R² should be high for noiseless data, got {}", r2);
}
#[test]
fn frobenius_lipschitz_mode() {
let (x, y) = simple_xy();
let mut model = GroupLasso::new(GroupLassoParams {
groups: vec![0, 1],
group_reg: 0.01,
l1_reg: 0.01,
frobenius_lipschitz: true,
n_iter: 300,
..Default::default()
});
let result = model.fit(&x, &y, None);
assert!(result.is_ok() || matches!(result, Err(GroupLassoError::ConvergenceWarning)));
assert!(model.fitted.is_some());
}
#[test]
fn new_with_regs_per_group() {
let (x, y) = simple_xy();
let mut model = GroupLasso::new_with_regs(
GroupLassoParams {
groups: vec![0, 1],
group_reg: 0.0,
l1_reg: 0.0,
n_iter: 300,
..Default::default()
},
vec![0.001, 100.0], );
let _ = model.fit(&x, &y, None);
let coef = &model.fitted.as_ref().unwrap().coef;
assert!(
coef[[0, 0]].abs() > coef[[1, 0]].abs(),
"Feature 0 (low reg) should be larger than feature 1 (high reg)"
);
}
#[test]
fn predict_shape_mismatch_errors() {
let (x, y) = simple_xy();
let mut model = GroupLasso::new(GroupLassoParams {
groups: vec![0, 0],
..Default::default()
});
let _ = model.fit(&x, &y, None);
let bad_x = Array2::zeros((5, 5)); assert!(model.predict(&bad_x).is_err());
}
#[test]
fn clustered_fit_predict() {
let (x, y) = simple_xy();
let n = x.nrows();
let mut clusters = Array1::zeros(n);
for i in 25..n {
clusters[i] = 1;
}
let mut model = ClusteredGroupLasso::new(GroupLassoParams {
groups: vec![0, 0],
group_reg: 0.001,
l1_reg: 0.001,
n_iter: 200,
..Default::default()
});
model.fit(&x, &y, &clusters).unwrap();
assert_eq!(model.models.len(), 2);
let pred = model.predict(&x, &clusters).unwrap();
assert_eq!(pred.shape(), &[n, 1]);
for i in 0..n {
assert_abs_diff_eq!(pred[[i, 0]], y[[i, 0]], epsilon = 0.5);
}
}
#[test]
fn clustered_coefficients_per_cluster() {
let (x, y) = simple_xy();
let n = x.nrows();
let mut clusters = Array1::zeros(n);
for i in 25..n {
clusters[i] = 1;
}
let mut model = ClusteredGroupLasso::new(GroupLassoParams {
groups: vec![0, 0],
group_reg: 0.001,
l1_reg: 0.001,
n_iter: 200,
..Default::default()
});
model.fit(&x, &y, &clusters).unwrap();
let coeffs = model.coefficients();
assert!(coeffs.contains_key(&0));
assert!(coeffs.contains_key(&1));
let (coef0, int0) = &coeffs[&0];
assert_eq!(coef0.nrows(), 2);
assert_eq!(int0.ncols(), 1);
}
#[test]
fn clustered_shape_mismatch_errors() {
let x = Array2::zeros((10, 2));
let y = Array2::zeros((10, 1));
let clusters = Array1::zeros(5); let mut model = ClusteredGroupLasso::new(GroupLassoParams::default());
assert!(model.fit(&x, &y, &clusters).is_err());
}
#[test]
fn clustered_sparsity_mask() {
let (x, y) = simple_xy();
let n = x.nrows();
let mut clusters = Array1::zeros(n);
for i in 25..n {
clusters[i] = 1;
}
let mut model = ClusteredGroupLasso::new(GroupLassoParams {
groups: vec![0, 0],
group_reg: 0.001,
l1_reg: 0.001,
n_iter: 200,
..Default::default()
});
model.fit(&x, &y, &clusters).unwrap();
let masks = model.sparsity_mask().unwrap();
assert_eq!(masks.len(), 2);
}
#[test]
fn default_groups_one_per_feature() {
let x = Array2::from_shape_fn((30, 3), |(i, j)| (i + j) as f64 / 30.0);
let y = Array2::from_shape_fn((30, 1), |(i, _)| x[[i, 0]] + x[[i, 1]]);
let mut model = GroupLasso::new(GroupLassoParams {
groups: vec![], group_reg: 0.01,
l1_reg: 0.01,
n_iter: 200,
..Default::default()
});
let result = model.fit(&x, &y, None);
assert!(result.is_ok() || matches!(result, Err(GroupLassoError::ConvergenceWarning)));
assert!(model.fitted.is_some());
}
}