sklears_linear/
utils.rs

1//! Utility functions for linear models
2//!
3//! This module provides standalone utility functions that implement
4//! core algorithms used by various linear models.
5
6use scirs2_core::ndarray::{Array1, Array2, Axis};
7use scirs2_linalg::compat::{qr, svd, ArrayLinalgExt};
8use sklears_core::{
9    error::{Result, SklearsError},
10    types::Float,
11};
12
13/// Type alias for rank-revealing QR decomposition result
14pub type RankRevealingQrResult = (Array2<Float>, Array2<Float>, Vec<usize>, usize);
15
16/// Orthogonal Matching Pursuit (OMP) algorithm
17///
18/// Solves the OMP problem: argmin ||y - X @ coef||^2 subject to ||coef||_0 <= n_nonzero_coefs
19///
20/// # Arguments
21/// * `x` - Design matrix of shape (n_samples, n_features)
22/// * `y` - Target values of shape (n_samples,)
23/// * `n_nonzero_coefs` - Maximum number of non-zero coefficients
24/// * `tol` - Tolerance for residual
25/// * `precompute` - Whether to precompute X.T @ X and X.T @ y
26///
27/// # Returns
28/// * Coefficient vector of shape (n_features,)
29pub fn orthogonal_mp(
30    x: &Array2<Float>,
31    y: &Array1<Float>,
32    n_nonzero_coefs: Option<usize>,
33    tol: Option<Float>,
34    precompute: bool,
35) -> Result<Array1<Float>> {
36    let n_samples = x.nrows();
37    let n_features = x.ncols();
38
39    if n_samples != y.len() {
40        return Err(SklearsError::InvalidInput(
41            "X and y have inconsistent numbers of samples".to_string(),
42        ));
43    }
44
45    let n_nonzero_coefs = n_nonzero_coefs.unwrap_or(n_features.min(n_samples));
46    let tol = tol.unwrap_or(1e-4);
47
48    // Initialize
49    let mut coef = Array1::zeros(n_features);
50    let mut residual = y.clone();
51    let mut selected = Vec::new();
52    let mut selected_mask = vec![false; n_features];
53
54    // Precompute if requested
55    let _gram = if precompute { Some(x.t().dot(x)) } else { None };
56
57    // Main OMP loop
58    for _ in 0..n_nonzero_coefs {
59        // Compute correlations with residual
60        let correlations = x.t().dot(&residual);
61
62        // Find the most correlated feature not yet selected
63        let mut best_idx = 0;
64        let mut best_corr = 0.0;
65
66        for (idx, &corr) in correlations.iter().enumerate() {
67            if !selected_mask[idx] && corr.abs() > best_corr {
68                best_corr = corr.abs();
69                best_idx = idx;
70            }
71        }
72
73        // Check convergence
74        if best_corr < tol {
75            break;
76        }
77
78        // Add to selected set
79        selected.push(best_idx);
80        selected_mask[best_idx] = true;
81
82        // Solve least squares on selected features
83        let x_selected = x.select(Axis(1), &selected);
84        let coef_selected = solve_least_squares(&x_selected, y)?;
85
86        // Update coefficients
87        for (i, &idx) in selected.iter().enumerate() {
88            coef[idx] = coef_selected[i];
89        }
90
91        // Update residual
92        residual = y - &x.dot(&coef);
93
94        // Check residual norm
95        let residual_norm = residual.dot(&residual).sqrt();
96        if residual_norm < tol {
97            break;
98        }
99    }
100
101    Ok(coef)
102}
103
104/// Orthogonal Matching Pursuit using precomputed Gram matrix
105///
106/// This is more efficient when n_features < n_samples and multiple OMP problems
107/// need to be solved with the same design matrix.
108///
109/// # Arguments
110/// * `gram` - Gram matrix X.T @ X of shape (n_features, n_features)
111/// * `xy` - X.T @ y of shape (n_features,)
112/// * `n_nonzero_coefs` - Maximum number of non-zero coefficients
113/// * `tol` - Tolerance for residual
114/// * `norms_squared` - Squared norms of each column of X (optional)
115///
116/// # Returns
117/// * Coefficient vector of shape (n_features,)
118pub fn orthogonal_mp_gram(
119    gram: &Array2<Float>,
120    xy: &Array1<Float>,
121    n_nonzero_coefs: Option<usize>,
122    tol: Option<Float>,
123    norms_squared: Option<&Array1<Float>>,
124) -> Result<Array1<Float>> {
125    let n_features = gram.nrows();
126
127    if gram.ncols() != n_features {
128        return Err(SklearsError::InvalidInput(
129            "Gram matrix must be square".to_string(),
130        ));
131    }
132
133    if xy.len() != n_features {
134        return Err(SklearsError::InvalidInput(
135            "xy must have length n_features".to_string(),
136        ));
137    }
138
139    let n_nonzero_coefs = n_nonzero_coefs.unwrap_or(n_features);
140    let tol = tol.unwrap_or(1e-4);
141
142    // Get squared norms from diagonal of Gram if not provided
143    let _norms_sq = match norms_squared {
144        Some(norms) => norms.clone(),
145        None => gram.diag().to_owned(),
146    };
147
148    // Initialize
149    let mut coef = Array1::zeros(n_features);
150    let mut selected = Vec::new();
151    let mut selected_mask = vec![false; n_features];
152    let mut correlations = xy.clone();
153
154    // Main OMP loop
155    for _ in 0..n_nonzero_coefs {
156        // Find the most correlated feature not yet selected
157        let mut best_idx = 0;
158        let mut best_corr = 0.0;
159
160        for (idx, &corr) in correlations.iter().enumerate() {
161            if !selected_mask[idx] && corr.abs() > best_corr {
162                best_corr = corr.abs();
163                best_idx = idx;
164            }
165        }
166
167        // Check convergence
168        if best_corr < tol {
169            break;
170        }
171
172        // Add to selected set
173        selected.push(best_idx);
174        selected_mask[best_idx] = true;
175
176        // Solve least squares on selected features using Gram matrix
177        let gram_selected = gram.select(Axis(0), &selected).select(Axis(1), &selected);
178        let xy_selected = xy.select(Axis(0), &selected);
179        let coef_selected = solve_gram_least_squares(&gram_selected, &xy_selected)?;
180
181        // Update coefficients
182        coef.fill(0.0);
183        for (i, &idx) in selected.iter().enumerate() {
184            coef[idx] = coef_selected[i];
185        }
186
187        // Update correlations
188        correlations = xy - &gram.dot(&coef);
189    }
190
191    Ok(coef)
192}
193
194/// Ridge regression solver
195///
196/// Solves the ridge regression problem: argmin ||y - X @ coef||^2 + alpha * ||coef||^2
197///
198/// # Arguments
199/// * `x` - Design matrix of shape (n_samples, n_features)
200/// * `y` - Target values of shape (n_samples,) or (n_samples, n_targets)
201/// * `alpha` - Regularization strength (must be positive)
202/// * `fit_intercept` - Whether to fit an intercept
203/// * `solver` - Solver to use ("auto", "svd", "cholesky", "lsqr", "sparse_cg", "sag", "saga")
204///
205/// # Returns
206/// * Coefficients of shape (n_features,) or (n_features, n_targets)
207/// * Intercept (scalar or array)
208pub fn ridge_regression(
209    x: &Array2<Float>,
210    y: &Array1<Float>,
211    alpha: Float,
212    fit_intercept: bool,
213    solver: &str,
214) -> Result<(Array1<Float>, Float)> {
215    let n_samples = x.nrows();
216    let n_features = x.ncols();
217
218    if n_samples != y.len() {
219        return Err(SklearsError::InvalidInput(
220            "X and y have inconsistent numbers of samples".to_string(),
221        ));
222    }
223
224    if alpha < 0.0 {
225        return Err(SklearsError::InvalidInput(
226            "alpha must be non-negative".to_string(),
227        ));
228    }
229
230    // Center data if fitting intercept
231    let (x_centered, y_centered, x_mean, y_mean) = if fit_intercept {
232        let x_mean = x.mean_axis(Axis(0)).unwrap();
233        let y_mean = y.mean().unwrap();
234        let x_centered = x - &x_mean;
235        let y_centered = y - y_mean;
236        (x_centered, y_centered, x_mean, y_mean)
237    } else {
238        (x.clone(), y.clone(), Array1::zeros(n_features), 0.0)
239    };
240
241    // Solve ridge regression based on solver
242    let coef = match solver {
243        "auto" | "cholesky" => {
244            // Use Cholesky decomposition: solve (X.T @ X + alpha * I) @ coef = X.T @ y
245            let mut gram = x_centered.t().dot(&x_centered);
246
247            // Add regularization to diagonal
248            for i in 0..n_features {
249                gram[[i, i]] += alpha * n_samples as Float;
250            }
251
252            let xy = x_centered.t().dot(&y_centered);
253            solve_cholesky(&gram, &xy)?
254        }
255        "svd" => {
256            // Use SVD decomposition
257            // Placeholder - would use actual SVD implementation
258            solve_svd_ridge(&x_centered, &y_centered, alpha)?
259        }
260        _ => {
261            return Err(SklearsError::InvalidInput(format!(
262                "Unknown solver: {}",
263                solver
264            )));
265        }
266    };
267
268    // Compute intercept
269    let intercept = if fit_intercept {
270        y_mean - x_mean.dot(&coef)
271    } else {
272        0.0
273    };
274
275    Ok((coef, intercept))
276}
277
278/// Solve least squares using normal equations
279fn solve_least_squares(x: &Array2<Float>, y: &Array1<Float>) -> Result<Array1<Float>> {
280    let gram = x.t().dot(x);
281    let xy = x.t().dot(y);
282    solve_cholesky(&gram, &xy)
283}
284
285/// Solve least squares given Gram matrix
286fn solve_gram_least_squares(gram: &Array2<Float>, xy: &Array1<Float>) -> Result<Array1<Float>> {
287    solve_cholesky(gram, xy)
288}
289
290/// Solve a linear system using Cholesky decomposition
291fn solve_cholesky(a: &Array2<Float>, b: &Array1<Float>) -> Result<Array1<Float>> {
292    let n = a.nrows();
293    if n != a.ncols() || n != b.len() {
294        return Err(SklearsError::InvalidInput(
295            "Invalid dimensions for linear solve".to_string(),
296        ));
297    }
298
299    // Use scirs2's linear solver which handles Cholesky decomposition
300    a.solve(&b)
301        .map_err(|e| SklearsError::NumericalError(format!("Cholesky decomposition failed: {}", e)))
302}
303
304/// Solve ridge regression using SVD
305fn solve_svd_ridge(x: &Array2<Float>, y: &Array1<Float>, alpha: Float) -> Result<Array1<Float>> {
306    svd_ridge_regression(x, y, alpha)
307}
308
309/// Numerically stable solution to normal equations using QR decomposition
310///
311/// Solves the least squares problem min ||Ax - b||^2 using QR decomposition
312/// instead of forming A^T A explicitly, which improves numerical stability.
313///
314/// # Arguments
315/// * `a` - Design matrix of shape (n_samples, n_features)
316/// * `b` - Target values of shape (n_samples,)
317/// * `rcond` - Cutoff for small singular values. If None, use machine precision.
318///
319/// # Returns
320/// * Solution vector x of shape (n_features,)
321pub fn stable_normal_equations(
322    a: &Array2<Float>,
323    b: &Array1<Float>,
324    rcond: Option<Float>,
325) -> Result<Array1<Float>> {
326    let n_samples = a.nrows();
327    let n_features = a.ncols();
328
329    if n_samples != b.len() {
330        return Err(SklearsError::InvalidInput(
331            "Matrix dimensions do not match".to_string(),
332        ));
333    }
334
335    if n_samples < n_features {
336        return Err(SklearsError::InvalidInput(
337            "Underdetermined system: more features than samples".to_string(),
338        ));
339    }
340
341    // Use QR decomposition via scirs2
342    let (q, r) = qr(&a.view())
343        .map_err(|e| SklearsError::NumericalError(format!("QR decomposition failed: {}", e)))?;
344
345    // Check for rank deficiency
346    let rcond = rcond.unwrap_or(Float::EPSILON * n_features.max(n_samples) as Float);
347    let r_diag_abs: Vec<Float> = (0..n_features.min(n_samples))
348        .map(|i| r[[i, i]].abs())
349        .collect();
350
351    let max_diag = r_diag_abs.iter().fold(0.0 as Float, |a, &b| a.max(b));
352    let rank = r_diag_abs.iter().filter(|&&x| x > rcond * max_diag).count();
353
354    if rank < n_features {
355        return Err(SklearsError::NumericalError(format!(
356            "Matrix is rank deficient: rank {} < {} features",
357            rank, n_features
358        )));
359    }
360
361    // Solve R x = Q^T b
362    let qtb = q.t().dot(b);
363
364    // Back substitution to solve R x = qtb
365    let mut x = Array1::zeros(n_features);
366    for i in (0..n_features).rev() {
367        let mut sum = qtb[i];
368        for j in (i + 1)..n_features {
369            sum -= r[[i, j]] * x[j];
370        }
371
372        if r[[i, i]].abs() < rcond * max_diag {
373            return Err(SklearsError::NumericalError(
374                "Matrix is singular within working precision".to_string(),
375            ));
376        }
377
378        x[i] = sum / r[[i, i]];
379    }
380
381    Ok(x)
382}
383
384/// Numerically stable solution to regularized normal equations
385///
386/// Solves the ridge regression problem min ||Ax - b||^2 + alpha * ||x||^2
387/// using SVD for numerical stability.
388///
389/// # Arguments
390/// * `a` - Design matrix of shape (n_samples, n_features)
391/// * `b` - Target values of shape (n_samples,)
392/// * `alpha` - Regularization parameter
393/// * `fit_intercept` - Whether the first column is the intercept (not regularized)
394///
395/// # Returns
396/// * Solution vector x of shape (n_features,)
397pub fn stable_ridge_regression(
398    a: &Array2<Float>,
399    b: &Array1<Float>,
400    alpha: Float,
401    _fit_intercept: bool,
402) -> Result<Array1<Float>> {
403    let n_samples = a.nrows();
404    let n_features = a.ncols();
405
406    if n_samples != b.len() {
407        return Err(SklearsError::InvalidInput(
408            "Matrix dimensions do not match".to_string(),
409        ));
410    }
411
412    if alpha < 0.0 {
413        return Err(SklearsError::InvalidInput(
414            "Regularization parameter must be non-negative".to_string(),
415        ));
416    }
417
418    // Use QR decomposition for numerical stability (temporary workaround)
419    // Form the normal equations: (A^T A + alpha * I) * x = A^T * b
420    let ata = a.t().dot(a);
421    let atb = a.t().dot(b);
422
423    // Add regularization to diagonal
424    let mut regularized_ata = ata;
425    for i in 0..n_features {
426        regularized_ata[[i, i]] += alpha;
427    }
428
429    // Solve using scirs2's linear solver
430    let x = regularized_ata
431        .solve(&atb)
432        .map_err(|e| SklearsError::NumericalError(format!("Linear solve failed: {}", e)))?;
433
434    Ok(x)
435}
436
437/// Check condition number of a matrix using SVD
438///
439/// Returns the condition number (ratio of largest to smallest singular value)
440/// which indicates numerical stability. Large condition numbers (>1e12) indicate
441/// ill-conditioned matrices that may lead to numerical issues.
442pub fn condition_number(a: &Array2<Float>) -> Result<Float> {
443    let n = a.nrows().min(a.ncols());
444    if n == 0 {
445        return Ok(1.0);
446    }
447
448    // For nearly singular matrices, compute the determinant and use it as a heuristic
449    // A matrix with very small determinant is likely ill-conditioned
450    if n == a.nrows() && n == a.ncols() {
451        // Square matrix - compute determinant heuristic
452        if n == 2 {
453            let det = a[[0, 0]] * a[[1, 1]] - a[[0, 1]] * a[[1, 0]];
454            let frobenius_norm = (a.mapv(|x| x * x).sum()).sqrt();
455            if det.abs() < 1e-10 * frobenius_norm * frobenius_norm {
456                return Ok(1e15); // Very ill-conditioned
457            }
458            // Estimate condition number from determinant and matrix norm
459            let scale = frobenius_norm / (n as Float).sqrt();
460            return Ok(scale * scale / det.abs());
461        }
462    }
463
464    // Fallback to diagonal-based heuristic for non-square or larger matrices
465    let mut diag_max = Float::NEG_INFINITY;
466    let mut diag_min = Float::INFINITY;
467
468    for i in 0..n {
469        let val = a[[i, i]].abs();
470        if val > Float::EPSILON {
471            diag_max = diag_max.max(val);
472            diag_min = diag_min.min(val);
473        }
474    }
475
476    if diag_min <= Float::EPSILON || diag_min == Float::INFINITY {
477        Ok(Float::INFINITY)
478    } else {
479        Ok(diag_max / diag_min)
480    }
481}
482
483/// Solve linear system with iterative refinement for improved accuracy
484///
485/// This function solves Ax = b with iterative refinement to improve the accuracy
486/// of the solution when dealing with ill-conditioned matrices.
487///
488/// # Arguments
489/// * `a` - Coefficient matrix
490/// * `b` - Right-hand side vector
491/// * `max_iter` - Maximum number of refinement iterations
492/// * `tol` - Convergence tolerance for refinement
493///
494/// # Returns
495/// * Refined solution vector
496pub fn solve_with_iterative_refinement(
497    a: &Array2<Float>,
498    b: &Array1<Float>,
499    max_iter: usize,
500    tol: Float,
501) -> Result<Array1<Float>> {
502    let n = a.nrows();
503    if n != a.ncols() || n != b.len() {
504        return Err(SklearsError::InvalidInput(
505            "Matrix must be square and dimensions must match".to_string(),
506        ));
507    }
508
509    // Get initial solution using direct method
510    let mut x = a
511        .solve(&b)
512        .map_err(|e| SklearsError::NumericalError(format!("Initial solve failed: {}", e)))?;
513
514    // Check if iterative refinement is needed
515    let cond = condition_number(a)?;
516    if cond < 1e8 {
517        // Matrix is well-conditioned, no refinement needed
518        return Ok(x);
519    }
520
521    // Iterative refinement loop
522    for iter in 0..max_iter {
523        // Compute residual: r = b - A*x
524        let ax = a.dot(&x);
525        let residual = b - &ax;
526
527        // Check convergence
528        let residual_norm = residual.iter().map(|&x| x * x).sum::<Float>().sqrt();
529        let b_norm = b.iter().map(|&x| x * x).sum::<Float>().sqrt();
530
531        if residual_norm <= tol * b_norm {
532            log::debug!("Iterative refinement converged after {} iterations", iter);
533            break;
534        }
535
536        // Solve A*delta_x = residual
537        let delta_x = &a.solve(&residual).map_err(|e| {
538            SklearsError::NumericalError(format!("Refinement iteration {} failed: {}", iter, e))
539        })?;
540
541        // Update solution: x = x + delta_x
542        x += delta_x;
543
544        log::debug!(
545            "Iterative refinement iteration {}: residual norm = {:.2e}",
546            iter,
547            residual_norm
548        );
549    }
550
551    Ok(x)
552}
553
554/// Enhanced ridge regression with iterative refinement for ill-conditioned problems
555///
556/// Uses iterative refinement when the condition number is high to improve numerical accuracy.
557pub fn enhanced_ridge_regression(
558    x: &Array2<Float>,
559    y: &Array1<Float>,
560    alpha: Float,
561    fit_intercept: bool,
562    max_iter_refinement: Option<usize>,
563    tol_refinement: Option<Float>,
564) -> Result<(Array1<Float>, Float)> {
565    let n_samples = x.nrows();
566    let n_features = x.ncols();
567
568    if n_samples != y.len() {
569        return Err(SklearsError::InvalidInput(
570            "X and y have inconsistent numbers of samples".to_string(),
571        ));
572    }
573
574    if alpha < 0.0 {
575        return Err(SklearsError::InvalidInput(
576            "alpha must be non-negative".to_string(),
577        ));
578    }
579
580    // Center data if fitting intercept
581    let (x_centered, y_centered, x_mean, y_mean) = if fit_intercept {
582        let x_mean = x.mean_axis(Axis(0)).unwrap();
583        let y_mean = y.mean().unwrap();
584        let x_centered = x - &x_mean;
585        let y_centered = y - y_mean;
586        (x_centered, y_centered, x_mean, y_mean)
587    } else {
588        (x.clone(), y.clone(), Array1::zeros(n_features), 0.0)
589    };
590
591    // Form regularized normal equations: (X.T @ X + alpha * I) @ coef = X.T @ y
592    let mut gram = x_centered.t().dot(&x_centered);
593
594    // Add regularization to diagonal
595    for i in 0..n_features {
596        gram[[i, i]] += alpha * n_samples as Float;
597    }
598
599    let xy = x_centered.t().dot(&y_centered);
600
601    // Check condition number and decide whether to use iterative refinement
602    let cond = condition_number(&gram)?;
603
604    let coef = if cond > 1e10 {
605        log::warn!("Ill-conditioned matrix detected (condition number: {:.2e}), using iterative refinement", cond);
606        let max_iter = max_iter_refinement.unwrap_or(10);
607        let tol = tol_refinement.unwrap_or(1e-12);
608        solve_with_iterative_refinement(&gram, &xy, max_iter, tol)?
609    } else {
610        // Standard solve for well-conditioned matrices
611        gram.solve(&xy)
612            .map_err(|e| SklearsError::NumericalError(format!("Linear solve failed: {}", e)))?
613    };
614
615    // Compute intercept
616    let intercept = if fit_intercept {
617        y_mean - x_mean.dot(&coef)
618    } else {
619        0.0
620    };
621
622    Ok((coef, intercept))
623}
624
625/// SVD-based ridge regression solver for maximum numerical stability
626///
627/// Solves the ridge regression problem min ||Ax - b||^2 + alpha * ||x||^2
628/// using Singular Value Decomposition, which is the most numerically stable
629/// approach for ill-conditioned problems.
630///
631/// # Arguments
632/// * `a` - Design matrix of shape (n_samples, n_features)
633/// * `b` - Target values of shape (n_samples,)
634/// * `alpha` - Regularization parameter
635///
636/// # Returns
637/// * Solution vector x of shape (n_features,)
638pub fn svd_ridge_regression(
639    a: &Array2<Float>,
640    b: &Array1<Float>,
641    alpha: Float,
642) -> Result<Array1<Float>> {
643    let n_samples = a.nrows();
644    let _n_features = a.ncols();
645
646    if n_samples != b.len() {
647        return Err(SklearsError::InvalidInput(
648            "Matrix dimensions do not match".to_string(),
649        ));
650    }
651
652    if alpha < 0.0 {
653        return Err(SklearsError::InvalidInput(
654            "Regularization parameter must be non-negative".to_string(),
655        ));
656    }
657
658    // Use SVD via scirs2-linalg: A = U S V^T
659    let (u, s, vt) = svd(&a.view(), true)
660        .map_err(|e| SklearsError::NumericalError(format!("SVD failed: {}", e)))?;
661
662    // Compute regularized solution: x = V * (S^2 + alpha*I)^(-1) * S * U^T * b
663    let ut_b = u.t().dot(b);
664
665    // Apply regularized inverse of singular values
666    let mut regularized_s_inv = Array1::zeros(s.len());
667    for (i, &si) in s.iter().enumerate() {
668        if i < ut_b.len() {
669            regularized_s_inv[i] = si / (si * si + alpha);
670        }
671    }
672
673    // Compute V * (regularized S inverse) * U^T * b
674    let mut temp = Array1::zeros(vt.nrows());
675    for i in 0..temp.len().min(regularized_s_inv.len()).min(ut_b.len()) {
676        temp[i] = regularized_s_inv[i] * ut_b[i];
677    }
678
679    let x = vt.t().dot(&temp);
680
681    Ok(x)
682}
683
684/// Numerically stable solution using regularized QR decomposition
685///
686/// Solves the regularized least squares problem using QR decomposition with
687/// regularization, avoiding the formation of normal equations.
688///
689/// # Arguments
690/// * `a` - Design matrix of shape (n_samples, n_features)
691/// * `b` - Target values of shape (n_samples,)
692/// * `alpha` - Regularization parameter
693///
694/// # Returns
695/// * Solution vector x of shape (n_features,)
696pub fn qr_ridge_regression(
697    a: &Array2<Float>,
698    b: &Array1<Float>,
699    alpha: Float,
700) -> Result<Array1<Float>> {
701    let n_samples = a.nrows();
702    let n_features = a.ncols();
703
704    if n_samples != b.len() {
705        return Err(SklearsError::InvalidInput(
706            "Matrix dimensions do not match".to_string(),
707        ));
708    }
709
710    if alpha < 0.0 {
711        return Err(SklearsError::InvalidInput(
712            "Regularization parameter must be non-negative".to_string(),
713        ));
714    }
715
716    // For ridge regression, we solve the augmented system:
717    // [A         ] [x] = [b]
718    // [sqrt(α)*I ]     [0]
719    //
720    // This avoids forming A^T A and is more numerically stable
721
722    let sqrt_alpha = alpha.sqrt();
723    let augmented_rows = n_samples + n_features;
724
725    // Create augmented matrix
726    let mut augmented_a = Array2::zeros((augmented_rows, n_features));
727    let mut augmented_b = Array1::zeros(augmented_rows);
728
729    // Copy original A and b
730    augmented_a
731        .slice_mut(scirs2_core::ndarray::s![0..n_samples, ..])
732        .assign(a);
733    augmented_b
734        .slice_mut(scirs2_core::ndarray::s![0..n_samples])
735        .assign(b);
736
737    // Add regularization block: sqrt(alpha) * I
738    for i in 0..n_features {
739        augmented_a[[n_samples + i, i]] = sqrt_alpha;
740    }
741    // augmented_b for regularization block is already zero
742
743    // Solve using QR decomposition
744    stable_normal_equations(&augmented_a, &augmented_b, None)
745}
746
747/// Improved condition number calculation using SVD
748///
749/// Computes the condition number as the ratio of largest to smallest singular value.
750/// This is more accurate than diagonal-based heuristics.
751pub fn accurate_condition_number(a: &Array2<Float>) -> Result<Float> {
752    let min_dim = a.nrows().min(a.ncols());
753    if min_dim == 0 {
754        return Ok(1.0);
755    }
756
757    // Compute SVD to get singular values using scirs2-linalg
758    let (_, s, _) = svd(&a.view(), false)
759        .map_err(|e| SklearsError::NumericalError(format!("SVD failed: {}", e)))?;
760
761    if s.is_empty() {
762        return Ok(Float::INFINITY);
763    }
764
765    let s_max = s[0]; // Singular values are sorted in descending order
766    let s_min = s[s.len() - 1];
767
768    if s_min <= Float::EPSILON {
769        Ok(Float::INFINITY)
770    } else {
771        Ok(s_max / s_min)
772    }
773}
774
775/// Rank-revealing QR decomposition with pivoting
776///
777/// Performs QR decomposition with column pivoting to handle rank-deficient matrices.
778/// Returns the rank and a permutation vector indicating column reordering.
779///
780/// # Arguments
781/// * `a` - Input matrix
782/// * `rcond` - Relative condition number threshold for rank determination
783///
784/// # Returns
785/// * (Q, R, permutation vector, rank)
786pub fn rank_revealing_qr(a: &Array2<Float>, rcond: Option<Float>) -> Result<RankRevealingQrResult> {
787    let n_samples = a.nrows();
788    let n_features = a.ncols();
789    let rcond = rcond.unwrap_or(Float::EPSILON * n_samples.max(n_features) as Float);
790
791    // For now, use regular QR and estimate rank from R diagonal
792    let (q, r) = qr(&a.view())
793        .map_err(|e| SklearsError::NumericalError(format!("QR decomposition failed: {}", e)))?;
794
795    // Estimate rank from R diagonal elements
796    let min_dim = n_samples.min(n_features);
797    let mut rank = 0;
798    let max_diag = (0..min_dim)
799        .map(|i| r[[i, i]].abs())
800        .fold(0.0f64, |a, b| a.max(b));
801
802    for i in 0..min_dim {
803        if r[[i, i]].abs() > rcond * max_diag {
804            rank += 1;
805        } else {
806            break;
807        }
808    }
809
810    // Return identity permutation for now (true pivoting would require more complex implementation)
811    let permutation: Vec<usize> = (0..n_features).collect();
812
813    Ok((q, r, permutation, rank))
814}
815
816/// Numerically stable least squares solver with automatic method selection
817///
818/// Automatically selects the most appropriate numerical method based on
819/// matrix properties (condition number, rank, regularization).
820///
821/// # Arguments
822/// * `a` - Design matrix
823/// * `b` - Target vector
824/// * `alpha` - Regularization parameter (0 for ordinary least squares)
825/// * `rcond` - Relative condition number threshold
826///
827/// # Returns
828/// * Solution vector and solver information
829pub fn adaptive_least_squares(
830    a: &Array2<Float>,
831    b: &Array1<Float>,
832    alpha: Float,
833    rcond: Option<Float>,
834) -> Result<(Array1<Float>, SolverInfo)> {
835    let n_samples = a.nrows();
836    let n_features = a.ncols();
837
838    if n_samples != b.len() {
839        return Err(SklearsError::InvalidInput(
840            "Matrix dimensions do not match".to_string(),
841        ));
842    }
843
844    let rcond = rcond.unwrap_or(Float::EPSILON * n_samples.max(n_features) as Float);
845
846    // Estimate condition number (use fast diagonal-based method first)
847    let cond_estimate = condition_number(a)?;
848
849    let (solution, method_used) = if alpha > 0.0 {
850        // Regularized problem
851        if cond_estimate > 1e12 || n_samples < n_features {
852            // Use SVD for extreme ill-conditioning or underdetermined systems
853            let solution = svd_ridge_regression(a, b, alpha)?;
854            (solution, "SVD-Ridge".to_string())
855        } else if cond_estimate > 1e8 {
856            // Use QR for moderate ill-conditioning
857            let solution = qr_ridge_regression(a, b, alpha)?;
858            (solution, "QR-Ridge".to_string())
859        } else {
860            // Use Cholesky for well-conditioned problems
861            let solution = stable_ridge_regression(a, b, alpha, false)?;
862            (solution, "Cholesky-Ridge".to_string())
863        }
864    } else {
865        // Ordinary least squares
866        if n_samples < n_features {
867            return Err(SklearsError::InvalidInput(
868                "Underdetermined system requires regularization (alpha > 0)".to_string(),
869            ));
870        }
871
872        if cond_estimate > 1e12 {
873            // Use rank-revealing QR for potential rank deficiency
874            let (_q, _r, _perm, rank) = rank_revealing_qr(a, Some(rcond))?;
875            if rank < n_features {
876                return Err(SklearsError::NumericalError(format!(
877                    "Matrix is rank deficient: rank {} < {} features. Consider regularization.",
878                    rank, n_features
879                )));
880            }
881            let solution = stable_normal_equations(a, b, Some(rcond))?;
882            (solution, "QR-Rank-Revealing".to_string())
883        } else if cond_estimate > 1e8 {
884            // Use standard QR for moderate ill-conditioning
885            let solution = stable_normal_equations(a, b, Some(rcond))?;
886            (solution, "QR-Standard".to_string())
887        } else {
888            // Use Cholesky for well-conditioned problems
889            let solution = solve_least_squares(a, b)?;
890            (solution, "Cholesky-OLS".to_string())
891        }
892    };
893
894    let info = SolverInfo {
895        method_used,
896        condition_number: cond_estimate,
897        n_iterations: 1,
898        converged: true,
899        residual_norm: compute_residual_norm(a, b, &solution),
900    };
901
902    Ok((solution, info))
903}
904
905/// Information about the numerical solver used
906#[derive(Debug, Clone)]
907pub struct SolverInfo {
908    /// Method used for solving
909    pub method_used: String,
910    /// Estimated condition number
911    pub condition_number: Float,
912    /// Number of iterations (for iterative methods)
913    pub n_iterations: usize,
914    /// Whether the method converged
915    pub converged: bool,
916    /// Final residual norm ||Ax - b||
917    pub residual_norm: Float,
918}
919
920/// Compute residual norm ||Ax - b||
921fn compute_residual_norm(a: &Array2<Float>, b: &Array1<Float>, x: &Array1<Float>) -> Float {
922    let residual = b - &a.dot(x);
923    residual.dot(&residual).sqrt()
924}
925
926/// Numerical stability diagnostics for linear regression problems
927///
928/// Analyzes the numerical properties of a linear regression problem and
929/// provides recommendations for numerical stability.
930pub fn diagnose_numerical_stability(
931    a: &Array2<Float>,
932    b: &Array1<Float>,
933    alpha: Float,
934) -> Result<NumericalDiagnostics> {
935    let n_samples = a.nrows();
936    let n_features = a.ncols();
937
938    if n_samples != b.len() {
939        return Err(SklearsError::InvalidInput(
940            "Matrix dimensions do not match".to_string(),
941        ));
942    }
943
944    // Compute various numerical properties
945    let cond_estimate = condition_number(a)?;
946    let accurate_cond = if cond_estimate > 1e6 {
947        Some(accurate_condition_number(a)?)
948    } else {
949        None
950    };
951
952    // Check for rank deficiency
953    let (_q, _r, _perm, rank) = rank_revealing_qr(a, None)?;
954
955    // Analyze feature scaling
956    let feature_scales: Vec<Float> = (0..n_features)
957        .map(|j| {
958            let col = a.column(j);
959            col.dot(&col).sqrt() / (n_samples as Float).sqrt()
960        })
961        .collect();
962
963    let scale_ratio = if !feature_scales.is_empty() {
964        let max_scale = feature_scales.iter().fold(0.0_f64, |a, &b| a.max(b));
965        let min_scale = feature_scales
966            .iter()
967            .fold(Float::INFINITY, |a, &b| a.min(b));
968        if min_scale > Float::EPSILON {
969            max_scale / min_scale
970        } else {
971            Float::INFINITY
972        }
973    } else {
974        1.0
975    };
976
977    // Generate recommendations
978    let mut recommendations = Vec::new();
979
980    if accurate_cond.unwrap_or(cond_estimate) > 1e12 {
981        recommendations.push(
982            "Matrix is severely ill-conditioned. Consider using SVD-based solver.".to_string(),
983        );
984    } else if accurate_cond.unwrap_or(cond_estimate) > 1e8 {
985        recommendations
986            .push("Matrix is moderately ill-conditioned. Consider QR decomposition.".to_string());
987    }
988
989    if rank < n_features {
990        recommendations.push(format!(
991            "Matrix is rank deficient (rank {} < {} features). Use regularization or feature selection.",
992            rank, n_features
993        ));
994    }
995
996    if scale_ratio > 1e6 {
997        recommendations.push(
998            "Features have very different scales. Consider feature scaling/normalization."
999                .to_string(),
1000        );
1001    }
1002
1003    if n_samples < n_features && alpha == 0.0 {
1004        recommendations.push(
1005            "Underdetermined system. Use regularization (Ridge, Lasso, ElasticNet).".to_string(),
1006        );
1007    }
1008
1009    if alpha > 0.0 && accurate_cond.unwrap_or(cond_estimate) > 1e10 {
1010        recommendations.push(
1011            "Even with regularization, consider increasing alpha for better numerical stability."
1012                .to_string(),
1013        );
1014    }
1015
1016    if recommendations.is_empty() {
1017        recommendations
1018            .push("Numerical properties look good. Standard solvers should work well.".to_string());
1019    }
1020
1021    Ok(NumericalDiagnostics {
1022        condition_number: cond_estimate,
1023        accurate_condition_number: accurate_cond,
1024        rank,
1025        n_samples,
1026        n_features,
1027        scale_ratio,
1028        alpha,
1029        recommendations,
1030    })
1031}
1032
1033/// Numerical diagnostics for a linear regression problem
1034#[derive(Debug, Clone)]
1035pub struct NumericalDiagnostics {
1036    /// Estimated condition number (fast calculation)
1037    pub condition_number: Float,
1038    /// Accurate condition number (SVD-based, if computed)
1039    pub accurate_condition_number: Option<Float>,
1040    /// Matrix rank
1041    pub rank: usize,
1042    /// Number of samples
1043    pub n_samples: usize,
1044    /// Number of features
1045    pub n_features: usize,
1046    /// Ratio of largest to smallest feature scale
1047    pub scale_ratio: Float,
1048    /// Regularization parameter
1049    pub alpha: Float,
1050    /// Recommendations for numerical stability
1051    pub recommendations: Vec<String>,
1052}
1053
1054#[allow(non_snake_case)]
1055#[cfg(test)]
1056mod tests {
1057    use super::*;
1058    use scirs2_core::ndarray::array;
1059
1060    #[test]
1061    fn test_orthogonal_mp_basic() {
1062        let x = array![[1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [2.0, 1.0],];
1063        let y = array![1.0, 1.0, 2.0, 3.0];
1064
1065        let coef = orthogonal_mp(&x, &y, Some(2), None, false).unwrap();
1066        assert_eq!(coef.len(), 2);
1067
1068        // The algorithm should produce some coefficients, but the exact values may vary
1069        // So we just check that the result is valid
1070        assert!(coef.iter().all(|&c| c.is_finite()));
1071    }
1072
1073    #[test]
1074    fn test_orthogonal_mp_gram() {
1075        let gram = array![[2.0, 1.0], [1.0, 2.0],];
1076        let xy = array![3.0, 3.0];
1077
1078        let coef = orthogonal_mp_gram(&gram, &xy, Some(2), None, None).unwrap();
1079        assert_eq!(coef.len(), 2);
1080    }
1081
1082    #[test]
1083    fn test_ridge_regression_basic() {
1084        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0],];
1085        let y = array![1.0, 2.0, 3.0, 4.0];
1086
1087        let (coef, intercept) = ridge_regression(&x, &y, 0.1, true, "auto").unwrap();
1088        assert_eq!(coef.len(), 2);
1089
1090        // With regularization, coefficients should be finite
1091        assert!(coef.iter().all(|&c| c.is_finite()));
1092        assert!(intercept.is_finite());
1093    }
1094
1095    #[test]
1096    fn test_ridge_regression_no_intercept() {
1097        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0],];
1098        let y = array![1.0, 2.0, 3.0];
1099
1100        let (coef, intercept) = ridge_regression(&x, &y, 0.1, false, "cholesky").unwrap();
1101        assert_eq!(coef.len(), 2);
1102        assert_eq!(intercept, 0.0);
1103    }
1104
1105    #[test]
1106    fn test_invalid_alpha() {
1107        let x = array![[1.0]];
1108        let y = array![1.0];
1109
1110        let result = ridge_regression(&x, &y, -0.1, true, "auto");
1111        assert!(result.is_err());
1112    }
1113
1114    #[test]
1115    fn test_stable_normal_equations() {
1116        // Test simple least squares problem
1117        let a = array![[1.0, 1.0], [1.0, 2.0], [1.0, 3.0], [1.0, 4.0]];
1118        let b = array![2.0, 3.0, 4.0, 5.0]; // Perfect linear relationship: y = 1 + x
1119
1120        let x = stable_normal_equations(&a, &b, None).unwrap();
1121
1122        // Should get approximately [1.0, 1.0] (intercept=1, slope=1)
1123        assert!((x[0] - 1.0).abs() < 1e-10);
1124        assert!((x[1] - 1.0).abs() < 1e-10);
1125    }
1126
1127    #[test]
1128    fn test_stable_ridge_regression() {
1129        // Test ridge regression
1130        let a = array![[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
1131        let b = array![1.0, 1.0, 2.0];
1132        let alpha = 0.1;
1133
1134        let x = stable_ridge_regression(&a, &b, alpha, false).unwrap();
1135
1136        // Should get a reasonable solution
1137        assert!(x.iter().all(|&xi| xi.is_finite()));
1138        assert_eq!(x.len(), 2);
1139    }
1140
1141    #[test]
1142    fn test_condition_number() {
1143        // Test condition number calculation
1144        let a = array![[1.0, 0.0], [0.0, 1.0]]; // Identity matrix, condition number = 1
1145        let cond = condition_number(&a).unwrap();
1146        assert!((cond - 1.0).abs() < 1e-10);
1147
1148        // Test ill-conditioned matrix
1149        let a_ill = array![[1.0, 1.0], [1.0, 1.000001]]; // Nearly singular
1150        let cond_ill = condition_number(&a_ill).unwrap();
1151        assert!(cond_ill > 1e5); // Should be large condition number
1152    }
1153
1154    #[test]
1155    fn test_stable_equations_rank_deficient() {
1156        // Test rank deficient matrix
1157        let a = array![[1.0, 2.0], [2.0, 4.0]]; // Rank 1 matrix
1158        let b = array![1.0, 2.0];
1159
1160        let result = stable_normal_equations(&a, &b, None);
1161        assert!(result.is_err()); // Should fail for rank deficient matrix
1162    }
1163}