linreg_core/
linalg.rs

1//! Minimal Linear Algebra module to replace nalgebra dependency.
2//!
3//! Implements matrix operations, QR decomposition, and solvers needed for OLS.
4//! Uses row-major storage for compatibility with statistical computing conventions.
5//!
6//! # Numerical Stability Considerations
7//!
8//! This implementation uses Householder QR decomposition with careful attention to
9//! numerical stability:
10//!
11//! - **Sign convention**: Uses the numerically stable Householder sign choice
12//!   (v = x + sgn(x₀)||x||e₁) to avoid cancellation
13//! - **Tolerance checking**: Uses predefined tolerances to detect near-singular matrices
14//! - **Zero-skipping**: Skips transformations when columns are already zero-aligned
15//!
16//! # Scaling Recommendations
17//!
18//! For optimal numerical stability when predictor variables have vastly different
19//! scales (e.g., one variable in millions, another in thousandths), consider
20//! standardizing predictors before regression. Z-score standardization
21//! (`x_scaled = (x - mean) / std`) is already done in VIF calculation.
22//!
23//! However, the current implementation handles typical OLS cases without explicit
24//! scaling, as QR decomposition is generally stable for well-conditioned matrices.
25
26// ============================================================================
27// Numerical Constants
28// ============================================================================
29
30/// Machine epsilon threshold for detecting zero values in QR decomposition.
31/// Values below this are treated as zero to avoid numerical instability.
32const QR_ZERO_TOLERANCE: f64 = 1e-12;
33
34/// Threshold for detecting singular matrices during inversion.
35/// Diagonal elements below this value indicate a near-singular matrix.
36const SINGULAR_TOLERANCE: f64 = 1e-10;
37
38/// A dense matrix stored in row-major order.
39///
40/// # Storage
41///
42/// Elements are stored in a single flat vector in row-major order:
43/// `data[row * cols + col]`
44///
45/// # Example
46///
47/// ```
48/// # use linreg_core::linalg::Matrix;
49/// let m = Matrix::new(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
50/// assert_eq!(m.rows, 2);
51/// assert_eq!(m.cols, 3);
52/// assert_eq!(m.get(0, 0), 1.0);
53/// assert_eq!(m.get(1, 2), 6.0);
54/// ```
55#[derive(Clone, Debug)]
56pub struct Matrix {
57    /// Number of rows in the matrix
58    pub rows: usize,
59    /// Number of columns in the matrix
60    pub cols: usize,
61    /// Flat vector storing matrix elements in row-major order
62    pub data: Vec<f64>,
63}
64
65impl Matrix {
66    /// Creates a new matrix from the given dimensions and data.
67    ///
68    /// # Panics
69    ///
70    /// Panics if `data.len() != rows * cols`.
71    ///
72    /// # Arguments
73    ///
74    /// * `rows` - Number of rows
75    /// * `cols` - Number of columns
76    /// * `data` - Flat vector of elements in row-major order
77    ///
78    /// # Example
79    ///
80    /// ```
81    /// # use linreg_core::linalg::Matrix;
82    /// let m = Matrix::new(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
83    /// assert_eq!(m.get(0, 0), 1.0);
84    /// assert_eq!(m.get(0, 1), 2.0);
85    /// assert_eq!(m.get(1, 0), 3.0);
86    /// assert_eq!(m.get(1, 1), 4.0);
87    /// ```
88    pub fn new(rows: usize, cols: usize, data: Vec<f64>) -> Self {
89        assert_eq!(data.len(), rows * cols, "Data length must match dimensions");
90        Matrix { rows, cols, data }
91    }
92
93    /// Creates a matrix filled with zeros.
94    ///
95    /// # Arguments
96    ///
97    /// * `rows` - Number of rows
98    /// * `cols` - Number of columns
99    ///
100    /// # Example
101    ///
102    /// ```
103    /// # use linreg_core::linalg::Matrix;
104    /// let m = Matrix::zeros(3, 2);
105    /// assert_eq!(m.rows, 3);
106    /// assert_eq!(m.cols, 2);
107    /// assert_eq!(m.get(1, 1), 0.0);
108    /// ```
109    pub fn zeros(rows: usize, cols: usize) -> Self {
110        Matrix {
111            rows,
112            cols,
113            data: vec![0.0; rows * cols],
114        }
115    }
116
117    // NOTE: Currently unused but kept as reference implementation.
118    // Uncomment if needed for convenience constructor.
119    /*
120    /// Creates a matrix from a row-major slice.
121    ///
122    /// # Arguments
123    ///
124    /// * `rows` - Number of rows
125    /// * `cols` - Number of columns
126    /// * `slice` - Slice containing matrix elements in row-major order
127    pub fn from_row_slice(rows: usize, cols: usize, slice: &[f64]) -> Self {
128        Matrix::new(rows, cols, slice.to_vec())
129    }
130    */
131
132    /// Gets the element at the specified row and column.
133    ///
134    /// # Arguments
135    ///
136    /// * `row` - Row index (0-based)
137    /// * `col` - Column index (0-based)
138    pub fn get(&self, row: usize, col: usize) -> f64 {
139        self.data[row * self.cols + col]
140    }
141
142    /// Sets the element at the specified row and column.
143    ///
144    /// # Arguments
145    ///
146    /// * `row` - Row index (0-based)
147    /// * `col` - Column index (0-based)
148    /// * `val` - Value to set
149    pub fn set(&mut self, row: usize, col: usize, val: f64) {
150        self.data[row * self.cols + col] = val;
151    }
152
153    /// Returns the transpose of this matrix.
154    ///
155    /// Swaps rows with columns: `result[col][row] = self[row][col]`.
156    ///
157    /// # Example
158    ///
159    /// ```
160    /// # use linreg_core::linalg::Matrix;
161    /// let m = Matrix::new(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
162    /// let t = m.transpose();
163    /// assert_eq!(t.rows, 3);
164    /// assert_eq!(t.cols, 2);
165    /// assert_eq!(t.get(0, 1), 4.0);
166    /// ```
167    pub fn transpose(&self) -> Matrix {
168        let mut t_data = vec![0.0; self.rows * self.cols];
169        for r in 0..self.rows {
170            for c in 0..self.cols {
171                t_data[c * self.rows + r] = self.get(r, c);
172            }
173        }
174        Matrix::new(self.cols, self.rows, t_data)
175    }
176
177    /// Performs matrix multiplication: `self * other`.
178    ///
179    /// # Panics
180    ///
181    /// Panics if `self.cols != other.rows`.
182    ///
183    /// # Example
184    ///
185    /// ```
186    /// # use linreg_core::linalg::Matrix;
187    /// let a = Matrix::new(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
188    /// let b = Matrix::new(3, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
189    /// let c = a.matmul(&b);
190    /// assert_eq!(c.rows, 2);
191    /// assert_eq!(c.cols, 2);
192    /// assert_eq!(c.get(0, 0), 22.0); // 1*1 + 2*3 + 3*5
193    /// ```
194    pub fn matmul(&self, other: &Matrix) -> Matrix {
195        assert_eq!(
196            self.cols, other.rows,
197            "Dimension mismatch for multiplication"
198        );
199        let mut result = Matrix::zeros(self.rows, other.cols);
200
201        for r in 0..self.rows {
202            for c in 0..other.cols {
203                let mut sum = 0.0;
204                for k in 0..self.cols {
205                    sum += self.get(r, k) * other.get(k, c);
206                }
207                result.set(r, c, sum);
208            }
209        }
210        result
211    }
212
213    /// Multiplies this matrix by a vector (treating vector as column matrix).
214    ///
215    /// Computes `self * vec` where vec is treated as an n×1 column matrix.
216    ///
217    /// # Panics
218    ///
219    /// Panics if `self.cols != vec.len()`.
220    ///
221    /// # Arguments
222    ///
223    /// * `vec` - Vector to multiply by
224    ///
225    /// # Example
226    ///
227    /// ```
228    /// # use linreg_core::linalg::Matrix;
229    /// let m = Matrix::new(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
230    /// let v = vec![1.0, 2.0, 3.0];
231    /// let result = m.mul_vec(&v);
232    /// assert_eq!(result.len(), 2);
233    /// assert_eq!(result[0], 14.0); // 1*1 + 2*2 + 3*3
234    /// ```
235    #[allow(clippy::needless_range_loop)]
236    pub fn mul_vec(&self, vec: &[f64]) -> Vec<f64> {
237        assert_eq!(
238            self.cols,
239            vec.len(),
240            "Dimension mismatch for matrix-vector multiplication"
241        );
242        let mut result = vec![0.0; self.rows];
243
244        for r in 0..self.rows {
245            let mut sum = 0.0;
246            for c in 0..self.cols {
247                sum += self.get(r, c) * vec[c];
248            }
249            result[r] = sum;
250        }
251        result
252    }
253
254    /// Computes the dot product of a column with a vector: `Σ(data[i * cols + col] * v[i])`.
255    ///
256    /// For a row-major matrix, this iterates through all rows at a fixed column.
257    ///
258    /// # Arguments
259    ///
260    /// * `col` - Column index
261    /// * `v` - Vector to dot with (must have length equal to rows)
262    ///
263    /// # Panics
264    ///
265    /// Panics if `col >= cols` or `v.len() != rows`.
266    #[allow(clippy::needless_range_loop)]
267    pub fn col_dot(&self, col: usize, v: &[f64]) -> f64 {
268        assert!(col < self.cols, "Column index out of bounds");
269        assert_eq!(
270            self.rows,
271            v.len(),
272            "Vector length must match number of rows"
273        );
274
275        let mut sum = 0.0;
276        for row in 0..self.rows {
277            sum += self.get(row, col) * v[row];
278        }
279        sum
280    }
281
282    /// Performs the column-vector operation in place: `v += alpha * column_col`.
283    ///
284    /// This is the AXPY operation where the column is treated as a vector.
285    /// For row-major storage, we iterate through rows at a fixed column.
286    ///
287    /// # Arguments
288    ///
289    /// * `col` - Column index
290    /// * `alpha` - Scaling factor for the column
291    /// * `v` - Vector to modify in place (must have length equal to rows)
292    ///
293    /// # Panics
294    ///
295    /// Panics if `col >= cols` or `v.len() != rows`.
296    #[allow(clippy::needless_range_loop)]
297    pub fn col_axpy_inplace(&self, col: usize, alpha: f64, v: &mut [f64]) {
298        assert!(col < self.cols, "Column index out of bounds");
299        assert_eq!(
300            self.rows,
301            v.len(),
302            "Vector length must match number of rows"
303        );
304
305        for row in 0..self.rows {
306            v[row] += alpha * self.get(row, col);
307        }
308    }
309
310    /// Computes the squared L2 norm of a column: `Σ(data[i * cols + col]²)`.
311    ///
312    /// # Arguments
313    ///
314    /// * `col` - Column index
315    ///
316    /// # Panics
317    ///
318    /// Panics if `col >= cols`.
319    #[allow(clippy::needless_range_loop)]
320    pub fn col_norm2(&self, col: usize) -> f64 {
321        assert!(col < self.cols, "Column index out of bounds");
322
323        let mut sum = 0.0;
324        for row in 0..self.rows {
325            let val = self.get(row, col);
326            sum += val * val;
327        }
328        sum
329    }
330
331    /// Adds a value to diagonal elements starting from a given index.
332    ///
333    /// This is useful for ridge regression where we add `lambda * I` to `X^T X`,
334    /// but the intercept column should not be penalized.
335    ///
336    /// # Arguments
337    ///
338    /// * `alpha` - Value to add to diagonal elements
339    /// * `start_index` - Starting diagonal index (0 = first diagonal element)
340    ///
341    /// # Panics
342    ///
343    /// Panics if the matrix is not square.
344    ///
345    /// # Example
346    ///
347    /// For a 3×3 identity matrix with intercept in first column (unpenalized):
348    /// ```text
349    /// add_diagonal_in_place(lambda, 1) on:
350    /// [1.0, 0.0, 0.0]       [1.0,   0.0,   0.0  ]
351    /// [0.0, 1.0, 0.0]  ->   [0.0,  1.0+λ, 0.0  ]
352    /// [0.0, 0.0, 1.0]       [0.0,   0.0,  1.0+λ]
353    /// ```
354    pub fn add_diagonal_in_place(&mut self, alpha: f64, start_index: usize) {
355        assert_eq!(self.rows, self.cols, "Matrix must be square");
356        let n = self.rows;
357        for i in start_index..n {
358            let current = self.get(i, i);
359            self.set(i, i, current + alpha);
360        }
361    }
362}
363
364// ============================================================================
365// QR Decomposition
366// ============================================================================
367
368impl Matrix {
369    /// Computes the QR decomposition using Householder reflections.
370    ///
371    /// Factorizes the matrix as `A = QR` where Q is orthogonal and R is upper triangular.
372    ///
373    /// # Requirements
374    ///
375    /// This implementation requires `rows >= cols` (tall matrix). For OLS regression,
376    /// we always have more observations than predictors, so this requirement is satisfied.
377    ///
378    /// # Returns
379    ///
380    /// A tuple `(Q, R)` where:
381    /// - `Q` is an orthogonal matrix (QᵀQ = I) of size m×m
382    /// - `R` is an upper triangular matrix of size m×n
383    #[allow(clippy::needless_range_loop)]
384    pub fn qr(&self) -> (Matrix, Matrix) {
385        let m = self.rows;
386        let n = self.cols;
387        let mut q = Matrix::identity(m);
388        let mut r = self.clone();
389
390        for k in 0..n.min(m - 1) {
391            // Create vector x = R[k:, k]
392            let mut x = vec![0.0; m - k];
393            for i in k..m {
394                x[i - k] = r.get(i, k);
395            }
396
397            // Norm of x
398            let norm_x: f64 = x.iter().map(|&v| v * v).sum::<f64>().sqrt();
399            if norm_x < QR_ZERO_TOLERANCE {
400                continue;
401            } // Already zero
402
403            // Create vector v = x + sign(x[0]) * ||x|| * e1
404            //
405            // NOTE: Numerical stability consideration (Householder sign choice)
406            // According to Overton & Yu (2023), the numerically stable choice is
407            // σ = -sgn(x₁) in the formula v = x - σ‖x‖e₁.
408            //
409            // This means: v = x - (-sgn(x₁))‖x‖e₁ = x + sgn(x₁)‖x‖e₁
410            //
411            // Equivalently: u₁ = x₁ + sgn(x₁)‖x‖
412            //
413            // Current implementation uses this formula (the "correct" choice for stability):
414            let sign = if x[0] >= 0.0 { 1.0 } else { -1.0 }; // sgn(x₀) as defined (sgn(0) = +1)
415            let u1 = x[0] + sign * norm_x;
416
417            // Normalize v to get Householder vector
418            let mut v = x; // Re-use storage
419            v[0] = u1;
420
421            let norm_v: f64 = v.iter().map(|&val| val * val).sum::<f64>().sqrt();
422            for val in &mut v {
423                *val /= norm_v;
424            }
425
426            // Apply Householder transformation to R: R = H * R = (I - 2vv^T)R = R - 2v(v^T R)
427            // Focus on submatrix R[k:, k:]
428            for j in k..n {
429                let mut dot = 0.0;
430                for i in 0..m - k {
431                    dot += v[i] * r.get(k + i, j);
432                }
433
434                for i in 0..m - k {
435                    let val = r.get(k + i, j) - 2.0 * v[i] * dot;
436                    r.set(k + i, j, val);
437                }
438            }
439
440            // Update Q: Q = Q * H = Q(I - 2vv^T) = Q - 2(Qv)v^T
441            // Focus on Q[:, k:]
442            for i in 0..m {
443                let mut dot = 0.0;
444                for l in 0..m - k {
445                    dot += q.get(i, k + l) * v[l];
446                }
447
448                for l in 0..m - k {
449                    let val = q.get(i, k + l) - 2.0 * dot * v[l];
450                    q.set(i, k + l, val);
451                }
452            }
453        }
454
455        (q, r)
456    }
457
458    /// Creates an identity matrix of the given size.
459    ///
460    /// # Arguments
461    ///
462    /// * `size` - Number of rows and columns (square matrix)
463    ///
464    /// # Example
465    ///
466    /// ```
467    /// # use linreg_core::linalg::Matrix;
468    /// let i = Matrix::identity(3);
469    /// assert_eq!(i.get(0, 0), 1.0);
470    /// assert_eq!(i.get(1, 1), 1.0);
471    /// assert_eq!(i.get(2, 2), 1.0);
472    /// assert_eq!(i.get(0, 1), 0.0);
473    /// ```
474    pub fn identity(size: usize) -> Self {
475        let mut data = vec![0.0; size * size];
476        for i in 0..size {
477            data[i * size + i] = 1.0;
478        }
479        Matrix::new(size, size, data)
480    }
481
482    /// Inverts an upper triangular matrix (such as R from QR decomposition).
483    ///
484    /// Uses back-substitution to compute the inverse. This is efficient for
485    /// triangular matrices compared to general matrix inversion.
486    ///
487    /// # Panics
488    ///
489    /// Panics if the matrix is not square.
490    ///
491    /// # Returns
492    ///
493    /// `None` if the matrix is singular (has a zero or near-zero diagonal element).
494    /// A matrix is considered singular if any diagonal element is below the
495    /// internal tolerance (1e-10), which indicates the matrix does not have full rank.
496    ///
497    /// # Note
498    ///
499    /// For upper triangular matrices, singularity is equivalent to having a
500    /// zero (or near-zero) diagonal element. This is much simpler to check than
501    /// for general matrices, which would require computing the condition number.
502    pub fn invert_upper_triangular(&self) -> Option<Matrix> {
503        let n = self.rows;
504        assert_eq!(n, self.cols, "Matrix must be square");
505
506        // Check for singularity using relative tolerance
507        // This scales with the magnitude of diagonal elements, handling matrices
508        // of different scales better than a fixed absolute tolerance.
509        //
510        // Previous implementation used absolute tolerance:
511        //   if self.get(i, i).abs() < SINGULAR_TOLERANCE { return None; }
512        //
513        // New implementation uses relative tolerance similar to LAPACK:
514        //   tolerance = max_diag * epsilon * n
515        // where epsilon is machine epsilon (~2.2e-16 for f64)
516        let max_diag: f64 = (0..n)
517            .map(|i| self.get(i, i).abs())
518            .fold(0.0_f64, |acc, val| acc.max(val));
519
520        // Use a relative tolerance based on the maximum diagonal element
521        // This is similar to LAPACK's dlamch machine epsilon approach
522        let epsilon = 2.0_f64 * f64::EPSILON; // ~4.4e-16 for f64
523        let relative_tolerance = max_diag * epsilon * n as f64;
524        let tolerance = SINGULAR_TOLERANCE.max(relative_tolerance);
525
526        for i in 0..n {
527            if self.get(i, i).abs() < tolerance {
528                return None; // Singular matrix - cannot invert
529            }
530        }
531
532        let mut inv = Matrix::zeros(n, n);
533
534        for i in 0..n {
535            inv.set(i, i, 1.0 / self.get(i, i));
536
537            for j in (0..i).rev() {
538                let mut sum = 0.0;
539                for k in j + 1..=i {
540                    sum += self.get(j, k) * inv.get(k, i);
541                }
542                inv.set(j, i, -sum / self.get(j, j));
543            }
544        }
545
546        Some(inv)
547    }
548
549    /// Inverts an upper triangular matrix with a custom tolerance multiplier.
550    ///
551    /// The tolerance is computed as `max_diag * epsilon * n * tolerance_mult`.
552    /// A higher tolerance_mult allows more tolerance for near-singular matrices.
553    ///
554    /// # Arguments
555    ///
556    /// * `tolerance_mult` - Multiplier for the tolerance (1.0 = standard, higher = more tolerant)
557    pub fn invert_upper_triangular_with_tolerance(&self, tolerance_mult: f64) -> Option<Matrix> {
558        let n = self.rows;
559        assert_eq!(n, self.cols, "Matrix must be square");
560
561        // Check for singularity using relative tolerance
562        let max_diag: f64 = (0..n)
563            .map(|i| self.get(i, i).abs())
564            .fold(0.0_f64, |acc, val| acc.max(val));
565
566        // Use a relative tolerance based on the maximum diagonal element
567        let epsilon = 2.0_f64 * f64::EPSILON;
568        let relative_tolerance = max_diag * epsilon * n as f64 * tolerance_mult;
569        let tolerance = SINGULAR_TOLERANCE.max(relative_tolerance);
570
571        for i in 0..n {
572            if self.get(i, i).abs() < tolerance {
573                return None;
574            }
575        }
576
577        let mut inv = Matrix::zeros(n, n);
578
579        for i in 0..n {
580            inv.set(i, i, 1.0 / self.get(i, i));
581
582            for j in (0..i).rev() {
583                let mut sum = 0.0;
584                for k in j + 1..=i {
585                    sum += self.get(j, k) * inv.get(k, i);
586                }
587                inv.set(j, i, -sum / self.get(j, j));
588            }
589        }
590
591        Some(inv)
592    }
593
594    /// Computes the inverse of a square matrix using QR decomposition.
595    ///
596    /// For an invertible matrix A, computes A⁻¹ such that A * A⁻¹ = I.
597    /// Uses QR decomposition for numerical stability.
598    ///
599    /// # Panics
600    ///
601    /// Panics if the matrix is not square (i.e., `self.rows != self.cols`).
602    /// Check dimensions before calling if the matrix shape is not guaranteed.
603    ///
604    /// # Returns
605    ///
606    /// Returns `Some(inverse)` if the matrix is invertible, or `None` if
607    /// the matrix is singular (non-invertible).
608    pub fn invert(&self) -> Option<Matrix> {
609        let n = self.rows;
610        if n != self.cols {
611            panic!("Matrix must be square for inversion");
612        }
613
614        // Use QR decomposition: A = Q * R
615        let (q, r) = self.qr();
616
617        // Compute R⁻¹ (upper triangular inverse)
618        let r_inv = r.invert_upper_triangular()?;
619
620        // A⁻¹ = R⁻¹ * Q^T
621        let q_transpose = q.transpose();
622        let mut result = Matrix::zeros(n, n);
623
624        for i in 0..n {
625            for j in 0..n {
626                let mut sum = 0.0;
627                for k in 0..n {
628                    sum += r_inv.get(i, k) * q_transpose.get(k, j);
629                }
630                result.set(i, j, sum);
631            }
632        }
633
634        Some(result)
635    }
636
637    /// Computes the inverse of X'X given the QR decomposition of X (R's chol2inv).
638    ///
639    /// This is equivalent to computing `(X'X)^(-1)` using the QR decomposition of X.
640    /// R's `chol2inv` function is used for numerical stability in recursive residuals.
641    ///
642    /// # Arguments
643    ///
644    /// * `x` - Input matrix (must have rows >= cols)
645    ///
646    /// # Returns
647    ///
648    /// `Some((X'X)^(-1))` if X has full rank, `None` otherwise.
649    ///
650    /// # Algorithm
651    ///
652    /// Given QR decomposition X = QR where R is upper triangular:
653    /// 1. Extract the upper p×p portion of R (denoted R₁)
654    /// 2. Invert R₁ (upper triangular inverse)
655    /// 3. Compute (X'X)^(-1) = R₁^(-1) × R₁^(-T)
656    ///
657    /// This works because X'X = R'Q'QR = R'R, and R₁ contains the Cholesky factor.
658    pub fn chol2inv_from_qr(&self) -> Option<Matrix> {
659        self.chol2inv_from_qr_with_tolerance(1.0)
660    }
661
662    /// Computes the inverse of X'X given the QR decomposition with custom tolerance.
663    ///
664    /// Similar to `chol2inv_from_qr` but allows specifying a tolerance multiplier
665    /// for handling near-singular matrices.
666    ///
667    /// # Arguments
668    ///
669    /// * `tolerance_mult` - Multiplier for the tolerance (higher = more tolerant)
670    pub fn chol2inv_from_qr_with_tolerance(&self, tolerance_mult: f64) -> Option<Matrix> {
671        let p = self.cols;
672
673        // QR decomposition: X = QR
674        // For X (m×n, m≥n), R is m×n upper triangular
675        // The upper n×n block of R contains the meaningful values
676        let (_, r_full) = self.qr();
677
678        // Extract upper p×p portion from R
679        // For tall matrices (m > p), R has zeros below diagonal in first p rows
680        // For square matrices (m = p), R is p×p upper triangular
681        let mut r1 = Matrix::zeros(p, p);
682        for i in 0..p {
683            // Row i of R1 is row i of R_full, columns 0..p
684            // But we only copy the upper triangular part (columns i..p)
685            for j in i..p {
686                r1.set(i, j, r_full.get(i, j));
687            }
688            // Also copy diagonal if not yet copied
689            if i < p {
690                r1.set(i, i, r_full.get(i, i));
691            }
692        }
693
694        // Invert R₁ (upper triangular) with custom tolerance
695        let r1_inv = r1.invert_upper_triangular_with_tolerance(tolerance_mult)?;
696
697        // Compute (X'X)^(-1) = R₁^(-1) × R₁^(-T)
698        let mut result = Matrix::zeros(p, p);
699        for i in 0..p {
700            for j in 0..p {
701                let mut sum = 0.0;
702                // result[i,j] = sum(R1_inv[i,k] * R1_inv[j,k] for k=0..p)
703                // R1_inv is upper triangular, but we iterate full range
704                for k in 0..p {
705                    sum += r1_inv.get(i, k) * r1_inv.get(j, k);
706                }
707                result.set(i, j, sum);
708            }
709        }
710
711        Some(result)
712    }
713}
714
715// ============================================================================
716// Vector Helper Functions
717// ============================================================================
718
719/// Computes the arithmetic mean of a slice of f64 values.
720///
721/// Returns 0.0 for empty slices.
722///
723/// # Examples
724///
725/// ```
726/// use linreg_core::linalg::vec_mean;
727///
728/// assert_eq!(vec_mean(&[1.0, 2.0, 3.0, 4.0, 5.0]), 3.0);
729/// assert_eq!(vec_mean(&[]), 0.0);
730/// ```
731///
732/// # Arguments
733///
734/// * `v` - Slice of values
735pub fn vec_mean(v: &[f64]) -> f64 {
736    if v.is_empty() {
737        return 0.0;
738    }
739    v.iter().sum::<f64>() / v.len() as f64
740}
741
742/// Computes element-wise subtraction of two slices: `a - b`.
743///
744/// # Examples
745///
746/// ```
747/// use linreg_core::linalg::vec_sub;
748///
749/// let a = vec![5.0, 4.0, 3.0];
750/// let b = vec![1.0, 1.0, 1.0];
751/// let result = vec_sub(&a, &b);
752/// assert_eq!(result, vec![4.0, 3.0, 2.0]);
753/// ```
754///
755/// # Arguments
756///
757/// * `a` - Minuend slice
758/// * `b` - Subtrahend slice
759///
760/// # Panics
761///
762/// Panics if slices have different lengths.
763pub fn vec_sub(a: &[f64], b: &[f64]) -> Vec<f64> {
764    assert_eq!(a.len(), b.len(), "vec_sub: slice lengths must match");
765    a.iter().zip(b.iter()).map(|(x, y)| x - y).collect()
766}
767
768/// Computes the dot product of two slices: `Σ(a[i] * b[i])`.
769///
770/// # Examples
771///
772/// ```
773/// use linreg_core::linalg::vec_dot;
774///
775/// let a = vec![1.0, 2.0, 3.0];
776/// let b = vec![4.0, 5.0, 6.0];
777/// assert_eq!(vec_dot(&a, &b), 32.0);  // 1*4 + 2*5 + 3*6
778/// ```
779///
780/// # Arguments
781///
782/// * `a` - First slice
783/// * `b` - Second slice
784///
785/// # Panics
786///
787/// Panics if slices have different lengths.
788pub fn vec_dot(a: &[f64], b: &[f64]) -> f64 {
789    assert_eq!(a.len(), b.len(), "vec_dot: slice lengths must match");
790    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
791}
792
793/// Computes element-wise addition of two slices: `a + b`.
794///
795/// # Examples
796///
797/// ```
798/// use linreg_core::linalg::vec_add;
799///
800/// let a = vec![1.0, 2.0, 3.0];
801/// let b = vec![4.0, 5.0, 6.0];
802/// assert_eq!(vec_add(&a, &b), vec![5.0, 7.0, 9.0]);
803/// ```
804///
805/// # Arguments
806///
807/// * `a` - First slice
808/// * `b` - Second slice
809///
810/// # Panics
811///
812/// Panics if slices have different lengths.
813pub fn vec_add(a: &[f64], b: &[f64]) -> Vec<f64> {
814    assert_eq!(a.len(), b.len(), "vec_add: slice lengths must match");
815    a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
816}
817
818/// Computes a scaled vector addition in place: `dst += alpha * src`.
819///
820/// This is the classic BLAS AXPY operation.
821///
822/// # Arguments
823///
824/// * `dst` - Destination slice (modified in place)
825/// * `alpha` - Scaling factor for src
826/// * `src` - Source slice
827///
828/// # Panics
829///
830/// Panics if slices have different lengths.
831pub fn vec_axpy_inplace(dst: &mut [f64], alpha: f64, src: &[f64]) {
832    assert_eq!(
833        dst.len(),
834        src.len(),
835        "vec_axpy_inplace: slice lengths must match"
836    );
837    for (d, &s) in dst.iter_mut().zip(src.iter()) {
838        *d += alpha * s;
839    }
840}
841
842/// Scales a vector in place: `v *= alpha`.
843///
844/// # Arguments
845///
846/// * `v` - Vector to scale (modified in place)
847/// * `alpha` - Scaling factor
848pub fn vec_scale_inplace(v: &mut [f64], alpha: f64) {
849    for val in v.iter_mut() {
850        *val *= alpha;
851    }
852}
853
854/// Returns a scaled copy of a vector: `v * alpha`.
855///
856/// # Examples
857///
858/// ```
859/// use linreg_core::linalg::vec_scale;
860///
861/// let v = vec![1.0, 2.0, 3.0];
862/// let scaled = vec_scale(&v, 2.5);
863/// assert_eq!(scaled, vec![2.5, 5.0, 7.5]);
864/// // Original is unchanged
865/// assert_eq!(v, vec![1.0, 2.0, 3.0]);
866/// ```
867///
868/// # Arguments
869///
870/// * `v` - Vector to scale
871/// * `alpha` - Scaling factor
872pub fn vec_scale(v: &[f64], alpha: f64) -> Vec<f64> {
873    v.iter().map(|&x| x * alpha).collect()
874}
875
876/// Computes the L2 norm (Euclidean norm) of a vector: `sqrt(Σ(v[i]²))`.
877///
878/// # Examples
879///
880/// ```
881/// use linreg_core::linalg::vec_l2_norm;
882///
883/// // Pythagorean triple: 3-4-5
884/// assert_eq!(vec_l2_norm(&[3.0, 4.0]), 5.0);
885/// // Unit vector
886/// assert_eq!(vec_l2_norm(&[1.0, 0.0, 0.0]), 1.0);
887/// ```
888///
889/// # Arguments
890///
891/// * `v` - Vector slice
892pub fn vec_l2_norm(v: &[f64]) -> f64 {
893    v.iter().map(|&x| x * x).sum::<f64>().sqrt()
894}
895
896/// Computes the maximum absolute value in a vector.
897///
898/// # Arguments
899///
900/// * `v` - Vector slice
901pub fn vec_max_abs(v: &[f64]) -> f64 {
902    v.iter().map(|&x| x.abs()).fold(0.0_f64, f64::max)
903}
904
905// ============================================================================
906// R-Compatible QR Decomposition (LINPACK dqrdc2 with Column Pivoting)
907// ============================================================================
908
909/// QR decomposition result using R's LINPACK dqrdc2 algorithm.
910///
911/// This implements the QR decomposition with column pivoting as used by R's
912/// `qr()` function with `LAPACK=FALSE`. The algorithm is a modification of
913/// LINPACK's DQRDC that:
914/// - Uses Householder transformations
915/// - Implements limited column pivoting based on 2-norms of reduced columns
916/// - Moves columns with near-zero norm to the right-hand edge
917/// - Computes the rank (number of linearly independent columns)
918///
919/// # Fields
920///
921/// * `qr` - The QR factorization (upper triangle contains R, below diagonal
922///   contains Householder vector information)
923/// * `qraux` - Auxiliary information for recovering the orthogonal part Q
924/// * `pivot` - Column permutation: `pivot\[j\]` contains the original column index
925///   now in column j
926/// * `rank` - Number of linearly independent columns (the computed rank)
927#[derive(Clone, Debug)]
928pub struct QRLinpack {
929    /// QR factorization matrix (same dimensions as input)
930    pub qr: Matrix,
931    /// Auxiliary information for Q recovery
932    pub qraux: Vec<f64>,
933    /// Column pivot vector (1-based indices like R)
934    pub pivot: Vec<usize>,
935    /// Computed rank (number of linearly independent columns)
936    pub rank: usize,
937}
938
939impl Matrix {
940    /// Computes QR decomposition using R's LINPACK dqrdc2 algorithm with column pivoting.
941    ///
942    /// This is a port of R's dqrdc2.f, which is a modification of LINPACK's DQRDC.
943    /// The algorithm:
944    /// 1. Uses Householder transformations for QR factorization
945    /// 2. Implements limited column pivoting based on column 2-norms
946    /// 3. Moves columns with near-zero norm to the right-hand edge
947    /// 4. Computes the rank (number of linearly independent columns)
948    ///
949    /// # Arguments
950    ///
951    /// * `tol` - Tolerance for determining linear independence. Default is 1e-7 (R's default).
952    ///   Columns with norm < tol * original_norm are considered negligible.
953    ///
954    /// # Returns
955    ///
956    /// A [`QRLinpack`] struct containing the QR factorization, auxiliary information,
957    /// pivot vector, and computed rank.
958    ///
959    /// # Algorithm Details
960    ///
961    /// The decomposition is A * P = Q * R where:
962    /// - P is the permutation matrix coded by `pivot`
963    /// - Q is orthogonal (m × m)
964    /// - R is upper triangular in the first `rank` rows
965    ///
966    /// The `qr` matrix contains:
967    /// - Upper triangle: R matrix (if pivoting was performed, this is R of the permuted matrix)
968    /// - Below diagonal: Householder vector information
969    ///
970    /// # Reference
971    ///
972    /// - R source: src/appl/dqrdc2.f
973    /// - LINPACK documentation: <https://www.netlib.org/linpack/dqrdc.f>
974    pub fn qr_linpack(&self, tol: Option<f64>) -> QRLinpack {
975        let n = self.rows;
976        let p = self.cols;
977        let lup = n.min(p);
978
979        // Default tolerance matches R's qr.default: tol = 1e-07
980        let tol = tol.unwrap_or(1e-07);
981
982        // Initialize working matrices
983        let mut x = self.clone(); // Working copy that will be modified
984        let mut qraux = vec![0.0; p];
985        let mut pivot: Vec<usize> = (1..=p).collect(); // 1-based indices like R
986        let mut work = vec![(0.0, 0.0); p]; // (work[j,1], work[j,2])
987
988        // Compute the norms of the columns of x (initialization)
989        if n > 0 {
990            for j in 0..p {
991                let mut norm = 0.0;
992                for i in 0..n {
993                    norm += x.get(i, j) * x.get(i, j);
994                }
995                norm = norm.sqrt();
996                qraux[j] = norm;
997                let original_norm = if norm == 0.0 { 1.0 } else { norm };
998                work[j] = (norm, original_norm);
999            }
1000        }
1001
1002        let mut k = p + 1; // Will be decremented to get the final rank
1003
1004        // Perform the Householder reduction of x
1005        for l in 0..lup {
1006            // Cycle columns from l to p until one with non-negligible norm is found
1007            // A column is negligible if its norm has fallen below tol * original_norm
1008            while l < k - 1 && qraux[l] < work[l].1 * tol {
1009                // Move column l to the end (it's negligible)
1010                let lp1 = l + 1;
1011
1012                // Shift columns in x: x(i, l..p-1) = x(i, l+1..p)
1013                for i in 0..n {
1014                    let t = x.get(i, l);
1015                    for j in lp1..p {
1016                        x.set(i, j - 1, x.get(i, j));
1017                    }
1018                    x.set(i, p - 1, t);
1019                }
1020
1021                // Shift pivot, qraux, and work arrays
1022                let saved_pivot = pivot[l];
1023                let saved_qraux = qraux[l];
1024                let saved_work = work[l];
1025
1026                for j in lp1..p {
1027                    pivot[j - 1] = pivot[j];
1028                    qraux[j - 1] = qraux[j];
1029                    work[j - 1] = work[j];
1030                }
1031
1032                pivot[p - 1] = saved_pivot;
1033                qraux[p - 1] = saved_qraux;
1034                work[p - 1] = saved_work;
1035
1036                k -= 1;
1037            }
1038
1039            if l == n - 1 {
1040                // Last row - skip transformation
1041                break;
1042            }
1043
1044            // Compute the Householder transformation for column l
1045            // nrmxl = norm of x[l:, l]
1046            let mut nrmxl = 0.0;
1047            for i in l..n {
1048                let val = x.get(i, l);
1049                nrmxl += val * val;
1050            }
1051            nrmxl = nrmxl.sqrt();
1052
1053            if nrmxl == 0.0 {
1054                // Zero column - continue to next
1055                continue;
1056            }
1057
1058            // Apply sign for numerical stability (dsign in Fortran)
1059            let x_ll = x.get(l, l);
1060            if x_ll != 0.0 {
1061                nrmxl = nrmxl.copysign(x_ll);
1062            }
1063
1064            // Scale the column
1065            let scale = 1.0 / nrmxl;
1066            for i in l..n {
1067                x.set(i, l, x.get(i, l) * scale);
1068            }
1069            x.set(l, l, 1.0 + x.get(l, l));
1070
1071            // Apply the transformation to remaining columns, updating the norms
1072            let lp1 = l + 1;
1073            if p > lp1 {
1074                for j in lp1..p {
1075                    // Compute t = -dot(x[l:, l], x[l:, j]) / x(l, l)
1076                    let mut dot = 0.0;
1077                    for i in l..n {
1078                        dot += x.get(i, l) * x.get(i, j);
1079                    }
1080                    let t = -dot / x.get(l, l);
1081
1082                    // x[l:, j] = x[l:, j] + t * x[l:, l]
1083                    for i in l..n {
1084                        let val = x.get(i, j) + t * x.get(i, l);
1085                        x.set(i, j, val);
1086                    }
1087
1088                    // Update the norm
1089                    if qraux[j] != 0.0 {
1090                        // tt = 1.0 - (x(l, j) / qraux[j])^2
1091                        let x_lj = x.get(l, j).abs();
1092                        let mut tt = 1.0 - (x_lj / qraux[j]).powi(2);
1093                        tt = tt.max(0.0);
1094
1095                        // Recompute norm if there is large reduction (BDR mod 9/99)
1096                        // The tolerance here is on the squared norm
1097                        if tt.abs() < 1e-6 {
1098                            // Re-compute norm directly
1099                            let mut new_norm = 0.0;
1100                            for i in (l + 1)..n {
1101                                let val = x.get(i, j);
1102                                new_norm += val * val;
1103                            }
1104                            new_norm = new_norm.sqrt();
1105                            qraux[j] = new_norm;
1106                            work[j].0 = new_norm;
1107                        } else {
1108                            qraux[j] *= tt.sqrt();
1109                        }
1110                    }
1111                }
1112            }
1113
1114            // Save the transformation
1115            qraux[l] = x.get(l, l);
1116            x.set(l, l, -nrmxl);
1117        }
1118
1119        // Compute final rank
1120        let rank = k - 1;
1121        let rank = rank.min(n);
1122
1123        QRLinpack {
1124            qr: x,
1125            qraux,
1126            pivot,
1127            rank,
1128        }
1129    }
1130
1131    /// Solves a linear system using the QR decomposition with column pivoting.
1132    ///
1133    /// This implements a least squares solver using the pivoted QR decomposition.
1134    /// For rank-deficient cases, coefficients corresponding to linearly dependent
1135    /// columns are set to `f64::NAN`.
1136    ///
1137    /// # Arguments
1138    ///
1139    /// * `qr_result` - QR decomposition from [`Matrix::qr_linpack`]
1140    /// * `y` - Right-hand side vector
1141    ///
1142    /// # Returns
1143    ///
1144    /// A vector of coefficients, or `None` if the system is exactly singular.
1145    ///
1146    /// # Algorithm
1147    ///
1148    /// This solver uses the standard QR decomposition approach:
1149    /// 1. Compute the QR decomposition of the permuted matrix
1150    /// 2. Extract R matrix (upper triangular with positive diagonal)
1151    /// 3. Compute qty = Q^T * y
1152    /// 4. Solve R * coef = qty using back substitution
1153    /// 5. Apply the pivot permutation to restore original column order
1154    ///
1155    /// # Note
1156    ///
1157    /// The LINPACK QR algorithm stores R with mixed signs on the diagonal.
1158    /// This solver corrects for that by taking the absolute value of R's diagonal.
1159    #[allow(clippy::needless_range_loop)]
1160    pub fn qr_solve_linpack(&self, qr_result: &QRLinpack, y: &[f64]) -> Option<Vec<f64>> {
1161        let n = self.rows;
1162        let p = self.cols;
1163        let k = qr_result.rank;
1164
1165        if y.len() != n {
1166            return None;
1167        }
1168
1169        if k == 0 {
1170            return None;
1171        }
1172
1173        // Step 1: Compute Q^T * y using the Householder vectors directly
1174        // This is more efficient than reconstructing the full Q matrix
1175        let mut qty = y.to_vec();
1176
1177        for j in 0..k {
1178            // Check if this Householder transformation is valid
1179            let r_jj = qr_result.qr.get(j, j);
1180            if r_jj == 0.0 {
1181                continue;
1182            }
1183
1184            // Compute dot = v_j^T * qty[j:]
1185            // where v_j is the Householder vector stored in qr[j:, j]
1186            // The storage convention:
1187            // - qr[j,j] = -nrmxl (after final overwrite)
1188            // - qr[i,j] for i > j is the scaled Householder vector element
1189            // - qraux[j] = 1 + original_x[j,j]/nrmxl (the unscaled first element)
1190
1191            // Reconstruct the Householder vector v_j
1192            // After scaling by 1/nrmxl, we have:
1193            // v_scaled[j] = 1 + x[j,j]/nrmxl
1194            // v_scaled[i] = x[i,j]/nrmxl for i > j
1195            // The actual unit vector is v = v_scaled / ||v_scaled||
1196
1197            let mut v = vec![0.0; n - j];
1198            // Copy the scaled Householder vector from qr
1199            for i in j..n {
1200                v[i - j] = qr_result.qr.get(i, j);
1201            }
1202
1203            // The j-th element was modified during the QR decomposition
1204            // We need to reconstruct it from qraux
1205            let alpha = qr_result.qraux[j];
1206            if alpha != 0.0 {
1207                v[0] = alpha;
1208            }
1209
1210            // Compute the norm of v
1211            let v_norm: f64 = v.iter().map(|&x| x * x).sum::<f64>().sqrt();
1212            if v_norm < 1e-14 {
1213                continue;
1214            }
1215
1216            // Compute dot = v^T * qty[j:]
1217            let mut dot = 0.0;
1218            for i in j..n {
1219                dot += v[i - j] * qty[i];
1220            }
1221
1222            // Apply Householder transformation: qty[j:] = qty[j:] - 2 * v * (v^T * qty[j:]) / (v^T * v)
1223            // Since v is already scaled, we use: t = 2 * dot / (v_norm^2)
1224            let t = 2.0 * dot / (v_norm * v_norm);
1225
1226            for i in j..n {
1227                qty[i] -= t * v[i - j];
1228            }
1229        }
1230
1231        // Step 2: Back substitution on R (solve R * coef = qty)
1232        // The R matrix is stored in the upper triangle of qr
1233        // Note: The diagonal elements of R are negative (from -nrmxl)
1234        // We use them as-is since the signs cancel out in the computation
1235        let mut coef_permuted = vec![f64::NAN; p];
1236
1237        for row in (0..k).rev() {
1238            let r_diag = qr_result.qr.get(row, row);
1239            // Use relative tolerance for singularity check
1240            let max_abs = (0..k)
1241                .map(|i| qr_result.qr.get(i, i).abs())
1242                .fold(0.0_f64, f64::max);
1243            let tolerance = 1e-14 * max_abs.max(1.0);
1244
1245            if r_diag.abs() < tolerance {
1246                return None; // Singular
1247            }
1248
1249            let mut sum = qty[row];
1250            for col in (row + 1)..k {
1251                sum -= qr_result.qr.get(row, col) * coef_permuted[col];
1252            }
1253            coef_permuted[row] = sum / r_diag;
1254        }
1255
1256        // Step 3: Apply pivot permutation to get coefficients in original order
1257        // pivot[j] is 1-based, indicating which original column is now in position j
1258        let mut result = vec![0.0; p];
1259        for j in 0..p {
1260            let original_col = qr_result.pivot[j] - 1; // Convert to 0-based
1261            result[original_col] = coef_permuted[j];
1262        }
1263
1264        Some(result)
1265    }
1266}
1267
1268/// Performs OLS regression using R's LINPACK QR algorithm.
1269///
1270/// This function is a drop-in replacement for `fit_ols` that uses the
1271/// R-compatible QR decomposition with column pivoting. It handles
1272/// rank-deficient matrices more gracefully than the standard QR decomposition.
1273///
1274/// # Arguments
1275///
1276/// * `y` - Response variable (n observations)
1277/// * `x` - Design matrix (n rows, p columns including intercept)
1278///
1279/// # Returns
1280///
1281/// * `Some(Vec<f64>)` - OLS coefficient vector (p elements)
1282/// * `None` - If the matrix is exactly singular or dimensions don't match
1283///
1284/// # Note
1285///
1286/// For rank-deficient systems, this function uses the pivoted QR which
1287/// automatically handles multicollinearity by selecting a linearly
1288/// independent subset of columns.
1289pub fn fit_ols_linpack(y: &[f64], x: &Matrix) -> Option<Vec<f64>> {
1290    let qr_result = x.qr_linpack(None);
1291    x.qr_solve_linpack(&qr_result, y)
1292}
1293
1294/// Fits OLS and predicts using R's LINPACK QR with rank-deficient handling.
1295///
1296/// This function matches R's `lm.fit` behavior for rank-deficient cases:
1297/// coefficients for linearly dependent columns are set to NA, and predictions
1298/// are computed using only the valid (non-NA) coefficients and their corresponding
1299/// columns. This matches how R handles rank-deficient models in prediction.
1300///
1301/// # Arguments
1302///
1303/// * `y` - Response variable (n observations)
1304/// * `x` - Design matrix (n rows, p columns including intercept)
1305///
1306/// # Returns
1307///
1308/// * `Some(Vec<f64>)` - Predictions (n elements)
1309/// * `None` - If the matrix is exactly singular or dimensions don't match
1310///
1311/// # Algorithm
1312///
1313/// For rank-deficient systems (rank < p):
1314/// 1. Compute QR decomposition with column pivoting
1315/// 2. Get coefficients (rank-deficient columns will have NaN)
1316/// 3. Build a reduced design matrix with only pivoted, non-singular columns
1317/// 4. Compute predictions using only the valid columns
1318///
1319/// This matches R's behavior where `predict(lm.fit(...))` handles NA coefficients
1320/// by excluding the corresponding columns from the prediction.
1321#[allow(clippy::needless_range_loop)]
1322pub fn fit_and_predict_linpack(y: &[f64], x: &Matrix) -> Option<Vec<f64>> {
1323    let n = x.rows;
1324    let p = x.cols;
1325
1326    // Compute QR decomposition
1327    let qr_result = x.qr_linpack(None);
1328    let k = qr_result.rank;
1329
1330    // Solve for coefficients
1331    let beta_permuted = x.qr_solve_linpack(&qr_result, y)?;
1332
1333    // Check for rank deficiency
1334    if k == p {
1335        // Full rank - use standard prediction
1336        return Some(x.mul_vec(&beta_permuted));
1337    }
1338
1339    // Rank-deficient case: some columns are collinear and have NaN coefficients
1340    // We compute predictions using only columns with valid (non-NaN) coefficients
1341    // This matches R's behavior where NA coefficients exclude columns from prediction
1342
1343    let mut pred = vec![0.0; n];
1344
1345    for row in 0..n {
1346        let mut sum = 0.0;
1347        for j in 0..p {
1348            let b_val = beta_permuted[j];
1349            if b_val.is_nan() {
1350                continue; // Skip collinear columns (matches R's NA coefficient behavior)
1351            }
1352            sum += x.get(row, j) * b_val;
1353        }
1354        pred[row] = sum;
1355    }
1356
1357    Some(pred)
1358}