Skip to main content

linreg_core/
linalg.rs

1//! Minimal Linear Algebra module to replace nalgebra dependency.
2//!
3//! Implements matrix operations, QR decomposition, and solvers needed for OLS.
4//! Uses row-major storage for compatibility with statistical computing conventions.
5
6#![allow(clippy::needless_range_loop)]
7//!
8//! # Numerical Stability Considerations
9//!
10//! This implementation uses Householder QR decomposition with careful attention to
11//! numerical stability:
12//!
13//! - **Sign convention**: Uses the numerically stable Householder sign choice
14//!   (v = x + sgn(x₀)||x||e₁) to avoid cancellation
15//! - **Tolerance checking**: Uses predefined tolerances to detect near-singular matrices
16//! - **Zero-skipping**: Skips transformations when columns are already zero-aligned
17//!
18//! # Scaling Recommendations
19//!
20//! For optimal numerical stability when predictor variables have vastly different
21//! scales (e.g., one variable in millions, another in thousandths), consider
22//! standardizing predictors before regression. Z-score standardization
23//! (`x_scaled = (x - mean) / std`) is already done in VIF calculation.
24//!
25//! However, the current implementation handles typical OLS cases without explicit
26//! scaling, as QR decomposition is generally stable for well-conditioned matrices.
27
28// ============================================================================
29// Numerical Constants
30// ============================================================================
31
32/// Machine epsilon threshold for detecting zero values in QR decomposition.
33/// Values below this are treated as zero to avoid numerical instability.
34const QR_ZERO_TOLERANCE: f64 = 1e-12;
35
36/// Threshold for detecting singular matrices during inversion.
37/// Diagonal elements below this value indicate a near-singular matrix.
38const SINGULAR_TOLERANCE: f64 = 1e-10;
39
40/// A dense matrix stored in row-major order.
41///
42/// # Storage
43///
44/// Elements are stored in a single flat vector in row-major order:
45/// `data[row * cols + col]`
46///
47/// # Example
48///
49/// ```
50/// # use linreg_core::linalg::Matrix;
51/// let m = Matrix::new(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
52/// assert_eq!(m.rows, 2);
53/// assert_eq!(m.cols, 3);
54/// assert_eq!(m.get(0, 0), 1.0);
55/// assert_eq!(m.get(1, 2), 6.0);
56/// ```
57#[derive(Clone, Debug)]
58pub struct Matrix {
59    /// Number of rows in the matrix
60    pub rows: usize,
61    /// Number of columns in the matrix
62    pub cols: usize,
63    /// Flat vector storing matrix elements in row-major order
64    pub data: Vec<f64>,
65}
66
67impl Matrix {
68    /// Creates a new matrix from the given dimensions and data.
69    ///
70    /// # Panics
71    ///
72    /// Panics if `data.len() != rows * cols`.
73    ///
74    /// # Arguments
75    ///
76    /// * `rows` - Number of rows
77    /// * `cols` - Number of columns
78    /// * `data` - Flat vector of elements in row-major order
79    ///
80    /// # Example
81    ///
82    /// ```
83    /// # use linreg_core::linalg::Matrix;
84    /// let m = Matrix::new(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
85    /// assert_eq!(m.get(0, 0), 1.0);
86    /// assert_eq!(m.get(0, 1), 2.0);
87    /// assert_eq!(m.get(1, 0), 3.0);
88    /// assert_eq!(m.get(1, 1), 4.0);
89    /// ```
90    pub fn new(rows: usize, cols: usize, data: Vec<f64>) -> Self {
91        assert_eq!(data.len(), rows * cols, "Data length must match dimensions");
92        Matrix { rows, cols, data }
93    }
94
95    /// Creates a matrix filled with zeros.
96    ///
97    /// # Arguments
98    ///
99    /// * `rows` - Number of rows
100    /// * `cols` - Number of columns
101    ///
102    /// # Example
103    ///
104    /// ```
105    /// # use linreg_core::linalg::Matrix;
106    /// let m = Matrix::zeros(3, 2);
107    /// assert_eq!(m.rows, 3);
108    /// assert_eq!(m.cols, 2);
109    /// assert_eq!(m.get(1, 1), 0.0);
110    /// ```
111    pub fn zeros(rows: usize, cols: usize) -> Self {
112        Matrix {
113            rows,
114            cols,
115            data: vec![0.0; rows * cols],
116        }
117    }
118
119    // NOTE: Currently unused but kept as reference implementation.
120    // Uncomment if needed for convenience constructor.
121    /*
122    /// Creates a matrix from a row-major slice.
123    ///
124    /// # Arguments
125    ///
126    /// * `rows` - Number of rows
127    /// * `cols` - Number of columns
128    /// * `slice` - Slice containing matrix elements in row-major order
129    pub fn from_row_slice(rows: usize, cols: usize, slice: &[f64]) -> Self {
130        Matrix::new(rows, cols, slice.to_vec())
131    }
132    */
133
134    /// Gets the element at the specified row and column.
135    ///
136    /// # Arguments
137    ///
138    /// * `row` - Row index (0-based)
139    /// * `col` - Column index (0-based)
140    pub fn get(&self, row: usize, col: usize) -> f64 {
141        self.data[row * self.cols + col]
142    }
143
144    /// Sets the element at the specified row and column.
145    ///
146    /// # Arguments
147    ///
148    /// * `row` - Row index (0-based)
149    /// * `col` - Column index (0-based)
150    /// * `val` - Value to set
151    pub fn set(&mut self, row: usize, col: usize, val: f64) {
152        self.data[row * self.cols + col] = val;
153    }
154
155    /// Returns the transpose of this matrix.
156    ///
157    /// Swaps rows with columns: `result[col][row] = self[row][col]`.
158    ///
159    /// # Example
160    ///
161    /// ```
162    /// # use linreg_core::linalg::Matrix;
163    /// let m = Matrix::new(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
164    /// let t = m.transpose();
165    /// assert_eq!(t.rows, 3);
166    /// assert_eq!(t.cols, 2);
167    /// assert_eq!(t.get(0, 1), 4.0);
168    /// ```
169    pub fn transpose(&self) -> Matrix {
170        let mut t_data = vec![0.0; self.rows * self.cols];
171        for r in 0..self.rows {
172            for c in 0..self.cols {
173                t_data[c * self.rows + r] = self.get(r, c);
174            }
175        }
176        Matrix::new(self.cols, self.rows, t_data)
177    }
178
179    /// Performs matrix multiplication: `self * other`.
180    ///
181    /// # Panics
182    ///
183    /// Panics if `self.cols != other.rows`.
184    ///
185    /// # Example
186    ///
187    /// ```
188    /// # use linreg_core::linalg::Matrix;
189    /// let a = Matrix::new(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
190    /// let b = Matrix::new(3, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
191    /// let c = a.matmul(&b);
192    /// assert_eq!(c.rows, 2);
193    /// assert_eq!(c.cols, 2);
194    /// assert_eq!(c.get(0, 0), 22.0); // 1*1 + 2*3 + 3*5
195    /// ```
196    pub fn matmul(&self, other: &Matrix) -> Matrix {
197        assert_eq!(
198            self.cols, other.rows,
199            "Dimension mismatch for multiplication"
200        );
201        let mut result = Matrix::zeros(self.rows, other.cols);
202
203        for r in 0..self.rows {
204            for c in 0..other.cols {
205                let mut sum = 0.0;
206                for k in 0..self.cols {
207                    sum += self.get(r, k) * other.get(k, c);
208                }
209                result.set(r, c, sum);
210            }
211        }
212        result
213    }
214
215    /// Multiplies this matrix by a vector (treating vector as column matrix).
216    ///
217    /// Computes `self * vec` where vec is treated as an n×1 column matrix.
218    ///
219    /// # Panics
220    ///
221    /// Panics if `self.cols != vec.len()`.
222    ///
223    /// # Arguments
224    ///
225    /// * `vec` - Vector to multiply by
226    ///
227    /// # Example
228    ///
229    /// ```
230    /// # use linreg_core::linalg::Matrix;
231    /// let m = Matrix::new(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
232    /// let v = vec![1.0, 2.0, 3.0];
233    /// let result = m.mul_vec(&v);
234    /// assert_eq!(result.len(), 2);
235    /// assert_eq!(result[0], 14.0); // 1*1 + 2*2 + 3*3
236    /// ```
237    #[allow(clippy::needless_range_loop)]
238    pub fn mul_vec(&self, vec: &[f64]) -> Vec<f64> {
239        assert_eq!(
240            self.cols,
241            vec.len(),
242            "Dimension mismatch for matrix-vector multiplication"
243        );
244        let mut result = vec![0.0; self.rows];
245
246        for r in 0..self.rows {
247            let mut sum = 0.0;
248            for c in 0..self.cols {
249                sum += self.get(r, c) * vec[c];
250            }
251            result[r] = sum;
252        }
253        result
254    }
255
256    /// Computes the dot product of a column with a vector: `Σ(data[i * cols + col] * v[i])`.
257    ///
258    /// For a row-major matrix, this iterates through all rows at a fixed column.
259    ///
260    /// # Arguments
261    ///
262    /// * `col` - Column index
263    /// * `v` - Vector to dot with (must have length equal to rows)
264    ///
265    /// # Panics
266    ///
267    /// Panics if `col >= cols` or `v.len() != rows`.
268    #[allow(clippy::needless_range_loop)]
269    pub fn col_dot(&self, col: usize, v: &[f64]) -> f64 {
270        assert!(col < self.cols, "Column index out of bounds");
271        assert_eq!(
272            self.rows,
273            v.len(),
274            "Vector length must match number of rows"
275        );
276
277        let mut sum = 0.0;
278        for row in 0..self.rows {
279            sum += self.get(row, col) * v[row];
280        }
281        sum
282    }
283
284    /// Performs the column-vector operation in place: `v += alpha * column_col`.
285    ///
286    /// This is the AXPY operation where the column is treated as a vector.
287    /// For row-major storage, we iterate through rows at a fixed column.
288    ///
289    /// # Arguments
290    ///
291    /// * `col` - Column index
292    /// * `alpha` - Scaling factor for the column
293    /// * `v` - Vector to modify in place (must have length equal to rows)
294    ///
295    /// # Panics
296    ///
297    /// Panics if `col >= cols` or `v.len() != rows`.
298    #[allow(clippy::needless_range_loop)]
299    pub fn col_axpy_inplace(&self, col: usize, alpha: f64, v: &mut [f64]) {
300        assert!(col < self.cols, "Column index out of bounds");
301        assert_eq!(
302            self.rows,
303            v.len(),
304            "Vector length must match number of rows"
305        );
306
307        for row in 0..self.rows {
308            v[row] += alpha * self.get(row, col);
309        }
310    }
311
312    /// Computes the squared L2 norm of a column: `Σ(data[i * cols + col]²)`.
313    ///
314    /// # Arguments
315    ///
316    /// * `col` - Column index
317    ///
318    /// # Panics
319    ///
320    /// Panics if `col >= cols`.
321    #[allow(clippy::needless_range_loop)]
322    pub fn col_norm2(&self, col: usize) -> f64 {
323        assert!(col < self.cols, "Column index out of bounds");
324
325        let mut sum = 0.0;
326        for row in 0..self.rows {
327            let val = self.get(row, col);
328            sum += val * val;
329        }
330        sum
331    }
332
333    /// Adds a value to diagonal elements starting from a given index.
334    ///
335    /// This is useful for ridge regression where we add `lambda * I` to `X^T X`,
336    /// but the intercept column should not be penalized.
337    ///
338    /// # Arguments
339    ///
340    /// * `alpha` - Value to add to diagonal elements
341    /// * `start_index` - Starting diagonal index (0 = first diagonal element)
342    ///
343    /// # Panics
344    ///
345    /// Panics if the matrix is not square.
346    ///
347    /// # Example
348    ///
349    /// For a 3×3 identity matrix with intercept in first column (unpenalized):
350    /// ```text
351    /// add_diagonal_in_place(lambda, 1) on:
352    /// [1.0, 0.0, 0.0]       [1.0,   0.0,   0.0  ]
353    /// [0.0, 1.0, 0.0]  ->   [0.0,  1.0+λ, 0.0  ]
354    /// [0.0, 0.0, 1.0]       [0.0,   0.0,  1.0+λ]
355    /// ```
356    pub fn add_diagonal_in_place(&mut self, alpha: f64, start_index: usize) {
357        assert_eq!(self.rows, self.cols, "Matrix must be square");
358        let n = self.rows;
359        for i in start_index..n {
360            let current = self.get(i, i);
361            self.set(i, i, current + alpha);
362        }
363    }
364}
365
366// ============================================================================
367// QR Decomposition
368// ============================================================================
369
370impl Matrix {
371    /// Computes the QR decomposition using Householder reflections.
372    ///
373    /// Factorizes the matrix as `A = QR` where Q is orthogonal and R is upper triangular.
374    ///
375    /// # Requirements
376    ///
377    /// This implementation requires `rows >= cols` (tall matrix). For OLS regression,
378    /// we always have more observations than predictors, so this requirement is satisfied.
379    ///
380    /// # Returns
381    ///
382    /// A tuple `(Q, R)` where:
383    /// - `Q` is an orthogonal matrix (QᵀQ = I) of size m×m
384    /// - `R` is an upper triangular matrix of size m×n
385    #[allow(clippy::needless_range_loop)]
386    pub fn qr(&self) -> (Matrix, Matrix) {
387        let m = self.rows;
388        let n = self.cols;
389        let mut q = Matrix::identity(m);
390        let mut r = self.clone();
391
392        for k in 0..n.min(m - 1) {
393            // Create vector x = R[k:, k]
394            let mut x = vec![0.0; m - k];
395            for i in k..m {
396                x[i - k] = r.get(i, k);
397            }
398
399            // Norm of x
400            let norm_x: f64 = x.iter().map(|&v| v * v).sum::<f64>().sqrt();
401            if norm_x < QR_ZERO_TOLERANCE {
402                continue;
403            } // Already zero
404
405            // Create vector v = x + sign(x[0]) * ||x|| * e1
406            //
407            // NOTE: Numerical stability consideration (Householder sign choice)
408            // According to Overton & Yu (2023), the numerically stable choice is
409            // σ = -sgn(x₁) in the formula v = x - σ‖x‖e₁.
410            //
411            // This means: v = x - (-sgn(x₁))‖x‖e₁ = x + sgn(x₁)‖x‖e₁
412            //
413            // Equivalently: u₁ = x₁ + sgn(x₁)‖x‖
414            //
415            // Current implementation uses this formula (the "correct" choice for stability):
416            let sign = if x[0] >= 0.0 { 1.0 } else { -1.0 }; // sgn(x₀) as defined (sgn(0) = +1)
417            let u1 = x[0] + sign * norm_x;
418
419            // Normalize v to get Householder vector
420            let mut v = x; // Re-use storage
421            v[0] = u1;
422
423            let norm_v: f64 = v.iter().map(|&val| val * val).sum::<f64>().sqrt();
424            for val in &mut v {
425                *val /= norm_v;
426            }
427
428            // Apply Householder transformation to R: R = H * R = (I - 2vv^T)R = R - 2v(v^T R)
429            // Focus on submatrix R[k:, k:]
430            for j in k..n {
431                let mut dot = 0.0;
432                for i in 0..m - k {
433                    dot += v[i] * r.get(k + i, j);
434                }
435
436                for i in 0..m - k {
437                    let val = r.get(k + i, j) - 2.0 * v[i] * dot;
438                    r.set(k + i, j, val);
439                }
440            }
441
442            // Update Q: Q = Q * H = Q(I - 2vv^T) = Q - 2(Qv)v^T
443            // Focus on Q[:, k:]
444            for i in 0..m {
445                let mut dot = 0.0;
446                for l in 0..m - k {
447                    dot += q.get(i, k + l) * v[l];
448                }
449
450                for l in 0..m - k {
451                    let val = q.get(i, k + l) - 2.0 * dot * v[l];
452                    q.set(i, k + l, val);
453                }
454            }
455        }
456
457        (q, r)
458    }
459
460    /// Creates an identity matrix of the given size.
461    ///
462    /// # Arguments
463    ///
464    /// * `size` - Number of rows and columns (square matrix)
465    ///
466    /// # Example
467    ///
468    /// ```
469    /// # use linreg_core::linalg::Matrix;
470    /// let i = Matrix::identity(3);
471    /// assert_eq!(i.get(0, 0), 1.0);
472    /// assert_eq!(i.get(1, 1), 1.0);
473    /// assert_eq!(i.get(2, 2), 1.0);
474    /// assert_eq!(i.get(0, 1), 0.0);
475    /// ```
476    pub fn identity(size: usize) -> Self {
477        let mut data = vec![0.0; size * size];
478        for i in 0..size {
479            data[i * size + i] = 1.0;
480        }
481        Matrix::new(size, size, data)
482    }
483
484    /// Inverts an upper triangular matrix (such as R from QR decomposition).
485    ///
486    /// Uses back-substitution to compute the inverse. This is efficient for
487    /// triangular matrices compared to general matrix inversion.
488    ///
489    /// # Panics
490    ///
491    /// Panics if the matrix is not square.
492    ///
493    /// # Returns
494    ///
495    /// `None` if the matrix is singular (has a zero or near-zero diagonal element).
496    /// A matrix is considered singular if any diagonal element is below the
497    /// internal tolerance (1e-10), which indicates the matrix does not have full rank.
498    ///
499    /// # Note
500    ///
501    /// For upper triangular matrices, singularity is equivalent to having a
502    /// zero (or near-zero) diagonal element. This is much simpler to check than
503    /// for general matrices, which would require computing the condition number.
504    pub fn invert_upper_triangular(&self) -> Option<Matrix> {
505        let n = self.rows;
506        assert_eq!(n, self.cols, "Matrix must be square");
507
508        // Check for singularity using relative tolerance
509        // This scales with the magnitude of diagonal elements, handling matrices
510        // of different scales better than a fixed absolute tolerance.
511        //
512        // Previous implementation used absolute tolerance:
513        //   if self.get(i, i).abs() < SINGULAR_TOLERANCE { return None; }
514        //
515        // New implementation uses relative tolerance similar to LAPACK:
516        //   tolerance = max_diag * epsilon * n
517        // where epsilon is machine epsilon (~2.2e-16 for f64)
518        let max_diag: f64 = (0..n)
519            .map(|i| self.get(i, i).abs())
520            .fold(0.0_f64, |acc, val| acc.max(val));
521
522        // Use a relative tolerance based on the maximum diagonal element
523        // This is similar to LAPACK's dlamch machine epsilon approach
524        let epsilon = 2.0_f64 * f64::EPSILON; // ~4.4e-16 for f64
525        let relative_tolerance = max_diag * epsilon * n as f64;
526        let tolerance = SINGULAR_TOLERANCE.max(relative_tolerance);
527
528        for i in 0..n {
529            if self.get(i, i).abs() < tolerance {
530                return None; // Singular matrix - cannot invert
531            }
532        }
533
534        let mut inv = Matrix::zeros(n, n);
535
536        for i in 0..n {
537            inv.set(i, i, 1.0 / self.get(i, i));
538
539            for j in (0..i).rev() {
540                let mut sum = 0.0;
541                for k in j + 1..=i {
542                    sum += self.get(j, k) * inv.get(k, i);
543                }
544                inv.set(j, i, -sum / self.get(j, j));
545            }
546        }
547
548        Some(inv)
549    }
550
551    /// Inverts an upper triangular matrix with a custom tolerance multiplier.
552    ///
553    /// The tolerance is computed as `max_diag * epsilon * n * tolerance_mult`.
554    /// A higher tolerance_mult allows more tolerance for near-singular matrices.
555    ///
556    /// # Arguments
557    ///
558    /// * `tolerance_mult` - Multiplier for the tolerance (1.0 = standard, higher = more tolerant)
559    pub fn invert_upper_triangular_with_tolerance(&self, tolerance_mult: f64) -> Option<Matrix> {
560        let n = self.rows;
561        assert_eq!(n, self.cols, "Matrix must be square");
562
563        // Check for singularity using relative tolerance
564        let max_diag: f64 = (0..n)
565            .map(|i| self.get(i, i).abs())
566            .fold(0.0_f64, |acc, val| acc.max(val));
567
568        // Use a relative tolerance based on the maximum diagonal element
569        let epsilon = 2.0_f64 * f64::EPSILON;
570        let relative_tolerance = max_diag * epsilon * n as f64 * tolerance_mult;
571        let tolerance = SINGULAR_TOLERANCE.max(relative_tolerance);
572
573        for i in 0..n {
574            if self.get(i, i).abs() < tolerance {
575                return None;
576            }
577        }
578
579        let mut inv = Matrix::zeros(n, n);
580
581        for i in 0..n {
582            inv.set(i, i, 1.0 / self.get(i, i));
583
584            for j in (0..i).rev() {
585                let mut sum = 0.0;
586                for k in j + 1..=i {
587                    sum += self.get(j, k) * inv.get(k, i);
588                }
589                inv.set(j, i, -sum / self.get(j, j));
590            }
591        }
592
593        Some(inv)
594    }
595
596    /// Computes the inverse of a square matrix using QR decomposition.
597    ///
598    /// For an invertible matrix A, computes A⁻¹ such that A * A⁻¹ = I.
599    /// Uses QR decomposition for numerical stability.
600    ///
601    /// # Panics
602    ///
603    /// Panics if the matrix is not square (i.e., `self.rows != self.cols`).
604    /// Check dimensions before calling if the matrix shape is not guaranteed.
605    ///
606    /// # Returns
607    ///
608    /// Returns `Some(inverse)` if the matrix is invertible, or `None` if
609    /// the matrix is singular (non-invertible).
610    pub fn invert(&self) -> Option<Matrix> {
611        let n = self.rows;
612        if n != self.cols {
613            panic!("Matrix must be square for inversion");
614        }
615
616        // Use QR decomposition: A = Q * R
617        let (q, r) = self.qr();
618
619        // Compute R⁻¹ (upper triangular inverse)
620        let r_inv = r.invert_upper_triangular()?;
621
622        // A⁻¹ = R⁻¹ * Q^T
623        let q_transpose = q.transpose();
624        let mut result = Matrix::zeros(n, n);
625
626        for i in 0..n {
627            for j in 0..n {
628                let mut sum = 0.0;
629                for k in 0..n {
630                    sum += r_inv.get(i, k) * q_transpose.get(k, j);
631                }
632                result.set(i, j, sum);
633            }
634        }
635
636        Some(result)
637    }
638
639    /// Computes the inverse of X'X given the QR decomposition of X (R's chol2inv).
640    ///
641    /// This is equivalent to computing `(X'X)^(-1)` using the QR decomposition of X.
642    /// R's `chol2inv` function is used for numerical stability in recursive residuals.
643    ///
644    /// # Arguments
645    ///
646    /// * `x` - Input matrix (must have rows >= cols)
647    ///
648    /// # Returns
649    ///
650    /// `Some((X'X)^(-1))` if X has full rank, `None` otherwise.
651    ///
652    /// # Algorithm
653    ///
654    /// Given QR decomposition X = QR where R is upper triangular:
655    /// 1. Extract the upper p×p portion of R (denoted R₁)
656    /// 2. Invert R₁ (upper triangular inverse)
657    /// 3. Compute (X'X)^(-1) = R₁^(-1) × R₁^(-T)
658    ///
659    /// This works because X'X = R'Q'QR = R'R, and R₁ contains the Cholesky factor.
660    pub fn chol2inv_from_qr(&self) -> Option<Matrix> {
661        self.chol2inv_from_qr_with_tolerance(1.0)
662    }
663
664    /// Computes the inverse of X'X given the QR decomposition with custom tolerance.
665    ///
666    /// Similar to `chol2inv_from_qr` but allows specifying a tolerance multiplier
667    /// for handling near-singular matrices.
668    ///
669    /// # Arguments
670    ///
671    /// * `tolerance_mult` - Multiplier for the tolerance (higher = more tolerant)
672    pub fn chol2inv_from_qr_with_tolerance(&self, tolerance_mult: f64) -> Option<Matrix> {
673        let p = self.cols;
674
675        // QR decomposition: X = QR
676        // For X (m×n, m≥n), R is m×n upper triangular
677        // The upper n×n block of R contains the meaningful values
678        let (_, r_full) = self.qr();
679
680        // Extract upper p×p portion from R
681        // For tall matrices (m > p), R has zeros below diagonal in first p rows
682        // For square matrices (m = p), R is p×p upper triangular
683        let mut r1 = Matrix::zeros(p, p);
684        for i in 0..p {
685            // Row i of R1 is row i of R_full, columns 0..p
686            // But we only copy the upper triangular part (columns i..p)
687            for j in i..p {
688                r1.set(i, j, r_full.get(i, j));
689            }
690            // Also copy diagonal if not yet copied
691            if i < p {
692                r1.set(i, i, r_full.get(i, i));
693            }
694        }
695
696        // Invert R₁ (upper triangular) with custom tolerance
697        let r1_inv = r1.invert_upper_triangular_with_tolerance(tolerance_mult)?;
698
699        // Compute (X'X)^(-1) = R₁^(-1) × R₁^(-T)
700        let mut result = Matrix::zeros(p, p);
701        for i in 0..p {
702            for j in 0..p {
703                let mut sum = 0.0;
704                // result[i,j] = sum(R1_inv[i,k] * R1_inv[j,k] for k=0..p)
705                // R1_inv is upper triangular, but we iterate full range
706                for k in 0..p {
707                    sum += r1_inv.get(i, k) * r1_inv.get(j, k);
708                }
709                result.set(i, j, sum);
710            }
711        }
712
713        Some(result)
714    }
715}
716
717// ============================================================================
718// Vector Helper Functions
719// ============================================================================
720
721/// Computes the arithmetic mean of a slice of f64 values.
722///
723/// Returns 0.0 for empty slices.
724///
725/// # Examples
726///
727/// ```
728/// use linreg_core::linalg::vec_mean;
729///
730/// assert_eq!(vec_mean(&[1.0, 2.0, 3.0, 4.0, 5.0]), 3.0);
731/// assert_eq!(vec_mean(&[]), 0.0);
732/// ```
733///
734/// # Arguments
735///
736/// * `v` - Slice of values
737pub fn vec_mean(v: &[f64]) -> f64 {
738    if v.is_empty() {
739        return 0.0;
740    }
741    v.iter().sum::<f64>() / v.len() as f64
742}
743
744/// Computes element-wise subtraction of two slices: `a - b`.
745///
746/// # Examples
747///
748/// ```
749/// use linreg_core::linalg::vec_sub;
750///
751/// let a = vec![5.0, 4.0, 3.0];
752/// let b = vec![1.0, 1.0, 1.0];
753/// let result = vec_sub(&a, &b);
754/// assert_eq!(result, vec![4.0, 3.0, 2.0]);
755/// ```
756///
757/// # Arguments
758///
759/// * `a` - Minuend slice
760/// * `b` - Subtrahend slice
761///
762/// # Panics
763///
764/// Panics if slices have different lengths.
765pub fn vec_sub(a: &[f64], b: &[f64]) -> Vec<f64> {
766    assert_eq!(a.len(), b.len(), "vec_sub: slice lengths must match");
767    a.iter().zip(b.iter()).map(|(x, y)| x - y).collect()
768}
769
770/// Computes the dot product of two slices: `Σ(a[i] * b[i])`.
771///
772/// # Examples
773///
774/// ```
775/// use linreg_core::linalg::vec_dot;
776///
777/// let a = vec![1.0, 2.0, 3.0];
778/// let b = vec![4.0, 5.0, 6.0];
779/// assert_eq!(vec_dot(&a, &b), 32.0);  // 1*4 + 2*5 + 3*6
780/// ```
781///
782/// # Arguments
783///
784/// * `a` - First slice
785/// * `b` - Second slice
786///
787/// # Panics
788///
789/// Panics if slices have different lengths.
790pub fn vec_dot(a: &[f64], b: &[f64]) -> f64 {
791    assert_eq!(a.len(), b.len(), "vec_dot: slice lengths must match");
792    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
793}
794
795/// Computes element-wise addition of two slices: `a + b`.
796///
797/// # Examples
798///
799/// ```
800/// use linreg_core::linalg::vec_add;
801///
802/// let a = vec![1.0, 2.0, 3.0];
803/// let b = vec![4.0, 5.0, 6.0];
804/// assert_eq!(vec_add(&a, &b), vec![5.0, 7.0, 9.0]);
805/// ```
806///
807/// # Arguments
808///
809/// * `a` - First slice
810/// * `b` - Second slice
811///
812/// # Panics
813///
814/// Panics if slices have different lengths.
815pub fn vec_add(a: &[f64], b: &[f64]) -> Vec<f64> {
816    assert_eq!(a.len(), b.len(), "vec_add: slice lengths must match");
817    a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
818}
819
820/// Computes a scaled vector addition in place: `dst += alpha * src`.
821///
822/// This is the classic BLAS AXPY operation.
823///
824/// # Arguments
825///
826/// * `dst` - Destination slice (modified in place)
827/// * `alpha` - Scaling factor for src
828/// * `src` - Source slice
829///
830/// # Panics
831///
832/// Panics if slices have different lengths.
833///
834/// # Example
835///
836/// ```
837/// use linreg_core::linalg::vec_axpy_inplace;
838///
839/// let mut dst = vec![1.0, 2.0, 3.0];
840/// let src = vec![0.5, 0.5, 0.5];
841/// vec_axpy_inplace(&mut dst, 2.0, &src);
842/// assert_eq!(dst, vec![2.0, 3.0, 4.0]);  // [1+2*0.5, 2+2*0.5, 3+2*0.5]
843/// ```
844pub fn vec_axpy_inplace(dst: &mut [f64], alpha: f64, src: &[f64]) {
845    assert_eq!(
846        dst.len(),
847        src.len(),
848        "vec_axpy_inplace: slice lengths must match"
849    );
850    for (d, &s) in dst.iter_mut().zip(src.iter()) {
851        *d += alpha * s;
852    }
853}
854
855/// Scales a vector in place: `v *= alpha`.
856///
857/// # Arguments
858///
859/// * `v` - Vector to scale (modified in place)
860/// * `alpha` - Scaling factor
861///
862/// # Example
863///
864/// ```
865/// use linreg_core::linalg::vec_scale_inplace;
866///
867/// let mut v = vec![1.0, 2.0, 3.0];
868/// vec_scale_inplace(&mut v, 2.5);
869/// assert_eq!(v, vec![2.5, 5.0, 7.5]);
870/// ```
871pub fn vec_scale_inplace(v: &mut [f64], alpha: f64) {
872    for val in v.iter_mut() {
873        *val *= alpha;
874    }
875}
876
877/// Returns a scaled copy of a vector: `v * alpha`.
878///
879/// # Examples
880///
881/// ```
882/// use linreg_core::linalg::vec_scale;
883///
884/// let v = vec![1.0, 2.0, 3.0];
885/// let scaled = vec_scale(&v, 2.5);
886/// assert_eq!(scaled, vec![2.5, 5.0, 7.5]);
887/// // Original is unchanged
888/// assert_eq!(v, vec![1.0, 2.0, 3.0]);
889/// ```
890///
891/// # Arguments
892///
893/// * `v` - Vector to scale
894/// * `alpha` - Scaling factor
895pub fn vec_scale(v: &[f64], alpha: f64) -> Vec<f64> {
896    v.iter().map(|&x| x * alpha).collect()
897}
898
899/// Computes the L2 norm (Euclidean norm) of a vector: `sqrt(Σ(v[i]²))`.
900///
901/// # Examples
902///
903/// ```
904/// use linreg_core::linalg::vec_l2_norm;
905///
906/// // Pythagorean triple: 3-4-5
907/// assert_eq!(vec_l2_norm(&[3.0, 4.0]), 5.0);
908/// // Unit vector
909/// assert_eq!(vec_l2_norm(&[1.0, 0.0, 0.0]), 1.0);
910/// ```
911///
912/// # Arguments
913///
914/// * `v` - Vector slice
915pub fn vec_l2_norm(v: &[f64]) -> f64 {
916    v.iter().map(|&x| x * x).sum::<f64>().sqrt()
917}
918
919/// Computes the maximum absolute value in a vector.
920///
921/// # Arguments
922///
923/// * `v` - Vector slice
924///
925/// # Example
926///
927/// ```
928/// use linreg_core::linalg::vec_max_abs;
929///
930/// assert_eq!(vec_max_abs(&[1.0, -5.0, 3.0]), 5.0);
931/// assert_eq!(vec_max_abs(&[-2.5, -1.5]), 2.5);
932/// ```
933pub fn vec_max_abs(v: &[f64]) -> f64 {
934    v.iter().map(|&x| x.abs()).fold(0.0_f64, f64::max)
935}
936
937// ============================================================================
938// R-Compatible QR Decomposition (LINPACK dqrdc2 with Column Pivoting)
939// ============================================================================
940
941/// QR decomposition result using R's LINPACK dqrdc2 algorithm.
942///
943/// This implements the QR decomposition with column pivoting as used by R's
944/// `qr()` function with `LAPACK=FALSE`. The algorithm is a modification of
945/// LINPACK's DQRDC that:
946/// - Uses Householder transformations
947/// - Implements limited column pivoting based on 2-norms of reduced columns
948/// - Moves columns with near-zero norm to the right-hand edge
949/// - Computes the rank (number of linearly independent columns)
950///
951/// # Fields
952///
953/// * `qr` - The QR factorization (upper triangle contains R, below diagonal
954///   contains Householder vector information)
955/// * `qraux` - Auxiliary information for recovering the orthogonal part Q
956/// * `pivot` - Column permutation: `pivot\[j\]` contains the original column index
957///   now in column j
958/// * `rank` - Number of linearly independent columns (the computed rank)
959#[derive(Clone, Debug)]
960pub struct QRLinpack {
961    /// QR factorization matrix (same dimensions as input)
962    pub qr: Matrix,
963    /// Auxiliary information for Q recovery
964    pub qraux: Vec<f64>,
965    /// Column pivot vector (1-based indices like R)
966    pub pivot: Vec<usize>,
967    /// Computed rank (number of linearly independent columns)
968    pub rank: usize,
969}
970
971impl Matrix {
972    /// Computes QR decomposition using R's LINPACK dqrdc2 algorithm with column pivoting.
973    ///
974    /// This is a port of R's dqrdc2.f, which is a modification of LINPACK's DQRDC.
975    /// The algorithm:
976    /// 1. Uses Householder transformations for QR factorization
977    /// 2. Implements limited column pivoting based on column 2-norms
978    /// 3. Moves columns with near-zero norm to the right-hand edge
979    /// 4. Computes the rank (number of linearly independent columns)
980    ///
981    /// # Arguments
982    ///
983    /// * `tol` - Tolerance for determining linear independence. Default is 1e-7 (R's default).
984    ///   Columns with norm < tol * original_norm are considered negligible.
985    ///
986    /// # Returns
987    ///
988    /// A [`QRLinpack`] struct containing the QR factorization, auxiliary information,
989    /// pivot vector, and computed rank.
990    ///
991    /// # Algorithm Details
992    ///
993    /// The decomposition is A * P = Q * R where:
994    /// - P is the permutation matrix coded by `pivot`
995    /// - Q is orthogonal (m × m)
996    /// - R is upper triangular in the first `rank` rows
997    ///
998    /// The `qr` matrix contains:
999    /// - Upper triangle: R matrix (if pivoting was performed, this is R of the permuted matrix)
1000    /// - Below diagonal: Householder vector information
1001    ///
1002    /// # Reference
1003    ///
1004    /// - R source: src/appl/dqrdc2.f
1005    /// - LINPACK documentation: <https://www.netlib.org/linpack/dqrdc.f>
1006    pub fn qr_linpack(&self, tol: Option<f64>) -> QRLinpack {
1007        let n = self.rows;
1008        let p = self.cols;
1009        let lup = n.min(p);
1010
1011        // Default tolerance matches R's qr.default: tol = 1e-07
1012        let tol = tol.unwrap_or(1e-07);
1013
1014        // Initialize working matrices
1015        let mut x = self.clone(); // Working copy that will be modified
1016        let mut qraux = vec![0.0; p];
1017        let mut pivot: Vec<usize> = (1..=p).collect(); // 1-based indices like R
1018        let mut work = vec![(0.0, 0.0); p]; // (work[j,1], work[j,2])
1019
1020        // Compute the norms of the columns of x (initialization)
1021        if n > 0 {
1022            for j in 0..p {
1023                let mut norm = 0.0;
1024                for i in 0..n {
1025                    norm += x.get(i, j) * x.get(i, j);
1026                }
1027                norm = norm.sqrt();
1028                qraux[j] = norm;
1029                let original_norm = if norm == 0.0 { 1.0 } else { norm };
1030                work[j] = (norm, original_norm);
1031            }
1032        }
1033
1034        let mut k = p + 1; // Will be decremented to get the final rank
1035
1036        // Perform the Householder reduction of x
1037        for l in 0..lup {
1038            // Cycle columns from l to p until one with non-negligible norm is found
1039            // A column is negligible if its norm has fallen below tol * original_norm
1040            while l < k - 1 && qraux[l] < work[l].1 * tol {
1041                // Move column l to the end (it's negligible)
1042                let lp1 = l + 1;
1043
1044                // Shift columns in x: x(i, l..p-1) = x(i, l+1..p)
1045                for i in 0..n {
1046                    let t = x.get(i, l);
1047                    for j in lp1..p {
1048                        x.set(i, j - 1, x.get(i, j));
1049                    }
1050                    x.set(i, p - 1, t);
1051                }
1052
1053                // Shift pivot, qraux, and work arrays
1054                let saved_pivot = pivot[l];
1055                let saved_qraux = qraux[l];
1056                let saved_work = work[l];
1057
1058                for j in lp1..p {
1059                    pivot[j - 1] = pivot[j];
1060                    qraux[j - 1] = qraux[j];
1061                    work[j - 1] = work[j];
1062                }
1063
1064                pivot[p - 1] = saved_pivot;
1065                qraux[p - 1] = saved_qraux;
1066                work[p - 1] = saved_work;
1067
1068                k -= 1;
1069            }
1070
1071            if l == n - 1 {
1072                // Last row - skip transformation
1073                break;
1074            }
1075
1076            // Compute the Householder transformation for column l
1077            // nrmxl = norm of x[l:, l]
1078            let mut nrmxl = 0.0;
1079            for i in l..n {
1080                let val = x.get(i, l);
1081                nrmxl += val * val;
1082            }
1083            nrmxl = nrmxl.sqrt();
1084
1085            if nrmxl == 0.0 {
1086                // Zero column - continue to next
1087                continue;
1088            }
1089
1090            // Apply sign for numerical stability
1091            let x_ll = x.get(l, l);
1092            if x_ll != 0.0 {
1093                nrmxl = nrmxl.copysign(x_ll);
1094            }
1095
1096            // Scale the column
1097            let scale = 1.0 / nrmxl;
1098            for i in l..n {
1099                x.set(i, l, x.get(i, l) * scale);
1100            }
1101            x.set(l, l, 1.0 + x.get(l, l));
1102
1103            // Apply the transformation to remaining columns, updating the norms
1104            let lp1 = l + 1;
1105            if p > lp1 {
1106                for j in lp1..p {
1107                    // Compute t = -dot(x[l:, l], x[l:, j]) / x(l, l)
1108                    let mut dot = 0.0;
1109                    for i in l..n {
1110                        dot += x.get(i, l) * x.get(i, j);
1111                    }
1112                    let t = -dot / x.get(l, l);
1113
1114                    // x[l:, j] = x[l:, j] + t * x[l:, l]
1115                    for i in l..n {
1116                        let val = x.get(i, j) + t * x.get(i, l);
1117                        x.set(i, j, val);
1118                    }
1119
1120                    // Update the norm
1121                    if qraux[j] != 0.0 {
1122                        // tt = 1.0 - (x(l, j) / qraux[j])^2
1123                        let x_lj = x.get(l, j).abs();
1124                        let mut tt = 1.0 - (x_lj / qraux[j]).powi(2);
1125                        tt = tt.max(0.0);
1126
1127                        // Recompute norm if there is large reduction (BDR mod 9/99)
1128                        // The tolerance here is on the squared norm
1129                        if tt.abs() < 1e-6 {
1130                            // Re-compute norm directly
1131                            let mut new_norm = 0.0;
1132                            for i in (l + 1)..n {
1133                                let val = x.get(i, j);
1134                                new_norm += val * val;
1135                            }
1136                            new_norm = new_norm.sqrt();
1137                            qraux[j] = new_norm;
1138                            work[j].0 = new_norm;
1139                        } else {
1140                            qraux[j] *= tt.sqrt();
1141                        }
1142                    }
1143                }
1144            }
1145
1146            // Save the transformation
1147            qraux[l] = x.get(l, l);
1148            x.set(l, l, -nrmxl);
1149        }
1150
1151        // Compute final rank
1152        let rank = k - 1;
1153        let rank = rank.min(n);
1154
1155        QRLinpack {
1156            qr: x,
1157            qraux,
1158            pivot,
1159            rank,
1160        }
1161    }
1162
1163    /// Solves a linear system using the QR decomposition with column pivoting.
1164    ///
1165    /// This implements a least squares solver using the pivoted QR decomposition.
1166    /// For rank-deficient cases, coefficients corresponding to linearly dependent
1167    /// columns are set to `f64::NAN`.
1168    ///
1169    /// # Arguments
1170    ///
1171    /// * `qr_result` - QR decomposition from [`Matrix::qr_linpack`]
1172    /// * `y` - Right-hand side vector
1173    ///
1174    /// # Returns
1175    ///
1176    /// A vector of coefficients, or `None` if the system is exactly singular.
1177    ///
1178    /// # Algorithm
1179    ///
1180    /// This solver uses the standard QR decomposition approach:
1181    /// 1. Compute the QR decomposition of the permuted matrix
1182    /// 2. Extract R matrix (upper triangular with positive diagonal)
1183    /// 3. Compute qty = Q^T * y
1184    /// 4. Solve R * coef = qty using back substitution
1185    /// 5. Apply the pivot permutation to restore original column order
1186    ///
1187    /// # Note
1188    ///
1189    /// The LINPACK QR algorithm stores R with mixed signs on the diagonal.
1190    /// This solver corrects for that by taking the absolute value of R's diagonal.
1191    #[allow(clippy::needless_range_loop)]
1192    pub fn qr_solve_linpack(&self, qr_result: &QRLinpack, y: &[f64]) -> Option<Vec<f64>> {
1193        let n = self.rows;
1194        let p = self.cols;
1195        let k = qr_result.rank;
1196
1197        if y.len() != n {
1198            return None;
1199        }
1200
1201        if k == 0 {
1202            return None;
1203        }
1204
1205        // Step 1: Compute Q^T * y using the Householder vectors directly
1206        // This is more efficient than reconstructing the full Q matrix
1207        let mut qty = y.to_vec();
1208
1209        for j in 0..k {
1210            // Check if this Householder transformation is valid
1211            let r_jj = qr_result.qr.get(j, j);
1212            if r_jj == 0.0 {
1213                continue;
1214            }
1215
1216            // Compute dot = v_j^T * qty[j:]
1217            // where v_j is the Householder vector stored in qr[j:, j]
1218            // The storage convention:
1219            // - qr[j,j] = -nrmxl (after final overwrite)
1220            // - qr[i,j] for i > j is the scaled Householder vector element
1221            // - qraux[j] = 1 + original_x[j,j]/nrmxl (the unscaled first element)
1222
1223            // Reconstruct the Householder vector v_j
1224            // After scaling by 1/nrmxl, we have:
1225            // v_scaled[j] = 1 + x[j,j]/nrmxl
1226            // v_scaled[i] = x[i,j]/nrmxl for i > j
1227            // The actual unit vector is v = v_scaled / ||v_scaled||
1228
1229            let mut v = vec![0.0; n - j];
1230            // Copy the scaled Householder vector from qr
1231            for i in j..n {
1232                v[i - j] = qr_result.qr.get(i, j);
1233            }
1234
1235            // The j-th element was modified during the QR decomposition
1236            // We need to reconstruct it from qraux
1237            let alpha = qr_result.qraux[j];
1238            if alpha != 0.0 {
1239                v[0] = alpha;
1240            }
1241
1242            // Compute the norm of v
1243            let v_norm: f64 = v.iter().map(|&x| x * x).sum::<f64>().sqrt();
1244            if v_norm < 1e-14 {
1245                continue;
1246            }
1247
1248            // Compute dot = v^T * qty[j:]
1249            let mut dot = 0.0;
1250            for i in j..n {
1251                dot += v[i - j] * qty[i];
1252            }
1253
1254            // Apply Householder transformation: qty[j:] = qty[j:] - 2 * v * (v^T * qty[j:]) / (v^T * v)
1255            // Since v is already scaled, we use: t = 2 * dot / (v_norm^2)
1256            let t = 2.0 * dot / (v_norm * v_norm);
1257
1258            for i in j..n {
1259                qty[i] -= t * v[i - j];
1260            }
1261        }
1262
1263        // Step 2: Back substitution on R (solve R * coef = qty)
1264        // The R matrix is stored in the upper triangle of qr
1265        // Note: The diagonal elements of R are negative (from -nrmxl)
1266        // We use them as-is since the signs cancel out in the computation
1267        let mut coef_permuted = vec![f64::NAN; p];
1268
1269        for row in (0..k).rev() {
1270            let r_diag = qr_result.qr.get(row, row);
1271            // Use relative tolerance for singularity check
1272            let max_abs = (0..k)
1273                .map(|i| qr_result.qr.get(i, i).abs())
1274                .fold(0.0_f64, f64::max);
1275            let tolerance = 1e-14 * max_abs.max(1.0);
1276
1277            if r_diag.abs() < tolerance {
1278                return None; // Singular
1279            }
1280
1281            let mut sum = qty[row];
1282            for col in (row + 1)..k {
1283                sum -= qr_result.qr.get(row, col) * coef_permuted[col];
1284            }
1285            coef_permuted[row] = sum / r_diag;
1286        }
1287
1288        // Step 3: Apply pivot permutation to get coefficients in original order
1289        // pivot[j] is 1-based, indicating which original column is now in position j
1290        let mut result = vec![0.0; p];
1291        for j in 0..p {
1292            let original_col = qr_result.pivot[j] - 1; // Convert to 0-based
1293            result[original_col] = coef_permuted[j];
1294        }
1295
1296        Some(result)
1297    }
1298}
1299
1300/// Performs OLS regression using R's LINPACK QR algorithm.
1301///
1302/// This function is a drop-in replacement for `fit_ols` that uses the
1303/// R-compatible QR decomposition with column pivoting. It handles
1304/// rank-deficient matrices more gracefully than the standard QR decomposition.
1305///
1306/// # Arguments
1307///
1308/// * `y` - Response variable (n observations)
1309/// * `x` - Design matrix (n rows, p columns including intercept)
1310///
1311/// # Returns
1312///
1313/// * `Some(Vec<f64>)` - OLS coefficient vector (p elements)
1314/// * `None` - If the matrix is exactly singular or dimensions don't match
1315///
1316/// # Note
1317///
1318/// For rank-deficient systems, this function uses the pivoted QR which
1319/// automatically handles multicollinearity by selecting a linearly
1320/// independent subset of columns.
1321///
1322/// # Arguments
1323///
1324/// * `y` - Response variable (n observations)
1325/// * `x` - Design matrix (n rows × p columns, including intercept column)
1326///
1327/// # Returns
1328///
1329/// * `Some(Vec<f64>)` - OLS coefficient estimates (length p), or `None` if the matrix is singular
1330///
1331/// # Notes
1332///
1333/// - Coefficients for dropped (collinear) columns are set to `NaN`
1334/// - This is a convenience wrapper around `Matrix::qr_linpack` and `Matrix::qr_solve_linpack`
1335///
1336/// # Example
1337///
1338/// ```
1339/// # use linreg_core::linalg::{fit_ols_linpack, Matrix};
1340/// let y = vec![2.0, 4.0, 6.0];
1341/// let x = Matrix::new(3, 2, vec![1.0, 1.0, 1.0, 2.0, 1.0, 3.0]);
1342///
1343/// let beta = fit_ols_linpack(&y, &x).unwrap();
1344/// assert_eq!(beta.len(), 2);  // Intercept and slope
1345/// ```
1346pub fn fit_ols_linpack(y: &[f64], x: &Matrix) -> Option<Vec<f64>> {
1347    let qr_result = x.qr_linpack(None);
1348    x.qr_solve_linpack(&qr_result, y)
1349}
1350
1351// ============================================================================
1352// SVD Decomposition (Golub-Kahan)
1353// ============================================================================
1354
1355/// SVD decomposition result.
1356///
1357/// Contains the singular value decomposition A = U * Sigma * V^T where:
1358/// - U is an m×min(m,n) orthogonal matrix (left singular vectors)
1359/// - Sigma is a vector of min(m,n) singular values (sorted in descending order)
1360/// - V is an n×n orthogonal matrix (right singular vectors, stored transposed as V^T)
1361#[derive(Clone, Debug)]
1362pub struct SVDResult {
1363    /// Left singular vectors (m × k matrix where k = min(m,n))
1364    pub u: Matrix,
1365    /// Singular values (k elements, sorted in descending order)
1366    pub sigma: Vec<f64>,
1367    /// Right singular vectors transposed (n × n matrix, rows are V^T)
1368    pub v_t: Matrix,
1369}
1370
1371impl Matrix {
1372    /// Computes the Singular Value Decomposition (SVD) using the Golub-Kahan algorithm.
1373    ///
1374    /// Factorizes the matrix as `A = U * Sigma * V^T` where:
1375    /// - U is an m×k orthogonal matrix (k = min(m,n))
1376    /// - Sigma is a diagonal matrix of singular values (sorted in descending order)
1377    /// - V is an n×n orthogonal matrix
1378    ///
1379    /// This implementation uses a simplified Golub-Kahan bidiagonalization approach
1380    /// suitable for the small matrices encountered in LOESS local fitting.
1381    ///
1382    /// # Algorithm
1383    ///
1384    /// The algorithm follows these steps:
1385    /// 1. Compute A^T * A (smaller symmetric matrix when m >= n)
1386    /// 2. Eigen-decompose A^T * A using QR iteration
1387    /// 3. Singular values are sqrt of eigenvalues
1388    /// 4. V contains eigenvectors of A^T * A
1389    /// 5. U = A * V * Sigma^(-1)
1390    ///
1391    /// # Returns
1392    ///
1393    /// A [`SVDResult`] struct containing U, Sigma, and V^T.
1394    ///
1395    /// # Note
1396    ///
1397    /// This implementation is designed for numerical stability with rank-deficient
1398    /// matrices, which is essential for LOESS fitting where some neighborhoods may
1399    /// have collinear points.
1400    ///
1401    /// # Alternative
1402    ///
1403    /// For potentially higher accuracy on small matrices, see [`Matrix::svd_jacobi`].
1404    #[allow(clippy::needless_range_loop)]
1405    pub fn svd(&self) -> SVDResult {
1406        let m = self.rows;
1407        let n = self.cols;
1408        let k = m.min(n);
1409
1410        // For typical LOESS cases, m >= n (tall matrix)
1411        // We use the covariance method: A^T * A = V * Sigma^2 * V^T
1412        // This is more efficient for tall matrices
1413
1414        // Compute A^T * A (n × n symmetric matrix)
1415        let mut ata = Matrix::zeros(n, n);
1416        for i in 0..n {
1417            for j in i..n {
1418                // ata[i,j] = sum(A[p,i] * A[p,j] for p in 0..m)
1419                let mut sum = 0.0;
1420                for p in 0..m {
1421                    sum += self.get(p, i) * self.get(p, j);
1422                }
1423                ata.set(i, j, sum);
1424                ata.set(j, i, sum); // Symmetric
1425            }
1426        }
1427
1428        // Eigen-decomposition of A^T * A using QR algorithm
1429        // Start with identity as V (will converge to eigenvectors)
1430        let mut v = Matrix::identity(n);
1431        let mut lambda = Vec::with_capacity(n);
1432
1433        // QR iteration for symmetric matrices
1434        // This finds eigenvalues (lambda) and eigenvectors (columns of V)
1435        let max_iterations = 100;
1436        let tolerance = 1e-14;
1437
1438        // Working copy for QR iteration
1439        let mut a_work = ata.clone();
1440
1441        for _iter in 0..max_iterations {
1442            // Check convergence - sum of off-diagonal elements
1443            let mut off_diag_sum = 0.0;
1444            for i in 0..n {
1445                for j in (i + 1)..n {
1446                    off_diag_sum += a_work.get(i, j).abs();
1447                }
1448            }
1449
1450            if off_diag_sum < tolerance {
1451                break;
1452            }
1453
1454            // QR decomposition with Wilkinson shift for faster convergence
1455            // For simplicity, we use basic QR without shift
1456            let (q, r) = a_work.qr();
1457            a_work = r.matmul(&q);
1458            v = v.matmul(&q);
1459        }
1460
1461        // Extract eigenvalues from diagonal
1462        for i in 0..n {
1463            lambda.push(a_work.get(i, i));
1464        }
1465
1466        // Sort eigenvalues and corresponding eigenvectors in descending order
1467        // We need to keep V synchronized
1468        for i in 0..n {
1469            for j in (i + 1)..n {
1470                if lambda[j] > lambda[i] {
1471                    // Swap eigenvalues
1472                    lambda.swap(i, j);
1473
1474                    // Swap corresponding columns of V
1475                    #[allow(clippy::manual_swap)]
1476                    for row in 0..n {
1477                        let temp = v.get(row, i);
1478                        v.set(row, i, v.get(row, j));
1479                        v.set(row, j, temp);
1480                    }
1481                }
1482            }
1483        }
1484
1485        // Singular values are sqrt of eigenvalues (clamp non-negative)
1486        let mut sigma = Vec::with_capacity(k);
1487        for i in 0..k {
1488            let s = lambda[i].max(0.0).sqrt();
1489            sigma.push(s);
1490        }
1491
1492        // Compute U = A * V * Sigma^(-1)
1493        // Only compute for non-zero singular values
1494        let mut u = Matrix::zeros(m, k);
1495        for j in 0..k {
1496            if sigma[j] > 1e-14 {
1497                // U[:,j] = (A * V[:,j]) / sigma[j]
1498                for i in 0..m {
1499                    let mut sum = 0.0;
1500                    for p in 0..n {
1501                        sum += self.get(i, p) * v.get(p, j);
1502                    }
1503                    u.set(i, j, sum / sigma[j]);
1504                }
1505            }
1506        }
1507
1508        // V^T is the transpose of V
1509        let v_t = v.transpose();
1510
1511        SVDResult { u, sigma, v_t }
1512    }
1513
1514    /// Solves a least squares problem using SVD with pseudoinverse for rank-deficient matrices.
1515    ///
1516    /// This implements the pseudoinverse solution: x = V * Sigma^+ * U^T * b
1517    /// where Sigma^+ replaces 1/sigma\[i\] with 0 for sigma\[i\] below tolerance.
1518    ///
1519    /// # Arguments
1520    ///
1521    /// * `svd_result` - SVD decomposition from [`Matrix::svd`]
1522    /// * `b` - Right-hand side vector (m elements)
1523    ///
1524    /// # Returns
1525    ///
1526    /// A vector of coefficients (n elements) that minimizes ||Ax - b||.
1527    ///
1528    #[allow(clippy::needless_range_loop)]
1529    pub fn svd_solve(&self, svd_result: &SVDResult, b: &[f64]) -> Vec<f64> {
1530        let m = self.rows;
1531        let n = self.cols;
1532        let k = m.min(n);
1533
1534        // Tolerance: sigma[0] * 100 * epsilon
1535        let max_sigma = svd_result.sigma.first().copied().unwrap_or(0.0);
1536        let tol = if max_sigma > 0.0 {
1537            max_sigma * 100.0 * f64::EPSILON
1538        } else {
1539            1e-14
1540        };
1541
1542        // Compute U^T * b
1543        let mut ut_b = vec![0.0; k];
1544        for j in 0..k {
1545            let mut sum = 0.0;
1546            for i in 0..m {
1547                sum += svd_result.u.get(i, j) * b[i];
1548            }
1549            ut_b[j] = sum;
1550        }
1551
1552        // Compute coefficients in V space: c[j] = ut_b[j] / sigma[j] if sigma[j] > tol, else 0
1553        let mut coeffs_v = vec![0.0; k];
1554        for j in 0..k {
1555            if svd_result.sigma[j] > tol {
1556                coeffs_v[j] = ut_b[j] / svd_result.sigma[j];
1557            } else {
1558                coeffs_v[j] = 0.0; // Singular value below threshold - use pseudoinverse (set to 0)
1559            }
1560        }
1561
1562        // Transform back to original space: x = V^T * coeffs_v
1563        // Since v_t contains rows of V^T, we compute x = v_t^T * coeffs_v
1564        let mut x = vec![0.0; n];
1565        for i in 0..n {
1566            let mut sum = 0.0;
1567            for j in 0..k {
1568                // v_t.get(j, i) is element (j, i) of V^T, which is V[i, j]
1569                sum += svd_result.v_t.get(j, i) * coeffs_v[j];
1570            }
1571            x[i] = sum;
1572        }
1573
1574        x
1575    }
1576}
1577
1578/// Fits OLS and predicts using R's LINPACK QR with rank-deficient handling.
1579///
1580/// This function matches R's `lm.fit` behavior for rank-deficient cases:
1581/// coefficients for linearly dependent columns are set to NA, and predictions
1582/// are computed using only the valid (non-NA) coefficients and their corresponding
1583/// columns. This matches how R handles rank-deficient models in prediction.
1584///
1585/// # Arguments
1586///
1587/// * `y` - Response variable (n observations)
1588/// * `x` - Design matrix (n rows, p columns including intercept)
1589///
1590/// # Returns
1591///
1592/// * `Some(Vec<f64>)` - Predictions (n elements)
1593/// * `None` - If the matrix is exactly singular or dimensions don't match
1594///
1595/// # Algorithm
1596///
1597/// For rank-deficient systems (rank < p):
1598/// 1. Compute QR decomposition with column pivoting
1599/// 2. Get coefficients (rank-deficient columns will have NaN)
1600/// 3. Build a reduced design matrix with only pivoted, non-singular columns
1601/// 4. Compute predictions using only the valid columns
1602///
1603/// This matches R's behavior where `predict(lm.fit(...))` handles NA coefficients
1604/// by excluding the corresponding columns from the prediction.
1605///
1606/// # Example
1607///
1608/// ```
1609/// # use linreg_core::linalg::{fit_and_predict_linpack, Matrix};
1610/// let y = vec![2.0, 4.0, 6.0];
1611/// let x = Matrix::new(3, 2, vec![1.0, 1.0, 1.0, 2.0, 1.0, 3.0]);
1612///
1613/// let preds = fit_and_predict_linpack(&y, &x).unwrap();
1614/// assert_eq!(preds.len(), 3);  // One prediction per observation
1615/// ```
1616#[allow(clippy::needless_range_loop)]
1617pub fn fit_and_predict_linpack(y: &[f64], x: &Matrix) -> Option<Vec<f64>> {
1618    let n = x.rows;
1619    let p = x.cols;
1620
1621    // Compute QR decomposition
1622    let qr_result = x.qr_linpack(None);
1623    let k = qr_result.rank;
1624
1625    // Solve for coefficients
1626    let beta_permuted = x.qr_solve_linpack(&qr_result, y)?;
1627
1628    // Check for rank deficiency
1629    if k == p {
1630        // Full rank - use standard prediction
1631        return Some(x.mul_vec(&beta_permuted));
1632    }
1633
1634    // Rank-deficient case: some columns are collinear and have NaN coefficients
1635    // We compute predictions using only columns with valid (non-NaN) coefficients
1636    // This matches R's behavior where NA coefficients exclude columns from prediction
1637
1638    let mut pred = vec![0.0; n];
1639
1640    for row in 0..n {
1641        let mut sum = 0.0;
1642        for j in 0..p {
1643            let b_val = beta_permuted[j];
1644            if b_val.is_nan() {
1645                continue; // Skip collinear columns (matches R's NA coefficient behavior)
1646            }
1647            sum += x.get(row, j) * b_val;
1648        }
1649        pred[row] = sum;
1650    }
1651
1652    Some(pred)
1653}
1654
1655// ============================================================================
1656// SVD and Pseudoinverse for Robust Weighted Least Squares
1657// ============================================================================
1658
1659impl Matrix {
1660    /// Compute SVD decomposition using the eigendecomposition method (Jacobi)
1661    ///
1662    /// For a matrix A (m×n), this computes A = U * Σ * V^T where:
1663    /// - U is m×k orthogonal (left singular vectors, k = min(m,n))
1664    /// - Σ is k singular values (sorted in descending order)
1665    /// - V is n×n orthogonal (right singular vectors)
1666    ///
1667    /// This implementation uses the method of computing the eigendecomposition
1668    /// of A^T*A to get V and singular values, then computing U = A * V * Σ^(-1).
1669    /// The Jacobi method is used for the eigendecomposition, which provides
1670    /// excellent numerical accuracy for small to medium-sized matrices.
1671    ///
1672    /// Returns None if the decomposition fails.
1673    ///
1674    /// # Alternative
1675    ///
1676    /// For a simpler/faster approach suitable for LOESS, see [`Matrix::svd`].
1677    pub fn svd_jacobi(&self) -> Option<SVDResult> {
1678        let m = self.rows;
1679        let n = self.cols;
1680
1681        if m < n {
1682            // For wide matrices, compute SVD of transpose instead
1683            let at = self.transpose();
1684            let svd_t = at.svd_jacobi()?;
1685            // A = U * Σ * V^T = (V' * Σ' * U'^T)^T = U' * Σ' * V'^T
1686            // So swap U and V, and V becomes V^T
1687            return Some(SVDResult {
1688                u: svd_t.v_t.transpose(),  // V' becomes U (need to transpose to get correct format)
1689                sigma: svd_t.sigma,
1690                v_t: svd_t.u.transpose(), // U' becomes V (need to transpose to get V^T)
1691            });
1692        }
1693
1694        // Compute A^T * A (n×n symmetric matrix)
1695        let ata = self.transpose().matmul(self);
1696
1697        // Compute eigendecomposition of A^T * A using Jacobi method
1698        let (v, s_sq) = ata.symmetric_eigen()?;
1699
1700        // Singular values are sqrt of eigenvalues
1701        let s: Vec<f64> = s_sq.iter().map(|&x| x.sqrt().max(0.0)).collect();
1702
1703        // Sort by singular values (descending)
1704        let mut indexed: Vec<(usize, f64)> = s.iter().enumerate()
1705            .map(|(i, &val)| (i, val))
1706            .collect();
1707        indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1708
1709        // Reorder V according to sorted singular values
1710        let mut v_sorted = Matrix::zeros(n, n);
1711        let mut s_sorted = vec![0.0; n];
1712        for (new_idx, (old_idx, val)) in indexed.iter().enumerate() {
1713            s_sorted[new_idx] = *val;
1714            for row in 0..n {
1715                v_sorted.set(row, new_idx, v.get(row, *old_idx));
1716            }
1717        }
1718
1719        // Compute U = A * V * Σ^(-1)
1720        let mut u = Matrix::zeros(m, n);
1721        for i in 0..m {
1722            for j in 0..n {
1723                if s_sorted[j] > f64::EPSILON {
1724                    let mut sum = 0.0;
1725                    for k in 0..n {
1726                        sum += self.get(i, k) * v_sorted.get(k, j);
1727                    }
1728                    u.set(i, j, sum / s_sorted[j]);
1729                }
1730            }
1731        }
1732
1733        // V^T is the transpose of V (columns of V are right singular vectors)
1734        let v_t = v_sorted.transpose();
1735
1736        Some(SVDResult {
1737            u,
1738            sigma: s_sorted,
1739            v_t,
1740        })
1741    }
1742
1743    /// Compute eigendecomposition of a symmetric matrix using Jacobi method
1744    ///
1745    /// For a symmetric matrix A (n×n), computes eigenvalues λ and eigenvectors V
1746    /// such that A = V * Λ * V^T where Λ is diagonal and V is orthogonal.
1747    ///
1748    /// Returns (V, eigenvalues) where eigenvalues\[i\] is the eigenvalue for column i of V.
1749    fn symmetric_eigen(&self) -> Option<(Matrix, Vec<f64>)> {
1750        let n = self.rows;
1751        assert_eq!(n, self.cols, "Matrix must be square");
1752
1753        let max_iterations = 100;
1754        let tolerance = 1e-10;
1755
1756        // Start with identity matrix for eigenvectors
1757        let mut v = Matrix::identity(n);
1758
1759        // Working copy of the matrix
1760        let mut a = self.clone();
1761
1762        for _iter in 0..max_iterations {
1763            let mut max_off_diag = 0.0;
1764
1765            // Find largest off-diagonal element
1766            let mut p = 0;
1767            let mut q = 1;
1768            for i in 0..n {
1769                for j in (i + 1)..n {
1770                    let val = a.get(i, j).abs();
1771                    if val > max_off_diag {
1772                        max_off_diag = val;
1773                        p = i;
1774                        q = j;
1775                    }
1776                }
1777            }
1778
1779            if max_off_diag < tolerance {
1780                break;
1781            }
1782
1783            // Jacobi rotation to zero out a(p,q)
1784            let app = a.get(p, p);
1785            let aqq = a.get(q, q);
1786            let apq = a.get(p, q);
1787
1788            if apq.abs() < tolerance {
1789                continue;
1790            }
1791
1792            // Compute rotation angle
1793            let tau = (aqq - app) / (2.0 * apq);
1794            let t = if tau >= 0.0 {
1795                1.0 / (tau + (1.0 + tau * tau).sqrt())
1796            } else {
1797                -1.0 / (-tau + (1.0 + tau * tau).sqrt())
1798            };
1799
1800            let c = 1.0 / (1.0 + t * t).sqrt();
1801            let s = t * c;
1802
1803            // Update matrix A
1804            for i in 0..n {
1805                if i != p && i != q {
1806                    let aip = a.get(i, p);
1807                    let aiq = a.get(i, q);
1808                    a.set(i, p, aip - s * (aiq + s * aip));
1809                    a.set(i, q, aiq + s * (aip - s * aiq));
1810                }
1811            }
1812
1813            a.set(p, p, app - t * apq);
1814            a.set(q, q, aqq + t * apq);
1815            a.set(p, q, 0.0);
1816            a.set(q, p, 0.0);
1817
1818            // Update eigenvectors V
1819            for i in 0..n {
1820                let vip = v.get(i, p);
1821                let viq = v.get(i, q);
1822                v.set(i, p, vip - s * (viq + s * vip));
1823                v.set(i, q, viq + s * (vip - s * viq));
1824            }
1825        }
1826
1827        // Extract eigenvalues from diagonal
1828        let eigenvalues: Vec<f64> = (0..n).map(|i| a.get(i, i)).collect();
1829
1830        Some((v, eigenvalues))
1831    }
1832
1833    /// Compute Moore-Penrose pseudoinverse using SVD
1834    ///
1835    /// For matrix A, computes A^+ = V * Σ^+ * U^T where:
1836    /// - Σ^+ has 1/σ for σ > tolerance, and 0 otherwise
1837    ///
1838    /// This provides a least-squares solution for rank-deficient or singular matrices.
1839    pub fn pseudo_inverse(&self, tolerance: Option<f64>) -> Option<Matrix> {
1840        let svd = self.svd_jacobi()?;
1841
1842        let m = self.rows;
1843        let n = self.cols;
1844
1845        // Use provided tolerance or compute based on largest singular value
1846        let tol = tolerance.unwrap_or_else(|| {
1847            let max_s = svd.sigma.iter().fold(0.0_f64, |a: f64, &b| a.max(b));
1848            if max_s > 0.0 {
1849                max_s * f64::EPSILON * (m.max(n) as f64)
1850            } else {
1851                f64::EPSILON
1852            }
1853        });
1854
1855        // Compute Σ^+ (pseudoinverse of singular values)
1856        let mut s_pinv = vec![0.0; svd.sigma.len()];
1857        for (i, &s_val) in svd.sigma.iter().enumerate() {
1858            if s_val > tol {
1859                s_pinv[i] = 1.0 / s_val;
1860            }
1861        }
1862
1863        // Compute A^+ = V * Σ^+ * U^T
1864        // First compute Σ^+ * U^T (n×m matrix)
1865        let mut s_ut = Matrix::zeros(n, m);
1866        for i in 0..n {
1867            for j in 0..m {
1868                s_ut.set(i, j, s_pinv[i] * svd.u.get(j, i));
1869            }
1870        }
1871
1872        // Then compute V * (Σ^+ * U^T)
1873        // Since v_t is V^T, we transpose it to get V
1874        let v = svd.v_t.transpose();
1875        let pseudoinv = v.matmul(&s_ut);
1876
1877        Some(pseudoinv)
1878    }
1879}
1880
1881// ============================================================================
1882// SVD Tests
1883// ============================================================================
1884
1885#[cfg(test)]
1886mod svd_tests {
1887    use super::*;
1888
1889    #[test]
1890    fn test_svd_simple_matrix() {
1891        // Test SVD on a simple 2x2 matrix
1892        let data = vec![1.0, 2.0, 3.0, 4.0];
1893        let m = Matrix::new(2, 2, data);
1894        let svd = m.svd();
1895
1896        // Verify singular values are sorted descending
1897        for i in 1..svd.sigma.len() {
1898            assert!(svd.sigma[i-1] >= svd.sigma[i]);
1899        }
1900
1901        // Verify U is orthogonal: U^T * U = I
1902        let ut = svd.u.transpose();
1903        let ut_u = ut.matmul(&svd.u);
1904        assert!((ut_u.get(0, 0) - 1.0).abs() < 1e-10);
1905        assert!(ut_u.get(0, 1).abs() < 1e-10);
1906        assert!(ut_u.get(1, 0).abs() < 1e-10);
1907        assert!((ut_u.get(1, 1) - 1.0).abs() < 1e-10);
1908    }
1909
1910    #[test]
1911    fn test_svd_solve_basic() {
1912        // Test SVD solver on a simple system
1913        let data = vec![
1914            1.0, 1.0,
1915            1.0, 2.0,
1916            1.0, 3.0,
1917        ];
1918        let m = Matrix::new(3, 2, data);
1919        let svd = m.svd();
1920
1921        // Solve: x + y = 2, x + 2y = 4, x + 3y = 6
1922        // Solution: x = 0, y = 2
1923        let b = vec![2.0, 4.0, 6.0];
1924        let x = m.svd_solve(&svd, &b);
1925
1926        // Check solution
1927        assert!((x[0] - 0.0).abs() < 1e-10);
1928        assert!((x[1] - 2.0).abs() < 1e-10);
1929    }
1930
1931    #[test]
1932    fn test_svd_tolerance_formula() {
1933        // Verify tolerance formula: tol = sigma[0] * 100 * epsilon
1934        let data = vec![
1935            1.0, 1.0,
1936            1.0, 2.0,
1937            1.0, 3.0,
1938        ];
1939        let m = Matrix::new(3, 2, data);
1940        let svd = m.svd();
1941
1942        let max_sigma = svd.sigma[0];
1943        let expected_tol = max_sigma * 100.0 * f64::EPSILON;
1944
1945        // Verify tolerance is computed correctly
1946        assert!(expected_tol > 0.0);
1947        assert!(expected_tol < 1e-10);
1948    }
1949
1950    #[test]
1951    fn test_svd_solve_rank_deficient() {
1952        // Test SVD solver on rank-deficient matrix
1953        // Second column is 2x first column
1954        let data = vec![
1955            1.0, 2.0,
1956            2.0, 4.0,
1957            3.0, 6.0,
1958        ];
1959        let m = Matrix::new(3, 2, data);
1960        let svd = m.svd();
1961
1962        // One singular value should be near zero
1963        assert!(svd.sigma[0] > 1e-10);
1964        assert!(svd.sigma[1] < 1e-10);
1965
1966        // Solve: should still work with pseudoinverse
1967        let b = vec![3.0, 6.0, 9.0];
1968        let x = m.svd_solve(&svd, &b);
1969
1970        // Check that solution is valid
1971        assert!(x[0].is_finite());
1972        assert!(x[1].is_finite());
1973
1974        // Verify prediction at first row: 1*x0 + 2*x1 ≈ 3
1975        let pred = m.get(0, 0) * x[0] + m.get(0, 1) * x[1];
1976        assert!((pred - b[0]).abs() < 1e-6);
1977    }
1978
1979    #[test]
1980    fn test_svd_jacobi_kahan_produce_results() {
1981        // Verify both Jacobi and Kahan SVD produce valid results
1982        let data = vec![
1983            1.0, 2.0, 3.0,
1984            4.0, 5.0, 6.0,
1985            7.0, 8.0, 9.0,
1986        ];
1987        let m = Matrix::new(3, 3, data);
1988
1989        // Kahan (default)
1990        let svd_kahan = m.svd();
1991
1992        // Jacobi
1993        let svd_jacobi = m.svd_jacobi().unwrap();
1994
1995        // Both should produce same number of singular values
1996        assert_eq!(svd_kahan.sigma.len(), svd_jacobi.sigma.len());
1997
1998        // Both should have sorted (descending) singular values
1999        for i in 1..svd_kahan.sigma.len() {
2000            assert!(svd_kahan.sigma[i-1] >= svd_kahan.sigma[i]);
2001            assert!(svd_jacobi.sigma[i-1] >= svd_jacobi.sigma[i]);
2002        }
2003
2004        // Both should detect the rank deficiency (one near-zero singular value)
2005        assert!(svd_kahan.sigma[2] < 1e-10);
2006        assert!(svd_jacobi.sigma[2] < 1e-10);
2007    }
2008
2009    #[test]
2010    fn test_svd_jacobi_rank_deficient() {
2011        // Test Jacobi SVD on rank-deficient matrix
2012        let data = vec![
2013            1.0, 2.0,
2014            2.0, 4.0,
2015            3.0, 6.0,
2016        ];
2017        let m = Matrix::new(3, 2, data);
2018        let svd = m.svd_jacobi().unwrap();
2019
2020        // Should successfully decompose
2021        assert_eq!(svd.sigma.len(), 2);
2022        // One singular value should be near zero (rank = 1)
2023        assert!(svd.sigma[1] < 1e-10);
2024    }
2025
2026    #[test]
2027    fn test_pseudo_inverse_basic() {
2028        // Test pseudoinverse on simple invertible matrix
2029        let data = vec![
2030            1.0, 0.0,
2031            0.0, 1.0,
2032        ];
2033        let m = Matrix::new(2, 2, data);
2034        let pinv = m.pseudo_inverse(None).unwrap();
2035
2036        // For identity matrix, pseudoinverse should be identity
2037        assert!((pinv.get(0, 0) - 1.0).abs() < 1e-10);
2038        assert!(pinv.get(0, 1).abs() < 1e-10);
2039        assert!(pinv.get(1, 0).abs() < 1e-10);
2040        assert!((pinv.get(1, 1) - 1.0).abs() < 1e-10);
2041    }
2042
2043    #[test]
2044    fn test_pseudo_inverse_rank_deficient() {
2045        // Test pseudoinverse on rank-deficient matrix
2046        let data = vec![
2047            1.0, 0.0,
2048            1.0, 0.0,
2049            1.0, 0.0,
2050        ];
2051        let m = Matrix::new(3, 2, data);
2052        let pinv = m.pseudo_inverse(None).unwrap();
2053
2054        // Verify pseudoinverse exists and is finite
2055        assert_eq!(pinv.rows, 2);
2056        assert_eq!(pinv.cols, 3);
2057        for i in 0..pinv.rows {
2058            for j in 0..pinv.cols {
2059                assert!(pinv.get(i, j).is_finite());
2060            }
2061        }
2062    }
2063
2064    #[test]
2065    fn test_svd_wide_matrix() {
2066        // Test SVD on wide matrix (more columns than rows)
2067        let data = vec![
2068            1.0, 2.0, 3.0, 4.0,
2069            5.0, 6.0, 7.0, 8.0,
2070        ];
2071        let m = Matrix::new(2, 4, data);
2072        let svd = m.svd();
2073
2074        // Should produce 2 singular values (min(2,4))
2075        assert_eq!(svd.sigma.len(), 2);
2076
2077        // U should be 2x2
2078        assert_eq!(svd.u.rows, 2);
2079        assert_eq!(svd.u.cols, 2);
2080
2081        // V^T should be 4x4
2082        assert_eq!(svd.v_t.rows, 4);
2083        assert_eq!(svd.v_t.cols, 4);
2084    }
2085
2086    #[test]
2087    fn test_svd_tall_matrix() {
2088        // Test SVD on tall matrix (more rows than columns)
2089        let data = vec![
2090            1.0, 2.0,
2091            3.0, 4.0,
2092            5.0, 6.0,
2093            7.0, 8.0,
2094        ];
2095        let m = Matrix::new(4, 2, data);
2096        let svd = m.svd();
2097
2098        // Should produce 2 singular values (min(4,2))
2099        assert_eq!(svd.sigma.len(), 2);
2100
2101        // U should be 4x2
2102        assert_eq!(svd.u.rows, 4);
2103        assert_eq!(svd.u.cols, 2);
2104
2105        // V^T should be 2x2
2106        assert_eq!(svd.v_t.rows, 2);
2107        assert_eq!(svd.v_t.cols, 2);
2108    }
2109
2110    #[test]
2111    fn test_svd_solve_with_custom_tolerance() {
2112        // Test that tolerance affects which singular values are kept
2113        let data = vec![
2114            1.0, 2.0,
2115            2.0, 4.0,
2116            3.0, 6.0,
2117        ];
2118        let m = Matrix::new(3, 2, data);
2119        let svd = m.svd();
2120
2121        let b = vec![3.0, 6.0, 9.0];
2122
2123        // Default tolerance
2124        let x_default = m.svd_solve(&svd, &b);
2125
2126        // With very high tolerance (more aggressive rejection)
2127        // This should still produce valid results
2128        assert!(x_default[0].is_finite());
2129        assert!(x_default[1].is_finite());
2130    }
2131}