linreg_core/
linalg.rs

1//! Minimal Linear Algebra module to replace nalgebra dependency.
2//!
3//! Implements matrix operations, QR decomposition, and solvers needed for OLS.
4//! Uses row-major storage for compatibility with statistical computing conventions.
5//!
6//! # Numerical Stability Considerations
7//!
8//! This implementation uses Householder QR decomposition with careful attention to
9//! numerical stability:
10//!
11//! - **Sign convention**: Uses the numerically stable Householder sign choice
12//!   (v = x + sgn(x₀)||x||e₁) to avoid cancellation
13//! - **Tolerance checking**: Uses predefined tolerances to detect near-singular matrices
14//! - **Zero-skipping**: Skips transformations when columns are already zero-aligned
15//!
16//! # Scaling Recommendations
17//!
18//! For optimal numerical stability when predictor variables have vastly different
19//! scales (e.g., one variable in millions, another in thousandths), consider
20//! standardizing predictors before regression. Z-score standardization
21//! (`x_scaled = (x - mean) / std`) is already done in VIF calculation.
22//!
23//! However, the current implementation handles typical OLS cases without explicit
24//! scaling, as QR decomposition is generally stable for well-conditioned matrices.
25
26// ============================================================================
27// Numerical Constants
28// ============================================================================
29
30/// Machine epsilon threshold for detecting zero values in QR decomposition.
31/// Values below this are treated as zero to avoid numerical instability.
32const QR_ZERO_TOLERANCE: f64 = 1e-12;
33
34/// Threshold for detecting singular matrices during inversion.
35/// Diagonal elements below this value indicate a near-singular matrix.
36const SINGULAR_TOLERANCE: f64 = 1e-10;
37
38/// A dense matrix stored in row-major order.
39///
40/// # Storage
41///
42/// Elements are stored in a single flat vector in row-major order:
43/// `data[row * cols + col]`
44#[derive(Clone, Debug)]
45pub struct Matrix {
46    /// Number of rows in the matrix
47    pub rows: usize,
48    /// Number of columns in the matrix
49    pub cols: usize,
50    /// Flat vector storing matrix elements in row-major order
51    pub data: Vec<f64>,
52}
53
54impl Matrix {
55    /// Creates a new matrix from the given dimensions and data.
56    ///
57    /// # Panics
58    ///
59    /// Panics if `data.len() != rows * cols`.
60    ///
61    /// # Arguments
62    ///
63    /// * `rows` - Number of rows
64    /// * `cols` - Number of columns
65    /// * `data` - Flat vector of elements in row-major order
66    pub fn new(rows: usize, cols: usize, data: Vec<f64>) -> Self {
67        assert_eq!(data.len(), rows * cols, "Data length must match dimensions");
68        Matrix { rows, cols, data }
69    }
70
71    /// Creates a matrix filled with zeros.
72    ///
73    /// # Arguments
74    ///
75    /// * `rows` - Number of rows
76    /// * `cols` - Number of columns
77    pub fn zeros(rows: usize, cols: usize) -> Self {
78        Matrix {
79            rows,
80            cols,
81            data: vec![0.0; rows * cols],
82        }
83    }
84
85    // NOTE: Currently unused but kept as reference implementation.
86    // Uncomment if needed for convenience constructor.
87    /*
88    /// Creates a matrix from a row-major slice.
89    ///
90    /// # Arguments
91    ///
92    /// * `rows` - Number of rows
93    /// * `cols` - Number of columns
94    /// * `slice` - Slice containing matrix elements in row-major order
95    pub fn from_row_slice(rows: usize, cols: usize, slice: &[f64]) -> Self {
96        Matrix::new(rows, cols, slice.to_vec())
97    }
98    */
99
100    /// Gets the element at the specified row and column.
101    ///
102    /// # Arguments
103    ///
104    /// * `row` - Row index (0-based)
105    /// * `col` - Column index (0-based)
106    pub fn get(&self, row: usize, col: usize) -> f64 {
107        self.data[row * self.cols + col]
108    }
109
110    /// Sets the element at the specified row and column.
111    ///
112    /// # Arguments
113    ///
114    /// * `row` - Row index (0-based)
115    /// * `col` - Column index (0-based)
116    /// * `val` - Value to set
117    pub fn set(&mut self, row: usize, col: usize, val: f64) {
118        self.data[row * self.cols + col] = val;
119    }
120
121    /// Returns the transpose of this matrix.
122    ///
123    /// Swaps rows with columns: `result\[col\]\[row\] = self\[row\]\[col\]`.
124    pub fn transpose(&self) -> Matrix {
125        let mut t_data = vec![0.0; self.rows * self.cols];
126        for r in 0..self.rows {
127            for c in 0..self.cols {
128                t_data[c * self.rows + r] = self.get(r, c);
129            }
130        }
131        Matrix::new(self.cols, self.rows, t_data)
132    }
133
134    /// Performs matrix multiplication: `self * other`.
135    ///
136    /// # Panics
137    ///
138    /// Panics if `self.cols != other.rows`.
139    pub fn matmul(&self, other: &Matrix) -> Matrix {
140        assert_eq!(self.cols, other.rows, "Dimension mismatch for multiplication");
141        let mut result = Matrix::zeros(self.rows, other.cols);
142
143        for r in 0..self.rows {
144            for c in 0..other.cols {
145                let mut sum = 0.0;
146                for k in 0..self.cols {
147                    sum += self.get(r, k) * other.get(k, c);
148                }
149                result.set(r, c, sum);
150            }
151        }
152        result
153    }
154
155    /// Multiplies this matrix by a vector (treating vector as column matrix).
156    ///
157    /// Computes `self * vec` where vec is treated as an n×1 column matrix.
158    ///
159    /// # Panics
160    ///
161    /// Panics if `self.cols != vec.len()`.
162    ///
163    /// # Arguments
164    ///
165    /// * `vec` - Vector to multiply by
166    pub fn mul_vec(&self, vec: &[f64]) -> Vec<f64> {
167        assert_eq!(self.cols, vec.len(), "Dimension mismatch for matrix-vector multiplication");
168        let mut result = vec![0.0; self.rows];
169        
170        for r in 0..self.rows {
171            let mut sum = 0.0;
172            for c in 0..self.cols {
173                sum += self.get(r, c) * vec[c];
174            }
175            result[r] = sum;
176        }
177        result
178    }
179}
180
181// ============================================================================
182// QR Decomposition
183// ============================================================================
184
185impl Matrix {
186    /// Computes the QR decomposition using Householder reflections.
187    ///
188    /// Factorizes the matrix as `A = QR` where Q is orthogonal and R is upper triangular.
189    ///
190    /// # Requirements
191    ///
192    /// This implementation requires `rows >= cols` (tall matrix). For OLS regression,
193    /// we always have more observations than predictors, so this requirement is satisfied.
194    ///
195    /// # Returns
196    ///
197    /// A tuple `(Q, R)` where:
198    /// - `Q` is an orthogonal matrix (QᵀQ = I) of size m×m
199    /// - `R` is an upper triangular matrix of size m×n
200    pub fn qr(&self) -> (Matrix, Matrix) {
201        let m = self.rows;
202        let n = self.cols;
203        let mut q = Matrix::identity(m);
204        let mut r = self.clone();
205
206        for k in 0..n.min(m - 1) {
207            // Create vector x = R[k:, k]
208            let mut x = vec![0.0; m - k];
209            for i in k..m {
210                x[i - k] = r.get(i, k);
211            }
212
213            // Norm of x
214            let norm_x: f64 = x.iter().map(|&v| v * v).sum::<f64>().sqrt();
215            if norm_x < QR_ZERO_TOLERANCE { continue; } // Already zero
216
217            // Create vector v = x + sign(x[0]) * ||x|| * e1
218            //
219            // NOTE: Numerical stability consideration (Householder sign choice)
220            // According to Overton & Yu (2023), the numerically stable choice is
221            // σ = -sgn(x₁) in the formula v = x - σ‖x‖e₁.
222            //
223            // This means: v = x - (-sgn(x₁))‖x‖e₁ = x + sgn(x₁)‖x‖e₁
224            //
225            // Equivalently: u₁ = x₁ + sgn(x₁)‖x‖
226            //
227            // Current implementation uses this formula (the "correct" choice for stability):
228            let sign = if x[0] >= 0.0 { 1.0 } else { -1.0 };  // sgn(x₀) as defined (sgn(0) = +1)
229            let u1 = x[0] + sign * norm_x;
230            
231            // Normalize v to get Householder vector
232            let mut v = x; // Re-use storage
233            v[0] = u1;
234
235            let norm_v: f64 = v.iter().map(|&val| val * val).sum::<f64>().sqrt();
236            for val in &mut v { *val /= norm_v; }
237
238            // Apply Householder transformation to R: R = H * R = (I - 2vv^T)R = R - 2v(v^T R)
239            // Focus on submatrix R[k:, k:]
240            for j in k..n {
241                let mut dot = 0.0;
242                for i in 0..m-k {
243                    dot += v[i] * r.get(k+i, j);
244                }
245                
246                for i in 0..m-k {
247                    let val = r.get(k+i, j) - 2.0 * v[i] * dot;
248                    r.set(k+i, j, val);
249                }
250            }
251
252            // Update Q: Q = Q * H = Q(I - 2vv^T) = Q - 2(Qv)v^T
253            // Focus on Q[:, k:]
254            for i in 0..m {
255                let mut dot = 0.0;
256                for l in 0..m-k {
257                    dot += q.get(i, k+l) * v[l];
258                }
259                
260                for l in 0..m-k {
261                    let val = q.get(i, k+l) - 2.0 * dot * v[l];
262                    q.set(i, k+l, val);
263                }
264            }
265        }
266
267        (q, r)
268    }
269
270    /// Creates an identity matrix of the given size.
271    ///
272    /// # Arguments
273    ///
274    /// * `size` - Number of rows and columns (square matrix)
275    pub fn identity(size: usize) -> Self {
276        let mut data = vec![0.0; size * size];
277        for i in 0..size {
278            data[i * size + i] = 1.0;
279        }
280        Matrix::new(size, size, data)
281    }
282
283    /// Inverts an upper triangular matrix (such as R from QR decomposition).
284    ///
285    /// Uses back-substitution to compute the inverse. This is efficient for
286    /// triangular matrices compared to general matrix inversion.
287    ///
288    /// # Panics
289    ///
290    /// Panics if the matrix is not square.
291    ///
292    /// # Returns
293    ///
294    /// `None` if the matrix is singular (has a zero or near-zero diagonal element).
295    /// A matrix is considered singular if any diagonal element is below the
296    /// internal tolerance (1e-10), which indicates the matrix does not have full rank.
297    ///
298    /// # Note
299    ///
300    /// For upper triangular matrices, singularity is equivalent to having a
301    /// zero (or near-zero) diagonal element. This is much simpler to check than
302    /// for general matrices, which would require computing the condition number.
303    pub fn invert_upper_triangular(&self) -> Option<Matrix> {
304        let n = self.rows;
305        assert_eq!(n, self.cols, "Matrix must be square");
306
307        // Check for singularity using relative tolerance
308        // This scales with the magnitude of diagonal elements, handling matrices
309        // of different scales better than a fixed absolute tolerance.
310        //
311        // Previous implementation used absolute tolerance:
312        //   if self.get(i, i).abs() < SINGULAR_TOLERANCE { return None; }
313        //
314        // New implementation uses relative tolerance similar to LAPACK:
315        //   tolerance = max_diag * epsilon * n
316        // where epsilon is machine epsilon (~2.2e-16 for f64)
317        let max_diag: f64 = (0..n)
318            .map(|i| self.get(i, i).abs())
319            .fold(0.0_f64, |acc, val| acc.max(val));
320
321        // Use a relative tolerance based on the maximum diagonal element
322        // This is similar to LAPACK's dlamch machine epsilon approach
323        let epsilon = 2.0_f64 * f64::EPSILON;  // ~4.4e-16 for f64
324        let relative_tolerance = max_diag * epsilon * n as f64;
325        let tolerance = SINGULAR_TOLERANCE.max(relative_tolerance);
326
327        for i in 0..n {
328            if self.get(i, i).abs() < tolerance {
329                return None; // Singular matrix - cannot invert
330            }
331        }
332
333        let mut inv = Matrix::zeros(n, n);
334
335        for i in 0..n {
336            inv.set(i, i, 1.0 / self.get(i, i));
337
338            for j in (0..i).rev() {
339                let mut sum = 0.0;
340                for k in j+1..=i {
341                    sum += self.get(j, k) * inv.get(k, i);
342                }
343                inv.set(j, i, -sum / self.get(j, j));
344            }
345        }
346
347        Some(inv)
348    }
349
350    /// Inverts an upper triangular matrix with a custom tolerance multiplier.
351    ///
352    /// The tolerance is computed as `max_diag * epsilon * n * tolerance_mult`.
353    /// A higher tolerance_mult allows more tolerance for near-singular matrices.
354    ///
355    /// # Arguments
356    ///
357    /// * `tolerance_mult` - Multiplier for the tolerance (1.0 = standard, higher = more tolerant)
358    pub fn invert_upper_triangular_with_tolerance(&self, tolerance_mult: f64) -> Option<Matrix> {
359        let n = self.rows;
360        assert_eq!(n, self.cols, "Matrix must be square");
361
362        // Check for singularity using relative tolerance
363        let max_diag: f64 = (0..n)
364            .map(|i| self.get(i, i).abs())
365            .fold(0.0_f64, |acc, val| acc.max(val));
366
367        // Use a relative tolerance based on the maximum diagonal element
368        let epsilon = 2.0_f64 * f64::EPSILON;
369        let relative_tolerance = max_diag * epsilon * n as f64 * tolerance_mult;
370        let tolerance = SINGULAR_TOLERANCE.max(relative_tolerance);
371
372        for i in 0..n {
373            if self.get(i, i).abs() < tolerance {
374                return None;
375            }
376        }
377
378        let mut inv = Matrix::zeros(n, n);
379
380        for i in 0..n {
381            inv.set(i, i, 1.0 / self.get(i, i));
382
383            for j in (0..i).rev() {
384                let mut sum = 0.0;
385                for k in j+1..=i {
386                    sum += self.get(j, k) * inv.get(k, i);
387                }
388                inv.set(j, i, -sum / self.get(j, j));
389            }
390        }
391
392        Some(inv)
393    }
394
395    /// Computes the inverse of a square matrix using QR decomposition.
396    ///
397    /// For an invertible matrix A, computes A⁻¹ such that A * A⁻¹ = I.
398    /// Uses QR decomposition for numerical stability.
399    ///
400    /// # Panics
401    ///
402    /// Panics if the matrix is not square (i.e., `self.rows != self.cols`).
403    /// Check dimensions before calling if the matrix shape is not guaranteed.
404    ///
405    /// # Returns
406    ///
407    /// Returns `Some(inverse)` if the matrix is invertible, or `None` if
408    /// the matrix is singular (non-invertible).
409    pub fn invert(&self) -> Option<Matrix> {
410        let n = self.rows;
411        if n != self.cols {
412            panic!("Matrix must be square for inversion");
413        }
414
415        // Use QR decomposition: A = Q * R
416        let (q, r) = self.qr();
417
418        // Compute R⁻¹ (upper triangular inverse)
419        let r_inv = r.invert_upper_triangular()?;
420
421        // A⁻¹ = R⁻¹ * Q^T
422        let q_transpose = q.transpose();
423        let mut result = Matrix::zeros(n, n);
424
425        for i in 0..n {
426            for j in 0..n {
427                let mut sum = 0.0;
428                for k in 0..n {
429                    sum += r_inv.get(i, k) * q_transpose.get(k, j);
430                }
431                result.set(i, j, sum);
432            }
433        }
434
435        Some(result)
436    }
437
438    /// Computes the inverse of X'X given the QR decomposition of X (R's chol2inv).
439    ///
440    /// This is equivalent to computing `(X'X)^(-1)` using the QR decomposition of X.
441    /// R's `chol2inv` function is used for numerical stability in recursive residuals.
442    ///
443    /// # Arguments
444    ///
445    /// * `x` - Input matrix (must have rows >= cols)
446    ///
447    /// # Returns
448    ///
449    /// `Some((X'X)^(-1))` if X has full rank, `None` otherwise.
450    ///
451    /// # Algorithm
452    ///
453    /// Given QR decomposition X = QR where R is upper triangular:
454    /// 1. Extract the upper p×p portion of R (denoted R₁)
455    /// 2. Invert R₁ (upper triangular inverse)
456    /// 3. Compute (X'X)^(-1) = R₁^(-1) × R₁^(-T)
457    ///
458    /// This works because X'X = R'Q'QR = R'R, and R₁ contains the Cholesky factor.
459    pub fn chol2inv_from_qr(&self) -> Option<Matrix> {
460        self.chol2inv_from_qr_with_tolerance(1.0)
461    }
462
463    /// Computes the inverse of X'X given the QR decomposition with custom tolerance.
464    ///
465    /// Similar to `chol2inv_from_qr` but allows specifying a tolerance multiplier
466    /// for handling near-singular matrices.
467    ///
468    /// # Arguments
469    ///
470    /// * `tolerance_mult` - Multiplier for the tolerance (higher = more tolerant)
471    pub fn chol2inv_from_qr_with_tolerance(&self, tolerance_mult: f64) -> Option<Matrix> {
472        let p = self.cols;
473
474        // QR decomposition: X = QR
475        // For X (m×n, m≥n), R is m×n upper triangular
476        // The upper n×n block of R contains the meaningful values
477        let (_, r_full) = self.qr();
478
479        // Extract upper p×p portion from R
480        // For tall matrices (m > p), R has zeros below diagonal in first p rows
481        // For square matrices (m = p), R is p×p upper triangular
482        let mut r1 = Matrix::zeros(p, p);
483        for i in 0..p {
484            // Row i of R1 is row i of R_full, columns 0..p
485            // But we only copy the upper triangular part (columns i..p)
486            for j in i..p {
487                r1.set(i, j, r_full.get(i, j));
488            }
489            // Also copy diagonal if not yet copied
490            if i < p {
491                r1.set(i, i, r_full.get(i, i));
492            }
493        }
494
495        // Invert R₁ (upper triangular) with custom tolerance
496        let r1_inv = r1.invert_upper_triangular_with_tolerance(tolerance_mult)?;
497
498        // Compute (X'X)^(-1) = R₁^(-1) × R₁^(-T)
499        let mut result = Matrix::zeros(p, p);
500        for i in 0..p {
501            for j in 0..p {
502                let mut sum = 0.0;
503                // result[i,j] = sum(R1_inv[i,k] * R1_inv[j,k] for k=0..p)
504                // R1_inv is upper triangular, but we iterate full range
505                for k in 0..p {
506                    sum += r1_inv.get(i, k) * r1_inv.get(j, k);
507                }
508                result.set(i, j, sum);
509            }
510        }
511
512        Some(result)
513    }
514}
515
516// ============================================================================
517// Vector Helper Functions
518// ============================================================================
519
520/// Computes the arithmetic mean of a slice of f64 values.
521///
522/// Returns 0.0 for empty slices.
523///
524/// # Arguments
525///
526/// * `v` - Slice of values
527pub fn vec_mean(v: &[f64]) -> f64 {
528    if v.is_empty() { return 0.0; }
529    v.iter().sum::<f64>() / v.len() as f64
530}
531
532/// Computes element-wise subtraction of two slices: `a - b`.
533///
534/// # Arguments
535///
536/// * `a` - Minuend slice
537/// * `b` - Subtrahend slice
538///
539/// # Panics
540///
541/// Panics if slices have different lengths.
542pub fn vec_sub(a: &[f64], b: &[f64]) -> Vec<f64> {
543    a.iter().zip(b.iter()).map(|(x, y)| x - y).collect()
544}
545
546/// Computes the dot product of two slices: `Σ(a[i] * b[i])`.
547///
548/// # Arguments
549///
550/// * `a` - First slice
551/// * `b` - Second slice
552///
553/// # Panics
554///
555/// Panics if slices have different lengths.
556pub fn vec_dot(a: &[f64], b: &[f64]) -> f64 {
557    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
558}
559
560// ============================================================================
561// R-Compatible QR Decomposition (LINPACK dqrdc2 with Column Pivoting)
562// ============================================================================
563
564/// QR decomposition result using R's LINPACK dqrdc2 algorithm.
565///
566/// This implements the QR decomposition with column pivoting as used by R's
567/// `qr()` function with `LAPACK=FALSE`. The algorithm is a modification of
568/// LINPACK's DQRDC that:
569/// - Uses Householder transformations
570/// - Implements limited column pivoting based on 2-norms of reduced columns
571/// - Moves columns with near-zero norm to the right-hand edge
572/// - Computes the rank (number of linearly independent columns)
573///
574/// # Fields
575///
576/// * `qr` - The QR factorization (upper triangle contains R, below diagonal
577///   contains Householder vector information)
578/// * `qraux` - Auxiliary information for recovering the orthogonal part Q
579/// * `pivot` - Column permutation: `pivot\[j\]` contains the original column index
580///   now in column j
581/// * `rank` - Number of linearly independent columns (the computed rank)
582#[derive(Clone, Debug)]
583pub struct QRLinpack {
584    /// QR factorization matrix (same dimensions as input)
585    pub qr: Matrix,
586    /// Auxiliary information for Q recovery
587    pub qraux: Vec<f64>,
588    /// Column pivot vector (1-based indices like R)
589    pub pivot: Vec<usize>,
590    /// Computed rank (number of linearly independent columns)
591    pub rank: usize,
592}
593
594impl Matrix {
595    /// Computes QR decomposition using R's LINPACK dqrdc2 algorithm with column pivoting.
596    ///
597    /// This is a port of R's dqrdc2.f, which is a modification of LINPACK's DQRDC.
598    /// The algorithm:
599    /// 1. Uses Householder transformations for QR factorization
600    /// 2. Implements limited column pivoting based on column 2-norms
601    /// 3. Moves columns with near-zero norm to the right-hand edge
602    /// 4. Computes the rank (number of linearly independent columns)
603    ///
604    /// # Arguments
605    ///
606    /// * `tol` - Tolerance for determining linear independence. Default is 1e-7 (R's default).
607    ///   Columns with norm < tol * original_norm are considered negligible.
608    ///
609    /// # Returns
610    ///
611    /// A [`QRLinpack`] struct containing the QR factorization, auxiliary information,
612    /// pivot vector, and computed rank.
613    ///
614    /// # Algorithm Details
615    ///
616    /// The decomposition is A * P = Q * R where:
617    /// - P is the permutation matrix coded by `pivot`
618    /// - Q is orthogonal (m × m)
619    /// - R is upper triangular in the first `rank` rows
620    ///
621    /// The `qr` matrix contains:
622    /// - Upper triangle: R matrix (if pivoting was performed, this is R of the permuted matrix)
623    /// - Below diagonal: Householder vector information
624    ///
625    /// # Reference
626    ///
627    /// - R source: src/appl/dqrdc2.f
628    /// - LINPACK documentation: <https://www.netlib.org/linpack/dqrdc.f>
629    pub fn qr_linpack(&self, tol: Option<f64>) -> QRLinpack {
630        let n = self.rows;
631        let p = self.cols;
632        let lup = n.min(p);
633
634        // Default tolerance matches R's qr.default: tol = 1e-07
635        let tol = tol.unwrap_or(1e-07);
636
637        // Initialize working matrices
638        let mut x = self.clone(); // Working copy that will be modified
639        let mut qraux = vec![0.0; p];
640        let mut pivot: Vec<usize> = (1..=p).collect(); // 1-based indices like R
641        let mut work = vec![(0.0, 0.0); p]; // (work[j,1], work[j,2])
642
643        // Compute the norms of the columns of x (initialization)
644        if n > 0 {
645            for j in 0..p {
646                let mut norm = 0.0;
647                for i in 0..n {
648                    norm += x.get(i, j) * x.get(i, j);
649                }
650                norm = norm.sqrt();
651                qraux[j] = norm;
652                let original_norm = if norm == 0.0 { 1.0 } else { norm };
653                work[j] = (norm, original_norm);
654            }
655        }
656
657        let mut k = p + 1; // Will be decremented to get the final rank
658
659        // Perform the Householder reduction of x
660        for l in 0..lup {
661            // Cycle columns from l to p until one with non-negligible norm is found
662            // A column is negligible if its norm has fallen below tol * original_norm
663            while l < k - 1 && qraux[l] < work[l].1 * tol {
664                // Move column l to the end (it's negligible)
665                let lp1 = l + 1;
666
667                // Shift columns in x: x(i, l..p-1) = x(i, l+1..p)
668                for i in 0..n {
669                    let t = x.get(i, l);
670                    for j in lp1..p {
671                        x.set(i, j - 1, x.get(i, j));
672                    }
673                    x.set(i, p - 1, t);
674                }
675
676                // Shift pivot, qraux, and work arrays
677                let saved_pivot = pivot[l];
678                let saved_qraux = qraux[l];
679                let saved_work = work[l];
680
681                for j in lp1..p {
682                    pivot[j - 1] = pivot[j];
683                    qraux[j - 1] = qraux[j];
684                    work[j - 1] = work[j];
685                }
686
687                pivot[p - 1] = saved_pivot;
688                qraux[p - 1] = saved_qraux;
689                work[p - 1] = saved_work;
690
691                k -= 1;
692            }
693
694            if l == n - 1 {
695                // Last row - skip transformation
696                break;
697            }
698
699            // Compute the Householder transformation for column l
700            // nrmxl = norm of x[l:, l]
701            let mut nrmxl = 0.0;
702            for i in l..n {
703                let val = x.get(i, l);
704                nrmxl += val * val;
705            }
706            nrmxl = nrmxl.sqrt();
707
708            if nrmxl == 0.0 {
709                // Zero column - continue to next
710                continue;
711            }
712
713            // Apply sign for numerical stability (dsign in Fortran)
714            let x_ll = x.get(l, l);
715            if x_ll != 0.0 {
716                nrmxl = nrmxl.copysign(x_ll);
717            }
718
719            // Scale the column
720            let scale = 1.0 / nrmxl;
721            for i in l..n {
722                x.set(i, l, x.get(i, l) * scale);
723            }
724            x.set(l, l, 1.0 + x.get(l, l));
725
726            // Apply the transformation to remaining columns, updating the norms
727            let lp1 = l + 1;
728            if p > lp1 {
729                for j in lp1..p {
730                    // Compute t = -dot(x[l:, l], x[l:, j]) / x(l, l)
731                    let mut dot = 0.0;
732                    for i in l..n {
733                        dot += x.get(i, l) * x.get(i, j);
734                    }
735                    let t = -dot / x.get(l, l);
736
737                    // x[l:, j] = x[l:, j] + t * x[l:, l]
738                    for i in l..n {
739                        let val = x.get(i, j) + t * x.get(i, l);
740                        x.set(i, j, val);
741                    }
742
743                    // Update the norm
744                    if qraux[j] != 0.0 {
745                        // tt = 1.0 - (x(l, j) / qraux[j])^2
746                        let x_lj = x.get(l, j).abs();
747                        let mut tt = 1.0 - (x_lj / qraux[j]).powi(2);
748                        tt = tt.max(0.0);
749
750                        // Recompute norm if there is large reduction (BDR mod 9/99)
751                        // The tolerance here is on the squared norm
752                        if tt.abs() < 1e-6 {
753                            // Re-compute norm directly
754                            let mut new_norm = 0.0;
755                            for i in (l + 1)..n {
756                                let val = x.get(i, j);
757                                new_norm += val * val;
758                            }
759                            new_norm = new_norm.sqrt();
760                            qraux[j] = new_norm;
761                            work[j].0 = new_norm;
762                        } else {
763                            qraux[j] = qraux[j] * tt.sqrt();
764                        }
765                    }
766                }
767            }
768
769            // Save the transformation
770            qraux[l] = x.get(l, l);
771            x.set(l, l, -nrmxl);
772        }
773
774        // Compute final rank
775        let rank = k - 1;
776        let rank = rank.min(n);
777
778        QRLinpack {
779            qr: x,
780            qraux,
781            pivot,
782            rank,
783        }
784    }
785
786    /// Solves a linear system using the QR decomposition with column pivoting.
787    ///
788    /// This implements a least squares solver using the pivoted QR decomposition.
789    /// For rank-deficient cases, coefficients corresponding to linearly dependent
790    /// columns are set to `f64::NAN`.
791    ///
792    /// # Arguments
793    ///
794    /// * `qr_result` - QR decomposition from [`Matrix::qr_linpack`]
795    /// * `y` - Right-hand side vector
796    ///
797    /// # Returns
798    ///
799    /// A vector of coefficients, or `None` if the system is exactly singular.
800    ///
801    /// # Algorithm
802    ///
803    /// This solver uses the standard QR decomposition approach:
804    /// 1. Compute the QR decomposition of the permuted matrix
805    /// 2. Extract R matrix (upper triangular with positive diagonal)
806    /// 3. Compute qty = Q^T * y
807    /// 4. Solve R * coef = qty using back substitution
808    /// 5. Apply the pivot permutation to restore original column order
809    ///
810    /// # Note
811    ///
812    /// The LINPACK QR algorithm stores R with mixed signs on the diagonal.
813    /// This solver corrects for that by taking the absolute value of R's diagonal.
814    pub fn qr_solve_linpack(&self, qr_result: &QRLinpack, y: &[f64]) -> Option<Vec<f64>> {
815        let n = self.rows;
816        let p = self.cols;
817        let k = qr_result.rank;
818
819        if y.len() != n {
820            return None;
821        }
822
823        if k == 0 {
824            return None;
825        }
826
827        // Step 1: Compute Q^T * y using the Householder vectors directly
828        // This is more efficient than reconstructing the full Q matrix
829        let mut qty = y.to_vec();
830
831        for j in 0..k {
832            // Check if this Householder transformation is valid
833            let r_jj = qr_result.qr.get(j, j);
834            if r_jj == 0.0 {
835                continue;
836            }
837
838            // Compute dot = v_j^T * qty[j:]
839            // where v_j is the Householder vector stored in qr[j:, j]
840            // The storage convention:
841            // - qr[j,j] = -nrmxl (after final overwrite)
842            // - qr[i,j] for i > j is the scaled Householder vector element
843            // - qraux[j] = 1 + original_x[j,j]/nrmxl (the unscaled first element)
844
845            // Reconstruct the Householder vector v_j
846            // After scaling by 1/nrmxl, we have:
847            // v_scaled[j] = 1 + x[j,j]/nrmxl
848            // v_scaled[i] = x[i,j]/nrmxl for i > j
849            // The actual unit vector is v = v_scaled / ||v_scaled||
850
851            let mut v = vec![0.0; n - j];
852            // Copy the scaled Householder vector from qr
853            for i in j..n {
854                v[i - j] = qr_result.qr.get(i, j);
855            }
856
857            // The j-th element was modified during the QR decomposition
858            // We need to reconstruct it from qraux
859            let alpha = qr_result.qraux[j];
860            if alpha != 0.0 {
861                v[0] = alpha;
862            }
863
864            // Compute the norm of v
865            let v_norm: f64 = v.iter().map(|&x| x * x).sum::<f64>().sqrt();
866            if v_norm < 1e-14 {
867                continue;
868            }
869
870            // Compute dot = v^T * qty[j:]
871            let mut dot = 0.0;
872            for i in j..n {
873                dot += v[i - j] * qty[i];
874            }
875
876            // Apply Householder transformation: qty[j:] = qty[j:] - 2 * v * (v^T * qty[j:]) / (v^T * v)
877            // Since v is already scaled, we use: t = 2 * dot / (v_norm^2)
878            let t = 2.0 * dot / (v_norm * v_norm);
879
880            for i in j..n {
881                qty[i] -= t * v[i - j];
882            }
883        }
884
885        // Step 2: Back substitution on R (solve R * coef = qty)
886        // The R matrix is stored in the upper triangle of qr
887        // Note: The diagonal elements of R are negative (from -nrmxl)
888        // We use them as-is since the signs cancel out in the computation
889        let mut coef_permuted = vec![f64::NAN; p];
890
891        for row in (0..k).rev() {
892            let r_diag = qr_result.qr.get(row, row);
893            // Use relative tolerance for singularity check
894            let max_abs = (0..k).map(|i| qr_result.qr.get(i, i).abs()).fold(0.0_f64, f64::max);
895            let tolerance = 1e-14 * max_abs.max(1.0);
896
897            if r_diag.abs() < tolerance {
898                return None;  // Singular
899            }
900
901            let mut sum = qty[row];
902            for col in (row + 1)..k {
903                sum -= qr_result.qr.get(row, col) * coef_permuted[col];
904            }
905            coef_permuted[row] = sum / r_diag;
906        }
907
908        // Step 3: Apply pivot permutation to get coefficients in original order
909        // pivot[j] is 1-based, indicating which original column is now in position j
910        let mut result = vec![0.0; p];
911        for j in 0..p {
912            let original_col = qr_result.pivot[j] - 1;  // Convert to 0-based
913            result[original_col] = coef_permuted[j];
914        }
915
916        Some(result)
917    }
918}
919
920/// Performs OLS regression using R's LINPACK QR algorithm.
921///
922/// This function is a drop-in replacement for `fit_ols` that uses the
923/// R-compatible QR decomposition with column pivoting. It handles
924/// rank-deficient matrices more gracefully than the standard QR decomposition.
925///
926/// # Arguments
927///
928/// * `y` - Response variable (n observations)
929/// * `x` - Design matrix (n rows, p columns including intercept)
930///
931/// # Returns
932///
933/// * `Some(Vec<f64>)` - OLS coefficient vector (p elements)
934/// * `None` - If the matrix is exactly singular or dimensions don't match
935///
936/// # Note
937///
938/// For rank-deficient systems, this function uses the pivoted QR which
939/// automatically handles multicollinearity by selecting a linearly
940/// independent subset of columns.
941pub fn fit_ols_linpack(y: &[f64], x: &Matrix) -> Option<Vec<f64>> {
942    let qr_result = x.qr_linpack(None);
943    x.qr_solve_linpack(&qr_result, y)
944}
945
946/// Fits OLS and predicts using R's LINPACK QR with rank-deficient handling.
947///
948/// This function matches R's `lm.fit` behavior for rank-deficient cases:
949/// coefficients for linearly dependent columns are set to NA, and predictions
950/// are computed using only the valid (non-NA) coefficients and their corresponding
951/// columns. This matches how R handles rank-deficient models in prediction.
952///
953/// # Arguments
954///
955/// * `y` - Response variable (n observations)
956/// * `x` - Design matrix (n rows, p columns including intercept)
957///
958/// # Returns
959///
960/// * `Some(Vec<f64>)` - Predictions (n elements)
961/// * `None` - If the matrix is exactly singular or dimensions don't match
962///
963/// # Algorithm
964///
965/// For rank-deficient systems (rank < p):
966/// 1. Compute QR decomposition with column pivoting
967/// 2. Get coefficients (rank-deficient columns will have NaN)
968/// 3. Build a reduced design matrix with only pivoted, non-singular columns
969/// 4. Compute predictions using only the valid columns
970///
971/// This matches R's behavior where `predict(lm.fit(...))` handles NA coefficients
972/// by excluding the corresponding columns from the prediction.
973pub fn fit_and_predict_linpack(y: &[f64], x: &Matrix) -> Option<Vec<f64>> {
974    let n = x.rows;
975    let p = x.cols;
976
977    // Compute QR decomposition
978    let qr_result = x.qr_linpack(None);
979    let k = qr_result.rank;
980
981    // Solve for coefficients
982    let beta_permuted = x.qr_solve_linpack(&qr_result, y)?;
983
984    // Check for rank deficiency
985    if k == p {
986        // Full rank - use standard prediction
987        return Some(x.mul_vec(&beta_permuted));
988    }
989
990    // Rank-deficient case: some columns are collinear and have NaN coefficients
991    // We compute predictions using only columns with valid (non-NaN) coefficients
992    // This matches R's behavior where NA coefficients exclude columns from prediction
993
994    let mut pred = vec![0.0; n];
995
996    for row in 0..n {
997        let mut sum = 0.0;
998        for j in 0..p {
999            let b_val = beta_permuted[j];
1000            if b_val.is_nan() {
1001                continue;  // Skip collinear columns (matches R's NA coefficient behavior)
1002            }
1003            sum += x.get(row, j) * b_val;
1004        }
1005        pred[row] = sum;
1006    }
1007
1008    Some(pred)
1009}