linreg_core/linalg.rs
1//! Minimal Linear Algebra module to replace nalgebra dependency.
2//!
3//! Implements matrix operations, QR decomposition, and solvers needed for OLS.
4//! Uses row-major storage for compatibility with statistical computing conventions.
5//!
6//! # Numerical Stability Considerations
7//!
8//! This implementation uses Householder QR decomposition with careful attention to
9//! numerical stability:
10//!
11//! - **Sign convention**: Uses the numerically stable Householder sign choice
12//! (v = x + sgn(x₀)||x||e₁) to avoid cancellation
13//! - **Tolerance checking**: Uses predefined tolerances to detect near-singular matrices
14//! - **Zero-skipping**: Skips transformations when columns are already zero-aligned
15//!
16//! # Scaling Recommendations
17//!
18//! For optimal numerical stability when predictor variables have vastly different
19//! scales (e.g., one variable in millions, another in thousandths), consider
20//! standardizing predictors before regression. Z-score standardization
21//! (`x_scaled = (x - mean) / std`) is already done in VIF calculation.
22//!
23//! However, the current implementation handles typical OLS cases without explicit
24//! scaling, as QR decomposition is generally stable for well-conditioned matrices.
25
26// ============================================================================
27// Numerical Constants
28// ============================================================================
29
30/// Machine epsilon threshold for detecting zero values in QR decomposition.
31/// Values below this are treated as zero to avoid numerical instability.
32const QR_ZERO_TOLERANCE: f64 = 1e-12;
33
34/// Threshold for detecting singular matrices during inversion.
35/// Diagonal elements below this value indicate a near-singular matrix.
36const SINGULAR_TOLERANCE: f64 = 1e-10;
37
38/// A dense matrix stored in row-major order.
39///
40/// # Storage
41///
42/// Elements are stored in a single flat vector in row-major order:
43/// `data[row * cols + col]`
44///
45/// # Example
46///
47/// ```
48/// # use linreg_core::linalg::Matrix;
49/// let m = Matrix::new(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
50/// assert_eq!(m.rows, 2);
51/// assert_eq!(m.cols, 3);
52/// assert_eq!(m.get(0, 0), 1.0);
53/// assert_eq!(m.get(1, 2), 6.0);
54/// ```
55#[derive(Clone, Debug)]
56pub struct Matrix {
57 /// Number of rows in the matrix
58 pub rows: usize,
59 /// Number of columns in the matrix
60 pub cols: usize,
61 /// Flat vector storing matrix elements in row-major order
62 pub data: Vec<f64>,
63}
64
65impl Matrix {
66 /// Creates a new matrix from the given dimensions and data.
67 ///
68 /// # Panics
69 ///
70 /// Panics if `data.len() != rows * cols`.
71 ///
72 /// # Arguments
73 ///
74 /// * `rows` - Number of rows
75 /// * `cols` - Number of columns
76 /// * `data` - Flat vector of elements in row-major order
77 ///
78 /// # Example
79 ///
80 /// ```
81 /// # use linreg_core::linalg::Matrix;
82 /// let m = Matrix::new(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
83 /// assert_eq!(m.get(0, 0), 1.0);
84 /// assert_eq!(m.get(0, 1), 2.0);
85 /// assert_eq!(m.get(1, 0), 3.0);
86 /// assert_eq!(m.get(1, 1), 4.0);
87 /// ```
88 pub fn new(rows: usize, cols: usize, data: Vec<f64>) -> Self {
89 assert_eq!(data.len(), rows * cols, "Data length must match dimensions");
90 Matrix { rows, cols, data }
91 }
92
93 /// Creates a matrix filled with zeros.
94 ///
95 /// # Arguments
96 ///
97 /// * `rows` - Number of rows
98 /// * `cols` - Number of columns
99 ///
100 /// # Example
101 ///
102 /// ```
103 /// # use linreg_core::linalg::Matrix;
104 /// let m = Matrix::zeros(3, 2);
105 /// assert_eq!(m.rows, 3);
106 /// assert_eq!(m.cols, 2);
107 /// assert_eq!(m.get(1, 1), 0.0);
108 /// ```
109 pub fn zeros(rows: usize, cols: usize) -> Self {
110 Matrix {
111 rows,
112 cols,
113 data: vec![0.0; rows * cols],
114 }
115 }
116
117 // NOTE: Currently unused but kept as reference implementation.
118 // Uncomment if needed for convenience constructor.
119 /*
120 /// Creates a matrix from a row-major slice.
121 ///
122 /// # Arguments
123 ///
124 /// * `rows` - Number of rows
125 /// * `cols` - Number of columns
126 /// * `slice` - Slice containing matrix elements in row-major order
127 pub fn from_row_slice(rows: usize, cols: usize, slice: &[f64]) -> Self {
128 Matrix::new(rows, cols, slice.to_vec())
129 }
130 */
131
132 /// Gets the element at the specified row and column.
133 ///
134 /// # Arguments
135 ///
136 /// * `row` - Row index (0-based)
137 /// * `col` - Column index (0-based)
138 pub fn get(&self, row: usize, col: usize) -> f64 {
139 self.data[row * self.cols + col]
140 }
141
142 /// Sets the element at the specified row and column.
143 ///
144 /// # Arguments
145 ///
146 /// * `row` - Row index (0-based)
147 /// * `col` - Column index (0-based)
148 /// * `val` - Value to set
149 pub fn set(&mut self, row: usize, col: usize, val: f64) {
150 self.data[row * self.cols + col] = val;
151 }
152
153 /// Returns the transpose of this matrix.
154 ///
155 /// Swaps rows with columns: `result[col][row] = self[row][col]`.
156 ///
157 /// # Example
158 ///
159 /// ```
160 /// # use linreg_core::linalg::Matrix;
161 /// let m = Matrix::new(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
162 /// let t = m.transpose();
163 /// assert_eq!(t.rows, 3);
164 /// assert_eq!(t.cols, 2);
165 /// assert_eq!(t.get(0, 1), 4.0);
166 /// ```
167 pub fn transpose(&self) -> Matrix {
168 let mut t_data = vec![0.0; self.rows * self.cols];
169 for r in 0..self.rows {
170 for c in 0..self.cols {
171 t_data[c * self.rows + r] = self.get(r, c);
172 }
173 }
174 Matrix::new(self.cols, self.rows, t_data)
175 }
176
177 /// Performs matrix multiplication: `self * other`.
178 ///
179 /// # Panics
180 ///
181 /// Panics if `self.cols != other.rows`.
182 ///
183 /// # Example
184 ///
185 /// ```
186 /// # use linreg_core::linalg::Matrix;
187 /// let a = Matrix::new(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
188 /// let b = Matrix::new(3, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
189 /// let c = a.matmul(&b);
190 /// assert_eq!(c.rows, 2);
191 /// assert_eq!(c.cols, 2);
192 /// assert_eq!(c.get(0, 0), 22.0); // 1*1 + 2*3 + 3*5
193 /// ```
194 pub fn matmul(&self, other: &Matrix) -> Matrix {
195 assert_eq!(
196 self.cols, other.rows,
197 "Dimension mismatch for multiplication"
198 );
199 let mut result = Matrix::zeros(self.rows, other.cols);
200
201 for r in 0..self.rows {
202 for c in 0..other.cols {
203 let mut sum = 0.0;
204 for k in 0..self.cols {
205 sum += self.get(r, k) * other.get(k, c);
206 }
207 result.set(r, c, sum);
208 }
209 }
210 result
211 }
212
213 /// Multiplies this matrix by a vector (treating vector as column matrix).
214 ///
215 /// Computes `self * vec` where vec is treated as an n×1 column matrix.
216 ///
217 /// # Panics
218 ///
219 /// Panics if `self.cols != vec.len()`.
220 ///
221 /// # Arguments
222 ///
223 /// * `vec` - Vector to multiply by
224 ///
225 /// # Example
226 ///
227 /// ```
228 /// # use linreg_core::linalg::Matrix;
229 /// let m = Matrix::new(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
230 /// let v = vec![1.0, 2.0, 3.0];
231 /// let result = m.mul_vec(&v);
232 /// assert_eq!(result.len(), 2);
233 /// assert_eq!(result[0], 14.0); // 1*1 + 2*2 + 3*3
234 /// ```
235 #[allow(clippy::needless_range_loop)]
236 pub fn mul_vec(&self, vec: &[f64]) -> Vec<f64> {
237 assert_eq!(
238 self.cols,
239 vec.len(),
240 "Dimension mismatch for matrix-vector multiplication"
241 );
242 let mut result = vec![0.0; self.rows];
243
244 for r in 0..self.rows {
245 let mut sum = 0.0;
246 for c in 0..self.cols {
247 sum += self.get(r, c) * vec[c];
248 }
249 result[r] = sum;
250 }
251 result
252 }
253
254 /// Computes the dot product of a column with a vector: `Σ(data[i * cols + col] * v[i])`.
255 ///
256 /// For a row-major matrix, this iterates through all rows at a fixed column.
257 ///
258 /// # Arguments
259 ///
260 /// * `col` - Column index
261 /// * `v` - Vector to dot with (must have length equal to rows)
262 ///
263 /// # Panics
264 ///
265 /// Panics if `col >= cols` or `v.len() != rows`.
266 #[allow(clippy::needless_range_loop)]
267 pub fn col_dot(&self, col: usize, v: &[f64]) -> f64 {
268 assert!(col < self.cols, "Column index out of bounds");
269 assert_eq!(
270 self.rows,
271 v.len(),
272 "Vector length must match number of rows"
273 );
274
275 let mut sum = 0.0;
276 for row in 0..self.rows {
277 sum += self.get(row, col) * v[row];
278 }
279 sum
280 }
281
282 /// Performs the column-vector operation in place: `v += alpha * column_col`.
283 ///
284 /// This is the AXPY operation where the column is treated as a vector.
285 /// For row-major storage, we iterate through rows at a fixed column.
286 ///
287 /// # Arguments
288 ///
289 /// * `col` - Column index
290 /// * `alpha` - Scaling factor for the column
291 /// * `v` - Vector to modify in place (must have length equal to rows)
292 ///
293 /// # Panics
294 ///
295 /// Panics if `col >= cols` or `v.len() != rows`.
296 #[allow(clippy::needless_range_loop)]
297 pub fn col_axpy_inplace(&self, col: usize, alpha: f64, v: &mut [f64]) {
298 assert!(col < self.cols, "Column index out of bounds");
299 assert_eq!(
300 self.rows,
301 v.len(),
302 "Vector length must match number of rows"
303 );
304
305 for row in 0..self.rows {
306 v[row] += alpha * self.get(row, col);
307 }
308 }
309
310 /// Computes the squared L2 norm of a column: `Σ(data[i * cols + col]²)`.
311 ///
312 /// # Arguments
313 ///
314 /// * `col` - Column index
315 ///
316 /// # Panics
317 ///
318 /// Panics if `col >= cols`.
319 #[allow(clippy::needless_range_loop)]
320 pub fn col_norm2(&self, col: usize) -> f64 {
321 assert!(col < self.cols, "Column index out of bounds");
322
323 let mut sum = 0.0;
324 for row in 0..self.rows {
325 let val = self.get(row, col);
326 sum += val * val;
327 }
328 sum
329 }
330
331 /// Adds a value to diagonal elements starting from a given index.
332 ///
333 /// This is useful for ridge regression where we add `lambda * I` to `X^T X`,
334 /// but the intercept column should not be penalized.
335 ///
336 /// # Arguments
337 ///
338 /// * `alpha` - Value to add to diagonal elements
339 /// * `start_index` - Starting diagonal index (0 = first diagonal element)
340 ///
341 /// # Panics
342 ///
343 /// Panics if the matrix is not square.
344 ///
345 /// # Example
346 ///
347 /// For a 3×3 identity matrix with intercept in first column (unpenalized):
348 /// ```text
349 /// add_diagonal_in_place(lambda, 1) on:
350 /// [1.0, 0.0, 0.0] [1.0, 0.0, 0.0 ]
351 /// [0.0, 1.0, 0.0] -> [0.0, 1.0+λ, 0.0 ]
352 /// [0.0, 0.0, 1.0] [0.0, 0.0, 1.0+λ]
353 /// ```
354 pub fn add_diagonal_in_place(&mut self, alpha: f64, start_index: usize) {
355 assert_eq!(self.rows, self.cols, "Matrix must be square");
356 let n = self.rows;
357 for i in start_index..n {
358 let current = self.get(i, i);
359 self.set(i, i, current + alpha);
360 }
361 }
362}
363
364// ============================================================================
365// QR Decomposition
366// ============================================================================
367
368impl Matrix {
369 /// Computes the QR decomposition using Householder reflections.
370 ///
371 /// Factorizes the matrix as `A = QR` where Q is orthogonal and R is upper triangular.
372 ///
373 /// # Requirements
374 ///
375 /// This implementation requires `rows >= cols` (tall matrix). For OLS regression,
376 /// we always have more observations than predictors, so this requirement is satisfied.
377 ///
378 /// # Returns
379 ///
380 /// A tuple `(Q, R)` where:
381 /// - `Q` is an orthogonal matrix (QᵀQ = I) of size m×m
382 /// - `R` is an upper triangular matrix of size m×n
383 #[allow(clippy::needless_range_loop)]
384 pub fn qr(&self) -> (Matrix, Matrix) {
385 let m = self.rows;
386 let n = self.cols;
387 let mut q = Matrix::identity(m);
388 let mut r = self.clone();
389
390 for k in 0..n.min(m - 1) {
391 // Create vector x = R[k:, k]
392 let mut x = vec![0.0; m - k];
393 for i in k..m {
394 x[i - k] = r.get(i, k);
395 }
396
397 // Norm of x
398 let norm_x: f64 = x.iter().map(|&v| v * v).sum::<f64>().sqrt();
399 if norm_x < QR_ZERO_TOLERANCE {
400 continue;
401 } // Already zero
402
403 // Create vector v = x + sign(x[0]) * ||x|| * e1
404 //
405 // NOTE: Numerical stability consideration (Householder sign choice)
406 // According to Overton & Yu (2023), the numerically stable choice is
407 // σ = -sgn(x₁) in the formula v = x - σ‖x‖e₁.
408 //
409 // This means: v = x - (-sgn(x₁))‖x‖e₁ = x + sgn(x₁)‖x‖e₁
410 //
411 // Equivalently: u₁ = x₁ + sgn(x₁)‖x‖
412 //
413 // Current implementation uses this formula (the "correct" choice for stability):
414 let sign = if x[0] >= 0.0 { 1.0 } else { -1.0 }; // sgn(x₀) as defined (sgn(0) = +1)
415 let u1 = x[0] + sign * norm_x;
416
417 // Normalize v to get Householder vector
418 let mut v = x; // Re-use storage
419 v[0] = u1;
420
421 let norm_v: f64 = v.iter().map(|&val| val * val).sum::<f64>().sqrt();
422 for val in &mut v {
423 *val /= norm_v;
424 }
425
426 // Apply Householder transformation to R: R = H * R = (I - 2vv^T)R = R - 2v(v^T R)
427 // Focus on submatrix R[k:, k:]
428 for j in k..n {
429 let mut dot = 0.0;
430 for i in 0..m - k {
431 dot += v[i] * r.get(k + i, j);
432 }
433
434 for i in 0..m - k {
435 let val = r.get(k + i, j) - 2.0 * v[i] * dot;
436 r.set(k + i, j, val);
437 }
438 }
439
440 // Update Q: Q = Q * H = Q(I - 2vv^T) = Q - 2(Qv)v^T
441 // Focus on Q[:, k:]
442 for i in 0..m {
443 let mut dot = 0.0;
444 for l in 0..m - k {
445 dot += q.get(i, k + l) * v[l];
446 }
447
448 for l in 0..m - k {
449 let val = q.get(i, k + l) - 2.0 * dot * v[l];
450 q.set(i, k + l, val);
451 }
452 }
453 }
454
455 (q, r)
456 }
457
458 /// Creates an identity matrix of the given size.
459 ///
460 /// # Arguments
461 ///
462 /// * `size` - Number of rows and columns (square matrix)
463 ///
464 /// # Example
465 ///
466 /// ```
467 /// # use linreg_core::linalg::Matrix;
468 /// let i = Matrix::identity(3);
469 /// assert_eq!(i.get(0, 0), 1.0);
470 /// assert_eq!(i.get(1, 1), 1.0);
471 /// assert_eq!(i.get(2, 2), 1.0);
472 /// assert_eq!(i.get(0, 1), 0.0);
473 /// ```
474 pub fn identity(size: usize) -> Self {
475 let mut data = vec![0.0; size * size];
476 for i in 0..size {
477 data[i * size + i] = 1.0;
478 }
479 Matrix::new(size, size, data)
480 }
481
482 /// Inverts an upper triangular matrix (such as R from QR decomposition).
483 ///
484 /// Uses back-substitution to compute the inverse. This is efficient for
485 /// triangular matrices compared to general matrix inversion.
486 ///
487 /// # Panics
488 ///
489 /// Panics if the matrix is not square.
490 ///
491 /// # Returns
492 ///
493 /// `None` if the matrix is singular (has a zero or near-zero diagonal element).
494 /// A matrix is considered singular if any diagonal element is below the
495 /// internal tolerance (1e-10), which indicates the matrix does not have full rank.
496 ///
497 /// # Note
498 ///
499 /// For upper triangular matrices, singularity is equivalent to having a
500 /// zero (or near-zero) diagonal element. This is much simpler to check than
501 /// for general matrices, which would require computing the condition number.
502 pub fn invert_upper_triangular(&self) -> Option<Matrix> {
503 let n = self.rows;
504 assert_eq!(n, self.cols, "Matrix must be square");
505
506 // Check for singularity using relative tolerance
507 // This scales with the magnitude of diagonal elements, handling matrices
508 // of different scales better than a fixed absolute tolerance.
509 //
510 // Previous implementation used absolute tolerance:
511 // if self.get(i, i).abs() < SINGULAR_TOLERANCE { return None; }
512 //
513 // New implementation uses relative tolerance similar to LAPACK:
514 // tolerance = max_diag * epsilon * n
515 // where epsilon is machine epsilon (~2.2e-16 for f64)
516 let max_diag: f64 = (0..n)
517 .map(|i| self.get(i, i).abs())
518 .fold(0.0_f64, |acc, val| acc.max(val));
519
520 // Use a relative tolerance based on the maximum diagonal element
521 // This is similar to LAPACK's dlamch machine epsilon approach
522 let epsilon = 2.0_f64 * f64::EPSILON; // ~4.4e-16 for f64
523 let relative_tolerance = max_diag * epsilon * n as f64;
524 let tolerance = SINGULAR_TOLERANCE.max(relative_tolerance);
525
526 for i in 0..n {
527 if self.get(i, i).abs() < tolerance {
528 return None; // Singular matrix - cannot invert
529 }
530 }
531
532 let mut inv = Matrix::zeros(n, n);
533
534 for i in 0..n {
535 inv.set(i, i, 1.0 / self.get(i, i));
536
537 for j in (0..i).rev() {
538 let mut sum = 0.0;
539 for k in j + 1..=i {
540 sum += self.get(j, k) * inv.get(k, i);
541 }
542 inv.set(j, i, -sum / self.get(j, j));
543 }
544 }
545
546 Some(inv)
547 }
548
549 /// Inverts an upper triangular matrix with a custom tolerance multiplier.
550 ///
551 /// The tolerance is computed as `max_diag * epsilon * n * tolerance_mult`.
552 /// A higher tolerance_mult allows more tolerance for near-singular matrices.
553 ///
554 /// # Arguments
555 ///
556 /// * `tolerance_mult` - Multiplier for the tolerance (1.0 = standard, higher = more tolerant)
557 pub fn invert_upper_triangular_with_tolerance(&self, tolerance_mult: f64) -> Option<Matrix> {
558 let n = self.rows;
559 assert_eq!(n, self.cols, "Matrix must be square");
560
561 // Check for singularity using relative tolerance
562 let max_diag: f64 = (0..n)
563 .map(|i| self.get(i, i).abs())
564 .fold(0.0_f64, |acc, val| acc.max(val));
565
566 // Use a relative tolerance based on the maximum diagonal element
567 let epsilon = 2.0_f64 * f64::EPSILON;
568 let relative_tolerance = max_diag * epsilon * n as f64 * tolerance_mult;
569 let tolerance = SINGULAR_TOLERANCE.max(relative_tolerance);
570
571 for i in 0..n {
572 if self.get(i, i).abs() < tolerance {
573 return None;
574 }
575 }
576
577 let mut inv = Matrix::zeros(n, n);
578
579 for i in 0..n {
580 inv.set(i, i, 1.0 / self.get(i, i));
581
582 for j in (0..i).rev() {
583 let mut sum = 0.0;
584 for k in j + 1..=i {
585 sum += self.get(j, k) * inv.get(k, i);
586 }
587 inv.set(j, i, -sum / self.get(j, j));
588 }
589 }
590
591 Some(inv)
592 }
593
594 /// Computes the inverse of a square matrix using QR decomposition.
595 ///
596 /// For an invertible matrix A, computes A⁻¹ such that A * A⁻¹ = I.
597 /// Uses QR decomposition for numerical stability.
598 ///
599 /// # Panics
600 ///
601 /// Panics if the matrix is not square (i.e., `self.rows != self.cols`).
602 /// Check dimensions before calling if the matrix shape is not guaranteed.
603 ///
604 /// # Returns
605 ///
606 /// Returns `Some(inverse)` if the matrix is invertible, or `None` if
607 /// the matrix is singular (non-invertible).
608 pub fn invert(&self) -> Option<Matrix> {
609 let n = self.rows;
610 if n != self.cols {
611 panic!("Matrix must be square for inversion");
612 }
613
614 // Use QR decomposition: A = Q * R
615 let (q, r) = self.qr();
616
617 // Compute R⁻¹ (upper triangular inverse)
618 let r_inv = r.invert_upper_triangular()?;
619
620 // A⁻¹ = R⁻¹ * Q^T
621 let q_transpose = q.transpose();
622 let mut result = Matrix::zeros(n, n);
623
624 for i in 0..n {
625 for j in 0..n {
626 let mut sum = 0.0;
627 for k in 0..n {
628 sum += r_inv.get(i, k) * q_transpose.get(k, j);
629 }
630 result.set(i, j, sum);
631 }
632 }
633
634 Some(result)
635 }
636
637 /// Computes the inverse of X'X given the QR decomposition of X (R's chol2inv).
638 ///
639 /// This is equivalent to computing `(X'X)^(-1)` using the QR decomposition of X.
640 /// R's `chol2inv` function is used for numerical stability in recursive residuals.
641 ///
642 /// # Arguments
643 ///
644 /// * `x` - Input matrix (must have rows >= cols)
645 ///
646 /// # Returns
647 ///
648 /// `Some((X'X)^(-1))` if X has full rank, `None` otherwise.
649 ///
650 /// # Algorithm
651 ///
652 /// Given QR decomposition X = QR where R is upper triangular:
653 /// 1. Extract the upper p×p portion of R (denoted R₁)
654 /// 2. Invert R₁ (upper triangular inverse)
655 /// 3. Compute (X'X)^(-1) = R₁^(-1) × R₁^(-T)
656 ///
657 /// This works because X'X = R'Q'QR = R'R, and R₁ contains the Cholesky factor.
658 pub fn chol2inv_from_qr(&self) -> Option<Matrix> {
659 self.chol2inv_from_qr_with_tolerance(1.0)
660 }
661
662 /// Computes the inverse of X'X given the QR decomposition with custom tolerance.
663 ///
664 /// Similar to `chol2inv_from_qr` but allows specifying a tolerance multiplier
665 /// for handling near-singular matrices.
666 ///
667 /// # Arguments
668 ///
669 /// * `tolerance_mult` - Multiplier for the tolerance (higher = more tolerant)
670 pub fn chol2inv_from_qr_with_tolerance(&self, tolerance_mult: f64) -> Option<Matrix> {
671 let p = self.cols;
672
673 // QR decomposition: X = QR
674 // For X (m×n, m≥n), R is m×n upper triangular
675 // The upper n×n block of R contains the meaningful values
676 let (_, r_full) = self.qr();
677
678 // Extract upper p×p portion from R
679 // For tall matrices (m > p), R has zeros below diagonal in first p rows
680 // For square matrices (m = p), R is p×p upper triangular
681 let mut r1 = Matrix::zeros(p, p);
682 for i in 0..p {
683 // Row i of R1 is row i of R_full, columns 0..p
684 // But we only copy the upper triangular part (columns i..p)
685 for j in i..p {
686 r1.set(i, j, r_full.get(i, j));
687 }
688 // Also copy diagonal if not yet copied
689 if i < p {
690 r1.set(i, i, r_full.get(i, i));
691 }
692 }
693
694 // Invert R₁ (upper triangular) with custom tolerance
695 let r1_inv = r1.invert_upper_triangular_with_tolerance(tolerance_mult)?;
696
697 // Compute (X'X)^(-1) = R₁^(-1) × R₁^(-T)
698 let mut result = Matrix::zeros(p, p);
699 for i in 0..p {
700 for j in 0..p {
701 let mut sum = 0.0;
702 // result[i,j] = sum(R1_inv[i,k] * R1_inv[j,k] for k=0..p)
703 // R1_inv is upper triangular, but we iterate full range
704 for k in 0..p {
705 sum += r1_inv.get(i, k) * r1_inv.get(j, k);
706 }
707 result.set(i, j, sum);
708 }
709 }
710
711 Some(result)
712 }
713}
714
715// ============================================================================
716// Vector Helper Functions
717// ============================================================================
718
719/// Computes the arithmetic mean of a slice of f64 values.
720///
721/// Returns 0.0 for empty slices.
722///
723/// # Examples
724///
725/// ```
726/// use linreg_core::linalg::vec_mean;
727///
728/// assert_eq!(vec_mean(&[1.0, 2.0, 3.0, 4.0, 5.0]), 3.0);
729/// assert_eq!(vec_mean(&[]), 0.0);
730/// ```
731///
732/// # Arguments
733///
734/// * `v` - Slice of values
735pub fn vec_mean(v: &[f64]) -> f64 {
736 if v.is_empty() {
737 return 0.0;
738 }
739 v.iter().sum::<f64>() / v.len() as f64
740}
741
742/// Computes element-wise subtraction of two slices: `a - b`.
743///
744/// # Examples
745///
746/// ```
747/// use linreg_core::linalg::vec_sub;
748///
749/// let a = vec![5.0, 4.0, 3.0];
750/// let b = vec![1.0, 1.0, 1.0];
751/// let result = vec_sub(&a, &b);
752/// assert_eq!(result, vec![4.0, 3.0, 2.0]);
753/// ```
754///
755/// # Arguments
756///
757/// * `a` - Minuend slice
758/// * `b` - Subtrahend slice
759///
760/// # Panics
761///
762/// Panics if slices have different lengths.
763pub fn vec_sub(a: &[f64], b: &[f64]) -> Vec<f64> {
764 assert_eq!(a.len(), b.len(), "vec_sub: slice lengths must match");
765 a.iter().zip(b.iter()).map(|(x, y)| x - y).collect()
766}
767
768/// Computes the dot product of two slices: `Σ(a[i] * b[i])`.
769///
770/// # Examples
771///
772/// ```
773/// use linreg_core::linalg::vec_dot;
774///
775/// let a = vec![1.0, 2.0, 3.0];
776/// let b = vec![4.0, 5.0, 6.0];
777/// assert_eq!(vec_dot(&a, &b), 32.0); // 1*4 + 2*5 + 3*6
778/// ```
779///
780/// # Arguments
781///
782/// * `a` - First slice
783/// * `b` - Second slice
784///
785/// # Panics
786///
787/// Panics if slices have different lengths.
788pub fn vec_dot(a: &[f64], b: &[f64]) -> f64 {
789 assert_eq!(a.len(), b.len(), "vec_dot: slice lengths must match");
790 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
791}
792
793/// Computes element-wise addition of two slices: `a + b`.
794///
795/// # Examples
796///
797/// ```
798/// use linreg_core::linalg::vec_add;
799///
800/// let a = vec![1.0, 2.0, 3.0];
801/// let b = vec![4.0, 5.0, 6.0];
802/// assert_eq!(vec_add(&a, &b), vec![5.0, 7.0, 9.0]);
803/// ```
804///
805/// # Arguments
806///
807/// * `a` - First slice
808/// * `b` - Second slice
809///
810/// # Panics
811///
812/// Panics if slices have different lengths.
813pub fn vec_add(a: &[f64], b: &[f64]) -> Vec<f64> {
814 assert_eq!(a.len(), b.len(), "vec_add: slice lengths must match");
815 a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
816}
817
818/// Computes a scaled vector addition in place: `dst += alpha * src`.
819///
820/// This is the classic BLAS AXPY operation.
821///
822/// # Arguments
823///
824/// * `dst` - Destination slice (modified in place)
825/// * `alpha` - Scaling factor for src
826/// * `src` - Source slice
827///
828/// # Panics
829///
830/// Panics if slices have different lengths.
831pub fn vec_axpy_inplace(dst: &mut [f64], alpha: f64, src: &[f64]) {
832 assert_eq!(
833 dst.len(),
834 src.len(),
835 "vec_axpy_inplace: slice lengths must match"
836 );
837 for (d, &s) in dst.iter_mut().zip(src.iter()) {
838 *d += alpha * s;
839 }
840}
841
842/// Scales a vector in place: `v *= alpha`.
843///
844/// # Arguments
845///
846/// * `v` - Vector to scale (modified in place)
847/// * `alpha` - Scaling factor
848pub fn vec_scale_inplace(v: &mut [f64], alpha: f64) {
849 for val in v.iter_mut() {
850 *val *= alpha;
851 }
852}
853
854/// Returns a scaled copy of a vector: `v * alpha`.
855///
856/// # Examples
857///
858/// ```
859/// use linreg_core::linalg::vec_scale;
860///
861/// let v = vec![1.0, 2.0, 3.0];
862/// let scaled = vec_scale(&v, 2.5);
863/// assert_eq!(scaled, vec![2.5, 5.0, 7.5]);
864/// // Original is unchanged
865/// assert_eq!(v, vec![1.0, 2.0, 3.0]);
866/// ```
867///
868/// # Arguments
869///
870/// * `v` - Vector to scale
871/// * `alpha` - Scaling factor
872pub fn vec_scale(v: &[f64], alpha: f64) -> Vec<f64> {
873 v.iter().map(|&x| x * alpha).collect()
874}
875
876/// Computes the L2 norm (Euclidean norm) of a vector: `sqrt(Σ(v[i]²))`.
877///
878/// # Examples
879///
880/// ```
881/// use linreg_core::linalg::vec_l2_norm;
882///
883/// // Pythagorean triple: 3-4-5
884/// assert_eq!(vec_l2_norm(&[3.0, 4.0]), 5.0);
885/// // Unit vector
886/// assert_eq!(vec_l2_norm(&[1.0, 0.0, 0.0]), 1.0);
887/// ```
888///
889/// # Arguments
890///
891/// * `v` - Vector slice
892pub fn vec_l2_norm(v: &[f64]) -> f64 {
893 v.iter().map(|&x| x * x).sum::<f64>().sqrt()
894}
895
896/// Computes the maximum absolute value in a vector.
897///
898/// # Arguments
899///
900/// * `v` - Vector slice
901pub fn vec_max_abs(v: &[f64]) -> f64 {
902 v.iter().map(|&x| x.abs()).fold(0.0_f64, f64::max)
903}
904
905// ============================================================================
906// R-Compatible QR Decomposition (LINPACK dqrdc2 with Column Pivoting)
907// ============================================================================
908
909/// QR decomposition result using R's LINPACK dqrdc2 algorithm.
910///
911/// This implements the QR decomposition with column pivoting as used by R's
912/// `qr()` function with `LAPACK=FALSE`. The algorithm is a modification of
913/// LINPACK's DQRDC that:
914/// - Uses Householder transformations
915/// - Implements limited column pivoting based on 2-norms of reduced columns
916/// - Moves columns with near-zero norm to the right-hand edge
917/// - Computes the rank (number of linearly independent columns)
918///
919/// # Fields
920///
921/// * `qr` - The QR factorization (upper triangle contains R, below diagonal
922/// contains Householder vector information)
923/// * `qraux` - Auxiliary information for recovering the orthogonal part Q
924/// * `pivot` - Column permutation: `pivot\[j\]` contains the original column index
925/// now in column j
926/// * `rank` - Number of linearly independent columns (the computed rank)
927#[derive(Clone, Debug)]
928pub struct QRLinpack {
929 /// QR factorization matrix (same dimensions as input)
930 pub qr: Matrix,
931 /// Auxiliary information for Q recovery
932 pub qraux: Vec<f64>,
933 /// Column pivot vector (1-based indices like R)
934 pub pivot: Vec<usize>,
935 /// Computed rank (number of linearly independent columns)
936 pub rank: usize,
937}
938
939impl Matrix {
940 /// Computes QR decomposition using R's LINPACK dqrdc2 algorithm with column pivoting.
941 ///
942 /// This is a port of R's dqrdc2.f, which is a modification of LINPACK's DQRDC.
943 /// The algorithm:
944 /// 1. Uses Householder transformations for QR factorization
945 /// 2. Implements limited column pivoting based on column 2-norms
946 /// 3. Moves columns with near-zero norm to the right-hand edge
947 /// 4. Computes the rank (number of linearly independent columns)
948 ///
949 /// # Arguments
950 ///
951 /// * `tol` - Tolerance for determining linear independence. Default is 1e-7 (R's default).
952 /// Columns with norm < tol * original_norm are considered negligible.
953 ///
954 /// # Returns
955 ///
956 /// A [`QRLinpack`] struct containing the QR factorization, auxiliary information,
957 /// pivot vector, and computed rank.
958 ///
959 /// # Algorithm Details
960 ///
961 /// The decomposition is A * P = Q * R where:
962 /// - P is the permutation matrix coded by `pivot`
963 /// - Q is orthogonal (m × m)
964 /// - R is upper triangular in the first `rank` rows
965 ///
966 /// The `qr` matrix contains:
967 /// - Upper triangle: R matrix (if pivoting was performed, this is R of the permuted matrix)
968 /// - Below diagonal: Householder vector information
969 ///
970 /// # Reference
971 ///
972 /// - R source: src/appl/dqrdc2.f
973 /// - LINPACK documentation: <https://www.netlib.org/linpack/dqrdc.f>
974 pub fn qr_linpack(&self, tol: Option<f64>) -> QRLinpack {
975 let n = self.rows;
976 let p = self.cols;
977 let lup = n.min(p);
978
979 // Default tolerance matches R's qr.default: tol = 1e-07
980 let tol = tol.unwrap_or(1e-07);
981
982 // Initialize working matrices
983 let mut x = self.clone(); // Working copy that will be modified
984 let mut qraux = vec![0.0; p];
985 let mut pivot: Vec<usize> = (1..=p).collect(); // 1-based indices like R
986 let mut work = vec![(0.0, 0.0); p]; // (work[j,1], work[j,2])
987
988 // Compute the norms of the columns of x (initialization)
989 if n > 0 {
990 for j in 0..p {
991 let mut norm = 0.0;
992 for i in 0..n {
993 norm += x.get(i, j) * x.get(i, j);
994 }
995 norm = norm.sqrt();
996 qraux[j] = norm;
997 let original_norm = if norm == 0.0 { 1.0 } else { norm };
998 work[j] = (norm, original_norm);
999 }
1000 }
1001
1002 let mut k = p + 1; // Will be decremented to get the final rank
1003
1004 // Perform the Householder reduction of x
1005 for l in 0..lup {
1006 // Cycle columns from l to p until one with non-negligible norm is found
1007 // A column is negligible if its norm has fallen below tol * original_norm
1008 while l < k - 1 && qraux[l] < work[l].1 * tol {
1009 // Move column l to the end (it's negligible)
1010 let lp1 = l + 1;
1011
1012 // Shift columns in x: x(i, l..p-1) = x(i, l+1..p)
1013 for i in 0..n {
1014 let t = x.get(i, l);
1015 for j in lp1..p {
1016 x.set(i, j - 1, x.get(i, j));
1017 }
1018 x.set(i, p - 1, t);
1019 }
1020
1021 // Shift pivot, qraux, and work arrays
1022 let saved_pivot = pivot[l];
1023 let saved_qraux = qraux[l];
1024 let saved_work = work[l];
1025
1026 for j in lp1..p {
1027 pivot[j - 1] = pivot[j];
1028 qraux[j - 1] = qraux[j];
1029 work[j - 1] = work[j];
1030 }
1031
1032 pivot[p - 1] = saved_pivot;
1033 qraux[p - 1] = saved_qraux;
1034 work[p - 1] = saved_work;
1035
1036 k -= 1;
1037 }
1038
1039 if l == n - 1 {
1040 // Last row - skip transformation
1041 break;
1042 }
1043
1044 // Compute the Householder transformation for column l
1045 // nrmxl = norm of x[l:, l]
1046 let mut nrmxl = 0.0;
1047 for i in l..n {
1048 let val = x.get(i, l);
1049 nrmxl += val * val;
1050 }
1051 nrmxl = nrmxl.sqrt();
1052
1053 if nrmxl == 0.0 {
1054 // Zero column - continue to next
1055 continue;
1056 }
1057
1058 // Apply sign for numerical stability (dsign in Fortran)
1059 let x_ll = x.get(l, l);
1060 if x_ll != 0.0 {
1061 nrmxl = nrmxl.copysign(x_ll);
1062 }
1063
1064 // Scale the column
1065 let scale = 1.0 / nrmxl;
1066 for i in l..n {
1067 x.set(i, l, x.get(i, l) * scale);
1068 }
1069 x.set(l, l, 1.0 + x.get(l, l));
1070
1071 // Apply the transformation to remaining columns, updating the norms
1072 let lp1 = l + 1;
1073 if p > lp1 {
1074 for j in lp1..p {
1075 // Compute t = -dot(x[l:, l], x[l:, j]) / x(l, l)
1076 let mut dot = 0.0;
1077 for i in l..n {
1078 dot += x.get(i, l) * x.get(i, j);
1079 }
1080 let t = -dot / x.get(l, l);
1081
1082 // x[l:, j] = x[l:, j] + t * x[l:, l]
1083 for i in l..n {
1084 let val = x.get(i, j) + t * x.get(i, l);
1085 x.set(i, j, val);
1086 }
1087
1088 // Update the norm
1089 if qraux[j] != 0.0 {
1090 // tt = 1.0 - (x(l, j) / qraux[j])^2
1091 let x_lj = x.get(l, j).abs();
1092 let mut tt = 1.0 - (x_lj / qraux[j]).powi(2);
1093 tt = tt.max(0.0);
1094
1095 // Recompute norm if there is large reduction (BDR mod 9/99)
1096 // The tolerance here is on the squared norm
1097 if tt.abs() < 1e-6 {
1098 // Re-compute norm directly
1099 let mut new_norm = 0.0;
1100 for i in (l + 1)..n {
1101 let val = x.get(i, j);
1102 new_norm += val * val;
1103 }
1104 new_norm = new_norm.sqrt();
1105 qraux[j] = new_norm;
1106 work[j].0 = new_norm;
1107 } else {
1108 qraux[j] *= tt.sqrt();
1109 }
1110 }
1111 }
1112 }
1113
1114 // Save the transformation
1115 qraux[l] = x.get(l, l);
1116 x.set(l, l, -nrmxl);
1117 }
1118
1119 // Compute final rank
1120 let rank = k - 1;
1121 let rank = rank.min(n);
1122
1123 QRLinpack {
1124 qr: x,
1125 qraux,
1126 pivot,
1127 rank,
1128 }
1129 }
1130
1131 /// Solves a linear system using the QR decomposition with column pivoting.
1132 ///
1133 /// This implements a least squares solver using the pivoted QR decomposition.
1134 /// For rank-deficient cases, coefficients corresponding to linearly dependent
1135 /// columns are set to `f64::NAN`.
1136 ///
1137 /// # Arguments
1138 ///
1139 /// * `qr_result` - QR decomposition from [`Matrix::qr_linpack`]
1140 /// * `y` - Right-hand side vector
1141 ///
1142 /// # Returns
1143 ///
1144 /// A vector of coefficients, or `None` if the system is exactly singular.
1145 ///
1146 /// # Algorithm
1147 ///
1148 /// This solver uses the standard QR decomposition approach:
1149 /// 1. Compute the QR decomposition of the permuted matrix
1150 /// 2. Extract R matrix (upper triangular with positive diagonal)
1151 /// 3. Compute qty = Q^T * y
1152 /// 4. Solve R * coef = qty using back substitution
1153 /// 5. Apply the pivot permutation to restore original column order
1154 ///
1155 /// # Note
1156 ///
1157 /// The LINPACK QR algorithm stores R with mixed signs on the diagonal.
1158 /// This solver corrects for that by taking the absolute value of R's diagonal.
1159 #[allow(clippy::needless_range_loop)]
1160 pub fn qr_solve_linpack(&self, qr_result: &QRLinpack, y: &[f64]) -> Option<Vec<f64>> {
1161 let n = self.rows;
1162 let p = self.cols;
1163 let k = qr_result.rank;
1164
1165 if y.len() != n {
1166 return None;
1167 }
1168
1169 if k == 0 {
1170 return None;
1171 }
1172
1173 // Step 1: Compute Q^T * y using the Householder vectors directly
1174 // This is more efficient than reconstructing the full Q matrix
1175 let mut qty = y.to_vec();
1176
1177 for j in 0..k {
1178 // Check if this Householder transformation is valid
1179 let r_jj = qr_result.qr.get(j, j);
1180 if r_jj == 0.0 {
1181 continue;
1182 }
1183
1184 // Compute dot = v_j^T * qty[j:]
1185 // where v_j is the Householder vector stored in qr[j:, j]
1186 // The storage convention:
1187 // - qr[j,j] = -nrmxl (after final overwrite)
1188 // - qr[i,j] for i > j is the scaled Householder vector element
1189 // - qraux[j] = 1 + original_x[j,j]/nrmxl (the unscaled first element)
1190
1191 // Reconstruct the Householder vector v_j
1192 // After scaling by 1/nrmxl, we have:
1193 // v_scaled[j] = 1 + x[j,j]/nrmxl
1194 // v_scaled[i] = x[i,j]/nrmxl for i > j
1195 // The actual unit vector is v = v_scaled / ||v_scaled||
1196
1197 let mut v = vec![0.0; n - j];
1198 // Copy the scaled Householder vector from qr
1199 for i in j..n {
1200 v[i - j] = qr_result.qr.get(i, j);
1201 }
1202
1203 // The j-th element was modified during the QR decomposition
1204 // We need to reconstruct it from qraux
1205 let alpha = qr_result.qraux[j];
1206 if alpha != 0.0 {
1207 v[0] = alpha;
1208 }
1209
1210 // Compute the norm of v
1211 let v_norm: f64 = v.iter().map(|&x| x * x).sum::<f64>().sqrt();
1212 if v_norm < 1e-14 {
1213 continue;
1214 }
1215
1216 // Compute dot = v^T * qty[j:]
1217 let mut dot = 0.0;
1218 for i in j..n {
1219 dot += v[i - j] * qty[i];
1220 }
1221
1222 // Apply Householder transformation: qty[j:] = qty[j:] - 2 * v * (v^T * qty[j:]) / (v^T * v)
1223 // Since v is already scaled, we use: t = 2 * dot / (v_norm^2)
1224 let t = 2.0 * dot / (v_norm * v_norm);
1225
1226 for i in j..n {
1227 qty[i] -= t * v[i - j];
1228 }
1229 }
1230
1231 // Step 2: Back substitution on R (solve R * coef = qty)
1232 // The R matrix is stored in the upper triangle of qr
1233 // Note: The diagonal elements of R are negative (from -nrmxl)
1234 // We use them as-is since the signs cancel out in the computation
1235 let mut coef_permuted = vec![f64::NAN; p];
1236
1237 for row in (0..k).rev() {
1238 let r_diag = qr_result.qr.get(row, row);
1239 // Use relative tolerance for singularity check
1240 let max_abs = (0..k)
1241 .map(|i| qr_result.qr.get(i, i).abs())
1242 .fold(0.0_f64, f64::max);
1243 let tolerance = 1e-14 * max_abs.max(1.0);
1244
1245 if r_diag.abs() < tolerance {
1246 return None; // Singular
1247 }
1248
1249 let mut sum = qty[row];
1250 for col in (row + 1)..k {
1251 sum -= qr_result.qr.get(row, col) * coef_permuted[col];
1252 }
1253 coef_permuted[row] = sum / r_diag;
1254 }
1255
1256 // Step 3: Apply pivot permutation to get coefficients in original order
1257 // pivot[j] is 1-based, indicating which original column is now in position j
1258 let mut result = vec![0.0; p];
1259 for j in 0..p {
1260 let original_col = qr_result.pivot[j] - 1; // Convert to 0-based
1261 result[original_col] = coef_permuted[j];
1262 }
1263
1264 Some(result)
1265 }
1266}
1267
1268/// Performs OLS regression using R's LINPACK QR algorithm.
1269///
1270/// This function is a drop-in replacement for `fit_ols` that uses the
1271/// R-compatible QR decomposition with column pivoting. It handles
1272/// rank-deficient matrices more gracefully than the standard QR decomposition.
1273///
1274/// # Arguments
1275///
1276/// * `y` - Response variable (n observations)
1277/// * `x` - Design matrix (n rows, p columns including intercept)
1278///
1279/// # Returns
1280///
1281/// * `Some(Vec<f64>)` - OLS coefficient vector (p elements)
1282/// * `None` - If the matrix is exactly singular or dimensions don't match
1283///
1284/// # Note
1285///
1286/// For rank-deficient systems, this function uses the pivoted QR which
1287/// automatically handles multicollinearity by selecting a linearly
1288/// independent subset of columns.
1289pub fn fit_ols_linpack(y: &[f64], x: &Matrix) -> Option<Vec<f64>> {
1290 let qr_result = x.qr_linpack(None);
1291 x.qr_solve_linpack(&qr_result, y)
1292}
1293
1294/// Fits OLS and predicts using R's LINPACK QR with rank-deficient handling.
1295///
1296/// This function matches R's `lm.fit` behavior for rank-deficient cases:
1297/// coefficients for linearly dependent columns are set to NA, and predictions
1298/// are computed using only the valid (non-NA) coefficients and their corresponding
1299/// columns. This matches how R handles rank-deficient models in prediction.
1300///
1301/// # Arguments
1302///
1303/// * `y` - Response variable (n observations)
1304/// * `x` - Design matrix (n rows, p columns including intercept)
1305///
1306/// # Returns
1307///
1308/// * `Some(Vec<f64>)` - Predictions (n elements)
1309/// * `None` - If the matrix is exactly singular or dimensions don't match
1310///
1311/// # Algorithm
1312///
1313/// For rank-deficient systems (rank < p):
1314/// 1. Compute QR decomposition with column pivoting
1315/// 2. Get coefficients (rank-deficient columns will have NaN)
1316/// 3. Build a reduced design matrix with only pivoted, non-singular columns
1317/// 4. Compute predictions using only the valid columns
1318///
1319/// This matches R's behavior where `predict(lm.fit(...))` handles NA coefficients
1320/// by excluding the corresponding columns from the prediction.
1321#[allow(clippy::needless_range_loop)]
1322pub fn fit_and_predict_linpack(y: &[f64], x: &Matrix) -> Option<Vec<f64>> {
1323 let n = x.rows;
1324 let p = x.cols;
1325
1326 // Compute QR decomposition
1327 let qr_result = x.qr_linpack(None);
1328 let k = qr_result.rank;
1329
1330 // Solve for coefficients
1331 let beta_permuted = x.qr_solve_linpack(&qr_result, y)?;
1332
1333 // Check for rank deficiency
1334 if k == p {
1335 // Full rank - use standard prediction
1336 return Some(x.mul_vec(&beta_permuted));
1337 }
1338
1339 // Rank-deficient case: some columns are collinear and have NaN coefficients
1340 // We compute predictions using only columns with valid (non-NaN) coefficients
1341 // This matches R's behavior where NA coefficients exclude columns from prediction
1342
1343 let mut pred = vec![0.0; n];
1344
1345 for row in 0..n {
1346 let mut sum = 0.0;
1347 for j in 0..p {
1348 let b_val = beta_permuted[j];
1349 if b_val.is_nan() {
1350 continue; // Skip collinear columns (matches R's NA coefficient behavior)
1351 }
1352 sum += x.get(row, j) * b_val;
1353 }
1354 pred[row] = sum;
1355 }
1356
1357 Some(pred)
1358}