linreg_core/linalg.rs
1//! Minimal Linear Algebra module to replace nalgebra dependency.
2//!
3//! Implements matrix operations, QR decomposition, and solvers needed for OLS.
4//! Uses row-major storage for compatibility with statistical computing conventions.
5//!
6//! # Numerical Stability Considerations
7//!
8//! This implementation uses Householder QR decomposition with careful attention to
9//! numerical stability:
10//!
11//! - **Sign convention**: Uses the numerically stable Householder sign choice
12//! (v = x + sgn(x₀)||x||e₁) to avoid cancellation
13//! - **Tolerance checking**: Uses predefined tolerances to detect near-singular matrices
14//! - **Zero-skipping**: Skips transformations when columns are already zero-aligned
15//!
16//! # Scaling Recommendations
17//!
18//! For optimal numerical stability when predictor variables have vastly different
19//! scales (e.g., one variable in millions, another in thousandths), consider
20//! standardizing predictors before regression. Z-score standardization
21//! (`x_scaled = (x - mean) / std`) is already done in VIF calculation.
22//!
23//! However, the current implementation handles typical OLS cases without explicit
24//! scaling, as QR decomposition is generally stable for well-conditioned matrices.
25
26// ============================================================================
27// Numerical Constants
28// ============================================================================
29
30/// Machine epsilon threshold for detecting zero values in QR decomposition.
31/// Values below this are treated as zero to avoid numerical instability.
32const QR_ZERO_TOLERANCE: f64 = 1e-12;
33
34/// Threshold for detecting singular matrices during inversion.
35/// Diagonal elements below this value indicate a near-singular matrix.
36const SINGULAR_TOLERANCE: f64 = 1e-10;
37
38/// A dense matrix stored in row-major order.
39///
40/// # Storage
41///
42/// Elements are stored in a single flat vector in row-major order:
43/// `data[row * cols + col]`
44///
45/// # Example
46///
47/// ```
48/// # use linreg_core::linalg::Matrix;
49/// let m = Matrix::new(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
50/// assert_eq!(m.rows, 2);
51/// assert_eq!(m.cols, 3);
52/// assert_eq!(m.get(0, 0), 1.0);
53/// assert_eq!(m.get(1, 2), 6.0);
54/// ```
55#[derive(Clone, Debug)]
56pub struct Matrix {
57 /// Number of rows in the matrix
58 pub rows: usize,
59 /// Number of columns in the matrix
60 pub cols: usize,
61 /// Flat vector storing matrix elements in row-major order
62 pub data: Vec<f64>,
63}
64
65impl Matrix {
66 /// Creates a new matrix from the given dimensions and data.
67 ///
68 /// # Panics
69 ///
70 /// Panics if `data.len() != rows * cols`.
71 ///
72 /// # Arguments
73 ///
74 /// * `rows` - Number of rows
75 /// * `cols` - Number of columns
76 /// * `data` - Flat vector of elements in row-major order
77 ///
78 /// # Example
79 ///
80 /// ```
81 /// # use linreg_core::linalg::Matrix;
82 /// let m = Matrix::new(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
83 /// assert_eq!(m.get(0, 0), 1.0);
84 /// assert_eq!(m.get(0, 1), 2.0);
85 /// assert_eq!(m.get(1, 0), 3.0);
86 /// assert_eq!(m.get(1, 1), 4.0);
87 /// ```
88 pub fn new(rows: usize, cols: usize, data: Vec<f64>) -> Self {
89 assert_eq!(data.len(), rows * cols, "Data length must match dimensions");
90 Matrix { rows, cols, data }
91 }
92
93 /// Creates a matrix filled with zeros.
94 ///
95 /// # Arguments
96 ///
97 /// * `rows` - Number of rows
98 /// * `cols` - Number of columns
99 ///
100 /// # Example
101 ///
102 /// ```
103 /// # use linreg_core::linalg::Matrix;
104 /// let m = Matrix::zeros(3, 2);
105 /// assert_eq!(m.rows, 3);
106 /// assert_eq!(m.cols, 2);
107 /// assert_eq!(m.get(1, 1), 0.0);
108 /// ```
109 pub fn zeros(rows: usize, cols: usize) -> Self {
110 Matrix {
111 rows,
112 cols,
113 data: vec![0.0; rows * cols],
114 }
115 }
116
117 // NOTE: Currently unused but kept as reference implementation.
118 // Uncomment if needed for convenience constructor.
119 /*
120 /// Creates a matrix from a row-major slice.
121 ///
122 /// # Arguments
123 ///
124 /// * `rows` - Number of rows
125 /// * `cols` - Number of columns
126 /// * `slice` - Slice containing matrix elements in row-major order
127 pub fn from_row_slice(rows: usize, cols: usize, slice: &[f64]) -> Self {
128 Matrix::new(rows, cols, slice.to_vec())
129 }
130 */
131
132 /// Gets the element at the specified row and column.
133 ///
134 /// # Arguments
135 ///
136 /// * `row` - Row index (0-based)
137 /// * `col` - Column index (0-based)
138 pub fn get(&self, row: usize, col: usize) -> f64 {
139 self.data[row * self.cols + col]
140 }
141
142 /// Sets the element at the specified row and column.
143 ///
144 /// # Arguments
145 ///
146 /// * `row` - Row index (0-based)
147 /// * `col` - Column index (0-based)
148 /// * `val` - Value to set
149 pub fn set(&mut self, row: usize, col: usize, val: f64) {
150 self.data[row * self.cols + col] = val;
151 }
152
153 /// Returns the transpose of this matrix.
154 ///
155 /// Swaps rows with columns: `result[col][row] = self[row][col]`.
156 ///
157 /// # Example
158 ///
159 /// ```
160 /// # use linreg_core::linalg::Matrix;
161 /// let m = Matrix::new(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
162 /// let t = m.transpose();
163 /// assert_eq!(t.rows, 3);
164 /// assert_eq!(t.cols, 2);
165 /// assert_eq!(t.get(0, 1), 4.0);
166 /// ```
167 pub fn transpose(&self) -> Matrix {
168 let mut t_data = vec![0.0; self.rows * self.cols];
169 for r in 0..self.rows {
170 for c in 0..self.cols {
171 t_data[c * self.rows + r] = self.get(r, c);
172 }
173 }
174 Matrix::new(self.cols, self.rows, t_data)
175 }
176
177 /// Performs matrix multiplication: `self * other`.
178 ///
179 /// # Panics
180 ///
181 /// Panics if `self.cols != other.rows`.
182 ///
183 /// # Example
184 ///
185 /// ```
186 /// # use linreg_core::linalg::Matrix;
187 /// let a = Matrix::new(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
188 /// let b = Matrix::new(3, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
189 /// let c = a.matmul(&b);
190 /// assert_eq!(c.rows, 2);
191 /// assert_eq!(c.cols, 2);
192 /// assert_eq!(c.get(0, 0), 22.0); // 1*1 + 2*3 + 3*5
193 /// ```
194 pub fn matmul(&self, other: &Matrix) -> Matrix {
195 assert_eq!(self.cols, other.rows, "Dimension mismatch for multiplication");
196 let mut result = Matrix::zeros(self.rows, other.cols);
197
198 for r in 0..self.rows {
199 for c in 0..other.cols {
200 let mut sum = 0.0;
201 for k in 0..self.cols {
202 sum += self.get(r, k) * other.get(k, c);
203 }
204 result.set(r, c, sum);
205 }
206 }
207 result
208 }
209
210 /// Multiplies this matrix by a vector (treating vector as column matrix).
211 ///
212 /// Computes `self * vec` where vec is treated as an n×1 column matrix.
213 ///
214 /// # Panics
215 ///
216 /// Panics if `self.cols != vec.len()`.
217 ///
218 /// # Arguments
219 ///
220 /// * `vec` - Vector to multiply by
221 ///
222 /// # Example
223 ///
224 /// ```
225 /// # use linreg_core::linalg::Matrix;
226 /// let m = Matrix::new(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
227 /// let v = vec![1.0, 2.0, 3.0];
228 /// let result = m.mul_vec(&v);
229 /// assert_eq!(result.len(), 2);
230 /// assert_eq!(result[0], 14.0); // 1*1 + 2*2 + 3*3
231 /// ```
232 #[allow(clippy::needless_range_loop)]
233 pub fn mul_vec(&self, vec: &[f64]) -> Vec<f64> {
234 assert_eq!(self.cols, vec.len(), "Dimension mismatch for matrix-vector multiplication");
235 let mut result = vec![0.0; self.rows];
236
237 for r in 0..self.rows {
238 let mut sum = 0.0;
239 for c in 0..self.cols {
240 sum += self.get(r, c) * vec[c];
241 }
242 result[r] = sum;
243 }
244 result
245 }
246
247 /// Computes the dot product of a column with a vector: `Σ(data[i * cols + col] * v[i])`.
248 ///
249 /// For a row-major matrix, this iterates through all rows at a fixed column.
250 ///
251 /// # Arguments
252 ///
253 /// * `col` - Column index
254 /// * `v` - Vector to dot with (must have length equal to rows)
255 ///
256 /// # Panics
257 ///
258 /// Panics if `col >= cols` or `v.len() != rows`.
259 #[allow(clippy::needless_range_loop)]
260 pub fn col_dot(&self, col: usize, v: &[f64]) -> f64 {
261 assert!(col < self.cols, "Column index out of bounds");
262 assert_eq!(self.rows, v.len(), "Vector length must match number of rows");
263
264 let mut sum = 0.0;
265 for row in 0..self.rows {
266 sum += self.get(row, col) * v[row];
267 }
268 sum
269 }
270
271 /// Performs the column-vector operation in place: `v += alpha * column_col`.
272 ///
273 /// This is the AXPY operation where the column is treated as a vector.
274 /// For row-major storage, we iterate through rows at a fixed column.
275 ///
276 /// # Arguments
277 ///
278 /// * `col` - Column index
279 /// * `alpha` - Scaling factor for the column
280 /// * `v` - Vector to modify in place (must have length equal to rows)
281 ///
282 /// # Panics
283 ///
284 /// Panics if `col >= cols` or `v.len() != rows`.
285 #[allow(clippy::needless_range_loop)]
286 pub fn col_axpy_inplace(&self, col: usize, alpha: f64, v: &mut [f64]) {
287 assert!(col < self.cols, "Column index out of bounds");
288 assert_eq!(self.rows, v.len(), "Vector length must match number of rows");
289
290 for row in 0..self.rows {
291 v[row] += alpha * self.get(row, col);
292 }
293 }
294
295 /// Computes the squared L2 norm of a column: `Σ(data[i * cols + col]²)`.
296 ///
297 /// # Arguments
298 ///
299 /// * `col` - Column index
300 ///
301 /// # Panics
302 ///
303 /// Panics if `col >= cols`.
304 #[allow(clippy::needless_range_loop)]
305 pub fn col_norm2(&self, col: usize) -> f64 {
306 assert!(col < self.cols, "Column index out of bounds");
307
308 let mut sum = 0.0;
309 for row in 0..self.rows {
310 let val = self.get(row, col);
311 sum += val * val;
312 }
313 sum
314 }
315
316 /// Adds a value to diagonal elements starting from a given index.
317 ///
318 /// This is useful for ridge regression where we add `lambda * I` to `X^T X`,
319 /// but the intercept column should not be penalized.
320 ///
321 /// # Arguments
322 ///
323 /// * `alpha` - Value to add to diagonal elements
324 /// * `start_index` - Starting diagonal index (0 = first diagonal element)
325 ///
326 /// # Panics
327 ///
328 /// Panics if the matrix is not square.
329 ///
330 /// # Example
331 ///
332 /// For a 3×3 identity matrix with intercept in first column (unpenalized):
333 /// ```text
334 /// add_diagonal_in_place(lambda, 1) on:
335 /// [1.0, 0.0, 0.0] [1.0, 0.0, 0.0 ]
336 /// [0.0, 1.0, 0.0] -> [0.0, 1.0+λ, 0.0 ]
337 /// [0.0, 0.0, 1.0] [0.0, 0.0, 1.0+λ]
338 /// ```
339 pub fn add_diagonal_in_place(&mut self, alpha: f64, start_index: usize) {
340 assert_eq!(self.rows, self.cols, "Matrix must be square");
341 let n = self.rows;
342 for i in start_index..n {
343 let current = self.get(i, i);
344 self.set(i, i, current + alpha);
345 }
346 }
347}
348
349// ============================================================================
350// QR Decomposition
351// ============================================================================
352
353impl Matrix {
354 /// Computes the QR decomposition using Householder reflections.
355 ///
356 /// Factorizes the matrix as `A = QR` where Q is orthogonal and R is upper triangular.
357 ///
358 /// # Requirements
359 ///
360 /// This implementation requires `rows >= cols` (tall matrix). For OLS regression,
361 /// we always have more observations than predictors, so this requirement is satisfied.
362 ///
363 /// # Returns
364 ///
365 /// A tuple `(Q, R)` where:
366 /// - `Q` is an orthogonal matrix (QᵀQ = I) of size m×m
367 /// - `R` is an upper triangular matrix of size m×n
368 #[allow(clippy::needless_range_loop)]
369 pub fn qr(&self) -> (Matrix, Matrix) {
370 let m = self.rows;
371 let n = self.cols;
372 let mut q = Matrix::identity(m);
373 let mut r = self.clone();
374
375 for k in 0..n.min(m - 1) {
376 // Create vector x = R[k:, k]
377 let mut x = vec![0.0; m - k];
378 for i in k..m {
379 x[i - k] = r.get(i, k);
380 }
381
382 // Norm of x
383 let norm_x: f64 = x.iter().map(|&v| v * v).sum::<f64>().sqrt();
384 if norm_x < QR_ZERO_TOLERANCE { continue; } // Already zero
385
386 // Create vector v = x + sign(x[0]) * ||x|| * e1
387 //
388 // NOTE: Numerical stability consideration (Householder sign choice)
389 // According to Overton & Yu (2023), the numerically stable choice is
390 // σ = -sgn(x₁) in the formula v = x - σ‖x‖e₁.
391 //
392 // This means: v = x - (-sgn(x₁))‖x‖e₁ = x + sgn(x₁)‖x‖e₁
393 //
394 // Equivalently: u₁ = x₁ + sgn(x₁)‖x‖
395 //
396 // Current implementation uses this formula (the "correct" choice for stability):
397 let sign = if x[0] >= 0.0 { 1.0 } else { -1.0 }; // sgn(x₀) as defined (sgn(0) = +1)
398 let u1 = x[0] + sign * norm_x;
399
400 // Normalize v to get Householder vector
401 let mut v = x; // Re-use storage
402 v[0] = u1;
403
404 let norm_v: f64 = v.iter().map(|&val| val * val).sum::<f64>().sqrt();
405 for val in &mut v { *val /= norm_v; }
406
407 // Apply Householder transformation to R: R = H * R = (I - 2vv^T)R = R - 2v(v^T R)
408 // Focus on submatrix R[k:, k:]
409 for j in k..n {
410 let mut dot = 0.0;
411 for i in 0..m-k {
412 dot += v[i] * r.get(k+i, j);
413 }
414
415 for i in 0..m-k {
416 let val = r.get(k+i, j) - 2.0 * v[i] * dot;
417 r.set(k+i, j, val);
418 }
419 }
420
421 // Update Q: Q = Q * H = Q(I - 2vv^T) = Q - 2(Qv)v^T
422 // Focus on Q[:, k:]
423 for i in 0..m {
424 let mut dot = 0.0;
425 for l in 0..m-k {
426 dot += q.get(i, k+l) * v[l];
427 }
428
429 for l in 0..m-k {
430 let val = q.get(i, k+l) - 2.0 * dot * v[l];
431 q.set(i, k+l, val);
432 }
433 }
434 }
435
436 (q, r)
437 }
438
439 /// Creates an identity matrix of the given size.
440 ///
441 /// # Arguments
442 ///
443 /// * `size` - Number of rows and columns (square matrix)
444 ///
445 /// # Example
446 ///
447 /// ```
448 /// # use linreg_core::linalg::Matrix;
449 /// let i = Matrix::identity(3);
450 /// assert_eq!(i.get(0, 0), 1.0);
451 /// assert_eq!(i.get(1, 1), 1.0);
452 /// assert_eq!(i.get(2, 2), 1.0);
453 /// assert_eq!(i.get(0, 1), 0.0);
454 /// ```
455 pub fn identity(size: usize) -> Self {
456 let mut data = vec![0.0; size * size];
457 for i in 0..size {
458 data[i * size + i] = 1.0;
459 }
460 Matrix::new(size, size, data)
461 }
462
463 /// Inverts an upper triangular matrix (such as R from QR decomposition).
464 ///
465 /// Uses back-substitution to compute the inverse. This is efficient for
466 /// triangular matrices compared to general matrix inversion.
467 ///
468 /// # Panics
469 ///
470 /// Panics if the matrix is not square.
471 ///
472 /// # Returns
473 ///
474 /// `None` if the matrix is singular (has a zero or near-zero diagonal element).
475 /// A matrix is considered singular if any diagonal element is below the
476 /// internal tolerance (1e-10), which indicates the matrix does not have full rank.
477 ///
478 /// # Note
479 ///
480 /// For upper triangular matrices, singularity is equivalent to having a
481 /// zero (or near-zero) diagonal element. This is much simpler to check than
482 /// for general matrices, which would require computing the condition number.
483 pub fn invert_upper_triangular(&self) -> Option<Matrix> {
484 let n = self.rows;
485 assert_eq!(n, self.cols, "Matrix must be square");
486
487 // Check for singularity using relative tolerance
488 // This scales with the magnitude of diagonal elements, handling matrices
489 // of different scales better than a fixed absolute tolerance.
490 //
491 // Previous implementation used absolute tolerance:
492 // if self.get(i, i).abs() < SINGULAR_TOLERANCE { return None; }
493 //
494 // New implementation uses relative tolerance similar to LAPACK:
495 // tolerance = max_diag * epsilon * n
496 // where epsilon is machine epsilon (~2.2e-16 for f64)
497 let max_diag: f64 = (0..n)
498 .map(|i| self.get(i, i).abs())
499 .fold(0.0_f64, |acc, val| acc.max(val));
500
501 // Use a relative tolerance based on the maximum diagonal element
502 // This is similar to LAPACK's dlamch machine epsilon approach
503 let epsilon = 2.0_f64 * f64::EPSILON; // ~4.4e-16 for f64
504 let relative_tolerance = max_diag * epsilon * n as f64;
505 let tolerance = SINGULAR_TOLERANCE.max(relative_tolerance);
506
507 for i in 0..n {
508 if self.get(i, i).abs() < tolerance {
509 return None; // Singular matrix - cannot invert
510 }
511 }
512
513 let mut inv = Matrix::zeros(n, n);
514
515 for i in 0..n {
516 inv.set(i, i, 1.0 / self.get(i, i));
517
518 for j in (0..i).rev() {
519 let mut sum = 0.0;
520 for k in j+1..=i {
521 sum += self.get(j, k) * inv.get(k, i);
522 }
523 inv.set(j, i, -sum / self.get(j, j));
524 }
525 }
526
527 Some(inv)
528 }
529
530 /// Inverts an upper triangular matrix with a custom tolerance multiplier.
531 ///
532 /// The tolerance is computed as `max_diag * epsilon * n * tolerance_mult`.
533 /// A higher tolerance_mult allows more tolerance for near-singular matrices.
534 ///
535 /// # Arguments
536 ///
537 /// * `tolerance_mult` - Multiplier for the tolerance (1.0 = standard, higher = more tolerant)
538 pub fn invert_upper_triangular_with_tolerance(&self, tolerance_mult: f64) -> Option<Matrix> {
539 let n = self.rows;
540 assert_eq!(n, self.cols, "Matrix must be square");
541
542 // Check for singularity using relative tolerance
543 let max_diag: f64 = (0..n)
544 .map(|i| self.get(i, i).abs())
545 .fold(0.0_f64, |acc, val| acc.max(val));
546
547 // Use a relative tolerance based on the maximum diagonal element
548 let epsilon = 2.0_f64 * f64::EPSILON;
549 let relative_tolerance = max_diag * epsilon * n as f64 * tolerance_mult;
550 let tolerance = SINGULAR_TOLERANCE.max(relative_tolerance);
551
552 for i in 0..n {
553 if self.get(i, i).abs() < tolerance {
554 return None;
555 }
556 }
557
558 let mut inv = Matrix::zeros(n, n);
559
560 for i in 0..n {
561 inv.set(i, i, 1.0 / self.get(i, i));
562
563 for j in (0..i).rev() {
564 let mut sum = 0.0;
565 for k in j+1..=i {
566 sum += self.get(j, k) * inv.get(k, i);
567 }
568 inv.set(j, i, -sum / self.get(j, j));
569 }
570 }
571
572 Some(inv)
573 }
574
575 /// Computes the inverse of a square matrix using QR decomposition.
576 ///
577 /// For an invertible matrix A, computes A⁻¹ such that A * A⁻¹ = I.
578 /// Uses QR decomposition for numerical stability.
579 ///
580 /// # Panics
581 ///
582 /// Panics if the matrix is not square (i.e., `self.rows != self.cols`).
583 /// Check dimensions before calling if the matrix shape is not guaranteed.
584 ///
585 /// # Returns
586 ///
587 /// Returns `Some(inverse)` if the matrix is invertible, or `None` if
588 /// the matrix is singular (non-invertible).
589 pub fn invert(&self) -> Option<Matrix> {
590 let n = self.rows;
591 if n != self.cols {
592 panic!("Matrix must be square for inversion");
593 }
594
595 // Use QR decomposition: A = Q * R
596 let (q, r) = self.qr();
597
598 // Compute R⁻¹ (upper triangular inverse)
599 let r_inv = r.invert_upper_triangular()?;
600
601 // A⁻¹ = R⁻¹ * Q^T
602 let q_transpose = q.transpose();
603 let mut result = Matrix::zeros(n, n);
604
605 for i in 0..n {
606 for j in 0..n {
607 let mut sum = 0.0;
608 for k in 0..n {
609 sum += r_inv.get(i, k) * q_transpose.get(k, j);
610 }
611 result.set(i, j, sum);
612 }
613 }
614
615 Some(result)
616 }
617
618 /// Computes the inverse of X'X given the QR decomposition of X (R's chol2inv).
619 ///
620 /// This is equivalent to computing `(X'X)^(-1)` using the QR decomposition of X.
621 /// R's `chol2inv` function is used for numerical stability in recursive residuals.
622 ///
623 /// # Arguments
624 ///
625 /// * `x` - Input matrix (must have rows >= cols)
626 ///
627 /// # Returns
628 ///
629 /// `Some((X'X)^(-1))` if X has full rank, `None` otherwise.
630 ///
631 /// # Algorithm
632 ///
633 /// Given QR decomposition X = QR where R is upper triangular:
634 /// 1. Extract the upper p×p portion of R (denoted R₁)
635 /// 2. Invert R₁ (upper triangular inverse)
636 /// 3. Compute (X'X)^(-1) = R₁^(-1) × R₁^(-T)
637 ///
638 /// This works because X'X = R'Q'QR = R'R, and R₁ contains the Cholesky factor.
639 pub fn chol2inv_from_qr(&self) -> Option<Matrix> {
640 self.chol2inv_from_qr_with_tolerance(1.0)
641 }
642
643 /// Computes the inverse of X'X given the QR decomposition with custom tolerance.
644 ///
645 /// Similar to `chol2inv_from_qr` but allows specifying a tolerance multiplier
646 /// for handling near-singular matrices.
647 ///
648 /// # Arguments
649 ///
650 /// * `tolerance_mult` - Multiplier for the tolerance (higher = more tolerant)
651 pub fn chol2inv_from_qr_with_tolerance(&self, tolerance_mult: f64) -> Option<Matrix> {
652 let p = self.cols;
653
654 // QR decomposition: X = QR
655 // For X (m×n, m≥n), R is m×n upper triangular
656 // The upper n×n block of R contains the meaningful values
657 let (_, r_full) = self.qr();
658
659 // Extract upper p×p portion from R
660 // For tall matrices (m > p), R has zeros below diagonal in first p rows
661 // For square matrices (m = p), R is p×p upper triangular
662 let mut r1 = Matrix::zeros(p, p);
663 for i in 0..p {
664 // Row i of R1 is row i of R_full, columns 0..p
665 // But we only copy the upper triangular part (columns i..p)
666 for j in i..p {
667 r1.set(i, j, r_full.get(i, j));
668 }
669 // Also copy diagonal if not yet copied
670 if i < p {
671 r1.set(i, i, r_full.get(i, i));
672 }
673 }
674
675 // Invert R₁ (upper triangular) with custom tolerance
676 let r1_inv = r1.invert_upper_triangular_with_tolerance(tolerance_mult)?;
677
678 // Compute (X'X)^(-1) = R₁^(-1) × R₁^(-T)
679 let mut result = Matrix::zeros(p, p);
680 for i in 0..p {
681 for j in 0..p {
682 let mut sum = 0.0;
683 // result[i,j] = sum(R1_inv[i,k] * R1_inv[j,k] for k=0..p)
684 // R1_inv is upper triangular, but we iterate full range
685 for k in 0..p {
686 sum += r1_inv.get(i, k) * r1_inv.get(j, k);
687 }
688 result.set(i, j, sum);
689 }
690 }
691
692 Some(result)
693 }
694}
695
696// ============================================================================
697// Vector Helper Functions
698// ============================================================================
699
700/// Computes the arithmetic mean of a slice of f64 values.
701///
702/// Returns 0.0 for empty slices.
703///
704/// # Examples
705///
706/// ```
707/// use linreg_core::linalg::vec_mean;
708///
709/// assert_eq!(vec_mean(&[1.0, 2.0, 3.0, 4.0, 5.0]), 3.0);
710/// assert_eq!(vec_mean(&[]), 0.0);
711/// ```
712///
713/// # Arguments
714///
715/// * `v` - Slice of values
716pub fn vec_mean(v: &[f64]) -> f64 {
717 if v.is_empty() { return 0.0; }
718 v.iter().sum::<f64>() / v.len() as f64
719}
720
721/// Computes element-wise subtraction of two slices: `a - b`.
722///
723/// # Examples
724///
725/// ```
726/// use linreg_core::linalg::vec_sub;
727///
728/// let a = vec![5.0, 4.0, 3.0];
729/// let b = vec![1.0, 1.0, 1.0];
730/// let result = vec_sub(&a, &b);
731/// assert_eq!(result, vec![4.0, 3.0, 2.0]);
732/// ```
733///
734/// # Arguments
735///
736/// * `a` - Minuend slice
737/// * `b` - Subtrahend slice
738///
739/// # Panics
740///
741/// Panics if slices have different lengths.
742pub fn vec_sub(a: &[f64], b: &[f64]) -> Vec<f64> {
743 assert_eq!(a.len(), b.len(), "vec_sub: slice lengths must match");
744 a.iter().zip(b.iter()).map(|(x, y)| x - y).collect()
745}
746
747/// Computes the dot product of two slices: `Σ(a[i] * b[i])`.
748///
749/// # Examples
750///
751/// ```
752/// use linreg_core::linalg::vec_dot;
753///
754/// let a = vec![1.0, 2.0, 3.0];
755/// let b = vec![4.0, 5.0, 6.0];
756/// assert_eq!(vec_dot(&a, &b), 32.0); // 1*4 + 2*5 + 3*6
757/// ```
758///
759/// # Arguments
760///
761/// * `a` - First slice
762/// * `b` - Second slice
763///
764/// # Panics
765///
766/// Panics if slices have different lengths.
767pub fn vec_dot(a: &[f64], b: &[f64]) -> f64 {
768 assert_eq!(a.len(), b.len(), "vec_dot: slice lengths must match");
769 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
770}
771
772/// Computes element-wise addition of two slices: `a + b`.
773///
774/// # Examples
775///
776/// ```
777/// use linreg_core::linalg::vec_add;
778///
779/// let a = vec![1.0, 2.0, 3.0];
780/// let b = vec![4.0, 5.0, 6.0];
781/// assert_eq!(vec_add(&a, &b), vec![5.0, 7.0, 9.0]);
782/// ```
783///
784/// # Arguments
785///
786/// * `a` - First slice
787/// * `b` - Second slice
788///
789/// # Panics
790///
791/// Panics if slices have different lengths.
792pub fn vec_add(a: &[f64], b: &[f64]) -> Vec<f64> {
793 assert_eq!(a.len(), b.len(), "vec_add: slice lengths must match");
794 a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
795}
796
797/// Computes a scaled vector addition in place: `dst += alpha * src`.
798///
799/// This is the classic BLAS AXPY operation.
800///
801/// # Arguments
802///
803/// * `dst` - Destination slice (modified in place)
804/// * `alpha` - Scaling factor for src
805/// * `src` - Source slice
806///
807/// # Panics
808///
809/// Panics if slices have different lengths.
810pub fn vec_axpy_inplace(dst: &mut [f64], alpha: f64, src: &[f64]) {
811 assert_eq!(dst.len(), src.len(), "vec_axpy_inplace: slice lengths must match");
812 for (d, &s) in dst.iter_mut().zip(src.iter()) {
813 *d += alpha * s;
814 }
815}
816
817/// Scales a vector in place: `v *= alpha`.
818///
819/// # Arguments
820///
821/// * `v` - Vector to scale (modified in place)
822/// * `alpha` - Scaling factor
823pub fn vec_scale_inplace(v: &mut [f64], alpha: f64) {
824 for val in v.iter_mut() {
825 *val *= alpha;
826 }
827}
828
829/// Returns a scaled copy of a vector: `v * alpha`.
830///
831/// # Examples
832///
833/// ```
834/// use linreg_core::linalg::vec_scale;
835///
836/// let v = vec![1.0, 2.0, 3.0];
837/// let scaled = vec_scale(&v, 2.5);
838/// assert_eq!(scaled, vec![2.5, 5.0, 7.5]);
839/// // Original is unchanged
840/// assert_eq!(v, vec![1.0, 2.0, 3.0]);
841/// ```
842///
843/// # Arguments
844///
845/// * `v` - Vector to scale
846/// * `alpha` - Scaling factor
847pub fn vec_scale(v: &[f64], alpha: f64) -> Vec<f64> {
848 v.iter().map(|&x| x * alpha).collect()
849}
850
851/// Computes the L2 norm (Euclidean norm) of a vector: `sqrt(Σ(v[i]²))`.
852///
853/// # Examples
854///
855/// ```
856/// use linreg_core::linalg::vec_l2_norm;
857///
858/// // Pythagorean triple: 3-4-5
859/// assert_eq!(vec_l2_norm(&[3.0, 4.0]), 5.0);
860/// // Unit vector
861/// assert_eq!(vec_l2_norm(&[1.0, 0.0, 0.0]), 1.0);
862/// ```
863///
864/// # Arguments
865///
866/// * `v` - Vector slice
867pub fn vec_l2_norm(v: &[f64]) -> f64 {
868 v.iter().map(|&x| x * x).sum::<f64>().sqrt()
869}
870
871/// Computes the maximum absolute value in a vector.
872///
873/// # Arguments
874///
875/// * `v` - Vector slice
876pub fn vec_max_abs(v: &[f64]) -> f64 {
877 v.iter().map(|&x| x.abs()).fold(0.0_f64, f64::max)
878}
879
880// ============================================================================
881// R-Compatible QR Decomposition (LINPACK dqrdc2 with Column Pivoting)
882// ============================================================================
883
884/// QR decomposition result using R's LINPACK dqrdc2 algorithm.
885///
886/// This implements the QR decomposition with column pivoting as used by R's
887/// `qr()` function with `LAPACK=FALSE`. The algorithm is a modification of
888/// LINPACK's DQRDC that:
889/// - Uses Householder transformations
890/// - Implements limited column pivoting based on 2-norms of reduced columns
891/// - Moves columns with near-zero norm to the right-hand edge
892/// - Computes the rank (number of linearly independent columns)
893///
894/// # Fields
895///
896/// * `qr` - The QR factorization (upper triangle contains R, below diagonal
897/// contains Householder vector information)
898/// * `qraux` - Auxiliary information for recovering the orthogonal part Q
899/// * `pivot` - Column permutation: `pivot\[j\]` contains the original column index
900/// now in column j
901/// * `rank` - Number of linearly independent columns (the computed rank)
902#[derive(Clone, Debug)]
903pub struct QRLinpack {
904 /// QR factorization matrix (same dimensions as input)
905 pub qr: Matrix,
906 /// Auxiliary information for Q recovery
907 pub qraux: Vec<f64>,
908 /// Column pivot vector (1-based indices like R)
909 pub pivot: Vec<usize>,
910 /// Computed rank (number of linearly independent columns)
911 pub rank: usize,
912}
913
914impl Matrix {
915 /// Computes QR decomposition using R's LINPACK dqrdc2 algorithm with column pivoting.
916 ///
917 /// This is a port of R's dqrdc2.f, which is a modification of LINPACK's DQRDC.
918 /// The algorithm:
919 /// 1. Uses Householder transformations for QR factorization
920 /// 2. Implements limited column pivoting based on column 2-norms
921 /// 3. Moves columns with near-zero norm to the right-hand edge
922 /// 4. Computes the rank (number of linearly independent columns)
923 ///
924 /// # Arguments
925 ///
926 /// * `tol` - Tolerance for determining linear independence. Default is 1e-7 (R's default).
927 /// Columns with norm < tol * original_norm are considered negligible.
928 ///
929 /// # Returns
930 ///
931 /// A [`QRLinpack`] struct containing the QR factorization, auxiliary information,
932 /// pivot vector, and computed rank.
933 ///
934 /// # Algorithm Details
935 ///
936 /// The decomposition is A * P = Q * R where:
937 /// - P is the permutation matrix coded by `pivot`
938 /// - Q is orthogonal (m × m)
939 /// - R is upper triangular in the first `rank` rows
940 ///
941 /// The `qr` matrix contains:
942 /// - Upper triangle: R matrix (if pivoting was performed, this is R of the permuted matrix)
943 /// - Below diagonal: Householder vector information
944 ///
945 /// # Reference
946 ///
947 /// - R source: src/appl/dqrdc2.f
948 /// - LINPACK documentation: <https://www.netlib.org/linpack/dqrdc.f>
949 pub fn qr_linpack(&self, tol: Option<f64>) -> QRLinpack {
950 let n = self.rows;
951 let p = self.cols;
952 let lup = n.min(p);
953
954 // Default tolerance matches R's qr.default: tol = 1e-07
955 let tol = tol.unwrap_or(1e-07);
956
957 // Initialize working matrices
958 let mut x = self.clone(); // Working copy that will be modified
959 let mut qraux = vec![0.0; p];
960 let mut pivot: Vec<usize> = (1..=p).collect(); // 1-based indices like R
961 let mut work = vec![(0.0, 0.0); p]; // (work[j,1], work[j,2])
962
963 // Compute the norms of the columns of x (initialization)
964 if n > 0 {
965 for j in 0..p {
966 let mut norm = 0.0;
967 for i in 0..n {
968 norm += x.get(i, j) * x.get(i, j);
969 }
970 norm = norm.sqrt();
971 qraux[j] = norm;
972 let original_norm = if norm == 0.0 { 1.0 } else { norm };
973 work[j] = (norm, original_norm);
974 }
975 }
976
977 let mut k = p + 1; // Will be decremented to get the final rank
978
979 // Perform the Householder reduction of x
980 for l in 0..lup {
981 // Cycle columns from l to p until one with non-negligible norm is found
982 // A column is negligible if its norm has fallen below tol * original_norm
983 while l < k - 1 && qraux[l] < work[l].1 * tol {
984 // Move column l to the end (it's negligible)
985 let lp1 = l + 1;
986
987 // Shift columns in x: x(i, l..p-1) = x(i, l+1..p)
988 for i in 0..n {
989 let t = x.get(i, l);
990 for j in lp1..p {
991 x.set(i, j - 1, x.get(i, j));
992 }
993 x.set(i, p - 1, t);
994 }
995
996 // Shift pivot, qraux, and work arrays
997 let saved_pivot = pivot[l];
998 let saved_qraux = qraux[l];
999 let saved_work = work[l];
1000
1001 for j in lp1..p {
1002 pivot[j - 1] = pivot[j];
1003 qraux[j - 1] = qraux[j];
1004 work[j - 1] = work[j];
1005 }
1006
1007 pivot[p - 1] = saved_pivot;
1008 qraux[p - 1] = saved_qraux;
1009 work[p - 1] = saved_work;
1010
1011 k -= 1;
1012 }
1013
1014 if l == n - 1 {
1015 // Last row - skip transformation
1016 break;
1017 }
1018
1019 // Compute the Householder transformation for column l
1020 // nrmxl = norm of x[l:, l]
1021 let mut nrmxl = 0.0;
1022 for i in l..n {
1023 let val = x.get(i, l);
1024 nrmxl += val * val;
1025 }
1026 nrmxl = nrmxl.sqrt();
1027
1028 if nrmxl == 0.0 {
1029 // Zero column - continue to next
1030 continue;
1031 }
1032
1033 // Apply sign for numerical stability (dsign in Fortran)
1034 let x_ll = x.get(l, l);
1035 if x_ll != 0.0 {
1036 nrmxl = nrmxl.copysign(x_ll);
1037 }
1038
1039 // Scale the column
1040 let scale = 1.0 / nrmxl;
1041 for i in l..n {
1042 x.set(i, l, x.get(i, l) * scale);
1043 }
1044 x.set(l, l, 1.0 + x.get(l, l));
1045
1046 // Apply the transformation to remaining columns, updating the norms
1047 let lp1 = l + 1;
1048 if p > lp1 {
1049 for j in lp1..p {
1050 // Compute t = -dot(x[l:, l], x[l:, j]) / x(l, l)
1051 let mut dot = 0.0;
1052 for i in l..n {
1053 dot += x.get(i, l) * x.get(i, j);
1054 }
1055 let t = -dot / x.get(l, l);
1056
1057 // x[l:, j] = x[l:, j] + t * x[l:, l]
1058 for i in l..n {
1059 let val = x.get(i, j) + t * x.get(i, l);
1060 x.set(i, j, val);
1061 }
1062
1063 // Update the norm
1064 if qraux[j] != 0.0 {
1065 // tt = 1.0 - (x(l, j) / qraux[j])^2
1066 let x_lj = x.get(l, j).abs();
1067 let mut tt = 1.0 - (x_lj / qraux[j]).powi(2);
1068 tt = tt.max(0.0);
1069
1070 // Recompute norm if there is large reduction (BDR mod 9/99)
1071 // The tolerance here is on the squared norm
1072 if tt.abs() < 1e-6 {
1073 // Re-compute norm directly
1074 let mut new_norm = 0.0;
1075 for i in (l + 1)..n {
1076 let val = x.get(i, j);
1077 new_norm += val * val;
1078 }
1079 new_norm = new_norm.sqrt();
1080 qraux[j] = new_norm;
1081 work[j].0 = new_norm;
1082 } else {
1083 qraux[j] = qraux[j] * tt.sqrt();
1084 }
1085 }
1086 }
1087 }
1088
1089 // Save the transformation
1090 qraux[l] = x.get(l, l);
1091 x.set(l, l, -nrmxl);
1092 }
1093
1094 // Compute final rank
1095 let rank = k - 1;
1096 let rank = rank.min(n);
1097
1098 QRLinpack {
1099 qr: x,
1100 qraux,
1101 pivot,
1102 rank,
1103 }
1104 }
1105
1106 /// Solves a linear system using the QR decomposition with column pivoting.
1107 ///
1108 /// This implements a least squares solver using the pivoted QR decomposition.
1109 /// For rank-deficient cases, coefficients corresponding to linearly dependent
1110 /// columns are set to `f64::NAN`.
1111 ///
1112 /// # Arguments
1113 ///
1114 /// * `qr_result` - QR decomposition from [`Matrix::qr_linpack`]
1115 /// * `y` - Right-hand side vector
1116 ///
1117 /// # Returns
1118 ///
1119 /// A vector of coefficients, or `None` if the system is exactly singular.
1120 ///
1121 /// # Algorithm
1122 ///
1123 /// This solver uses the standard QR decomposition approach:
1124 /// 1. Compute the QR decomposition of the permuted matrix
1125 /// 2. Extract R matrix (upper triangular with positive diagonal)
1126 /// 3. Compute qty = Q^T * y
1127 /// 4. Solve R * coef = qty using back substitution
1128 /// 5. Apply the pivot permutation to restore original column order
1129 ///
1130 /// # Note
1131 ///
1132 /// The LINPACK QR algorithm stores R with mixed signs on the diagonal.
1133 /// This solver corrects for that by taking the absolute value of R's diagonal.
1134 pub fn qr_solve_linpack(&self, qr_result: &QRLinpack, y: &[f64]) -> Option<Vec<f64>> {
1135 let n = self.rows;
1136 let p = self.cols;
1137 let k = qr_result.rank;
1138
1139 if y.len() != n {
1140 return None;
1141 }
1142
1143 if k == 0 {
1144 return None;
1145 }
1146
1147 // Step 1: Compute Q^T * y using the Householder vectors directly
1148 // This is more efficient than reconstructing the full Q matrix
1149 let mut qty = y.to_vec();
1150
1151 for j in 0..k {
1152 // Check if this Householder transformation is valid
1153 let r_jj = qr_result.qr.get(j, j);
1154 if r_jj == 0.0 {
1155 continue;
1156 }
1157
1158 // Compute dot = v_j^T * qty[j:]
1159 // where v_j is the Householder vector stored in qr[j:, j]
1160 // The storage convention:
1161 // - qr[j,j] = -nrmxl (after final overwrite)
1162 // - qr[i,j] for i > j is the scaled Householder vector element
1163 // - qraux[j] = 1 + original_x[j,j]/nrmxl (the unscaled first element)
1164
1165 // Reconstruct the Householder vector v_j
1166 // After scaling by 1/nrmxl, we have:
1167 // v_scaled[j] = 1 + x[j,j]/nrmxl
1168 // v_scaled[i] = x[i,j]/nrmxl for i > j
1169 // The actual unit vector is v = v_scaled / ||v_scaled||
1170
1171 let mut v = vec![0.0; n - j];
1172 // Copy the scaled Householder vector from qr
1173 for i in j..n {
1174 v[i - j] = qr_result.qr.get(i, j);
1175 }
1176
1177 // The j-th element was modified during the QR decomposition
1178 // We need to reconstruct it from qraux
1179 let alpha = qr_result.qraux[j];
1180 if alpha != 0.0 {
1181 v[0] = alpha;
1182 }
1183
1184 // Compute the norm of v
1185 let v_norm: f64 = v.iter().map(|&x| x * x).sum::<f64>().sqrt();
1186 if v_norm < 1e-14 {
1187 continue;
1188 }
1189
1190 // Compute dot = v^T * qty[j:]
1191 let mut dot = 0.0;
1192 for i in j..n {
1193 dot += v[i - j] * qty[i];
1194 }
1195
1196 // Apply Householder transformation: qty[j:] = qty[j:] - 2 * v * (v^T * qty[j:]) / (v^T * v)
1197 // Since v is already scaled, we use: t = 2 * dot / (v_norm^2)
1198 let t = 2.0 * dot / (v_norm * v_norm);
1199
1200 for i in j..n {
1201 qty[i] -= t * v[i - j];
1202 }
1203 }
1204
1205 // Step 2: Back substitution on R (solve R * coef = qty)
1206 // The R matrix is stored in the upper triangle of qr
1207 // Note: The diagonal elements of R are negative (from -nrmxl)
1208 // We use them as-is since the signs cancel out in the computation
1209 let mut coef_permuted = vec![f64::NAN; p];
1210
1211 for row in (0..k).rev() {
1212 let r_diag = qr_result.qr.get(row, row);
1213 // Use relative tolerance for singularity check
1214 let max_abs = (0..k).map(|i| qr_result.qr.get(i, i).abs()).fold(0.0_f64, f64::max);
1215 let tolerance = 1e-14 * max_abs.max(1.0);
1216
1217 if r_diag.abs() < tolerance {
1218 return None; // Singular
1219 }
1220
1221 let mut sum = qty[row];
1222 for col in (row + 1)..k {
1223 sum -= qr_result.qr.get(row, col) * coef_permuted[col];
1224 }
1225 coef_permuted[row] = sum / r_diag;
1226 }
1227
1228 // Step 3: Apply pivot permutation to get coefficients in original order
1229 // pivot[j] is 1-based, indicating which original column is now in position j
1230 let mut result = vec![0.0; p];
1231 for j in 0..p {
1232 let original_col = qr_result.pivot[j] - 1; // Convert to 0-based
1233 result[original_col] = coef_permuted[j];
1234 }
1235
1236 Some(result)
1237 }
1238}
1239
1240/// Performs OLS regression using R's LINPACK QR algorithm.
1241///
1242/// This function is a drop-in replacement for `fit_ols` that uses the
1243/// R-compatible QR decomposition with column pivoting. It handles
1244/// rank-deficient matrices more gracefully than the standard QR decomposition.
1245///
1246/// # Arguments
1247///
1248/// * `y` - Response variable (n observations)
1249/// * `x` - Design matrix (n rows, p columns including intercept)
1250///
1251/// # Returns
1252///
1253/// * `Some(Vec<f64>)` - OLS coefficient vector (p elements)
1254/// * `None` - If the matrix is exactly singular or dimensions don't match
1255///
1256/// # Note
1257///
1258/// For rank-deficient systems, this function uses the pivoted QR which
1259/// automatically handles multicollinearity by selecting a linearly
1260/// independent subset of columns.
1261pub fn fit_ols_linpack(y: &[f64], x: &Matrix) -> Option<Vec<f64>> {
1262 let qr_result = x.qr_linpack(None);
1263 x.qr_solve_linpack(&qr_result, y)
1264}
1265
1266/// Fits OLS and predicts using R's LINPACK QR with rank-deficient handling.
1267///
1268/// This function matches R's `lm.fit` behavior for rank-deficient cases:
1269/// coefficients for linearly dependent columns are set to NA, and predictions
1270/// are computed using only the valid (non-NA) coefficients and their corresponding
1271/// columns. This matches how R handles rank-deficient models in prediction.
1272///
1273/// # Arguments
1274///
1275/// * `y` - Response variable (n observations)
1276/// * `x` - Design matrix (n rows, p columns including intercept)
1277///
1278/// # Returns
1279///
1280/// * `Some(Vec<f64>)` - Predictions (n elements)
1281/// * `None` - If the matrix is exactly singular or dimensions don't match
1282///
1283/// # Algorithm
1284///
1285/// For rank-deficient systems (rank < p):
1286/// 1. Compute QR decomposition with column pivoting
1287/// 2. Get coefficients (rank-deficient columns will have NaN)
1288/// 3. Build a reduced design matrix with only pivoted, non-singular columns
1289/// 4. Compute predictions using only the valid columns
1290///
1291/// This matches R's behavior where `predict(lm.fit(...))` handles NA coefficients
1292/// by excluding the corresponding columns from the prediction.
1293pub fn fit_and_predict_linpack(y: &[f64], x: &Matrix) -> Option<Vec<f64>> {
1294 let n = x.rows;
1295 let p = x.cols;
1296
1297 // Compute QR decomposition
1298 let qr_result = x.qr_linpack(None);
1299 let k = qr_result.rank;
1300
1301 // Solve for coefficients
1302 let beta_permuted = x.qr_solve_linpack(&qr_result, y)?;
1303
1304 // Check for rank deficiency
1305 if k == p {
1306 // Full rank - use standard prediction
1307 return Some(x.mul_vec(&beta_permuted));
1308 }
1309
1310 // Rank-deficient case: some columns are collinear and have NaN coefficients
1311 // We compute predictions using only columns with valid (non-NaN) coefficients
1312 // This matches R's behavior where NA coefficients exclude columns from prediction
1313
1314 let mut pred = vec![0.0; n];
1315
1316 for row in 0..n {
1317 let mut sum = 0.0;
1318 for j in 0..p {
1319 let b_val = beta_permuted[j];
1320 if b_val.is_nan() {
1321 continue; // Skip collinear columns (matches R's NA coefficient behavior)
1322 }
1323 sum += x.get(row, j) * b_val;
1324 }
1325 pred[row] = sum;
1326 }
1327
1328 Some(pred)
1329}