Skip to main content

scirs2_stats/regression/
regularized.rs

1//! Regularized regression implementations
2
3use crate::error::{StatsError, StatsResult};
4use crate::regression::utils::*;
5use crate::regression::RegressionResults;
6use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2};
7use scirs2_core::numeric::Float;
8use scirs2_linalg::{inv, lstsq};
9use std::collections::HashSet;
10
11// Type alias for complex return type
12type PreprocessingResult<F> = (Array2<F>, F, Array1<F>, Array1<F>);
13
14/// Perform ridge regression (L2 regularization).
15///
16/// Ridge regression adds a penalty term to the sum of squared residuals,
17/// which can help reduce overfitting and handle multicollinearity.
18///
19/// # Arguments
20///
21/// * `x` - Independent variables (design matrix)
22/// * `y` - Dependent variable
23/// * `alpha` - Regularization strength (default: 1.0)
24/// * `fit_intercept` - Whether to fit an intercept term (default: true)
25/// * `normalize` - Whether to normalize the data before fitting (default: false)
26/// * `tol` - Convergence tolerance (default: 1e-4)
27/// * `max_iter` - Maximum number of iterations (default: 1000)
28/// * `conf_level` - Confidence level for confidence intervals (default: 0.95)
29///
30/// # Returns
31///
32/// A RegressionResults struct with the regression results.
33///
34/// # Examples
35///
36/// ```
37/// use scirs2_core::ndarray::{array, Array2};
38/// use scirs2_stats::ridge_regression;
39///
40/// // Create a design matrix with 3 variables
41/// let x = Array2::from_shape_vec((5, 3), vec![
42///     1.0, 2.0, 3.0,
43///     2.0, 3.0, 4.0,
44///     3.0, 4.0, 5.0,
45///     4.0, 5.0, 6.0,
46///     5.0, 6.0, 7.0,
47/// ]).expect("Operation failed");
48///
49/// // Target values
50/// let y = array![10.0, 15.0, 20.0, 25.0, 30.0];
51///
52/// // Perform ridge regression with alpha=0.1
53/// let result = ridge_regression(&x.view(), &y.view(), Some(0.1), None, None, None, None, None).expect("Operation failed");
54///
55/// // Check that we get some coefficients
56/// assert!(result.coefficients.len() > 0);
57/// ```
58#[allow(clippy::too_many_arguments)]
59#[allow(dead_code)]
60pub fn ridge_regression<F>(
61    x: &ArrayView2<F>,
62    y: &ArrayView1<F>,
63    alpha: Option<F>,
64    fit_intercept: Option<bool>,
65    normalize: Option<bool>,
66    tol: Option<F>,
67    max_iter: Option<usize>,
68    conf_level: Option<F>,
69) -> StatsResult<RegressionResults<F>>
70where
71    F: Float
72        + std::iter::Sum<F>
73        + std::ops::Div<Output = F>
74        + std::fmt::Debug
75        + std::fmt::Display
76        + 'static
77        + scirs2_core::numeric::NumAssign
78        + scirs2_core::numeric::One
79        + scirs2_core::ndarray::ScalarOperand
80        + Send
81        + Sync,
82{
83    // Check input dimensions
84    if x.nrows() != y.len() {
85        return Err(StatsError::DimensionMismatch(format!(
86            "Input x has {} rows but y has length {}",
87            x.nrows(),
88            y.len()
89        )));
90    }
91
92    let n = x.nrows();
93    let p_features = x.ncols();
94
95    // Set default parameters
96    let alpha = alpha.unwrap_or_else(|| F::from(1.0).expect("Failed to convert constant to float"));
97    let fit_intercept = fit_intercept.unwrap_or(true);
98    let normalize = normalize.unwrap_or(false);
99    let tol = tol.unwrap_or_else(|| F::from(1e-4).expect("Failed to convert constant to float"));
100    let max_iter = max_iter.unwrap_or(1000);
101    let conf_level =
102        conf_level.unwrap_or_else(|| F::from(0.95).expect("Failed to convert constant to float"));
103
104    if alpha < F::zero() {
105        return Err(StatsError::InvalidArgument(
106            "alpha must be non-negative".to_string(),
107        ));
108    }
109
110    // Preprocess x and y
111    let (x_processed, y_mean, x_mean, x_std) = preprocessdata(x, y, fit_intercept, normalize)?;
112
113    // Total number of coefficients (including _intercept if fitted)
114    let p = if fit_intercept {
115        p_features + 1
116    } else {
117        p_features
118    };
119
120    // We need at least 2 observations for meaningful regression
121    if n < 2 {
122        return Err(StatsError::InvalidArgument(
123            "At least 2 observations required for ridge regression".to_string(),
124        ));
125    }
126
127    // Solve the ridge regression problem
128    // We solve the linear system [X; sqrt(alpha)I] beta = [y; 0]
129
130    // Create the regularization matrix sqrt(alpha)I
131    let ridgesize = if fit_intercept { p_features } else { p };
132    let mut x_ridge = Array2::zeros((n + ridgesize, p));
133
134    // Copy X to the top part of the augmented matrix
135    for i in 0..n {
136        for j in 0..p {
137            x_ridge[[i, j]] = x_processed[[i, j]];
138        }
139    }
140
141    // Add sqrt(alpha)I to the bottom part
142    let sqrt_alpha = scirs2_core::numeric::Float::sqrt(alpha);
143    for i in 0..ridgesize {
144        let j = if fit_intercept { i + 1 } else { i }; // Skip _intercept if present
145        x_ridge[[n + i, j]] = sqrt_alpha;
146    }
147
148    // Create the augmented target vector [y; 0]
149    let mut y_ridge = Array1::zeros(n + ridgesize);
150    for i in 0..n {
151        y_ridge[i] = y[i];
152    }
153
154    // Solve the ridge regression problem
155    let coefficients = solve_ridge_system(&x_ridge.view(), &y_ridge.view(), tol, max_iter)?;
156
157    // If data was normalized/centered, transform coefficients back
158    let transformed_coefficients = if normalize || fit_intercept {
159        transform_coefficients(&coefficients, y_mean, &x_mean, &x_std, fit_intercept)
160    } else {
161        coefficients.clone()
162    };
163
164    // Calculate fitted values and residuals
165    let x_design = if fit_intercept {
166        add_intercept(x)
167    } else {
168        x.to_owned()
169    };
170
171    let fitted_values = x_design.dot(&transformed_coefficients);
172    let residuals = y.to_owned() - &fitted_values;
173
174    // Calculate degrees of freedom
175    let df_model = p - 1; // Subtract 1 for _intercept
176    let df_residuals = n - p;
177
178    // Calculate sum of squares
179    let (_y_mean, ss_total, ss_residual, ss_explained) =
180        calculate_sum_of_squares(y, &residuals.view());
181
182    // Calculate R-squared and adjusted R-squared
183    let r_squared = ss_explained / ss_total;
184    let adj_r_squared = F::one()
185        - (F::one() - r_squared) * F::from(n - 1).expect("Failed to convert to float")
186            / F::from(df_residuals).expect("Failed to convert to float");
187
188    // Calculate mean squared error and residual standard error
189    let mse = ss_residual / F::from(df_residuals).expect("Failed to convert to float");
190    let residual_std_error = scirs2_core::numeric::Float::sqrt(mse);
191
192    // Calculate standard errors for coefficients (approximate)
193    let std_errors = match calculate_ridge_std_errors(
194        &x_design.view(),
195        &residuals.view(),
196        alpha,
197        df_residuals,
198    ) {
199        Ok(se) => se,
200        Err(_) => Array1::<F>::zeros(p),
201    };
202
203    // Calculate t-values
204    let t_values = calculate_t_values(&transformed_coefficients, &std_errors);
205
206    // Calculate p-values (simplified)
207    let p_values = t_values.mapv(|t| {
208        let t_abs = crate::regression::utils::float_abs(t);
209        let df_f = F::from(df_residuals).expect("Failed to convert to float");
210        let ratio = t_abs / crate::regression::utils::float_sqrt(df_f + t_abs * t_abs);
211        let one_minus_ratio = F::one() - ratio;
212        F::from(2.0).expect("Failed to convert constant to float") * one_minus_ratio
213    });
214
215    // Calculate confidence intervals
216    let mut conf_intervals = Array2::<F>::zeros((p, 2));
217    let z = norm_ppf(
218        F::from(0.5).expect("Failed to convert constant to float") * (F::one() + conf_level),
219    );
220
221    for i in 0..p {
222        let margin = std_errors[i] * z;
223        conf_intervals[[i, 0]] = transformed_coefficients[i] - margin;
224        conf_intervals[[i, 1]] = transformed_coefficients[i] + margin;
225    }
226
227    // Calculate F-statistic
228    let f_statistic = if df_model > 0 && df_residuals > 0 {
229        (ss_explained / F::from(df_model).expect("Failed to convert to float"))
230            / (ss_residual / F::from(df_residuals).expect("Failed to convert to float"))
231    } else {
232        F::infinity()
233    };
234
235    // Calculate p-value for F-statistic (simplified)
236    let f_p_value = F::zero(); // In a real implementation, use F-distribution
237
238    // Create and return the results structure
239    Ok(RegressionResults {
240        coefficients: transformed_coefficients,
241        std_errors,
242        t_values,
243        p_values,
244        conf_intervals,
245        r_squared,
246        adj_r_squared,
247        f_statistic,
248        f_p_value,
249        residual_std_error,
250        df_residuals,
251        residuals,
252        fitted_values,
253        inlier_mask: vec![true; n], // All points are inliers in ridge regression
254    })
255}
256
257/// Helper function to solve the ridge regression system
258#[allow(dead_code)]
259fn solve_ridge_system<F>(
260    x_ridge: &ArrayView2<F>,
261    y_ridge: &ArrayView1<F>,
262    _tol: F,
263    _max_iter: usize,
264) -> StatsResult<Array1<F>>
265where
266    F: Float
267        + std::iter::Sum<F>
268        + std::ops::Div<Output = F>
269        + 'static
270        + scirs2_core::numeric::NumAssign
271        + scirs2_core::numeric::One
272        + scirs2_core::ndarray::ScalarOperand
273        + std::fmt::Display
274        + Send
275        + Sync,
276{
277    match lstsq(x_ridge, y_ridge, None) {
278        Ok(result) => Ok(result.x),
279        Err(e) => Err(StatsError::ComputationError(format!(
280            "Least squares computation failed: {:?}",
281            e
282        ))),
283    }
284}
285
286/// Preprocess data for regularized regression
287#[allow(dead_code)]
288fn preprocessdata<F>(
289    x: &ArrayView2<F>,
290    y: &ArrayView1<F>,
291    fit_intercept: bool,
292    normalize: bool,
293) -> StatsResult<PreprocessingResult<F>>
294where
295    F: Float + std::iter::Sum<F> + 'static + std::fmt::Display,
296{
297    let n = x.nrows();
298    let p = x.ncols();
299
300    // Calculate y_mean if fitting _intercept
301    let y_mean = if fit_intercept {
302        y.iter().cloned().sum::<F>() / F::from(n).expect("Failed to convert to float")
303    } else {
304        F::zero()
305    };
306
307    // Calculate x_mean and x_std if normalizing or fitting _intercept
308    let mut x_mean = Array1::<F>::zeros(p);
309    let mut x_std = Array1::<F>::ones(p);
310
311    if fit_intercept || normalize {
312        for j in 0..p {
313            let col = x.column(j);
314            let mean =
315                col.iter().cloned().sum::<F>() / F::from(n).expect("Failed to convert to float");
316            x_mean[j] = mean;
317
318            if normalize {
319                let mut ss = F::zero();
320                for &val in col {
321                    ss = ss + scirs2_core::numeric::Float::powi(val - mean, 2);
322                }
323                let std_dev = scirs2_core::numeric::Float::sqrt(
324                    ss / F::from(n).expect("Failed to convert to float"),
325                );
326                x_std[j] = if std_dev > F::epsilon() {
327                    std_dev
328                } else {
329                    F::one()
330                };
331            }
332        }
333    }
334
335    // Create processed X matrix
336    let mut x_processed = if fit_intercept {
337        Array2::<F>::zeros((n, p + 1))
338    } else {
339        Array2::<F>::zeros((n, p))
340    };
341
342    // Add _intercept column if needed
343    if fit_intercept {
344        for i in 0..n {
345            x_processed[[i, 0]] = F::one();
346        }
347    }
348
349    // Copy and normalize X data
350    let offset = if fit_intercept { 1 } else { 0 };
351    for i in 0..n {
352        for j in 0..p {
353            let val = if normalize || fit_intercept {
354                (x[[i, j]] - x_mean[j]) / x_std[j]
355            } else {
356                x[[i, j]]
357            };
358            x_processed[[i, j + offset]] = val;
359        }
360    }
361
362    Ok((x_processed, y_mean, x_mean, x_std))
363}
364
365/// Transform coefficients back after fitting with normalized/centered data
366#[allow(dead_code)]
367fn transform_coefficients<F>(
368    coefficients: &Array1<F>,
369    y_mean: F,
370    x_mean: &Array1<F>,
371    x_std: &Array1<F>,
372    fit_intercept: bool,
373) -> Array1<F>
374where
375    F: Float + 'static + std::fmt::Display,
376{
377    let _p = coefficients.len();
378    let p_features = x_mean.len();
379
380    let mut transformed = coefficients.clone();
381
382    if fit_intercept {
383        let mut _intercept = coefficients[0];
384
385        // Adjust _intercept for the effect of normalizing/centering
386        for j in 0..p_features {
387            _intercept = _intercept - coefficients[j + 1] * x_mean[j] / x_std[j];
388        }
389
390        // Add back the _mean of y
391        _intercept = _intercept + y_mean;
392
393        transformed[0] = _intercept;
394
395        // Adjust feature coefficients for the scaling
396        for j in 0..p_features {
397            transformed[j + 1] = coefficients[j + 1] / x_std[j];
398        }
399    } else {
400        // Adjust feature coefficients for the scaling
401        for j in 0..p_features {
402            transformed[j] = coefficients[j] / x_std[j];
403        }
404    }
405
406    transformed
407}
408
409/// Calculate standard errors for ridge regression
410#[allow(dead_code)]
411fn calculate_ridge_std_errors<F>(
412    x: &ArrayView2<F>,
413    residuals: &ArrayView1<F>,
414    alpha: F,
415    df: usize,
416) -> StatsResult<Array1<F>>
417where
418    F: Float
419        + std::iter::Sum<F>
420        + std::ops::Div<Output = F>
421        + 'static
422        + scirs2_core::numeric::NumAssign
423        + scirs2_core::numeric::One
424        + scirs2_core::ndarray::ScalarOperand
425        + std::fmt::Display
426        + Send
427        + Sync,
428{
429    // Calculate the mean squared error of the residuals
430    let mse = residuals
431        .iter()
432        .map(|&r| scirs2_core::numeric::Float::powi(r, 2))
433        .sum::<F>()
434        / F::from(df).expect("Failed to convert to float");
435
436    // Calculate X'X
437    let xtx = x.t().dot(x);
438
439    // Add regularization term: X'X + alpha*I
440    let p = x.ncols();
441    let mut xtx_reg = xtx.clone();
442
443    for i in 0..p {
444        xtx_reg[[i, i]] += alpha;
445    }
446
447    // Invert (X'X + alpha*I) to get (X'X + alpha*I)^-1
448    let xtx_reg_inv = match inv(&xtx_reg.view(), None) {
449        Ok(inv_result) => inv_result,
450        Err(_) => {
451            // If inversion fails, return zeros for standard errors
452            return Ok(Array1::<F>::zeros(p));
453        }
454    };
455
456    // Calculate standard errors
457    // The diagonal elements of (X'X + alpha*I)^-1 * X'X * (X'X + alpha*I)^-1 * MSE are the variances
458    let std_errors = (xtx_reg_inv.dot(&xtx).dot(&xtx_reg_inv))
459        .diag()
460        .mapv(|v| scirs2_core::numeric::Float::sqrt(v * mse));
461
462    Ok(std_errors)
463}
464
465/// Perform lasso regression (L1 regularization).
466///
467/// Lasso regression adds an L1 penalty term to the sum of squared residuals,
468/// which can help with feature selection by driving some coefficients to zero.
469///
470/// # Arguments
471///
472/// * `x` - Independent variables (design matrix)
473/// * `y` - Dependent variable
474/// * `alpha` - Regularization strength (default: 1.0)
475/// * `fit_intercept` - Whether to fit an intercept term (default: true)
476/// * `normalize` - Whether to normalize the data before fitting (default: false)
477/// * `tol` - Convergence tolerance (default: 1e-4)
478/// * `max_iter` - Maximum number of iterations (default: 1000)
479/// * `conf_level` - Confidence level for confidence intervals (default: 0.95)
480///
481/// # Returns
482///
483/// A RegressionResults struct with the regression results.
484///
485/// # Examples
486///
487/// ```ignore
488/// use scirs2_core::ndarray::{array, Array2};
489/// use scirs2_stats::lasso_regression;
490///
491/// // Create a design matrix with 5 variables, where only the first 2 are relevant
492/// let x = Array2::from_shape_vec((10, 5), vec![
493///     1.0, 2.0, 0.1, 0.2, 0.3,
494///     2.0, 3.0, 0.2, 0.3, 0.4,
495///     3.0, 4.0, 0.3, 0.4, 0.5,
496///     4.0, 5.0, 0.4, 0.5, 0.6,
497///     5.0, 6.0, 0.5, 0.6, 0.7,
498///     6.0, 7.0, 0.6, 0.7, 0.8,
499///     7.0, 8.0, 0.7, 0.8, 0.9,
500///     8.0, 9.0, 0.8, 0.9, 1.0,
501///     9.0, 10.0, 0.9, 1.0, 1.1,
502///     10.0, 11.0, 1.0, 1.1, 1.2,
503/// ]).expect("Operation failed");
504///
505/// // Target values depend only on first two variables
506/// let y = array![5.0, 8.0, 11.0, 14.0, 17.0, 20.0, 23.0, 26.0, 29.0, 32.0];
507///
508/// // Perform lasso regression with alpha=0.1
509/// let result = lasso_regression(&x.view(), &y.view(), Some(0.1), None, None, None, None, None).expect("Operation failed");
510///
511/// // Check that we got coefficients
512/// assert!(result.coefficients.len() > 0);
513///
514/// // Typically, lasso would drive coefficients of irrelevant features toward zero
515/// ```
516#[allow(clippy::too_many_arguments)]
517#[allow(dead_code)]
518pub fn lasso_regression<F>(
519    x: &ArrayView2<F>,
520    y: &ArrayView1<F>,
521    alpha: Option<F>,
522    fit_intercept: Option<bool>,
523    normalize: Option<bool>,
524    tol: Option<F>,
525    max_iter: Option<usize>,
526    conf_level: Option<F>,
527) -> StatsResult<RegressionResults<F>>
528where
529    F: Float
530        + std::iter::Sum<F>
531        + std::ops::Div<Output = F>
532        + std::fmt::Debug
533        + std::fmt::Display
534        + 'static
535        + scirs2_core::numeric::NumAssign
536        + scirs2_core::numeric::One
537        + scirs2_core::ndarray::ScalarOperand
538        + Send
539        + Sync,
540{
541    // Check input dimensions
542    if x.nrows() != y.len() {
543        return Err(StatsError::DimensionMismatch(format!(
544            "Input x has {} rows but y has length {}",
545            x.nrows(),
546            y.len()
547        )));
548    }
549
550    let n = x.nrows();
551    let p_features = x.ncols();
552
553    // Set default parameters
554    let alpha = alpha.unwrap_or_else(|| F::from(1.0).expect("Failed to convert constant to float"));
555    let fit_intercept = fit_intercept.unwrap_or(true);
556    let normalize = normalize.unwrap_or(false);
557    let tol = tol.unwrap_or_else(|| F::from(1e-4).expect("Failed to convert constant to float"));
558    let max_iter = max_iter.unwrap_or(1000);
559    let conf_level =
560        conf_level.unwrap_or_else(|| F::from(0.95).expect("Failed to convert constant to float"));
561
562    if alpha < F::zero() {
563        return Err(StatsError::InvalidArgument(
564            "alpha must be non-negative".to_string(),
565        ));
566    }
567
568    // Preprocess x and y
569    let (x_processed, y_mean, x_mean, x_std) = preprocessdata(x, y, fit_intercept, normalize)?;
570
571    // Total number of coefficients (including _intercept if fitted)
572    let p = if fit_intercept {
573        p_features + 1
574    } else {
575        p_features
576    };
577
578    // We need at least 2 observations for meaningful regression
579    if n < 2 {
580        return Err(StatsError::InvalidArgument(
581            "At least 2 observations required for lasso regression".to_string(),
582        ));
583    }
584
585    // Initialize coefficients
586    let mut coefficients = Array1::<F>::zeros(p);
587
588    // Calculate X'X and X'y for faster computations
589    let xtx = x_processed.t().dot(&x_processed);
590    let xty = x_processed.t().dot(y);
591
592    // Coordinate descent algorithm for lasso
593    let mut converged = false;
594    let mut _iter = 0;
595
596    while !converged && _iter < max_iter {
597        converged = true;
598
599        // Save old coefficients for convergence check
600        let old_coefs = coefficients.clone();
601
602        // Update each coefficient in turn
603        for j in 0..p {
604            // Calculate partial residual
605            let r_partial = xty[j]
606                - xtx
607                    .row(j)
608                    .iter()
609                    .zip(coefficients.iter())
610                    .enumerate()
611                    .filter(|&(i_, _)| i_ != j)
612                    .map(|(_, (&xtx_ij, &coef_i))| xtx_ij * coef_i)
613                    .sum::<F>();
614
615            // Apply soft thresholding
616            let xtx_jj = xtx[[j, j]];
617            if xtx_jj < F::epsilon() {
618                coefficients[j] = F::zero();
619                continue;
620            }
621
622            if j == 0 && fit_intercept {
623                // No penalty for _intercept
624                coefficients[j] = r_partial / xtx_jj;
625            } else {
626                // Apply soft thresholding for L1 penalty
627                if crate::regression::utils::float_abs(r_partial) <= alpha {
628                    coefficients[j] = F::zero();
629                } else if r_partial > F::zero() {
630                    coefficients[j] = (r_partial - alpha) / xtx_jj;
631                } else {
632                    coefficients[j] = (r_partial + alpha) / xtx_jj;
633                }
634            }
635        }
636
637        // Check for convergence
638        let coef_diff = (&coefficients - &old_coefs)
639            .mapv(|x| scirs2_core::numeric::Float::abs(x))
640            .sum();
641        let coef_norm = old_coefs
642            .mapv(|x| scirs2_core::numeric::Float::abs(x))
643            .sum()
644            .max(F::epsilon());
645
646        if coef_diff / coef_norm < tol {
647            converged = true;
648        }
649
650        _iter += 1;
651    }
652
653    // If data was normalized/centered, transform coefficients back
654    let transformed_coefficients = if normalize || fit_intercept {
655        transform_coefficients(&coefficients, y_mean, &x_mean, &x_std, fit_intercept)
656    } else {
657        coefficients.clone()
658    };
659
660    // Calculate fitted values and residuals
661    let x_design = if fit_intercept {
662        add_intercept(x)
663    } else {
664        x.to_owned()
665    };
666
667    let fitted_values = x_design.dot(&transformed_coefficients);
668    let residuals = y.to_owned() - &fitted_values;
669
670    // Calculate degrees of freedom
671    // For lasso, df = number of non-zero coefficients
672    let nonzero_coefs = transformed_coefficients
673        .iter()
674        .filter(|&&x| crate::regression::utils::float_abs(x) > F::epsilon())
675        .count();
676    let df_model = nonzero_coefs - if fit_intercept { 1 } else { 0 };
677    let df_residuals = n - nonzero_coefs;
678
679    // Calculate sum of squares
680    let (_y_mean, ss_total, ss_residual, ss_explained) =
681        calculate_sum_of_squares(y, &residuals.view());
682
683    // Calculate R-squared and adjusted R-squared
684    let r_squared = ss_explained / ss_total;
685    let adj_r_squared = F::one()
686        - (F::one() - r_squared) * F::from(n - 1).expect("Failed to convert to float")
687            / F::from(df_residuals).expect("Failed to convert to float");
688
689    // Calculate mean squared error and residual standard error
690    let mse = ss_residual / F::from(df_residuals).expect("Failed to convert to float");
691    let residual_std_error = scirs2_core::numeric::Float::sqrt(mse);
692
693    // Calculate standard errors for coefficients (approximate)
694    let std_errors = match calculate_lasso_std_errors(
695        &x_design.view(),
696        &residuals.view(),
697        &transformed_coefficients,
698        df_residuals,
699    ) {
700        Ok(se) => se,
701        Err(_) => Array1::<F>::zeros(p),
702    };
703
704    // Calculate t-values
705    let t_values = calculate_t_values(&transformed_coefficients, &std_errors);
706
707    // Calculate p-values (simplified)
708    let p_values = t_values.mapv(|t| {
709        let t_abs = crate::regression::utils::float_abs(t);
710        let df_f = F::from(df_residuals).expect("Failed to convert to float");
711        let ratio = t_abs / crate::regression::utils::float_sqrt(df_f + t_abs * t_abs);
712        let one_minus_ratio = F::one() - ratio;
713        F::from(2.0).expect("Failed to convert constant to float") * one_minus_ratio
714    });
715
716    // Calculate confidence intervals
717    let mut conf_intervals = Array2::<F>::zeros((p, 2));
718    let z = norm_ppf(
719        F::from(0.5).expect("Failed to convert constant to float") * (F::one() + conf_level),
720    );
721
722    for i in 0..p {
723        let margin = std_errors[i] * z;
724        conf_intervals[[i, 0]] = transformed_coefficients[i] - margin;
725        conf_intervals[[i, 1]] = transformed_coefficients[i] + margin;
726    }
727
728    // Calculate F-statistic
729    let f_statistic = if df_model > 0 && df_residuals > 0 {
730        (ss_explained / F::from(df_model).expect("Failed to convert to float"))
731            / (ss_residual / F::from(df_residuals).expect("Failed to convert to float"))
732    } else {
733        F::infinity()
734    };
735
736    // Calculate p-value for F-statistic (simplified)
737    let f_p_value = F::zero(); // In a real implementation, use F-distribution
738
739    // Create and return the results structure
740    Ok(RegressionResults {
741        coefficients: transformed_coefficients,
742        std_errors,
743        t_values,
744        p_values,
745        conf_intervals,
746        r_squared,
747        adj_r_squared,
748        f_statistic,
749        f_p_value,
750        residual_std_error,
751        df_residuals,
752        residuals,
753        fitted_values,
754        inlier_mask: vec![true; n], // All points are inliers in lasso regression
755    })
756}
757
758/// Calculate standard errors for lasso regression
759#[allow(dead_code)]
760fn calculate_lasso_std_errors<F>(
761    x: &ArrayView2<F>,
762    residuals: &ArrayView1<F>,
763    coefficients: &Array1<F>,
764    df: usize,
765) -> StatsResult<Array1<F>>
766where
767    F: Float
768        + std::iter::Sum<F>
769        + std::ops::Div<Output = F>
770        + 'static
771        + scirs2_core::numeric::NumAssign
772        + scirs2_core::numeric::One
773        + scirs2_core::ndarray::ScalarOperand
774        + std::fmt::Display
775        + Send
776        + Sync,
777{
778    // Calculate the mean squared error of the residuals
779    let mse = residuals
780        .iter()
781        .map(|&r| scirs2_core::numeric::Float::powi(r, 2))
782        .sum::<F>()
783        / F::from(df).expect("Failed to convert to float");
784
785    // Find non-zero coefficients
786    let p = coefficients.len();
787    let mut active_set = Vec::new();
788
789    for j in 0..p {
790        if crate::regression::utils::float_abs(coefficients[j]) > F::epsilon() {
791            active_set.push(j);
792        }
793    }
794
795    // If no active features, return zeros
796    if active_set.is_empty() {
797        return Ok(Array1::<F>::zeros(p));
798    }
799
800    // Calculate X_active'X_active for active features
801    let n_active = active_set.len();
802    let mut xtx_active = Array2::<F>::zeros((n_active, n_active));
803
804    for (i, &idx_i) in active_set.iter().enumerate() {
805        for (j, &idx_j) in active_set.iter().enumerate() {
806            let x_i = x.column(idx_i);
807            let x_j = x.column(idx_j);
808
809            xtx_active[[i, j]] = x_i.iter().zip(x_j.iter()).map(|(&xi, &xj)| xi * xj).sum();
810        }
811    }
812
813    // Invert X_active'X_active
814    let xtx_active_inv = match inv(&xtx_active.view(), None) {
815        Ok(inv_result) => inv_result,
816        Err(_) => {
817            // If inversion fails, return zeros for standard errors
818            return Ok(Array1::<F>::zeros(p));
819        }
820    };
821
822    // Create full standard error vector
823    let mut std_errors = Array1::<F>::zeros(p);
824
825    for (i, &idx) in active_set.iter().enumerate() {
826        std_errors[idx] = scirs2_core::numeric::Float::sqrt(xtx_active_inv[[i, i]] * mse);
827    }
828
829    Ok(std_errors)
830}
831
832/// Perform elastic net regression (L1 + L2 regularization).
833///
834/// Elastic net combines L1 and L2 penalties, offering a compromise between
835/// lasso and ridge regression.
836///
837/// # Arguments
838///
839/// * `x` - Independent variables (design matrix)
840/// * `y` - Dependent variable
841/// * `alpha` - Total regularization strength (default: 1.0)
842/// * `l1_ratio` - Ratio of L1 penalty (default: 0.5, 0 = ridge, 1 = lasso)
843/// * `fit_intercept` - Whether to fit an intercept term (default: true)
844/// * `normalize` - Whether to normalize the data before fitting (default: false)
845/// * `tol` - Convergence tolerance (default: 1e-4)
846/// * `max_iter` - Maximum number of iterations (default: 1000)
847/// * `conf_level` - Confidence level for confidence intervals (default: 0.95)
848///
849/// # Returns
850///
851/// A RegressionResults struct with the regression results.
852///
853/// # Examples
854///
855/// ```ignore
856/// use scirs2_core::ndarray::{array, Array2};
857/// use scirs2_stats::elastic_net;
858///
859/// // Create a design matrix with 5 variables
860/// let x = Array2::from_shape_vec((10, 5), vec![
861///     1.0, 2.0, 0.1, 0.2, 0.3,
862///     2.0, 3.0, 0.2, 0.3, 0.4,
863///     3.0, 4.0, 0.3, 0.4, 0.5,
864///     4.0, 5.0, 0.4, 0.5, 0.6,
865///     5.0, 6.0, 0.5, 0.6, 0.7,
866///     6.0, 7.0, 0.6, 0.7, 0.8,
867///     7.0, 8.0, 0.7, 0.8, 0.9,
868///     8.0, 9.0, 0.8, 0.9, 1.0,
869///     9.0, 10.0, 0.9, 1.0, 1.1,
870///     10.0, 11.0, 1.0, 1.1, 1.2,
871/// ]).expect("Operation failed");
872///
873/// // Target values
874/// let y = array![5.0, 8.0, 11.0, 14.0, 17.0, 20.0, 23.0, 26.0, 29.0, 32.0];
875///
876/// // Perform elastic net regression with alpha=0.1 and l1_ratio=0.5
877/// let result = elastic_net(&x.view(), &y.view(), Some(0.1), Some(0.5), None, None, None, None, None).expect("Operation failed");
878///
879/// // Check that we got coefficients
880/// assert!(result.coefficients.len() > 0);
881/// ```
882#[allow(clippy::too_many_arguments)]
883#[allow(dead_code)]
884pub fn elastic_net<F>(
885    x: &ArrayView2<F>,
886    y: &ArrayView1<F>,
887    alpha: Option<F>,
888    l1_ratio: Option<F>,
889    fit_intercept: Option<bool>,
890    normalize: Option<bool>,
891    tol: Option<F>,
892    max_iter: Option<usize>,
893    conf_level: Option<F>,
894) -> StatsResult<RegressionResults<F>>
895where
896    F: Float
897        + std::iter::Sum<F>
898        + std::ops::Div<Output = F>
899        + std::fmt::Debug
900        + std::fmt::Display
901        + 'static
902        + scirs2_core::numeric::NumAssign
903        + scirs2_core::numeric::One
904        + scirs2_core::ndarray::ScalarOperand
905        + Send
906        + Sync,
907{
908    // Check input dimensions
909    if x.nrows() != y.len() {
910        return Err(StatsError::DimensionMismatch(format!(
911            "Input x has {} rows but y has length {}",
912            x.nrows(),
913            y.len()
914        )));
915    }
916
917    let n = x.nrows();
918    let p_features = x.ncols();
919
920    // Set default parameters
921    let alpha = alpha.unwrap_or_else(|| F::from(1.0).expect("Failed to convert constant to float"));
922    let l1_ratio =
923        l1_ratio.unwrap_or_else(|| F::from(0.5).expect("Failed to convert constant to float"));
924    let fit_intercept = fit_intercept.unwrap_or(true);
925    let normalize = normalize.unwrap_or(false);
926    let tol = tol.unwrap_or_else(|| F::from(1e-4).expect("Failed to convert constant to float"));
927    let max_iter = max_iter.unwrap_or(1000);
928    let conf_level =
929        conf_level.unwrap_or_else(|| F::from(0.95).expect("Failed to convert constant to float"));
930
931    if alpha < F::zero() {
932        return Err(StatsError::InvalidArgument(
933            "alpha must be non-negative".to_string(),
934        ));
935    }
936
937    if l1_ratio < F::zero() || l1_ratio > F::one() {
938        return Err(StatsError::InvalidArgument(
939            "l1_ratio must be between 0 and 1".to_string(),
940        ));
941    }
942
943    // If l1_ratio is 0, it's ridge regression
944    if l1_ratio < F::epsilon() {
945        return ridge_regression(
946            x,
947            y,
948            Some(alpha),
949            Some(fit_intercept),
950            Some(normalize),
951            Some(tol),
952            Some(max_iter),
953            Some(conf_level),
954        );
955    }
956
957    // If l1_ratio is 1, it's lasso regression
958    if crate::regression::utils::float_abs(l1_ratio - F::one()) < F::epsilon() {
959        return lasso_regression(
960            x,
961            y,
962            Some(alpha),
963            Some(fit_intercept),
964            Some(normalize),
965            Some(tol),
966            Some(max_iter),
967            Some(conf_level),
968        );
969    }
970
971    // Preprocess x and y
972    let (x_processed, y_mean, x_mean, x_std) = preprocessdata(x, y, fit_intercept, normalize)?;
973
974    // Total number of coefficients (including _intercept if fitted)
975    let p = if fit_intercept {
976        p_features + 1
977    } else {
978        p_features
979    };
980
981    // We need at least 2 observations for meaningful regression
982    if n < 2 {
983        return Err(StatsError::InvalidArgument(
984            "At least 2 observations required for elastic net regression".to_string(),
985        ));
986    }
987
988    // Initialize coefficients
989    let mut coefficients = Array1::<F>::zeros(p);
990
991    // Calculate X'X and X'y for faster computations
992    let xtx = x_processed.t().dot(&x_processed);
993    let xty = x_processed.t().dot(y);
994
995    // Elastic net parameters
996    let alpha_l1 = alpha * l1_ratio;
997    let one_minus_l1_ratio = F::one() - l1_ratio;
998    let alpha_l2 = alpha * one_minus_l1_ratio;
999
1000    // Coordinate descent algorithm for elastic net
1001    let mut converged = false;
1002    let mut _iter = 0;
1003
1004    while !converged && _iter < max_iter {
1005        converged = true;
1006
1007        // Save old coefficients for convergence check
1008        let old_coefs = coefficients.clone();
1009
1010        // Update each coefficient in turn
1011        for j in 0..p {
1012            // Calculate partial residual
1013            let r_partial = xty[j]
1014                - xtx
1015                    .row(j)
1016                    .iter()
1017                    .zip(coefficients.iter())
1018                    .enumerate()
1019                    .filter(|&(i_, _)| i_ != j)
1020                    .map(|(_, (&xtx_ij, &coef_i))| xtx_ij * coef_i)
1021                    .sum::<F>();
1022
1023            // Apply soft thresholding with L2 adjustment
1024            let xtx_jj = xtx[[j, j]] + alpha_l2;
1025            if xtx_jj < F::epsilon() {
1026                coefficients[j] = F::zero();
1027                continue;
1028            }
1029
1030            if j == 0 && fit_intercept {
1031                // No L1 penalty for _intercept
1032                coefficients[j] = r_partial / xtx_jj;
1033            } else {
1034                // Apply soft thresholding for L1 penalty
1035                if crate::regression::utils::float_abs(r_partial) <= alpha_l1 {
1036                    coefficients[j] = F::zero();
1037                } else if r_partial > F::zero() {
1038                    coefficients[j] = (r_partial - alpha_l1) / xtx_jj;
1039                } else {
1040                    coefficients[j] = (r_partial + alpha_l1) / xtx_jj;
1041                }
1042            }
1043        }
1044
1045        // Check for convergence
1046        let coef_diff = (&coefficients - &old_coefs)
1047            .mapv(|x| scirs2_core::numeric::Float::abs(x))
1048            .sum();
1049        let coef_norm = old_coefs
1050            .mapv(|x| scirs2_core::numeric::Float::abs(x))
1051            .sum()
1052            .max(F::epsilon());
1053
1054        if coef_diff / coef_norm < tol {
1055            converged = true;
1056        }
1057
1058        _iter += 1;
1059    }
1060
1061    // If data was normalized/centered, transform coefficients back
1062    let transformed_coefficients = if normalize || fit_intercept {
1063        transform_coefficients(&coefficients, y_mean, &x_mean, &x_std, fit_intercept)
1064    } else {
1065        coefficients.clone()
1066    };
1067
1068    // Calculate fitted values and residuals
1069    let x_design = if fit_intercept {
1070        add_intercept(x)
1071    } else {
1072        x.to_owned()
1073    };
1074
1075    let fitted_values = x_design.dot(&transformed_coefficients);
1076    let residuals = y.to_owned() - &fitted_values;
1077
1078    // Calculate degrees of freedom
1079    // For elastic net, df = number of non-zero coefficients, adjusted for L2 penalty
1080    let nonzero_coefs = transformed_coefficients
1081        .iter()
1082        .filter(|&&x| crate::regression::utils::float_abs(x) > F::epsilon())
1083        .count();
1084    let df_model = nonzero_coefs - if fit_intercept { 1 } else { 0 };
1085    let df_residuals = n - nonzero_coefs;
1086
1087    // Calculate sum of squares
1088    let (_y_mean, ss_total, ss_residual, ss_explained) =
1089        calculate_sum_of_squares(y, &residuals.view());
1090
1091    // Calculate R-squared and adjusted R-squared
1092    let r_squared = ss_explained / ss_total;
1093    let adj_r_squared = F::one()
1094        - (F::one() - r_squared) * F::from(n - 1).expect("Failed to convert to float")
1095            / F::from(df_residuals).expect("Failed to convert to float");
1096
1097    // Calculate mean squared error and residual standard error
1098    let mse = ss_residual / F::from(df_residuals).expect("Failed to convert to float");
1099    let residual_std_error = scirs2_core::numeric::Float::sqrt(mse);
1100
1101    // Calculate standard errors for coefficients (approximate)
1102    let std_errors = match calculate_elastic_net_std_errors(
1103        &x_design.view(),
1104        &residuals.view(),
1105        &transformed_coefficients,
1106        alpha_l2,
1107        df_residuals,
1108    ) {
1109        Ok(se) => se,
1110        Err(_) => Array1::<F>::zeros(p),
1111    };
1112
1113    // Calculate t-values
1114    let t_values = calculate_t_values(&transformed_coefficients, &std_errors);
1115
1116    // Calculate p-values (simplified)
1117    let p_values = t_values.mapv(|t| {
1118        let t_abs = crate::regression::utils::float_abs(t);
1119        let df_f = F::from(df_residuals).expect("Failed to convert to float");
1120        let _ratio = t_abs / crate::regression::utils::float_sqrt(df_f + t_abs * t_abs);
1121        let one_minus_ratio = F::one() - _ratio;
1122        F::from(2.0).expect("Failed to convert constant to float") * one_minus_ratio
1123    });
1124
1125    // Calculate confidence intervals
1126    let mut conf_intervals = Array2::<F>::zeros((p, 2));
1127    let z = norm_ppf(
1128        F::from(0.5).expect("Failed to convert constant to float") * (F::one() + conf_level),
1129    );
1130
1131    for i in 0..p {
1132        let margin = std_errors[i] * z;
1133        conf_intervals[[i, 0]] = transformed_coefficients[i] - margin;
1134        conf_intervals[[i, 1]] = transformed_coefficients[i] + margin;
1135    }
1136
1137    // Calculate F-statistic
1138    let f_statistic = if df_model > 0 && df_residuals > 0 {
1139        (ss_explained / F::from(df_model).expect("Failed to convert to float"))
1140            / (ss_residual / F::from(df_residuals).expect("Failed to convert to float"))
1141    } else {
1142        F::infinity()
1143    };
1144
1145    // Calculate p-value for F-statistic (simplified)
1146    let f_p_value = F::zero(); // In a real implementation, use F-distribution
1147
1148    // Create and return the results structure
1149    Ok(RegressionResults {
1150        coefficients: transformed_coefficients,
1151        std_errors,
1152        t_values,
1153        p_values,
1154        conf_intervals,
1155        r_squared,
1156        adj_r_squared,
1157        f_statistic,
1158        f_p_value,
1159        residual_std_error,
1160        df_residuals,
1161        residuals,
1162        fitted_values,
1163        inlier_mask: vec![true; n], // All points are inliers in elastic net regression
1164    })
1165}
1166
1167/// Calculate standard errors for elastic net regression
1168#[allow(dead_code)]
1169fn calculate_elastic_net_std_errors<F>(
1170    x: &ArrayView2<F>,
1171    residuals: &ArrayView1<F>,
1172    coefficients: &Array1<F>,
1173    alpha_l2: F,
1174    df: usize,
1175) -> StatsResult<Array1<F>>
1176where
1177    F: Float
1178        + std::iter::Sum<F>
1179        + std::ops::Div<Output = F>
1180        + 'static
1181        + scirs2_core::numeric::NumAssign
1182        + scirs2_core::numeric::One
1183        + scirs2_core::ndarray::ScalarOperand
1184        + std::fmt::Display
1185        + Send
1186        + Sync,
1187{
1188    // Calculate the mean squared error of the residuals
1189    let mse = residuals
1190        .iter()
1191        .map(|&r| scirs2_core::numeric::Float::powi(r, 2))
1192        .sum::<F>()
1193        / F::from(df).expect("Failed to convert to float");
1194
1195    // Find non-zero coefficients
1196    let p = coefficients.len();
1197    let mut active_set = Vec::new();
1198
1199    for j in 0..p {
1200        if crate::regression::utils::float_abs(coefficients[j]) > F::epsilon() {
1201            active_set.push(j);
1202        }
1203    }
1204
1205    // If no active features, return zeros
1206    if active_set.is_empty() {
1207        return Ok(Array1::<F>::zeros(p));
1208    }
1209
1210    // Calculate X_active'X_active for active features
1211    let n_active = active_set.len();
1212    let mut xtx_active = Array2::<F>::zeros((n_active, n_active));
1213
1214    for (i, &idx_i) in active_set.iter().enumerate() {
1215        for (j, &idx_j) in active_set.iter().enumerate() {
1216            let x_i = x.column(idx_i);
1217            let x_j = x.column(idx_j);
1218
1219            xtx_active[[i, j]] = x_i.iter().zip(x_j.iter()).map(|(&xi, &xj)| xi * xj).sum();
1220
1221            // Add L2 penalty to diagonal
1222            if i == j {
1223                xtx_active[[i, j]] += alpha_l2;
1224            }
1225        }
1226    }
1227
1228    // Invert (X_active'X_active + alpha_l2*I)
1229    let xtx_active_inv = match inv(&xtx_active.view(), None) {
1230        Ok(inv_result) => inv_result,
1231        Err(_) => {
1232            // If inversion fails, return zeros for standard errors
1233            return Ok(Array1::<F>::zeros(p));
1234        }
1235    };
1236
1237    // Create full standard error vector
1238    let mut std_errors = Array1::<F>::zeros(p);
1239
1240    for (i, &idx) in active_set.iter().enumerate() {
1241        std_errors[idx] = scirs2_core::numeric::Float::sqrt(xtx_active_inv[[i, i]] * mse);
1242    }
1243
1244    Ok(std_errors)
1245}
1246
1247/// Perform group lasso regression (L1/L2 regularization with grouped variables).
1248///
1249/// Group lasso allows variables to be grouped together such that they are
1250/// either all included or all excluded from the model.
1251///
1252/// # Arguments
1253///
1254/// * `x` - Independent variables (design matrix)
1255/// * `y` - Dependent variable
1256/// * `groups` - Vector of group indices for each feature (0-based)
1257/// * `alpha` - Regularization strength (default: 1.0)
1258/// * `fit_intercept` - Whether to fit an intercept term (default: true)
1259/// * `normalize` - Whether to normalize the data before fitting (default: false)
1260/// * `tol` - Convergence tolerance (default: 1e-4)
1261/// * `max_iter` - Maximum number of iterations (default: 1000)
1262/// * `conf_level` - Confidence level for confidence intervals (default: 0.95)
1263///
1264/// # Returns
1265///
1266/// A RegressionResults struct with the regression results.
1267///
1268/// # Examples
1269///
1270/// ```ignore
1271/// use scirs2_core::ndarray::{array, Array2};
1272/// use scirs2_stats::group_lasso;
1273///
1274/// // Create a design matrix with 6 variables in 2 groups
1275/// let x = Array2::from_shape_vec((10, 6), vec![
1276///     1.0, 2.0, 3.0, 0.1, 0.2, 0.3,
1277///     2.0, 3.0, 4.0, 0.2, 0.3, 0.4,
1278///     3.0, 4.0, 5.0, 0.3, 0.4, 0.5,
1279///     4.0, 5.0, 6.0, 0.4, 0.5, 0.6,
1280///     5.0, 6.0, 7.0, 0.5, 0.6, 0.7,
1281///     6.0, 7.0, 8.0, 0.6, 0.7, 0.8,
1282///     7.0, 8.0, 9.0, 0.7, 0.8, 0.9,
1283///     8.0, 9.0, 10.0, 0.8, 0.9, 1.0,
1284///     9.0, 10.0, 11.0, 0.9, 1.0, 1.1,
1285///     10.0, 11.0, 12.0, 1.0, 1.1, 1.2,
1286/// ]).expect("Operation failed");
1287///
1288/// // Target values depend only on the first group (first 3 variables)
1289/// let y = array![10.0, 15.0, 20.0, 25.0, 30.0, 35.0, 40.0, 45.0, 50.0, 55.0];
1290///
1291/// // Define groups: first 3 variables in group 0, next 3 in group 1
1292/// let groups = vec![0, 0, 0, 1, 1, 1];
1293///
1294/// // Perform group lasso regression with alpha=0.1
1295/// let result = group_lasso(&x.view(), &y.view(), &groups, Some(0.1), None, None, None, None, None).expect("Operation failed");
1296///
1297/// // Check that we got coefficients
1298/// assert!(result.coefficients.len() > 0);
1299///
1300/// // Group lasso should ideally set all coefficients in group 1 to zero or near-zero
1301/// ```
1302#[allow(clippy::too_many_arguments)]
1303#[allow(dead_code)]
1304pub fn group_lasso<F>(
1305    x: &ArrayView2<F>,
1306    y: &ArrayView1<F>,
1307    groups: &[usize],
1308    alpha: Option<F>,
1309    fit_intercept: Option<bool>,
1310    normalize: Option<bool>,
1311    tol: Option<F>,
1312    max_iter: Option<usize>,
1313    conf_level: Option<F>,
1314) -> StatsResult<RegressionResults<F>>
1315where
1316    F: Float
1317        + std::iter::Sum<F>
1318        + std::ops::Div<Output = F>
1319        + std::fmt::Debug
1320        + std::fmt::Display
1321        + 'static
1322        + scirs2_core::numeric::NumAssign
1323        + scirs2_core::numeric::One
1324        + scirs2_core::ndarray::ScalarOperand
1325        + Send
1326        + Sync,
1327{
1328    // Check input dimensions
1329    if x.nrows() != y.len() {
1330        return Err(StatsError::DimensionMismatch(format!(
1331            "Input x has {} rows but y has length {}",
1332            x.nrows(),
1333            y.len()
1334        )));
1335    }
1336
1337    if x.ncols() != groups.len() {
1338        return Err(StatsError::DimensionMismatch(format!(
1339            "Number of columns in x ({}) must match length of groups ({})",
1340            x.ncols(),
1341            groups.len()
1342        )));
1343    }
1344
1345    let n = x.nrows();
1346    let p_features = x.ncols();
1347
1348    // Set default parameters
1349    let alpha = alpha.unwrap_or_else(|| F::from(1.0).expect("Failed to convert constant to float"));
1350    let fit_intercept = fit_intercept.unwrap_or(true);
1351    let normalize = normalize.unwrap_or(false);
1352    let tol = tol.unwrap_or_else(|| F::from(1e-4).expect("Failed to convert constant to float"));
1353    let max_iter = max_iter.unwrap_or(1000);
1354    let conf_level =
1355        conf_level.unwrap_or_else(|| F::from(0.95).expect("Failed to convert constant to float"));
1356
1357    if alpha < F::zero() {
1358        return Err(StatsError::InvalidArgument(
1359            "alpha must be non-negative".to_string(),
1360        ));
1361    }
1362
1363    // Preprocess x and y
1364    let (x_processed, y_mean, x_mean, x_std) = preprocessdata(x, y, fit_intercept, normalize)?;
1365
1366    // Total number of coefficients (including _intercept if fitted)
1367    let p = if fit_intercept {
1368        p_features + 1
1369    } else {
1370        p_features
1371    };
1372
1373    // We need at least 2 observations for meaningful regression
1374    if n < 2 {
1375        return Err(StatsError::InvalidArgument(
1376            "At least 2 observations required for group lasso regression".to_string(),
1377        ));
1378    }
1379
1380    // Determine unique groups and group sizes
1381    let mut unique_groups = HashSet::new();
1382    for &g in groups {
1383        unique_groups.insert(g);
1384    }
1385
1386    let mut group_indices = Vec::new();
1387    for &g in &unique_groups {
1388        let mut indices = Vec::new();
1389        for (i, &group) in groups.iter().enumerate() {
1390            if group == g {
1391                indices.push(if fit_intercept { i + 1 } else { i });
1392            }
1393        }
1394        group_indices.push(indices);
1395    }
1396
1397    // Initialize coefficients
1398    let mut coefficients = Array1::<F>::zeros(p);
1399
1400    // Block coordinate descent algorithm for group lasso
1401    let mut converged = false;
1402    let mut _iter = 0;
1403
1404    while !converged && _iter < max_iter {
1405        converged = true;
1406
1407        // Save old coefficients for convergence check
1408        let old_coefs = coefficients.clone();
1409
1410        // Update _intercept if fitting
1411        if fit_intercept {
1412            let r = y - &x_processed
1413                .slice(s![.., 1..])
1414                .dot(&coefficients.slice(s![1..]));
1415            let r_sum: F = r.iter().cloned().sum();
1416            coefficients[0] = r_sum / F::from(r.len()).expect("Operation failed");
1417        }
1418
1419        // Update each group in turn
1420        for group in &group_indices {
1421            // Skip empty groups
1422            if group.is_empty() {
1423                continue;
1424            }
1425
1426            // Calculate partial residual for this group
1427            let mut r = y.to_owned();
1428
1429            // Subtract contribution of other variables
1430            for j in 0..p {
1431                if !group.contains(&j) {
1432                    let x_j = x_processed.column(j);
1433                    let beta_j = coefficients[j];
1434
1435                    for i in 0..n {
1436                        r[i] -= x_j[i] * beta_j;
1437                    }
1438                }
1439            }
1440
1441            // Extract group variables
1442            let mut x_group = Array2::<F>::zeros((n, group.len()));
1443            for (i, &idx) in group.iter().enumerate() {
1444                x_group.column_mut(i).assign(&x_processed.column(idx));
1445            }
1446
1447            // Calculate X_g'r
1448            let xtr = x_group.t().dot(&r);
1449
1450            // Calculate X_g'X_g
1451            let xtx = x_group.t().dot(&x_group);
1452
1453            // Calculate the group norm of X_g'r
1454            let xtr_norm = scirs2_core::numeric::Float::sqrt(
1455                xtr.iter()
1456                    .map(|&x| scirs2_core::numeric::Float::powi(x, 2))
1457                    .sum::<F>(),
1458            );
1459
1460            // Skip if the norm is too small
1461            if xtr_norm < alpha {
1462                for &idx in group {
1463                    coefficients[idx] = F::zero();
1464                }
1465                continue;
1466            }
1467
1468            // Solve for group coefficients
1469            let mut beta_group = match solve_group(xtr, xtx, alpha, tol, max_iter) {
1470                Ok(beta) => beta,
1471                Err(_) => Array1::<F>::zeros(group.len()),
1472            };
1473
1474            // Apply group shrinkage
1475            let beta_norm = scirs2_core::numeric::Float::sqrt(
1476                beta_group
1477                    .iter()
1478                    .map(|&x| scirs2_core::numeric::Float::powi(x, 2))
1479                    .sum::<F>(),
1480            );
1481            if beta_norm > F::epsilon() {
1482                let shrinkage = F::one().max((beta_norm - alpha) / beta_norm);
1483                beta_group = beta_group.mapv(|x| x * shrinkage);
1484            } else {
1485                beta_group.fill(F::zero());
1486            }
1487
1488            // Update coefficients
1489            for (i, &idx) in group.iter().enumerate() {
1490                coefficients[idx] = beta_group[i];
1491            }
1492        }
1493
1494        // Check for convergence
1495        let coef_diff = (&coefficients - &old_coefs)
1496            .mapv(|x| scirs2_core::numeric::Float::abs(x))
1497            .sum();
1498        let coef_norm = old_coefs
1499            .mapv(|x| scirs2_core::numeric::Float::abs(x))
1500            .sum()
1501            .max(F::epsilon());
1502
1503        if coef_diff / coef_norm < tol {
1504            converged = true;
1505        }
1506
1507        _iter += 1;
1508    }
1509
1510    // If data was normalized/centered, transform coefficients back
1511    let transformed_coefficients = if normalize || fit_intercept {
1512        transform_coefficients(&coefficients, y_mean, &x_mean, &x_std, fit_intercept)
1513    } else {
1514        coefficients.clone()
1515    };
1516
1517    // Calculate fitted values and residuals
1518    let x_design = if fit_intercept {
1519        add_intercept(x)
1520    } else {
1521        x.to_owned()
1522    };
1523
1524    let fitted_values = x_design.dot(&transformed_coefficients);
1525    let residuals = y.to_owned() - &fitted_values;
1526
1527    // Calculate degrees of freedom
1528    // For group lasso, df = sum of group sizes for non-zero groups
1529    let mut nonzero_coefs = 0;
1530    let mut nonzero_groups = HashSet::new();
1531
1532    for (i, &g) in groups.iter().enumerate() {
1533        let idx = if fit_intercept { i + 1 } else { i };
1534        if crate::regression::utils::float_abs(transformed_coefficients[idx]) > F::epsilon() {
1535            nonzero_groups.insert(g);
1536        }
1537    }
1538
1539    for &g in &nonzero_groups {
1540        let groupsize = groups.iter().filter(|&&group| group == g).count();
1541        nonzero_coefs += groupsize;
1542    }
1543
1544    if fit_intercept
1545        && crate::regression::utils::float_abs(transformed_coefficients[0]) > F::epsilon()
1546    {
1547        nonzero_coefs += 1;
1548    }
1549
1550    let df_model = nonzero_coefs - if fit_intercept { 1 } else { 0 };
1551    let df_residuals = n - nonzero_coefs;
1552
1553    // Calculate sum of squares
1554    let (_y_mean, ss_total, ss_residual, ss_explained) =
1555        calculate_sum_of_squares(y, &residuals.view());
1556
1557    // Calculate R-squared and adjusted R-squared
1558    let r_squared = ss_explained / ss_total;
1559    let adj_r_squared = F::one()
1560        - (F::one() - r_squared) * F::from(n - 1).expect("Failed to convert to float")
1561            / F::from(df_residuals).expect("Failed to convert to float");
1562
1563    // Calculate mean squared error and residual standard error
1564    let mse = ss_residual / F::from(df_residuals).expect("Failed to convert to float");
1565    let residual_std_error = scirs2_core::numeric::Float::sqrt(mse);
1566
1567    // Calculate standard errors for coefficients (approximate)
1568    let std_errors = match calculate_group_lasso_std_errors(
1569        &x_design.view(),
1570        &residuals.view(),
1571        &transformed_coefficients,
1572        groups,
1573        fit_intercept,
1574        df_residuals,
1575    ) {
1576        Ok(se) => se,
1577        Err(_) => Array1::<F>::zeros(p),
1578    };
1579
1580    // Calculate t-values
1581    let t_values = calculate_t_values(&transformed_coefficients, &std_errors);
1582
1583    // Calculate p-values (simplified)
1584    let p_values = t_values.mapv(|t| {
1585        let t_abs = crate::regression::utils::float_abs(t);
1586        let df_f = F::from(df_residuals).expect("Failed to convert to float");
1587        let ratio = t_abs / crate::regression::utils::float_sqrt(df_f + t_abs * t_abs);
1588        let one_minus_ratio = F::one() - ratio;
1589        F::from(2.0).expect("Failed to convert constant to float") * one_minus_ratio
1590    });
1591
1592    // Calculate confidence intervals
1593    let mut conf_intervals = Array2::<F>::zeros((p, 2));
1594    let z = norm_ppf(
1595        F::from(0.5).expect("Failed to convert constant to float") * (F::one() + conf_level),
1596    );
1597
1598    for i in 0..p {
1599        let margin = std_errors[i] * z;
1600        conf_intervals[[i, 0]] = transformed_coefficients[i] - margin;
1601        conf_intervals[[i, 1]] = transformed_coefficients[i] + margin;
1602    }
1603
1604    // Calculate F-statistic
1605    let f_statistic = if df_model > 0 && df_residuals > 0 {
1606        (ss_explained / F::from(df_model).expect("Failed to convert to float"))
1607            / (ss_residual / F::from(df_residuals).expect("Failed to convert to float"))
1608    } else {
1609        F::infinity()
1610    };
1611
1612    // Calculate p-value for F-statistic (simplified)
1613    let f_p_value = F::zero(); // In a real implementation, use F-distribution
1614
1615    // Create and return the results structure
1616    Ok(RegressionResults {
1617        coefficients: transformed_coefficients,
1618        std_errors,
1619        t_values,
1620        p_values,
1621        conf_intervals,
1622        r_squared,
1623        adj_r_squared,
1624        f_statistic,
1625        f_p_value,
1626        residual_std_error,
1627        df_residuals,
1628        residuals,
1629        fitted_values,
1630        inlier_mask: vec![true; n], // All points are inliers in group lasso regression
1631    })
1632}
1633
1634/// Solve group lasso subproblem for a single group
1635#[allow(dead_code)]
1636fn solve_group<F>(
1637    xtr: Array1<F>,
1638    xtx: Array2<F>,
1639    _alpha: F,
1640    tol: F,
1641    max_iter: usize,
1642) -> StatsResult<Array1<F>>
1643where
1644    F: Float
1645        + std::iter::Sum<F>
1646        + std::ops::Div<Output = F>
1647        + 'static
1648        + scirs2_core::numeric::NumAssign
1649        + scirs2_core::numeric::One
1650        + scirs2_core::ndarray::ScalarOperand
1651        + std::fmt::Display
1652        + Send
1653        + Sync,
1654{
1655    let p = xtr.len();
1656
1657    // Initialize beta to zero
1658    let mut beta = Array1::<F>::zeros(p);
1659
1660    // Try to solve directly if possible
1661    match inv(&xtx.view(), None) {
1662        Ok(xtx_inv) => {
1663            beta = xtx_inv.dot(&xtr);
1664            return Ok(beta);
1665        }
1666        Err(_) => {
1667            // If direct solution fails, use iterative method
1668        }
1669    }
1670
1671    // Iterative method: gradient descent
1672    let mut _iter = 0;
1673    let mut converged = false;
1674
1675    // Learning rate
1676    let lr = F::from(0.01).expect("Failed to convert constant to float");
1677
1678    while !converged && _iter < max_iter {
1679        let old_beta = beta.clone();
1680
1681        // Gradient of squared loss: -X'r + X'X * beta
1682        let xtx_beta = xtx.dot(&beta);
1683        let grad = &xtx_beta - &xtr;
1684
1685        // Update beta
1686        let lr_grad = grad.mapv(|g| g * lr);
1687        beta = &beta - &lr_grad;
1688
1689        // Check for convergence
1690        let beta_diff = (&beta - &old_beta)
1691            .mapv(|x| scirs2_core::numeric::Float::abs(x))
1692            .sum();
1693        let beta_norm = old_beta
1694            .mapv(|x| scirs2_core::numeric::Float::abs(x))
1695            .sum()
1696            .max(F::epsilon());
1697
1698        if beta_diff / beta_norm < tol {
1699            converged = true;
1700        }
1701
1702        _iter += 1;
1703    }
1704
1705    Ok(beta)
1706}
1707
1708/// Calculate standard errors for group lasso regression
1709#[allow(dead_code)]
1710fn calculate_group_lasso_std_errors<F>(
1711    x: &ArrayView2<F>,
1712    residuals: &ArrayView1<F>,
1713    coefficients: &Array1<F>,
1714    groups: &[usize],
1715    fit_intercept: bool,
1716    df: usize,
1717) -> StatsResult<Array1<F>>
1718where
1719    F: Float
1720        + std::iter::Sum<F>
1721        + std::ops::Div<Output = F>
1722        + 'static
1723        + scirs2_core::numeric::NumAssign
1724        + scirs2_core::numeric::One
1725        + scirs2_core::ndarray::ScalarOperand
1726        + std::fmt::Display
1727        + Send
1728        + Sync,
1729{
1730    // Calculate the mean squared error of the residuals
1731    let mse = residuals
1732        .iter()
1733        .map(|&r| scirs2_core::numeric::Float::powi(r, 2))
1734        .sum::<F>()
1735        / F::from(df).expect("Failed to convert to float");
1736
1737    // Find non-zero groups
1738    let p = coefficients.len();
1739    let mut active_groups = HashSet::new();
1740
1741    for (i, &g) in groups.iter().enumerate() {
1742        let idx = if fit_intercept { i + 1 } else { i };
1743        if crate::regression::utils::float_abs(coefficients[idx]) > F::epsilon() {
1744            active_groups.insert(g);
1745        }
1746    }
1747
1748    // Create active set of indices
1749    let mut active_set = Vec::new();
1750
1751    if fit_intercept && crate::regression::utils::float_abs(coefficients[0]) > F::epsilon() {
1752        active_set.push(0);
1753    }
1754
1755    for (i, &g) in groups.iter().enumerate() {
1756        if active_groups.contains(&g) {
1757            let idx = if fit_intercept { i + 1 } else { i };
1758            active_set.push(idx);
1759        }
1760    }
1761
1762    // If no active features, return zeros
1763    if active_set.is_empty() {
1764        return Ok(Array1::<F>::zeros(p));
1765    }
1766
1767    // Calculate X_active'X_active for active features
1768    let n_active = active_set.len();
1769    let mut xtx_active = Array2::<F>::zeros((n_active, n_active));
1770
1771    for (i, &idx_i) in active_set.iter().enumerate() {
1772        for (j, &idx_j) in active_set.iter().enumerate() {
1773            let x_i = x.column(idx_i);
1774            let x_j = x.column(idx_j);
1775
1776            xtx_active[[i, j]] = x_i.iter().zip(x_j.iter()).map(|(&xi, &xj)| xi * xj).sum();
1777        }
1778    }
1779
1780    // Invert X_active'X_active
1781    let xtx_active_inv = match inv(&xtx_active.view(), None) {
1782        Ok(inv_result) => inv_result,
1783        Err(_) => {
1784            // If inversion fails, return zeros for standard errors
1785            return Ok(Array1::<F>::zeros(p));
1786        }
1787    };
1788
1789    // Create full standard error vector
1790    let mut std_errors = Array1::<F>::zeros(p);
1791
1792    for (i, &idx) in active_set.iter().enumerate() {
1793        std_errors[idx] = scirs2_core::numeric::Float::sqrt(xtx_active_inv[[i, i]] * mse);
1794    }
1795
1796    Ok(std_errors)
1797}