linreg_core/linalg.rs
1//! Minimal Linear Algebra module to replace nalgebra dependency.
2//!
3//! Implements matrix operations, QR decomposition, and solvers needed for OLS.
4//! Uses row-major storage for compatibility with statistical computing conventions.
5//!
6//! # Numerical Stability Considerations
7//!
8//! This implementation uses Householder QR decomposition with careful attention to
9//! numerical stability:
10//!
11//! - **Sign convention**: Uses the numerically stable Householder sign choice
12//! (v = x + sgn(x₀)||x||e₁) to avoid cancellation
13//! - **Tolerance checking**: Uses predefined tolerances to detect near-singular matrices
14//! - **Zero-skipping**: Skips transformations when columns are already zero-aligned
15//!
16//! # Scaling Recommendations
17//!
18//! For optimal numerical stability when predictor variables have vastly different
19//! scales (e.g., one variable in millions, another in thousandths), consider
20//! standardizing predictors before regression. Z-score standardization
21//! (`x_scaled = (x - mean) / std`) is already done in VIF calculation.
22//!
23//! However, the current implementation handles typical OLS cases without explicit
24//! scaling, as QR decomposition is generally stable for well-conditioned matrices.
25
26// ============================================================================
27// Numerical Constants
28// ============================================================================
29
30/// Machine epsilon threshold for detecting zero values in QR decomposition.
31/// Values below this are treated as zero to avoid numerical instability.
32const QR_ZERO_TOLERANCE: f64 = 1e-12;
33
34/// Threshold for detecting singular matrices during inversion.
35/// Diagonal elements below this value indicate a near-singular matrix.
36const SINGULAR_TOLERANCE: f64 = 1e-10;
37
38/// A dense matrix stored in row-major order.
39///
40/// # Storage
41///
42/// Elements are stored in a single flat vector in row-major order:
43/// `data[row * cols + col]`
44#[derive(Clone, Debug)]
45pub struct Matrix {
46 /// Number of rows in the matrix
47 pub rows: usize,
48 /// Number of columns in the matrix
49 pub cols: usize,
50 /// Flat vector storing matrix elements in row-major order
51 pub data: Vec<f64>,
52}
53
54impl Matrix {
55 /// Creates a new matrix from the given dimensions and data.
56 ///
57 /// # Panics
58 ///
59 /// Panics if `data.len() != rows * cols`.
60 ///
61 /// # Arguments
62 ///
63 /// * `rows` - Number of rows
64 /// * `cols` - Number of columns
65 /// * `data` - Flat vector of elements in row-major order
66 pub fn new(rows: usize, cols: usize, data: Vec<f64>) -> Self {
67 assert_eq!(data.len(), rows * cols, "Data length must match dimensions");
68 Matrix { rows, cols, data }
69 }
70
71 /// Creates a matrix filled with zeros.
72 ///
73 /// # Arguments
74 ///
75 /// * `rows` - Number of rows
76 /// * `cols` - Number of columns
77 pub fn zeros(rows: usize, cols: usize) -> Self {
78 Matrix {
79 rows,
80 cols,
81 data: vec![0.0; rows * cols],
82 }
83 }
84
85 // NOTE: Currently unused but kept as reference implementation.
86 // Uncomment if needed for convenience constructor.
87 /*
88 /// Creates a matrix from a row-major slice.
89 ///
90 /// # Arguments
91 ///
92 /// * `rows` - Number of rows
93 /// * `cols` - Number of columns
94 /// * `slice` - Slice containing matrix elements in row-major order
95 pub fn from_row_slice(rows: usize, cols: usize, slice: &[f64]) -> Self {
96 Matrix::new(rows, cols, slice.to_vec())
97 }
98 */
99
100 /// Gets the element at the specified row and column.
101 ///
102 /// # Arguments
103 ///
104 /// * `row` - Row index (0-based)
105 /// * `col` - Column index (0-based)
106 pub fn get(&self, row: usize, col: usize) -> f64 {
107 self.data[row * self.cols + col]
108 }
109
110 /// Sets the element at the specified row and column.
111 ///
112 /// # Arguments
113 ///
114 /// * `row` - Row index (0-based)
115 /// * `col` - Column index (0-based)
116 /// * `val` - Value to set
117 pub fn set(&mut self, row: usize, col: usize, val: f64) {
118 self.data[row * self.cols + col] = val;
119 }
120
121 /// Returns the transpose of this matrix.
122 ///
123 /// Swaps rows with columns: `result\[col\]\[row\] = self\[row\]\[col\]`.
124 pub fn transpose(&self) -> Matrix {
125 let mut t_data = vec![0.0; self.rows * self.cols];
126 for r in 0..self.rows {
127 for c in 0..self.cols {
128 t_data[c * self.rows + r] = self.get(r, c);
129 }
130 }
131 Matrix::new(self.cols, self.rows, t_data)
132 }
133
134 /// Performs matrix multiplication: `self * other`.
135 ///
136 /// # Panics
137 ///
138 /// Panics if `self.cols != other.rows`.
139 pub fn matmul(&self, other: &Matrix) -> Matrix {
140 assert_eq!(self.cols, other.rows, "Dimension mismatch for multiplication");
141 let mut result = Matrix::zeros(self.rows, other.cols);
142
143 for r in 0..self.rows {
144 for c in 0..other.cols {
145 let mut sum = 0.0;
146 for k in 0..self.cols {
147 sum += self.get(r, k) * other.get(k, c);
148 }
149 result.set(r, c, sum);
150 }
151 }
152 result
153 }
154
155 /// Multiplies this matrix by a vector (treating vector as column matrix).
156 ///
157 /// Computes `self * vec` where vec is treated as an n×1 column matrix.
158 ///
159 /// # Panics
160 ///
161 /// Panics if `self.cols != vec.len()`.
162 ///
163 /// # Arguments
164 ///
165 /// * `vec` - Vector to multiply by
166 pub fn mul_vec(&self, vec: &[f64]) -> Vec<f64> {
167 assert_eq!(self.cols, vec.len(), "Dimension mismatch for matrix-vector multiplication");
168 let mut result = vec![0.0; self.rows];
169
170 for r in 0..self.rows {
171 let mut sum = 0.0;
172 for c in 0..self.cols {
173 sum += self.get(r, c) * vec[c];
174 }
175 result[r] = sum;
176 }
177 result
178 }
179}
180
181// ============================================================================
182// QR Decomposition
183// ============================================================================
184
185impl Matrix {
186 /// Computes the QR decomposition using Householder reflections.
187 ///
188 /// Factorizes the matrix as `A = QR` where Q is orthogonal and R is upper triangular.
189 ///
190 /// # Requirements
191 ///
192 /// This implementation requires `rows >= cols` (tall matrix). For OLS regression,
193 /// we always have more observations than predictors, so this requirement is satisfied.
194 ///
195 /// # Returns
196 ///
197 /// A tuple `(Q, R)` where:
198 /// - `Q` is an orthogonal matrix (QᵀQ = I) of size m×m
199 /// - `R` is an upper triangular matrix of size m×n
200 pub fn qr(&self) -> (Matrix, Matrix) {
201 let m = self.rows;
202 let n = self.cols;
203 let mut q = Matrix::identity(m);
204 let mut r = self.clone();
205
206 for k in 0..n.min(m - 1) {
207 // Create vector x = R[k:, k]
208 let mut x = vec![0.0; m - k];
209 for i in k..m {
210 x[i - k] = r.get(i, k);
211 }
212
213 // Norm of x
214 let norm_x: f64 = x.iter().map(|&v| v * v).sum::<f64>().sqrt();
215 if norm_x < QR_ZERO_TOLERANCE { continue; } // Already zero
216
217 // Create vector v = x + sign(x[0]) * ||x|| * e1
218 //
219 // NOTE: Numerical stability consideration (Householder sign choice)
220 // According to Overton & Yu (2023), the numerically stable choice is
221 // σ = -sgn(x₁) in the formula v = x - σ‖x‖e₁.
222 //
223 // This means: v = x - (-sgn(x₁))‖x‖e₁ = x + sgn(x₁)‖x‖e₁
224 //
225 // Equivalently: u₁ = x₁ + sgn(x₁)‖x‖
226 //
227 // Current implementation uses this formula (the "correct" choice for stability):
228 let sign = if x[0] >= 0.0 { 1.0 } else { -1.0 }; // sgn(x₀) as defined (sgn(0) = +1)
229 let u1 = x[0] + sign * norm_x;
230
231 // Normalize v to get Householder vector
232 let mut v = x; // Re-use storage
233 v[0] = u1;
234
235 let norm_v: f64 = v.iter().map(|&val| val * val).sum::<f64>().sqrt();
236 for val in &mut v { *val /= norm_v; }
237
238 // Apply Householder transformation to R: R = H * R = (I - 2vv^T)R = R - 2v(v^T R)
239 // Focus on submatrix R[k:, k:]
240 for j in k..n {
241 let mut dot = 0.0;
242 for i in 0..m-k {
243 dot += v[i] * r.get(k+i, j);
244 }
245
246 for i in 0..m-k {
247 let val = r.get(k+i, j) - 2.0 * v[i] * dot;
248 r.set(k+i, j, val);
249 }
250 }
251
252 // Update Q: Q = Q * H = Q(I - 2vv^T) = Q - 2(Qv)v^T
253 // Focus on Q[:, k:]
254 for i in 0..m {
255 let mut dot = 0.0;
256 for l in 0..m-k {
257 dot += q.get(i, k+l) * v[l];
258 }
259
260 for l in 0..m-k {
261 let val = q.get(i, k+l) - 2.0 * dot * v[l];
262 q.set(i, k+l, val);
263 }
264 }
265 }
266
267 (q, r)
268 }
269
270 /// Creates an identity matrix of the given size.
271 ///
272 /// # Arguments
273 ///
274 /// * `size` - Number of rows and columns (square matrix)
275 pub fn identity(size: usize) -> Self {
276 let mut data = vec![0.0; size * size];
277 for i in 0..size {
278 data[i * size + i] = 1.0;
279 }
280 Matrix::new(size, size, data)
281 }
282
283 /// Inverts an upper triangular matrix (such as R from QR decomposition).
284 ///
285 /// Uses back-substitution to compute the inverse. This is efficient for
286 /// triangular matrices compared to general matrix inversion.
287 ///
288 /// # Panics
289 ///
290 /// Panics if the matrix is not square.
291 ///
292 /// # Returns
293 ///
294 /// `None` if the matrix is singular (has a zero or near-zero diagonal element).
295 /// A matrix is considered singular if any diagonal element is below the
296 /// internal tolerance (1e-10), which indicates the matrix does not have full rank.
297 ///
298 /// # Note
299 ///
300 /// For upper triangular matrices, singularity is equivalent to having a
301 /// zero (or near-zero) diagonal element. This is much simpler to check than
302 /// for general matrices, which would require computing the condition number.
303 pub fn invert_upper_triangular(&self) -> Option<Matrix> {
304 let n = self.rows;
305 assert_eq!(n, self.cols, "Matrix must be square");
306
307 // Check for singularity using relative tolerance
308 // This scales with the magnitude of diagonal elements, handling matrices
309 // of different scales better than a fixed absolute tolerance.
310 //
311 // Previous implementation used absolute tolerance:
312 // if self.get(i, i).abs() < SINGULAR_TOLERANCE { return None; }
313 //
314 // New implementation uses relative tolerance similar to LAPACK:
315 // tolerance = max_diag * epsilon * n
316 // where epsilon is machine epsilon (~2.2e-16 for f64)
317 let max_diag: f64 = (0..n)
318 .map(|i| self.get(i, i).abs())
319 .fold(0.0_f64, |acc, val| acc.max(val));
320
321 // Use a relative tolerance based on the maximum diagonal element
322 // This is similar to LAPACK's dlamch machine epsilon approach
323 let epsilon = 2.0_f64 * f64::EPSILON; // ~4.4e-16 for f64
324 let relative_tolerance = max_diag * epsilon * n as f64;
325 let tolerance = SINGULAR_TOLERANCE.max(relative_tolerance);
326
327 for i in 0..n {
328 if self.get(i, i).abs() < tolerance {
329 return None; // Singular matrix - cannot invert
330 }
331 }
332
333 let mut inv = Matrix::zeros(n, n);
334
335 for i in 0..n {
336 inv.set(i, i, 1.0 / self.get(i, i));
337
338 for j in (0..i).rev() {
339 let mut sum = 0.0;
340 for k in j+1..=i {
341 sum += self.get(j, k) * inv.get(k, i);
342 }
343 inv.set(j, i, -sum / self.get(j, j));
344 }
345 }
346
347 Some(inv)
348 }
349
350 /// Inverts an upper triangular matrix with a custom tolerance multiplier.
351 ///
352 /// The tolerance is computed as `max_diag * epsilon * n * tolerance_mult`.
353 /// A higher tolerance_mult allows more tolerance for near-singular matrices.
354 ///
355 /// # Arguments
356 ///
357 /// * `tolerance_mult` - Multiplier for the tolerance (1.0 = standard, higher = more tolerant)
358 pub fn invert_upper_triangular_with_tolerance(&self, tolerance_mult: f64) -> Option<Matrix> {
359 let n = self.rows;
360 assert_eq!(n, self.cols, "Matrix must be square");
361
362 // Check for singularity using relative tolerance
363 let max_diag: f64 = (0..n)
364 .map(|i| self.get(i, i).abs())
365 .fold(0.0_f64, |acc, val| acc.max(val));
366
367 // Use a relative tolerance based on the maximum diagonal element
368 let epsilon = 2.0_f64 * f64::EPSILON;
369 let relative_tolerance = max_diag * epsilon * n as f64 * tolerance_mult;
370 let tolerance = SINGULAR_TOLERANCE.max(relative_tolerance);
371
372 for i in 0..n {
373 if self.get(i, i).abs() < tolerance {
374 return None;
375 }
376 }
377
378 let mut inv = Matrix::zeros(n, n);
379
380 for i in 0..n {
381 inv.set(i, i, 1.0 / self.get(i, i));
382
383 for j in (0..i).rev() {
384 let mut sum = 0.0;
385 for k in j+1..=i {
386 sum += self.get(j, k) * inv.get(k, i);
387 }
388 inv.set(j, i, -sum / self.get(j, j));
389 }
390 }
391
392 Some(inv)
393 }
394
395 /// Computes the inverse of a square matrix using QR decomposition.
396 ///
397 /// For an invertible matrix A, computes A⁻¹ such that A * A⁻¹ = I.
398 /// Uses QR decomposition for numerical stability.
399 ///
400 /// # Panics
401 ///
402 /// Panics if the matrix is not square (i.e., `self.rows != self.cols`).
403 /// Check dimensions before calling if the matrix shape is not guaranteed.
404 ///
405 /// # Returns
406 ///
407 /// Returns `Some(inverse)` if the matrix is invertible, or `None` if
408 /// the matrix is singular (non-invertible).
409 pub fn invert(&self) -> Option<Matrix> {
410 let n = self.rows;
411 if n != self.cols {
412 panic!("Matrix must be square for inversion");
413 }
414
415 // Use QR decomposition: A = Q * R
416 let (q, r) = self.qr();
417
418 // Compute R⁻¹ (upper triangular inverse)
419 let r_inv = r.invert_upper_triangular()?;
420
421 // A⁻¹ = R⁻¹ * Q^T
422 let q_transpose = q.transpose();
423 let mut result = Matrix::zeros(n, n);
424
425 for i in 0..n {
426 for j in 0..n {
427 let mut sum = 0.0;
428 for k in 0..n {
429 sum += r_inv.get(i, k) * q_transpose.get(k, j);
430 }
431 result.set(i, j, sum);
432 }
433 }
434
435 Some(result)
436 }
437
438 /// Computes the inverse of X'X given the QR decomposition of X (R's chol2inv).
439 ///
440 /// This is equivalent to computing `(X'X)^(-1)` using the QR decomposition of X.
441 /// R's `chol2inv` function is used for numerical stability in recursive residuals.
442 ///
443 /// # Arguments
444 ///
445 /// * `x` - Input matrix (must have rows >= cols)
446 ///
447 /// # Returns
448 ///
449 /// `Some((X'X)^(-1))` if X has full rank, `None` otherwise.
450 ///
451 /// # Algorithm
452 ///
453 /// Given QR decomposition X = QR where R is upper triangular:
454 /// 1. Extract the upper p×p portion of R (denoted R₁)
455 /// 2. Invert R₁ (upper triangular inverse)
456 /// 3. Compute (X'X)^(-1) = R₁^(-1) × R₁^(-T)
457 ///
458 /// This works because X'X = R'Q'QR = R'R, and R₁ contains the Cholesky factor.
459 pub fn chol2inv_from_qr(&self) -> Option<Matrix> {
460 self.chol2inv_from_qr_with_tolerance(1.0)
461 }
462
463 /// Computes the inverse of X'X given the QR decomposition with custom tolerance.
464 ///
465 /// Similar to `chol2inv_from_qr` but allows specifying a tolerance multiplier
466 /// for handling near-singular matrices.
467 ///
468 /// # Arguments
469 ///
470 /// * `tolerance_mult` - Multiplier for the tolerance (higher = more tolerant)
471 pub fn chol2inv_from_qr_with_tolerance(&self, tolerance_mult: f64) -> Option<Matrix> {
472 let p = self.cols;
473
474 // QR decomposition: X = QR
475 // For X (m×n, m≥n), R is m×n upper triangular
476 // The upper n×n block of R contains the meaningful values
477 let (_, r_full) = self.qr();
478
479 // Extract upper p×p portion from R
480 // For tall matrices (m > p), R has zeros below diagonal in first p rows
481 // For square matrices (m = p), R is p×p upper triangular
482 let mut r1 = Matrix::zeros(p, p);
483 for i in 0..p {
484 // Row i of R1 is row i of R_full, columns 0..p
485 // But we only copy the upper triangular part (columns i..p)
486 for j in i..p {
487 r1.set(i, j, r_full.get(i, j));
488 }
489 // Also copy diagonal if not yet copied
490 if i < p {
491 r1.set(i, i, r_full.get(i, i));
492 }
493 }
494
495 // Invert R₁ (upper triangular) with custom tolerance
496 let r1_inv = r1.invert_upper_triangular_with_tolerance(tolerance_mult)?;
497
498 // Compute (X'X)^(-1) = R₁^(-1) × R₁^(-T)
499 let mut result = Matrix::zeros(p, p);
500 for i in 0..p {
501 for j in 0..p {
502 let mut sum = 0.0;
503 // result[i,j] = sum(R1_inv[i,k] * R1_inv[j,k] for k=0..p)
504 // R1_inv is upper triangular, but we iterate full range
505 for k in 0..p {
506 sum += r1_inv.get(i, k) * r1_inv.get(j, k);
507 }
508 result.set(i, j, sum);
509 }
510 }
511
512 Some(result)
513 }
514}
515
516// ============================================================================
517// Vector Helper Functions
518// ============================================================================
519
520/// Computes the arithmetic mean of a slice of f64 values.
521///
522/// Returns 0.0 for empty slices.
523///
524/// # Arguments
525///
526/// * `v` - Slice of values
527pub fn vec_mean(v: &[f64]) -> f64 {
528 if v.is_empty() { return 0.0; }
529 v.iter().sum::<f64>() / v.len() as f64
530}
531
532/// Computes element-wise subtraction of two slices: `a - b`.
533///
534/// # Arguments
535///
536/// * `a` - Minuend slice
537/// * `b` - Subtrahend slice
538///
539/// # Panics
540///
541/// Panics if slices have different lengths.
542pub fn vec_sub(a: &[f64], b: &[f64]) -> Vec<f64> {
543 a.iter().zip(b.iter()).map(|(x, y)| x - y).collect()
544}
545
546/// Computes the dot product of two slices: `Σ(a[i] * b[i])`.
547///
548/// # Arguments
549///
550/// * `a` - First slice
551/// * `b` - Second slice
552///
553/// # Panics
554///
555/// Panics if slices have different lengths.
556pub fn vec_dot(a: &[f64], b: &[f64]) -> f64 {
557 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
558}
559
560// ============================================================================
561// R-Compatible QR Decomposition (LINPACK dqrdc2 with Column Pivoting)
562// ============================================================================
563
564/// QR decomposition result using R's LINPACK dqrdc2 algorithm.
565///
566/// This implements the QR decomposition with column pivoting as used by R's
567/// `qr()` function with `LAPACK=FALSE`. The algorithm is a modification of
568/// LINPACK's DQRDC that:
569/// - Uses Householder transformations
570/// - Implements limited column pivoting based on 2-norms of reduced columns
571/// - Moves columns with near-zero norm to the right-hand edge
572/// - Computes the rank (number of linearly independent columns)
573///
574/// # Fields
575///
576/// * `qr` - The QR factorization (upper triangle contains R, below diagonal
577/// contains Householder vector information)
578/// * `qraux` - Auxiliary information for recovering the orthogonal part Q
579/// * `pivot` - Column permutation: `pivot\[j\]` contains the original column index
580/// now in column j
581/// * `rank` - Number of linearly independent columns (the computed rank)
582#[derive(Clone, Debug)]
583pub struct QRLinpack {
584 /// QR factorization matrix (same dimensions as input)
585 pub qr: Matrix,
586 /// Auxiliary information for Q recovery
587 pub qraux: Vec<f64>,
588 /// Column pivot vector (1-based indices like R)
589 pub pivot: Vec<usize>,
590 /// Computed rank (number of linearly independent columns)
591 pub rank: usize,
592}
593
594impl Matrix {
595 /// Computes QR decomposition using R's LINPACK dqrdc2 algorithm with column pivoting.
596 ///
597 /// This is a port of R's dqrdc2.f, which is a modification of LINPACK's DQRDC.
598 /// The algorithm:
599 /// 1. Uses Householder transformations for QR factorization
600 /// 2. Implements limited column pivoting based on column 2-norms
601 /// 3. Moves columns with near-zero norm to the right-hand edge
602 /// 4. Computes the rank (number of linearly independent columns)
603 ///
604 /// # Arguments
605 ///
606 /// * `tol` - Tolerance for determining linear independence. Default is 1e-7 (R's default).
607 /// Columns with norm < tol * original_norm are considered negligible.
608 ///
609 /// # Returns
610 ///
611 /// A [`QRLinpack`] struct containing the QR factorization, auxiliary information,
612 /// pivot vector, and computed rank.
613 ///
614 /// # Algorithm Details
615 ///
616 /// The decomposition is A * P = Q * R where:
617 /// - P is the permutation matrix coded by `pivot`
618 /// - Q is orthogonal (m × m)
619 /// - R is upper triangular in the first `rank` rows
620 ///
621 /// The `qr` matrix contains:
622 /// - Upper triangle: R matrix (if pivoting was performed, this is R of the permuted matrix)
623 /// - Below diagonal: Householder vector information
624 ///
625 /// # Reference
626 ///
627 /// - R source: src/appl/dqrdc2.f
628 /// - LINPACK documentation: <https://www.netlib.org/linpack/dqrdc.f>
629 pub fn qr_linpack(&self, tol: Option<f64>) -> QRLinpack {
630 let n = self.rows;
631 let p = self.cols;
632 let lup = n.min(p);
633
634 // Default tolerance matches R's qr.default: tol = 1e-07
635 let tol = tol.unwrap_or(1e-07);
636
637 // Initialize working matrices
638 let mut x = self.clone(); // Working copy that will be modified
639 let mut qraux = vec![0.0; p];
640 let mut pivot: Vec<usize> = (1..=p).collect(); // 1-based indices like R
641 let mut work = vec![(0.0, 0.0); p]; // (work[j,1], work[j,2])
642
643 // Compute the norms of the columns of x (initialization)
644 if n > 0 {
645 for j in 0..p {
646 let mut norm = 0.0;
647 for i in 0..n {
648 norm += x.get(i, j) * x.get(i, j);
649 }
650 norm = norm.sqrt();
651 qraux[j] = norm;
652 let original_norm = if norm == 0.0 { 1.0 } else { norm };
653 work[j] = (norm, original_norm);
654 }
655 }
656
657 let mut k = p + 1; // Will be decremented to get the final rank
658
659 // Perform the Householder reduction of x
660 for l in 0..lup {
661 // Cycle columns from l to p until one with non-negligible norm is found
662 // A column is negligible if its norm has fallen below tol * original_norm
663 while l < k - 1 && qraux[l] < work[l].1 * tol {
664 // Move column l to the end (it's negligible)
665 let lp1 = l + 1;
666
667 // Shift columns in x: x(i, l..p-1) = x(i, l+1..p)
668 for i in 0..n {
669 let t = x.get(i, l);
670 for j in lp1..p {
671 x.set(i, j - 1, x.get(i, j));
672 }
673 x.set(i, p - 1, t);
674 }
675
676 // Shift pivot, qraux, and work arrays
677 let saved_pivot = pivot[l];
678 let saved_qraux = qraux[l];
679 let saved_work = work[l];
680
681 for j in lp1..p {
682 pivot[j - 1] = pivot[j];
683 qraux[j - 1] = qraux[j];
684 work[j - 1] = work[j];
685 }
686
687 pivot[p - 1] = saved_pivot;
688 qraux[p - 1] = saved_qraux;
689 work[p - 1] = saved_work;
690
691 k -= 1;
692 }
693
694 if l == n - 1 {
695 // Last row - skip transformation
696 break;
697 }
698
699 // Compute the Householder transformation for column l
700 // nrmxl = norm of x[l:, l]
701 let mut nrmxl = 0.0;
702 for i in l..n {
703 let val = x.get(i, l);
704 nrmxl += val * val;
705 }
706 nrmxl = nrmxl.sqrt();
707
708 if nrmxl == 0.0 {
709 // Zero column - continue to next
710 continue;
711 }
712
713 // Apply sign for numerical stability (dsign in Fortran)
714 let x_ll = x.get(l, l);
715 if x_ll != 0.0 {
716 nrmxl = nrmxl.copysign(x_ll);
717 }
718
719 // Scale the column
720 let scale = 1.0 / nrmxl;
721 for i in l..n {
722 x.set(i, l, x.get(i, l) * scale);
723 }
724 x.set(l, l, 1.0 + x.get(l, l));
725
726 // Apply the transformation to remaining columns, updating the norms
727 let lp1 = l + 1;
728 if p > lp1 {
729 for j in lp1..p {
730 // Compute t = -dot(x[l:, l], x[l:, j]) / x(l, l)
731 let mut dot = 0.0;
732 for i in l..n {
733 dot += x.get(i, l) * x.get(i, j);
734 }
735 let t = -dot / x.get(l, l);
736
737 // x[l:, j] = x[l:, j] + t * x[l:, l]
738 for i in l..n {
739 let val = x.get(i, j) + t * x.get(i, l);
740 x.set(i, j, val);
741 }
742
743 // Update the norm
744 if qraux[j] != 0.0 {
745 // tt = 1.0 - (x(l, j) / qraux[j])^2
746 let x_lj = x.get(l, j).abs();
747 let mut tt = 1.0 - (x_lj / qraux[j]).powi(2);
748 tt = tt.max(0.0);
749
750 // Recompute norm if there is large reduction (BDR mod 9/99)
751 // The tolerance here is on the squared norm
752 if tt.abs() < 1e-6 {
753 // Re-compute norm directly
754 let mut new_norm = 0.0;
755 for i in (l + 1)..n {
756 let val = x.get(i, j);
757 new_norm += val * val;
758 }
759 new_norm = new_norm.sqrt();
760 qraux[j] = new_norm;
761 work[j].0 = new_norm;
762 } else {
763 qraux[j] = qraux[j] * tt.sqrt();
764 }
765 }
766 }
767 }
768
769 // Save the transformation
770 qraux[l] = x.get(l, l);
771 x.set(l, l, -nrmxl);
772 }
773
774 // Compute final rank
775 let rank = k - 1;
776 let rank = rank.min(n);
777
778 QRLinpack {
779 qr: x,
780 qraux,
781 pivot,
782 rank,
783 }
784 }
785
786 /// Solves a linear system using the QR decomposition with column pivoting.
787 ///
788 /// This implements a least squares solver using the pivoted QR decomposition.
789 /// For rank-deficient cases, coefficients corresponding to linearly dependent
790 /// columns are set to `f64::NAN`.
791 ///
792 /// # Arguments
793 ///
794 /// * `qr_result` - QR decomposition from [`Matrix::qr_linpack`]
795 /// * `y` - Right-hand side vector
796 ///
797 /// # Returns
798 ///
799 /// A vector of coefficients, or `None` if the system is exactly singular.
800 ///
801 /// # Algorithm
802 ///
803 /// This solver uses the standard QR decomposition approach:
804 /// 1. Compute the QR decomposition of the permuted matrix
805 /// 2. Extract R matrix (upper triangular with positive diagonal)
806 /// 3. Compute qty = Q^T * y
807 /// 4. Solve R * coef = qty using back substitution
808 /// 5. Apply the pivot permutation to restore original column order
809 ///
810 /// # Note
811 ///
812 /// The LINPACK QR algorithm stores R with mixed signs on the diagonal.
813 /// This solver corrects for that by taking the absolute value of R's diagonal.
814 pub fn qr_solve_linpack(&self, qr_result: &QRLinpack, y: &[f64]) -> Option<Vec<f64>> {
815 let n = self.rows;
816 let p = self.cols;
817 let k = qr_result.rank;
818
819 if y.len() != n {
820 return None;
821 }
822
823 if k == 0 {
824 return None;
825 }
826
827 // Step 1: Compute Q^T * y using the Householder vectors directly
828 // This is more efficient than reconstructing the full Q matrix
829 let mut qty = y.to_vec();
830
831 for j in 0..k {
832 // Check if this Householder transformation is valid
833 let r_jj = qr_result.qr.get(j, j);
834 if r_jj == 0.0 {
835 continue;
836 }
837
838 // Compute dot = v_j^T * qty[j:]
839 // where v_j is the Householder vector stored in qr[j:, j]
840 // The storage convention:
841 // - qr[j,j] = -nrmxl (after final overwrite)
842 // - qr[i,j] for i > j is the scaled Householder vector element
843 // - qraux[j] = 1 + original_x[j,j]/nrmxl (the unscaled first element)
844
845 // Reconstruct the Householder vector v_j
846 // After scaling by 1/nrmxl, we have:
847 // v_scaled[j] = 1 + x[j,j]/nrmxl
848 // v_scaled[i] = x[i,j]/nrmxl for i > j
849 // The actual unit vector is v = v_scaled / ||v_scaled||
850
851 let mut v = vec![0.0; n - j];
852 // Copy the scaled Householder vector from qr
853 for i in j..n {
854 v[i - j] = qr_result.qr.get(i, j);
855 }
856
857 // The j-th element was modified during the QR decomposition
858 // We need to reconstruct it from qraux
859 let alpha = qr_result.qraux[j];
860 if alpha != 0.0 {
861 v[0] = alpha;
862 }
863
864 // Compute the norm of v
865 let v_norm: f64 = v.iter().map(|&x| x * x).sum::<f64>().sqrt();
866 if v_norm < 1e-14 {
867 continue;
868 }
869
870 // Compute dot = v^T * qty[j:]
871 let mut dot = 0.0;
872 for i in j..n {
873 dot += v[i - j] * qty[i];
874 }
875
876 // Apply Householder transformation: qty[j:] = qty[j:] - 2 * v * (v^T * qty[j:]) / (v^T * v)
877 // Since v is already scaled, we use: t = 2 * dot / (v_norm^2)
878 let t = 2.0 * dot / (v_norm * v_norm);
879
880 for i in j..n {
881 qty[i] -= t * v[i - j];
882 }
883 }
884
885 // Step 2: Back substitution on R (solve R * coef = qty)
886 // The R matrix is stored in the upper triangle of qr
887 // Note: The diagonal elements of R are negative (from -nrmxl)
888 // We use them as-is since the signs cancel out in the computation
889 let mut coef_permuted = vec![f64::NAN; p];
890
891 for row in (0..k).rev() {
892 let r_diag = qr_result.qr.get(row, row);
893 // Use relative tolerance for singularity check
894 let max_abs = (0..k).map(|i| qr_result.qr.get(i, i).abs()).fold(0.0_f64, f64::max);
895 let tolerance = 1e-14 * max_abs.max(1.0);
896
897 if r_diag.abs() < tolerance {
898 return None; // Singular
899 }
900
901 let mut sum = qty[row];
902 for col in (row + 1)..k {
903 sum -= qr_result.qr.get(row, col) * coef_permuted[col];
904 }
905 coef_permuted[row] = sum / r_diag;
906 }
907
908 // Step 3: Apply pivot permutation to get coefficients in original order
909 // pivot[j] is 1-based, indicating which original column is now in position j
910 let mut result = vec![0.0; p];
911 for j in 0..p {
912 let original_col = qr_result.pivot[j] - 1; // Convert to 0-based
913 result[original_col] = coef_permuted[j];
914 }
915
916 Some(result)
917 }
918}
919
920/// Performs OLS regression using R's LINPACK QR algorithm.
921///
922/// This function is a drop-in replacement for `fit_ols` that uses the
923/// R-compatible QR decomposition with column pivoting. It handles
924/// rank-deficient matrices more gracefully than the standard QR decomposition.
925///
926/// # Arguments
927///
928/// * `y` - Response variable (n observations)
929/// * `x` - Design matrix (n rows, p columns including intercept)
930///
931/// # Returns
932///
933/// * `Some(Vec<f64>)` - OLS coefficient vector (p elements)
934/// * `None` - If the matrix is exactly singular or dimensions don't match
935///
936/// # Note
937///
938/// For rank-deficient systems, this function uses the pivoted QR which
939/// automatically handles multicollinearity by selecting a linearly
940/// independent subset of columns.
941pub fn fit_ols_linpack(y: &[f64], x: &Matrix) -> Option<Vec<f64>> {
942 let qr_result = x.qr_linpack(None);
943 x.qr_solve_linpack(&qr_result, y)
944}
945
946/// Fits OLS and predicts using R's LINPACK QR with rank-deficient handling.
947///
948/// This function matches R's `lm.fit` behavior for rank-deficient cases:
949/// coefficients for linearly dependent columns are set to NA, and predictions
950/// are computed using only the valid (non-NA) coefficients and their corresponding
951/// columns. This matches how R handles rank-deficient models in prediction.
952///
953/// # Arguments
954///
955/// * `y` - Response variable (n observations)
956/// * `x` - Design matrix (n rows, p columns including intercept)
957///
958/// # Returns
959///
960/// * `Some(Vec<f64>)` - Predictions (n elements)
961/// * `None` - If the matrix is exactly singular or dimensions don't match
962///
963/// # Algorithm
964///
965/// For rank-deficient systems (rank < p):
966/// 1. Compute QR decomposition with column pivoting
967/// 2. Get coefficients (rank-deficient columns will have NaN)
968/// 3. Build a reduced design matrix with only pivoted, non-singular columns
969/// 4. Compute predictions using only the valid columns
970///
971/// This matches R's behavior where `predict(lm.fit(...))` handles NA coefficients
972/// by excluding the corresponding columns from the prediction.
973pub fn fit_and_predict_linpack(y: &[f64], x: &Matrix) -> Option<Vec<f64>> {
974 let n = x.rows;
975 let p = x.cols;
976
977 // Compute QR decomposition
978 let qr_result = x.qr_linpack(None);
979 let k = qr_result.rank;
980
981 // Solve for coefficients
982 let beta_permuted = x.qr_solve_linpack(&qr_result, y)?;
983
984 // Check for rank deficiency
985 if k == p {
986 // Full rank - use standard prediction
987 return Some(x.mul_vec(&beta_permuted));
988 }
989
990 // Rank-deficient case: some columns are collinear and have NaN coefficients
991 // We compute predictions using only columns with valid (non-NaN) coefficients
992 // This matches R's behavior where NA coefficients exclude columns from prediction
993
994 let mut pred = vec![0.0; n];
995
996 for row in 0..n {
997 let mut sum = 0.0;
998 for j in 0..p {
999 let b_val = beta_permuted[j];
1000 if b_val.is_nan() {
1001 continue; // Skip collinear columns (matches R's NA coefficient behavior)
1002 }
1003 sum += x.get(row, j) * b_val;
1004 }
1005 pred[row] = sum;
1006 }
1007
1008 Some(pred)
1009}