linreg_core/
linalg.rs

1//! Minimal Linear Algebra module to replace nalgebra dependency.
2//!
3//! Implements matrix operations, QR decomposition, and solvers needed for OLS.
4//! Uses row-major storage for compatibility with statistical computing conventions.
5//!
6//! # Numerical Stability Considerations
7//!
8//! This implementation uses Householder QR decomposition with careful attention to
9//! numerical stability:
10//!
11//! - **Sign convention**: Uses the numerically stable Householder sign choice
12//!   (v = x + sgn(x₀)||x||e₁) to avoid cancellation
13//! - **Tolerance checking**: Uses predefined tolerances to detect near-singular matrices
14//! - **Zero-skipping**: Skips transformations when columns are already zero-aligned
15//!
16//! # Scaling Recommendations
17//!
18//! For optimal numerical stability when predictor variables have vastly different
19//! scales (e.g., one variable in millions, another in thousandths), consider
20//! standardizing predictors before regression. Z-score standardization
21//! (`x_scaled = (x - mean) / std`) is already done in VIF calculation.
22//!
23//! However, the current implementation handles typical OLS cases without explicit
24//! scaling, as QR decomposition is generally stable for well-conditioned matrices.
25
26// ============================================================================
27// Numerical Constants
28// ============================================================================
29
30/// Machine epsilon threshold for detecting zero values in QR decomposition.
31/// Values below this are treated as zero to avoid numerical instability.
32const QR_ZERO_TOLERANCE: f64 = 1e-12;
33
34/// Threshold for detecting singular matrices during inversion.
35/// Diagonal elements below this value indicate a near-singular matrix.
36const SINGULAR_TOLERANCE: f64 = 1e-10;
37
38/// A dense matrix stored in row-major order.
39///
40/// # Storage
41///
42/// Elements are stored in a single flat vector in row-major order:
43/// `data[row * cols + col]`
44///
45/// # Example
46///
47/// ```
48/// # use linreg_core::linalg::Matrix;
49/// let m = Matrix::new(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
50/// assert_eq!(m.rows, 2);
51/// assert_eq!(m.cols, 3);
52/// assert_eq!(m.get(0, 0), 1.0);
53/// assert_eq!(m.get(1, 2), 6.0);
54/// ```
55#[derive(Clone, Debug)]
56pub struct Matrix {
57    /// Number of rows in the matrix
58    pub rows: usize,
59    /// Number of columns in the matrix
60    pub cols: usize,
61    /// Flat vector storing matrix elements in row-major order
62    pub data: Vec<f64>,
63}
64
65impl Matrix {
66    /// Creates a new matrix from the given dimensions and data.
67    ///
68    /// # Panics
69    ///
70    /// Panics if `data.len() != rows * cols`.
71    ///
72    /// # Arguments
73    ///
74    /// * `rows` - Number of rows
75    /// * `cols` - Number of columns
76    /// * `data` - Flat vector of elements in row-major order
77    ///
78    /// # Example
79    ///
80    /// ```
81    /// # use linreg_core::linalg::Matrix;
82    /// let m = Matrix::new(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
83    /// assert_eq!(m.get(0, 0), 1.0);
84    /// assert_eq!(m.get(0, 1), 2.0);
85    /// assert_eq!(m.get(1, 0), 3.0);
86    /// assert_eq!(m.get(1, 1), 4.0);
87    /// ```
88    pub fn new(rows: usize, cols: usize, data: Vec<f64>) -> Self {
89        assert_eq!(data.len(), rows * cols, "Data length must match dimensions");
90        Matrix { rows, cols, data }
91    }
92
93    /// Creates a matrix filled with zeros.
94    ///
95    /// # Arguments
96    ///
97    /// * `rows` - Number of rows
98    /// * `cols` - Number of columns
99    ///
100    /// # Example
101    ///
102    /// ```
103    /// # use linreg_core::linalg::Matrix;
104    /// let m = Matrix::zeros(3, 2);
105    /// assert_eq!(m.rows, 3);
106    /// assert_eq!(m.cols, 2);
107    /// assert_eq!(m.get(1, 1), 0.0);
108    /// ```
109    pub fn zeros(rows: usize, cols: usize) -> Self {
110        Matrix {
111            rows,
112            cols,
113            data: vec![0.0; rows * cols],
114        }
115    }
116
117    // NOTE: Currently unused but kept as reference implementation.
118    // Uncomment if needed for convenience constructor.
119    /*
120    /// Creates a matrix from a row-major slice.
121    ///
122    /// # Arguments
123    ///
124    /// * `rows` - Number of rows
125    /// * `cols` - Number of columns
126    /// * `slice` - Slice containing matrix elements in row-major order
127    pub fn from_row_slice(rows: usize, cols: usize, slice: &[f64]) -> Self {
128        Matrix::new(rows, cols, slice.to_vec())
129    }
130    */
131
132    /// Gets the element at the specified row and column.
133    ///
134    /// # Arguments
135    ///
136    /// * `row` - Row index (0-based)
137    /// * `col` - Column index (0-based)
138    pub fn get(&self, row: usize, col: usize) -> f64 {
139        self.data[row * self.cols + col]
140    }
141
142    /// Sets the element at the specified row and column.
143    ///
144    /// # Arguments
145    ///
146    /// * `row` - Row index (0-based)
147    /// * `col` - Column index (0-based)
148    /// * `val` - Value to set
149    pub fn set(&mut self, row: usize, col: usize, val: f64) {
150        self.data[row * self.cols + col] = val;
151    }
152
153    /// Returns the transpose of this matrix.
154    ///
155    /// Swaps rows with columns: `result[col][row] = self[row][col]`.
156    ///
157    /// # Example
158    ///
159    /// ```
160    /// # use linreg_core::linalg::Matrix;
161    /// let m = Matrix::new(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
162    /// let t = m.transpose();
163    /// assert_eq!(t.rows, 3);
164    /// assert_eq!(t.cols, 2);
165    /// assert_eq!(t.get(0, 1), 4.0);
166    /// ```
167    pub fn transpose(&self) -> Matrix {
168        let mut t_data = vec![0.0; self.rows * self.cols];
169        for r in 0..self.rows {
170            for c in 0..self.cols {
171                t_data[c * self.rows + r] = self.get(r, c);
172            }
173        }
174        Matrix::new(self.cols, self.rows, t_data)
175    }
176
177    /// Performs matrix multiplication: `self * other`.
178    ///
179    /// # Panics
180    ///
181    /// Panics if `self.cols != other.rows`.
182    ///
183    /// # Example
184    ///
185    /// ```
186    /// # use linreg_core::linalg::Matrix;
187    /// let a = Matrix::new(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
188    /// let b = Matrix::new(3, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
189    /// let c = a.matmul(&b);
190    /// assert_eq!(c.rows, 2);
191    /// assert_eq!(c.cols, 2);
192    /// assert_eq!(c.get(0, 0), 22.0); // 1*1 + 2*3 + 3*5
193    /// ```
194    pub fn matmul(&self, other: &Matrix) -> Matrix {
195        assert_eq!(self.cols, other.rows, "Dimension mismatch for multiplication");
196        let mut result = Matrix::zeros(self.rows, other.cols);
197
198        for r in 0..self.rows {
199            for c in 0..other.cols {
200                let mut sum = 0.0;
201                for k in 0..self.cols {
202                    sum += self.get(r, k) * other.get(k, c);
203                }
204                result.set(r, c, sum);
205            }
206        }
207        result
208    }
209
210    /// Multiplies this matrix by a vector (treating vector as column matrix).
211    ///
212    /// Computes `self * vec` where vec is treated as an n×1 column matrix.
213    ///
214    /// # Panics
215    ///
216    /// Panics if `self.cols != vec.len()`.
217    ///
218    /// # Arguments
219    ///
220    /// * `vec` - Vector to multiply by
221    ///
222    /// # Example
223    ///
224    /// ```
225    /// # use linreg_core::linalg::Matrix;
226    /// let m = Matrix::new(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
227    /// let v = vec![1.0, 2.0, 3.0];
228    /// let result = m.mul_vec(&v);
229    /// assert_eq!(result.len(), 2);
230    /// assert_eq!(result[0], 14.0); // 1*1 + 2*2 + 3*3
231    /// ```
232    #[allow(clippy::needless_range_loop)]
233    pub fn mul_vec(&self, vec: &[f64]) -> Vec<f64> {
234        assert_eq!(self.cols, vec.len(), "Dimension mismatch for matrix-vector multiplication");
235        let mut result = vec![0.0; self.rows];
236
237        for r in 0..self.rows {
238            let mut sum = 0.0;
239            for c in 0..self.cols {
240                sum += self.get(r, c) * vec[c];
241            }
242            result[r] = sum;
243        }
244        result
245    }
246
247    /// Computes the dot product of a column with a vector: `Σ(data[i * cols + col] * v[i])`.
248    ///
249    /// For a row-major matrix, this iterates through all rows at a fixed column.
250    ///
251    /// # Arguments
252    ///
253    /// * `col` - Column index
254    /// * `v` - Vector to dot with (must have length equal to rows)
255    ///
256    /// # Panics
257    ///
258    /// Panics if `col >= cols` or `v.len() != rows`.
259    #[allow(clippy::needless_range_loop)]
260    pub fn col_dot(&self, col: usize, v: &[f64]) -> f64 {
261        assert!(col < self.cols, "Column index out of bounds");
262        assert_eq!(self.rows, v.len(), "Vector length must match number of rows");
263
264        let mut sum = 0.0;
265        for row in 0..self.rows {
266            sum += self.get(row, col) * v[row];
267        }
268        sum
269    }
270
271    /// Performs the column-vector operation in place: `v += alpha * column_col`.
272    ///
273    /// This is the AXPY operation where the column is treated as a vector.
274    /// For row-major storage, we iterate through rows at a fixed column.
275    ///
276    /// # Arguments
277    ///
278    /// * `col` - Column index
279    /// * `alpha` - Scaling factor for the column
280    /// * `v` - Vector to modify in place (must have length equal to rows)
281    ///
282    /// # Panics
283    ///
284    /// Panics if `col >= cols` or `v.len() != rows`.
285    #[allow(clippy::needless_range_loop)]
286    pub fn col_axpy_inplace(&self, col: usize, alpha: f64, v: &mut [f64]) {
287        assert!(col < self.cols, "Column index out of bounds");
288        assert_eq!(self.rows, v.len(), "Vector length must match number of rows");
289
290        for row in 0..self.rows {
291            v[row] += alpha * self.get(row, col);
292        }
293    }
294
295    /// Computes the squared L2 norm of a column: `Σ(data[i * cols + col]²)`.
296    ///
297    /// # Arguments
298    ///
299    /// * `col` - Column index
300    ///
301    /// # Panics
302    ///
303    /// Panics if `col >= cols`.
304    #[allow(clippy::needless_range_loop)]
305    pub fn col_norm2(&self, col: usize) -> f64 {
306        assert!(col < self.cols, "Column index out of bounds");
307
308        let mut sum = 0.0;
309        for row in 0..self.rows {
310            let val = self.get(row, col);
311            sum += val * val;
312        }
313        sum
314    }
315
316    /// Adds a value to diagonal elements starting from a given index.
317    ///
318    /// This is useful for ridge regression where we add `lambda * I` to `X^T X`,
319    /// but the intercept column should not be penalized.
320    ///
321    /// # Arguments
322    ///
323    /// * `alpha` - Value to add to diagonal elements
324    /// * `start_index` - Starting diagonal index (0 = first diagonal element)
325    ///
326    /// # Panics
327    ///
328    /// Panics if the matrix is not square.
329    ///
330    /// # Example
331    ///
332    /// For a 3×3 identity matrix with intercept in first column (unpenalized):
333    /// ```text
334    /// add_diagonal_in_place(lambda, 1) on:
335    /// [1.0, 0.0, 0.0]       [1.0,   0.0,   0.0  ]
336    /// [0.0, 1.0, 0.0]  ->   [0.0,  1.0+λ, 0.0  ]
337    /// [0.0, 0.0, 1.0]       [0.0,   0.0,  1.0+λ]
338    /// ```
339    pub fn add_diagonal_in_place(&mut self, alpha: f64, start_index: usize) {
340        assert_eq!(self.rows, self.cols, "Matrix must be square");
341        let n = self.rows;
342        for i in start_index..n {
343            let current = self.get(i, i);
344            self.set(i, i, current + alpha);
345        }
346    }
347}
348
349// ============================================================================
350// QR Decomposition
351// ============================================================================
352
353impl Matrix {
354    /// Computes the QR decomposition using Householder reflections.
355    ///
356    /// Factorizes the matrix as `A = QR` where Q is orthogonal and R is upper triangular.
357    ///
358    /// # Requirements
359    ///
360    /// This implementation requires `rows >= cols` (tall matrix). For OLS regression,
361    /// we always have more observations than predictors, so this requirement is satisfied.
362    ///
363    /// # Returns
364    ///
365    /// A tuple `(Q, R)` where:
366    /// - `Q` is an orthogonal matrix (QᵀQ = I) of size m×m
367    /// - `R` is an upper triangular matrix of size m×n
368    #[allow(clippy::needless_range_loop)]
369    pub fn qr(&self) -> (Matrix, Matrix) {
370        let m = self.rows;
371        let n = self.cols;
372        let mut q = Matrix::identity(m);
373        let mut r = self.clone();
374
375        for k in 0..n.min(m - 1) {
376            // Create vector x = R[k:, k]
377            let mut x = vec![0.0; m - k];
378            for i in k..m {
379                x[i - k] = r.get(i, k);
380            }
381
382            // Norm of x
383            let norm_x: f64 = x.iter().map(|&v| v * v).sum::<f64>().sqrt();
384            if norm_x < QR_ZERO_TOLERANCE { continue; } // Already zero
385
386            // Create vector v = x + sign(x[0]) * ||x|| * e1
387            //
388            // NOTE: Numerical stability consideration (Householder sign choice)
389            // According to Overton & Yu (2023), the numerically stable choice is
390            // σ = -sgn(x₁) in the formula v = x - σ‖x‖e₁.
391            //
392            // This means: v = x - (-sgn(x₁))‖x‖e₁ = x + sgn(x₁)‖x‖e₁
393            //
394            // Equivalently: u₁ = x₁ + sgn(x₁)‖x‖
395            //
396            // Current implementation uses this formula (the "correct" choice for stability):
397            let sign = if x[0] >= 0.0 { 1.0 } else { -1.0 };  // sgn(x₀) as defined (sgn(0) = +1)
398            let u1 = x[0] + sign * norm_x;
399            
400            // Normalize v to get Householder vector
401            let mut v = x; // Re-use storage
402            v[0] = u1;
403
404            let norm_v: f64 = v.iter().map(|&val| val * val).sum::<f64>().sqrt();
405            for val in &mut v { *val /= norm_v; }
406
407            // Apply Householder transformation to R: R = H * R = (I - 2vv^T)R = R - 2v(v^T R)
408            // Focus on submatrix R[k:, k:]
409            for j in k..n {
410                let mut dot = 0.0;
411                for i in 0..m-k {
412                    dot += v[i] * r.get(k+i, j);
413                }
414                
415                for i in 0..m-k {
416                    let val = r.get(k+i, j) - 2.0 * v[i] * dot;
417                    r.set(k+i, j, val);
418                }
419            }
420
421            // Update Q: Q = Q * H = Q(I - 2vv^T) = Q - 2(Qv)v^T
422            // Focus on Q[:, k:]
423            for i in 0..m {
424                let mut dot = 0.0;
425                for l in 0..m-k {
426                    dot += q.get(i, k+l) * v[l];
427                }
428                
429                for l in 0..m-k {
430                    let val = q.get(i, k+l) - 2.0 * dot * v[l];
431                    q.set(i, k+l, val);
432                }
433            }
434        }
435
436        (q, r)
437    }
438
439    /// Creates an identity matrix of the given size.
440    ///
441    /// # Arguments
442    ///
443    /// * `size` - Number of rows and columns (square matrix)
444    ///
445    /// # Example
446    ///
447    /// ```
448    /// # use linreg_core::linalg::Matrix;
449    /// let i = Matrix::identity(3);
450    /// assert_eq!(i.get(0, 0), 1.0);
451    /// assert_eq!(i.get(1, 1), 1.0);
452    /// assert_eq!(i.get(2, 2), 1.0);
453    /// assert_eq!(i.get(0, 1), 0.0);
454    /// ```
455    pub fn identity(size: usize) -> Self {
456        let mut data = vec![0.0; size * size];
457        for i in 0..size {
458            data[i * size + i] = 1.0;
459        }
460        Matrix::new(size, size, data)
461    }
462
463    /// Inverts an upper triangular matrix (such as R from QR decomposition).
464    ///
465    /// Uses back-substitution to compute the inverse. This is efficient for
466    /// triangular matrices compared to general matrix inversion.
467    ///
468    /// # Panics
469    ///
470    /// Panics if the matrix is not square.
471    ///
472    /// # Returns
473    ///
474    /// `None` if the matrix is singular (has a zero or near-zero diagonal element).
475    /// A matrix is considered singular if any diagonal element is below the
476    /// internal tolerance (1e-10), which indicates the matrix does not have full rank.
477    ///
478    /// # Note
479    ///
480    /// For upper triangular matrices, singularity is equivalent to having a
481    /// zero (or near-zero) diagonal element. This is much simpler to check than
482    /// for general matrices, which would require computing the condition number.
483    pub fn invert_upper_triangular(&self) -> Option<Matrix> {
484        let n = self.rows;
485        assert_eq!(n, self.cols, "Matrix must be square");
486
487        // Check for singularity using relative tolerance
488        // This scales with the magnitude of diagonal elements, handling matrices
489        // of different scales better than a fixed absolute tolerance.
490        //
491        // Previous implementation used absolute tolerance:
492        //   if self.get(i, i).abs() < SINGULAR_TOLERANCE { return None; }
493        //
494        // New implementation uses relative tolerance similar to LAPACK:
495        //   tolerance = max_diag * epsilon * n
496        // where epsilon is machine epsilon (~2.2e-16 for f64)
497        let max_diag: f64 = (0..n)
498            .map(|i| self.get(i, i).abs())
499            .fold(0.0_f64, |acc, val| acc.max(val));
500
501        // Use a relative tolerance based on the maximum diagonal element
502        // This is similar to LAPACK's dlamch machine epsilon approach
503        let epsilon = 2.0_f64 * f64::EPSILON;  // ~4.4e-16 for f64
504        let relative_tolerance = max_diag * epsilon * n as f64;
505        let tolerance = SINGULAR_TOLERANCE.max(relative_tolerance);
506
507        for i in 0..n {
508            if self.get(i, i).abs() < tolerance {
509                return None; // Singular matrix - cannot invert
510            }
511        }
512
513        let mut inv = Matrix::zeros(n, n);
514
515        for i in 0..n {
516            inv.set(i, i, 1.0 / self.get(i, i));
517
518            for j in (0..i).rev() {
519                let mut sum = 0.0;
520                for k in j+1..=i {
521                    sum += self.get(j, k) * inv.get(k, i);
522                }
523                inv.set(j, i, -sum / self.get(j, j));
524            }
525        }
526
527        Some(inv)
528    }
529
530    /// Inverts an upper triangular matrix with a custom tolerance multiplier.
531    ///
532    /// The tolerance is computed as `max_diag * epsilon * n * tolerance_mult`.
533    /// A higher tolerance_mult allows more tolerance for near-singular matrices.
534    ///
535    /// # Arguments
536    ///
537    /// * `tolerance_mult` - Multiplier for the tolerance (1.0 = standard, higher = more tolerant)
538    pub fn invert_upper_triangular_with_tolerance(&self, tolerance_mult: f64) -> Option<Matrix> {
539        let n = self.rows;
540        assert_eq!(n, self.cols, "Matrix must be square");
541
542        // Check for singularity using relative tolerance
543        let max_diag: f64 = (0..n)
544            .map(|i| self.get(i, i).abs())
545            .fold(0.0_f64, |acc, val| acc.max(val));
546
547        // Use a relative tolerance based on the maximum diagonal element
548        let epsilon = 2.0_f64 * f64::EPSILON;
549        let relative_tolerance = max_diag * epsilon * n as f64 * tolerance_mult;
550        let tolerance = SINGULAR_TOLERANCE.max(relative_tolerance);
551
552        for i in 0..n {
553            if self.get(i, i).abs() < tolerance {
554                return None;
555            }
556        }
557
558        let mut inv = Matrix::zeros(n, n);
559
560        for i in 0..n {
561            inv.set(i, i, 1.0 / self.get(i, i));
562
563            for j in (0..i).rev() {
564                let mut sum = 0.0;
565                for k in j+1..=i {
566                    sum += self.get(j, k) * inv.get(k, i);
567                }
568                inv.set(j, i, -sum / self.get(j, j));
569            }
570        }
571
572        Some(inv)
573    }
574
575    /// Computes the inverse of a square matrix using QR decomposition.
576    ///
577    /// For an invertible matrix A, computes A⁻¹ such that A * A⁻¹ = I.
578    /// Uses QR decomposition for numerical stability.
579    ///
580    /// # Panics
581    ///
582    /// Panics if the matrix is not square (i.e., `self.rows != self.cols`).
583    /// Check dimensions before calling if the matrix shape is not guaranteed.
584    ///
585    /// # Returns
586    ///
587    /// Returns `Some(inverse)` if the matrix is invertible, or `None` if
588    /// the matrix is singular (non-invertible).
589    pub fn invert(&self) -> Option<Matrix> {
590        let n = self.rows;
591        if n != self.cols {
592            panic!("Matrix must be square for inversion");
593        }
594
595        // Use QR decomposition: A = Q * R
596        let (q, r) = self.qr();
597
598        // Compute R⁻¹ (upper triangular inverse)
599        let r_inv = r.invert_upper_triangular()?;
600
601        // A⁻¹ = R⁻¹ * Q^T
602        let q_transpose = q.transpose();
603        let mut result = Matrix::zeros(n, n);
604
605        for i in 0..n {
606            for j in 0..n {
607                let mut sum = 0.0;
608                for k in 0..n {
609                    sum += r_inv.get(i, k) * q_transpose.get(k, j);
610                }
611                result.set(i, j, sum);
612            }
613        }
614
615        Some(result)
616    }
617
618    /// Computes the inverse of X'X given the QR decomposition of X (R's chol2inv).
619    ///
620    /// This is equivalent to computing `(X'X)^(-1)` using the QR decomposition of X.
621    /// R's `chol2inv` function is used for numerical stability in recursive residuals.
622    ///
623    /// # Arguments
624    ///
625    /// * `x` - Input matrix (must have rows >= cols)
626    ///
627    /// # Returns
628    ///
629    /// `Some((X'X)^(-1))` if X has full rank, `None` otherwise.
630    ///
631    /// # Algorithm
632    ///
633    /// Given QR decomposition X = QR where R is upper triangular:
634    /// 1. Extract the upper p×p portion of R (denoted R₁)
635    /// 2. Invert R₁ (upper triangular inverse)
636    /// 3. Compute (X'X)^(-1) = R₁^(-1) × R₁^(-T)
637    ///
638    /// This works because X'X = R'Q'QR = R'R, and R₁ contains the Cholesky factor.
639    pub fn chol2inv_from_qr(&self) -> Option<Matrix> {
640        self.chol2inv_from_qr_with_tolerance(1.0)
641    }
642
643    /// Computes the inverse of X'X given the QR decomposition with custom tolerance.
644    ///
645    /// Similar to `chol2inv_from_qr` but allows specifying a tolerance multiplier
646    /// for handling near-singular matrices.
647    ///
648    /// # Arguments
649    ///
650    /// * `tolerance_mult` - Multiplier for the tolerance (higher = more tolerant)
651    pub fn chol2inv_from_qr_with_tolerance(&self, tolerance_mult: f64) -> Option<Matrix> {
652        let p = self.cols;
653
654        // QR decomposition: X = QR
655        // For X (m×n, m≥n), R is m×n upper triangular
656        // The upper n×n block of R contains the meaningful values
657        let (_, r_full) = self.qr();
658
659        // Extract upper p×p portion from R
660        // For tall matrices (m > p), R has zeros below diagonal in first p rows
661        // For square matrices (m = p), R is p×p upper triangular
662        let mut r1 = Matrix::zeros(p, p);
663        for i in 0..p {
664            // Row i of R1 is row i of R_full, columns 0..p
665            // But we only copy the upper triangular part (columns i..p)
666            for j in i..p {
667                r1.set(i, j, r_full.get(i, j));
668            }
669            // Also copy diagonal if not yet copied
670            if i < p {
671                r1.set(i, i, r_full.get(i, i));
672            }
673        }
674
675        // Invert R₁ (upper triangular) with custom tolerance
676        let r1_inv = r1.invert_upper_triangular_with_tolerance(tolerance_mult)?;
677
678        // Compute (X'X)^(-1) = R₁^(-1) × R₁^(-T)
679        let mut result = Matrix::zeros(p, p);
680        for i in 0..p {
681            for j in 0..p {
682                let mut sum = 0.0;
683                // result[i,j] = sum(R1_inv[i,k] * R1_inv[j,k] for k=0..p)
684                // R1_inv is upper triangular, but we iterate full range
685                for k in 0..p {
686                    sum += r1_inv.get(i, k) * r1_inv.get(j, k);
687                }
688                result.set(i, j, sum);
689            }
690        }
691
692        Some(result)
693    }
694}
695
696// ============================================================================
697// Vector Helper Functions
698// ============================================================================
699
700/// Computes the arithmetic mean of a slice of f64 values.
701///
702/// Returns 0.0 for empty slices.
703///
704/// # Examples
705///
706/// ```
707/// use linreg_core::linalg::vec_mean;
708///
709/// assert_eq!(vec_mean(&[1.0, 2.0, 3.0, 4.0, 5.0]), 3.0);
710/// assert_eq!(vec_mean(&[]), 0.0);
711/// ```
712///
713/// # Arguments
714///
715/// * `v` - Slice of values
716pub fn vec_mean(v: &[f64]) -> f64 {
717    if v.is_empty() { return 0.0; }
718    v.iter().sum::<f64>() / v.len() as f64
719}
720
721/// Computes element-wise subtraction of two slices: `a - b`.
722///
723/// # Examples
724///
725/// ```
726/// use linreg_core::linalg::vec_sub;
727///
728/// let a = vec![5.0, 4.0, 3.0];
729/// let b = vec![1.0, 1.0, 1.0];
730/// let result = vec_sub(&a, &b);
731/// assert_eq!(result, vec![4.0, 3.0, 2.0]);
732/// ```
733///
734/// # Arguments
735///
736/// * `a` - Minuend slice
737/// * `b` - Subtrahend slice
738///
739/// # Panics
740///
741/// Panics if slices have different lengths.
742pub fn vec_sub(a: &[f64], b: &[f64]) -> Vec<f64> {
743    assert_eq!(a.len(), b.len(), "vec_sub: slice lengths must match");
744    a.iter().zip(b.iter()).map(|(x, y)| x - y).collect()
745}
746
747/// Computes the dot product of two slices: `Σ(a[i] * b[i])`.
748///
749/// # Examples
750///
751/// ```
752/// use linreg_core::linalg::vec_dot;
753///
754/// let a = vec![1.0, 2.0, 3.0];
755/// let b = vec![4.0, 5.0, 6.0];
756/// assert_eq!(vec_dot(&a, &b), 32.0);  // 1*4 + 2*5 + 3*6
757/// ```
758///
759/// # Arguments
760///
761/// * `a` - First slice
762/// * `b` - Second slice
763///
764/// # Panics
765///
766/// Panics if slices have different lengths.
767pub fn vec_dot(a: &[f64], b: &[f64]) -> f64 {
768    assert_eq!(a.len(), b.len(), "vec_dot: slice lengths must match");
769    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
770}
771
772/// Computes element-wise addition of two slices: `a + b`.
773///
774/// # Examples
775///
776/// ```
777/// use linreg_core::linalg::vec_add;
778///
779/// let a = vec![1.0, 2.0, 3.0];
780/// let b = vec![4.0, 5.0, 6.0];
781/// assert_eq!(vec_add(&a, &b), vec![5.0, 7.0, 9.0]);
782/// ```
783///
784/// # Arguments
785///
786/// * `a` - First slice
787/// * `b` - Second slice
788///
789/// # Panics
790///
791/// Panics if slices have different lengths.
792pub fn vec_add(a: &[f64], b: &[f64]) -> Vec<f64> {
793    assert_eq!(a.len(), b.len(), "vec_add: slice lengths must match");
794    a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
795}
796
797/// Computes a scaled vector addition in place: `dst += alpha * src`.
798///
799/// This is the classic BLAS AXPY operation.
800///
801/// # Arguments
802///
803/// * `dst` - Destination slice (modified in place)
804/// * `alpha` - Scaling factor for src
805/// * `src` - Source slice
806///
807/// # Panics
808///
809/// Panics if slices have different lengths.
810pub fn vec_axpy_inplace(dst: &mut [f64], alpha: f64, src: &[f64]) {
811    assert_eq!(dst.len(), src.len(), "vec_axpy_inplace: slice lengths must match");
812    for (d, &s) in dst.iter_mut().zip(src.iter()) {
813        *d += alpha * s;
814    }
815}
816
817/// Scales a vector in place: `v *= alpha`.
818///
819/// # Arguments
820///
821/// * `v` - Vector to scale (modified in place)
822/// * `alpha` - Scaling factor
823pub fn vec_scale_inplace(v: &mut [f64], alpha: f64) {
824    for val in v.iter_mut() {
825        *val *= alpha;
826    }
827}
828
829/// Returns a scaled copy of a vector: `v * alpha`.
830///
831/// # Examples
832///
833/// ```
834/// use linreg_core::linalg::vec_scale;
835///
836/// let v = vec![1.0, 2.0, 3.0];
837/// let scaled = vec_scale(&v, 2.5);
838/// assert_eq!(scaled, vec![2.5, 5.0, 7.5]);
839/// // Original is unchanged
840/// assert_eq!(v, vec![1.0, 2.0, 3.0]);
841/// ```
842///
843/// # Arguments
844///
845/// * `v` - Vector to scale
846/// * `alpha` - Scaling factor
847pub fn vec_scale(v: &[f64], alpha: f64) -> Vec<f64> {
848    v.iter().map(|&x| x * alpha).collect()
849}
850
851/// Computes the L2 norm (Euclidean norm) of a vector: `sqrt(Σ(v[i]²))`.
852///
853/// # Examples
854///
855/// ```
856/// use linreg_core::linalg::vec_l2_norm;
857///
858/// // Pythagorean triple: 3-4-5
859/// assert_eq!(vec_l2_norm(&[3.0, 4.0]), 5.0);
860/// // Unit vector
861/// assert_eq!(vec_l2_norm(&[1.0, 0.0, 0.0]), 1.0);
862/// ```
863///
864/// # Arguments
865///
866/// * `v` - Vector slice
867pub fn vec_l2_norm(v: &[f64]) -> f64 {
868    v.iter().map(|&x| x * x).sum::<f64>().sqrt()
869}
870
871/// Computes the maximum absolute value in a vector.
872///
873/// # Arguments
874///
875/// * `v` - Vector slice
876pub fn vec_max_abs(v: &[f64]) -> f64 {
877    v.iter().map(|&x| x.abs()).fold(0.0_f64, f64::max)
878}
879
880// ============================================================================
881// R-Compatible QR Decomposition (LINPACK dqrdc2 with Column Pivoting)
882// ============================================================================
883
884/// QR decomposition result using R's LINPACK dqrdc2 algorithm.
885///
886/// This implements the QR decomposition with column pivoting as used by R's
887/// `qr()` function with `LAPACK=FALSE`. The algorithm is a modification of
888/// LINPACK's DQRDC that:
889/// - Uses Householder transformations
890/// - Implements limited column pivoting based on 2-norms of reduced columns
891/// - Moves columns with near-zero norm to the right-hand edge
892/// - Computes the rank (number of linearly independent columns)
893///
894/// # Fields
895///
896/// * `qr` - The QR factorization (upper triangle contains R, below diagonal
897///   contains Householder vector information)
898/// * `qraux` - Auxiliary information for recovering the orthogonal part Q
899/// * `pivot` - Column permutation: `pivot\[j\]` contains the original column index
900///   now in column j
901/// * `rank` - Number of linearly independent columns (the computed rank)
902#[derive(Clone, Debug)]
903pub struct QRLinpack {
904    /// QR factorization matrix (same dimensions as input)
905    pub qr: Matrix,
906    /// Auxiliary information for Q recovery
907    pub qraux: Vec<f64>,
908    /// Column pivot vector (1-based indices like R)
909    pub pivot: Vec<usize>,
910    /// Computed rank (number of linearly independent columns)
911    pub rank: usize,
912}
913
914impl Matrix {
915    /// Computes QR decomposition using R's LINPACK dqrdc2 algorithm with column pivoting.
916    ///
917    /// This is a port of R's dqrdc2.f, which is a modification of LINPACK's DQRDC.
918    /// The algorithm:
919    /// 1. Uses Householder transformations for QR factorization
920    /// 2. Implements limited column pivoting based on column 2-norms
921    /// 3. Moves columns with near-zero norm to the right-hand edge
922    /// 4. Computes the rank (number of linearly independent columns)
923    ///
924    /// # Arguments
925    ///
926    /// * `tol` - Tolerance for determining linear independence. Default is 1e-7 (R's default).
927    ///   Columns with norm < tol * original_norm are considered negligible.
928    ///
929    /// # Returns
930    ///
931    /// A [`QRLinpack`] struct containing the QR factorization, auxiliary information,
932    /// pivot vector, and computed rank.
933    ///
934    /// # Algorithm Details
935    ///
936    /// The decomposition is A * P = Q * R where:
937    /// - P is the permutation matrix coded by `pivot`
938    /// - Q is orthogonal (m × m)
939    /// - R is upper triangular in the first `rank` rows
940    ///
941    /// The `qr` matrix contains:
942    /// - Upper triangle: R matrix (if pivoting was performed, this is R of the permuted matrix)
943    /// - Below diagonal: Householder vector information
944    ///
945    /// # Reference
946    ///
947    /// - R source: src/appl/dqrdc2.f
948    /// - LINPACK documentation: <https://www.netlib.org/linpack/dqrdc.f>
949    pub fn qr_linpack(&self, tol: Option<f64>) -> QRLinpack {
950        let n = self.rows;
951        let p = self.cols;
952        let lup = n.min(p);
953
954        // Default tolerance matches R's qr.default: tol = 1e-07
955        let tol = tol.unwrap_or(1e-07);
956
957        // Initialize working matrices
958        let mut x = self.clone(); // Working copy that will be modified
959        let mut qraux = vec![0.0; p];
960        let mut pivot: Vec<usize> = (1..=p).collect(); // 1-based indices like R
961        let mut work = vec![(0.0, 0.0); p]; // (work[j,1], work[j,2])
962
963        // Compute the norms of the columns of x (initialization)
964        if n > 0 {
965            for j in 0..p {
966                let mut norm = 0.0;
967                for i in 0..n {
968                    norm += x.get(i, j) * x.get(i, j);
969                }
970                norm = norm.sqrt();
971                qraux[j] = norm;
972                let original_norm = if norm == 0.0 { 1.0 } else { norm };
973                work[j] = (norm, original_norm);
974            }
975        }
976
977        let mut k = p + 1; // Will be decremented to get the final rank
978
979        // Perform the Householder reduction of x
980        for l in 0..lup {
981            // Cycle columns from l to p until one with non-negligible norm is found
982            // A column is negligible if its norm has fallen below tol * original_norm
983            while l < k - 1 && qraux[l] < work[l].1 * tol {
984                // Move column l to the end (it's negligible)
985                let lp1 = l + 1;
986
987                // Shift columns in x: x(i, l..p-1) = x(i, l+1..p)
988                for i in 0..n {
989                    let t = x.get(i, l);
990                    for j in lp1..p {
991                        x.set(i, j - 1, x.get(i, j));
992                    }
993                    x.set(i, p - 1, t);
994                }
995
996                // Shift pivot, qraux, and work arrays
997                let saved_pivot = pivot[l];
998                let saved_qraux = qraux[l];
999                let saved_work = work[l];
1000
1001                for j in lp1..p {
1002                    pivot[j - 1] = pivot[j];
1003                    qraux[j - 1] = qraux[j];
1004                    work[j - 1] = work[j];
1005                }
1006
1007                pivot[p - 1] = saved_pivot;
1008                qraux[p - 1] = saved_qraux;
1009                work[p - 1] = saved_work;
1010
1011                k -= 1;
1012            }
1013
1014            if l == n - 1 {
1015                // Last row - skip transformation
1016                break;
1017            }
1018
1019            // Compute the Householder transformation for column l
1020            // nrmxl = norm of x[l:, l]
1021            let mut nrmxl = 0.0;
1022            for i in l..n {
1023                let val = x.get(i, l);
1024                nrmxl += val * val;
1025            }
1026            nrmxl = nrmxl.sqrt();
1027
1028            if nrmxl == 0.0 {
1029                // Zero column - continue to next
1030                continue;
1031            }
1032
1033            // Apply sign for numerical stability (dsign in Fortran)
1034            let x_ll = x.get(l, l);
1035            if x_ll != 0.0 {
1036                nrmxl = nrmxl.copysign(x_ll);
1037            }
1038
1039            // Scale the column
1040            let scale = 1.0 / nrmxl;
1041            for i in l..n {
1042                x.set(i, l, x.get(i, l) * scale);
1043            }
1044            x.set(l, l, 1.0 + x.get(l, l));
1045
1046            // Apply the transformation to remaining columns, updating the norms
1047            let lp1 = l + 1;
1048            if p > lp1 {
1049                for j in lp1..p {
1050                    // Compute t = -dot(x[l:, l], x[l:, j]) / x(l, l)
1051                    let mut dot = 0.0;
1052                    for i in l..n {
1053                        dot += x.get(i, l) * x.get(i, j);
1054                    }
1055                    let t = -dot / x.get(l, l);
1056
1057                    // x[l:, j] = x[l:, j] + t * x[l:, l]
1058                    for i in l..n {
1059                        let val = x.get(i, j) + t * x.get(i, l);
1060                        x.set(i, j, val);
1061                    }
1062
1063                    // Update the norm
1064                    if qraux[j] != 0.0 {
1065                        // tt = 1.0 - (x(l, j) / qraux[j])^2
1066                        let x_lj = x.get(l, j).abs();
1067                        let mut tt = 1.0 - (x_lj / qraux[j]).powi(2);
1068                        tt = tt.max(0.0);
1069
1070                        // Recompute norm if there is large reduction (BDR mod 9/99)
1071                        // The tolerance here is on the squared norm
1072                        if tt.abs() < 1e-6 {
1073                            // Re-compute norm directly
1074                            let mut new_norm = 0.0;
1075                            for i in (l + 1)..n {
1076                                let val = x.get(i, j);
1077                                new_norm += val * val;
1078                            }
1079                            new_norm = new_norm.sqrt();
1080                            qraux[j] = new_norm;
1081                            work[j].0 = new_norm;
1082                        } else {
1083                            qraux[j] = qraux[j] * tt.sqrt();
1084                        }
1085                    }
1086                }
1087            }
1088
1089            // Save the transformation
1090            qraux[l] = x.get(l, l);
1091            x.set(l, l, -nrmxl);
1092        }
1093
1094        // Compute final rank
1095        let rank = k - 1;
1096        let rank = rank.min(n);
1097
1098        QRLinpack {
1099            qr: x,
1100            qraux,
1101            pivot,
1102            rank,
1103        }
1104    }
1105
1106    /// Solves a linear system using the QR decomposition with column pivoting.
1107    ///
1108    /// This implements a least squares solver using the pivoted QR decomposition.
1109    /// For rank-deficient cases, coefficients corresponding to linearly dependent
1110    /// columns are set to `f64::NAN`.
1111    ///
1112    /// # Arguments
1113    ///
1114    /// * `qr_result` - QR decomposition from [`Matrix::qr_linpack`]
1115    /// * `y` - Right-hand side vector
1116    ///
1117    /// # Returns
1118    ///
1119    /// A vector of coefficients, or `None` if the system is exactly singular.
1120    ///
1121    /// # Algorithm
1122    ///
1123    /// This solver uses the standard QR decomposition approach:
1124    /// 1. Compute the QR decomposition of the permuted matrix
1125    /// 2. Extract R matrix (upper triangular with positive diagonal)
1126    /// 3. Compute qty = Q^T * y
1127    /// 4. Solve R * coef = qty using back substitution
1128    /// 5. Apply the pivot permutation to restore original column order
1129    ///
1130    /// # Note
1131    ///
1132    /// The LINPACK QR algorithm stores R with mixed signs on the diagonal.
1133    /// This solver corrects for that by taking the absolute value of R's diagonal.
1134    pub fn qr_solve_linpack(&self, qr_result: &QRLinpack, y: &[f64]) -> Option<Vec<f64>> {
1135        let n = self.rows;
1136        let p = self.cols;
1137        let k = qr_result.rank;
1138
1139        if y.len() != n {
1140            return None;
1141        }
1142
1143        if k == 0 {
1144            return None;
1145        }
1146
1147        // Step 1: Compute Q^T * y using the Householder vectors directly
1148        // This is more efficient than reconstructing the full Q matrix
1149        let mut qty = y.to_vec();
1150
1151        for j in 0..k {
1152            // Check if this Householder transformation is valid
1153            let r_jj = qr_result.qr.get(j, j);
1154            if r_jj == 0.0 {
1155                continue;
1156            }
1157
1158            // Compute dot = v_j^T * qty[j:]
1159            // where v_j is the Householder vector stored in qr[j:, j]
1160            // The storage convention:
1161            // - qr[j,j] = -nrmxl (after final overwrite)
1162            // - qr[i,j] for i > j is the scaled Householder vector element
1163            // - qraux[j] = 1 + original_x[j,j]/nrmxl (the unscaled first element)
1164
1165            // Reconstruct the Householder vector v_j
1166            // After scaling by 1/nrmxl, we have:
1167            // v_scaled[j] = 1 + x[j,j]/nrmxl
1168            // v_scaled[i] = x[i,j]/nrmxl for i > j
1169            // The actual unit vector is v = v_scaled / ||v_scaled||
1170
1171            let mut v = vec![0.0; n - j];
1172            // Copy the scaled Householder vector from qr
1173            for i in j..n {
1174                v[i - j] = qr_result.qr.get(i, j);
1175            }
1176
1177            // The j-th element was modified during the QR decomposition
1178            // We need to reconstruct it from qraux
1179            let alpha = qr_result.qraux[j];
1180            if alpha != 0.0 {
1181                v[0] = alpha;
1182            }
1183
1184            // Compute the norm of v
1185            let v_norm: f64 = v.iter().map(|&x| x * x).sum::<f64>().sqrt();
1186            if v_norm < 1e-14 {
1187                continue;
1188            }
1189
1190            // Compute dot = v^T * qty[j:]
1191            let mut dot = 0.0;
1192            for i in j..n {
1193                dot += v[i - j] * qty[i];
1194            }
1195
1196            // Apply Householder transformation: qty[j:] = qty[j:] - 2 * v * (v^T * qty[j:]) / (v^T * v)
1197            // Since v is already scaled, we use: t = 2 * dot / (v_norm^2)
1198            let t = 2.0 * dot / (v_norm * v_norm);
1199
1200            for i in j..n {
1201                qty[i] -= t * v[i - j];
1202            }
1203        }
1204
1205        // Step 2: Back substitution on R (solve R * coef = qty)
1206        // The R matrix is stored in the upper triangle of qr
1207        // Note: The diagonal elements of R are negative (from -nrmxl)
1208        // We use them as-is since the signs cancel out in the computation
1209        let mut coef_permuted = vec![f64::NAN; p];
1210
1211        for row in (0..k).rev() {
1212            let r_diag = qr_result.qr.get(row, row);
1213            // Use relative tolerance for singularity check
1214            let max_abs = (0..k).map(|i| qr_result.qr.get(i, i).abs()).fold(0.0_f64, f64::max);
1215            let tolerance = 1e-14 * max_abs.max(1.0);
1216
1217            if r_diag.abs() < tolerance {
1218                return None;  // Singular
1219            }
1220
1221            let mut sum = qty[row];
1222            for col in (row + 1)..k {
1223                sum -= qr_result.qr.get(row, col) * coef_permuted[col];
1224            }
1225            coef_permuted[row] = sum / r_diag;
1226        }
1227
1228        // Step 3: Apply pivot permutation to get coefficients in original order
1229        // pivot[j] is 1-based, indicating which original column is now in position j
1230        let mut result = vec![0.0; p];
1231        for j in 0..p {
1232            let original_col = qr_result.pivot[j] - 1;  // Convert to 0-based
1233            result[original_col] = coef_permuted[j];
1234        }
1235
1236        Some(result)
1237    }
1238}
1239
1240/// Performs OLS regression using R's LINPACK QR algorithm.
1241///
1242/// This function is a drop-in replacement for `fit_ols` that uses the
1243/// R-compatible QR decomposition with column pivoting. It handles
1244/// rank-deficient matrices more gracefully than the standard QR decomposition.
1245///
1246/// # Arguments
1247///
1248/// * `y` - Response variable (n observations)
1249/// * `x` - Design matrix (n rows, p columns including intercept)
1250///
1251/// # Returns
1252///
1253/// * `Some(Vec<f64>)` - OLS coefficient vector (p elements)
1254/// * `None` - If the matrix is exactly singular or dimensions don't match
1255///
1256/// # Note
1257///
1258/// For rank-deficient systems, this function uses the pivoted QR which
1259/// automatically handles multicollinearity by selecting a linearly
1260/// independent subset of columns.
1261pub fn fit_ols_linpack(y: &[f64], x: &Matrix) -> Option<Vec<f64>> {
1262    let qr_result = x.qr_linpack(None);
1263    x.qr_solve_linpack(&qr_result, y)
1264}
1265
1266/// Fits OLS and predicts using R's LINPACK QR with rank-deficient handling.
1267///
1268/// This function matches R's `lm.fit` behavior for rank-deficient cases:
1269/// coefficients for linearly dependent columns are set to NA, and predictions
1270/// are computed using only the valid (non-NA) coefficients and their corresponding
1271/// columns. This matches how R handles rank-deficient models in prediction.
1272///
1273/// # Arguments
1274///
1275/// * `y` - Response variable (n observations)
1276/// * `x` - Design matrix (n rows, p columns including intercept)
1277///
1278/// # Returns
1279///
1280/// * `Some(Vec<f64>)` - Predictions (n elements)
1281/// * `None` - If the matrix is exactly singular or dimensions don't match
1282///
1283/// # Algorithm
1284///
1285/// For rank-deficient systems (rank < p):
1286/// 1. Compute QR decomposition with column pivoting
1287/// 2. Get coefficients (rank-deficient columns will have NaN)
1288/// 3. Build a reduced design matrix with only pivoted, non-singular columns
1289/// 4. Compute predictions using only the valid columns
1290///
1291/// This matches R's behavior where `predict(lm.fit(...))` handles NA coefficients
1292/// by excluding the corresponding columns from the prediction.
1293pub fn fit_and_predict_linpack(y: &[f64], x: &Matrix) -> Option<Vec<f64>> {
1294    let n = x.rows;
1295    let p = x.cols;
1296
1297    // Compute QR decomposition
1298    let qr_result = x.qr_linpack(None);
1299    let k = qr_result.rank;
1300
1301    // Solve for coefficients
1302    let beta_permuted = x.qr_solve_linpack(&qr_result, y)?;
1303
1304    // Check for rank deficiency
1305    if k == p {
1306        // Full rank - use standard prediction
1307        return Some(x.mul_vec(&beta_permuted));
1308    }
1309
1310    // Rank-deficient case: some columns are collinear and have NaN coefficients
1311    // We compute predictions using only columns with valid (non-NaN) coefficients
1312    // This matches R's behavior where NA coefficients exclude columns from prediction
1313
1314    let mut pred = vec![0.0; n];
1315
1316    for row in 0..n {
1317        let mut sum = 0.0;
1318        for j in 0..p {
1319            let b_val = beta_permuted[j];
1320            if b_val.is_nan() {
1321                continue;  // Skip collinear columns (matches R's NA coefficient behavior)
1322            }
1323            sum += x.get(row, j) * b_val;
1324        }
1325        pred[row] = sum;
1326    }
1327
1328    Some(pred)
1329}