linreg_core/linalg.rs
1//! Minimal Linear Algebra module to replace nalgebra dependency.
2//!
3//! Implements matrix operations, QR decomposition, and solvers needed for OLS.
4//! Uses row-major storage for compatibility with statistical computing conventions.
5//!
6//! # Numerical Stability Considerations
7//!
8//! This implementation uses Householder QR decomposition with careful attention to
9//! numerical stability:
10//!
11//! - **Sign convention**: Uses the numerically stable Householder sign choice
12//! (v = x + sgn(x₀)||x||e₁) to avoid cancellation
13//! - **Tolerance checking**: Uses predefined tolerances to detect near-singular matrices
14//! - **Zero-skipping**: Skips transformations when columns are already zero-aligned
15//!
16//! # Scaling Recommendations
17//!
18//! For optimal numerical stability when predictor variables have vastly different
19//! scales (e.g., one variable in millions, another in thousandths), consider
20//! standardizing predictors before regression. Z-score standardization
21//! (`x_scaled = (x - mean) / std`) is already done in VIF calculation.
22//!
23//! However, the current implementation handles typical OLS cases without explicit
24//! scaling, as QR decomposition is generally stable for well-conditioned matrices.
25
26// ============================================================================
27// Numerical Constants
28// ============================================================================
29
30/// Machine epsilon threshold for detecting zero values in QR decomposition.
31/// Values below this are treated as zero to avoid numerical instability.
32const QR_ZERO_TOLERANCE: f64 = 1e-12;
33
34/// Threshold for detecting singular matrices during inversion.
35/// Diagonal elements below this value indicate a near-singular matrix.
36const SINGULAR_TOLERANCE: f64 = 1e-10;
37
38/// A dense matrix stored in row-major order.
39///
40/// # Storage
41///
42/// Elements are stored in a single flat vector in row-major order:
43/// `data[row * cols + col]`
44///
45/// # Example
46///
47/// ```
48/// # use linreg_core::linalg::Matrix;
49/// let m = Matrix::new(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
50/// assert_eq!(m.rows, 2);
51/// assert_eq!(m.cols, 3);
52/// assert_eq!(m.get(0, 0), 1.0);
53/// assert_eq!(m.get(1, 2), 6.0);
54/// ```
55#[derive(Clone, Debug)]
56pub struct Matrix {
57 /// Number of rows in the matrix
58 pub rows: usize,
59 /// Number of columns in the matrix
60 pub cols: usize,
61 /// Flat vector storing matrix elements in row-major order
62 pub data: Vec<f64>,
63}
64
65impl Matrix {
66 /// Creates a new matrix from the given dimensions and data.
67 ///
68 /// # Panics
69 ///
70 /// Panics if `data.len() != rows * cols`.
71 ///
72 /// # Arguments
73 ///
74 /// * `rows` - Number of rows
75 /// * `cols` - Number of columns
76 /// * `data` - Flat vector of elements in row-major order
77 ///
78 /// # Example
79 ///
80 /// ```
81 /// # use linreg_core::linalg::Matrix;
82 /// let m = Matrix::new(2, 2, vec![1.0, 2.0, 3.0, 4.0]);
83 /// assert_eq!(m.get(0, 0), 1.0);
84 /// assert_eq!(m.get(0, 1), 2.0);
85 /// assert_eq!(m.get(1, 0), 3.0);
86 /// assert_eq!(m.get(1, 1), 4.0);
87 /// ```
88 pub fn new(rows: usize, cols: usize, data: Vec<f64>) -> Self {
89 assert_eq!(data.len(), rows * cols, "Data length must match dimensions");
90 Matrix { rows, cols, data }
91 }
92
93 /// Creates a matrix filled with zeros.
94 ///
95 /// # Arguments
96 ///
97 /// * `rows` - Number of rows
98 /// * `cols` - Number of columns
99 ///
100 /// # Example
101 ///
102 /// ```
103 /// # use linreg_core::linalg::Matrix;
104 /// let m = Matrix::zeros(3, 2);
105 /// assert_eq!(m.rows, 3);
106 /// assert_eq!(m.cols, 2);
107 /// assert_eq!(m.get(1, 1), 0.0);
108 /// ```
109 pub fn zeros(rows: usize, cols: usize) -> Self {
110 Matrix {
111 rows,
112 cols,
113 data: vec![0.0; rows * cols],
114 }
115 }
116
117 // NOTE: Currently unused but kept as reference implementation.
118 // Uncomment if needed for convenience constructor.
119 /*
120 /// Creates a matrix from a row-major slice.
121 ///
122 /// # Arguments
123 ///
124 /// * `rows` - Number of rows
125 /// * `cols` - Number of columns
126 /// * `slice` - Slice containing matrix elements in row-major order
127 pub fn from_row_slice(rows: usize, cols: usize, slice: &[f64]) -> Self {
128 Matrix::new(rows, cols, slice.to_vec())
129 }
130 */
131
132 /// Gets the element at the specified row and column.
133 ///
134 /// # Arguments
135 ///
136 /// * `row` - Row index (0-based)
137 /// * `col` - Column index (0-based)
138 pub fn get(&self, row: usize, col: usize) -> f64 {
139 self.data[row * self.cols + col]
140 }
141
142 /// Sets the element at the specified row and column.
143 ///
144 /// # Arguments
145 ///
146 /// * `row` - Row index (0-based)
147 /// * `col` - Column index (0-based)
148 /// * `val` - Value to set
149 pub fn set(&mut self, row: usize, col: usize, val: f64) {
150 self.data[row * self.cols + col] = val;
151 }
152
153 /// Returns the transpose of this matrix.
154 ///
155 /// Swaps rows with columns: `result[col][row] = self[row][col]`.
156 ///
157 /// # Example
158 ///
159 /// ```
160 /// # use linreg_core::linalg::Matrix;
161 /// let m = Matrix::new(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
162 /// let t = m.transpose();
163 /// assert_eq!(t.rows, 3);
164 /// assert_eq!(t.cols, 2);
165 /// assert_eq!(t.get(0, 1), 4.0);
166 /// ```
167 pub fn transpose(&self) -> Matrix {
168 let mut t_data = vec![0.0; self.rows * self.cols];
169 for r in 0..self.rows {
170 for c in 0..self.cols {
171 t_data[c * self.rows + r] = self.get(r, c);
172 }
173 }
174 Matrix::new(self.cols, self.rows, t_data)
175 }
176
177 /// Performs matrix multiplication: `self * other`.
178 ///
179 /// # Panics
180 ///
181 /// Panics if `self.cols != other.rows`.
182 ///
183 /// # Example
184 ///
185 /// ```
186 /// # use linreg_core::linalg::Matrix;
187 /// let a = Matrix::new(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
188 /// let b = Matrix::new(3, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
189 /// let c = a.matmul(&b);
190 /// assert_eq!(c.rows, 2);
191 /// assert_eq!(c.cols, 2);
192 /// assert_eq!(c.get(0, 0), 22.0); // 1*1 + 2*3 + 3*5
193 /// ```
194 pub fn matmul(&self, other: &Matrix) -> Matrix {
195 assert_eq!(
196 self.cols, other.rows,
197 "Dimension mismatch for multiplication"
198 );
199 let mut result = Matrix::zeros(self.rows, other.cols);
200
201 for r in 0..self.rows {
202 for c in 0..other.cols {
203 let mut sum = 0.0;
204 for k in 0..self.cols {
205 sum += self.get(r, k) * other.get(k, c);
206 }
207 result.set(r, c, sum);
208 }
209 }
210 result
211 }
212
213 /// Multiplies this matrix by a vector (treating vector as column matrix).
214 ///
215 /// Computes `self * vec` where vec is treated as an n×1 column matrix.
216 ///
217 /// # Panics
218 ///
219 /// Panics if `self.cols != vec.len()`.
220 ///
221 /// # Arguments
222 ///
223 /// * `vec` - Vector to multiply by
224 ///
225 /// # Example
226 ///
227 /// ```
228 /// # use linreg_core::linalg::Matrix;
229 /// let m = Matrix::new(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
230 /// let v = vec![1.0, 2.0, 3.0];
231 /// let result = m.mul_vec(&v);
232 /// assert_eq!(result.len(), 2);
233 /// assert_eq!(result[0], 14.0); // 1*1 + 2*2 + 3*3
234 /// ```
235 #[allow(clippy::needless_range_loop)]
236 pub fn mul_vec(&self, vec: &[f64]) -> Vec<f64> {
237 assert_eq!(
238 self.cols,
239 vec.len(),
240 "Dimension mismatch for matrix-vector multiplication"
241 );
242 let mut result = vec![0.0; self.rows];
243
244 for r in 0..self.rows {
245 let mut sum = 0.0;
246 for c in 0..self.cols {
247 sum += self.get(r, c) * vec[c];
248 }
249 result[r] = sum;
250 }
251 result
252 }
253
254 /// Computes the dot product of a column with a vector: `Σ(data[i * cols + col] * v[i])`.
255 ///
256 /// For a row-major matrix, this iterates through all rows at a fixed column.
257 ///
258 /// # Arguments
259 ///
260 /// * `col` - Column index
261 /// * `v` - Vector to dot with (must have length equal to rows)
262 ///
263 /// # Panics
264 ///
265 /// Panics if `col >= cols` or `v.len() != rows`.
266 #[allow(clippy::needless_range_loop)]
267 pub fn col_dot(&self, col: usize, v: &[f64]) -> f64 {
268 assert!(col < self.cols, "Column index out of bounds");
269 assert_eq!(
270 self.rows,
271 v.len(),
272 "Vector length must match number of rows"
273 );
274
275 let mut sum = 0.0;
276 for row in 0..self.rows {
277 sum += self.get(row, col) * v[row];
278 }
279 sum
280 }
281
282 /// Performs the column-vector operation in place: `v += alpha * column_col`.
283 ///
284 /// This is the AXPY operation where the column is treated as a vector.
285 /// For row-major storage, we iterate through rows at a fixed column.
286 ///
287 /// # Arguments
288 ///
289 /// * `col` - Column index
290 /// * `alpha` - Scaling factor for the column
291 /// * `v` - Vector to modify in place (must have length equal to rows)
292 ///
293 /// # Panics
294 ///
295 /// Panics if `col >= cols` or `v.len() != rows`.
296 #[allow(clippy::needless_range_loop)]
297 pub fn col_axpy_inplace(&self, col: usize, alpha: f64, v: &mut [f64]) {
298 assert!(col < self.cols, "Column index out of bounds");
299 assert_eq!(
300 self.rows,
301 v.len(),
302 "Vector length must match number of rows"
303 );
304
305 for row in 0..self.rows {
306 v[row] += alpha * self.get(row, col);
307 }
308 }
309
310 /// Computes the squared L2 norm of a column: `Σ(data[i * cols + col]²)`.
311 ///
312 /// # Arguments
313 ///
314 /// * `col` - Column index
315 ///
316 /// # Panics
317 ///
318 /// Panics if `col >= cols`.
319 #[allow(clippy::needless_range_loop)]
320 pub fn col_norm2(&self, col: usize) -> f64 {
321 assert!(col < self.cols, "Column index out of bounds");
322
323 let mut sum = 0.0;
324 for row in 0..self.rows {
325 let val = self.get(row, col);
326 sum += val * val;
327 }
328 sum
329 }
330
331 /// Adds a value to diagonal elements starting from a given index.
332 ///
333 /// This is useful for ridge regression where we add `lambda * I` to `X^T X`,
334 /// but the intercept column should not be penalized.
335 ///
336 /// # Arguments
337 ///
338 /// * `alpha` - Value to add to diagonal elements
339 /// * `start_index` - Starting diagonal index (0 = first diagonal element)
340 ///
341 /// # Panics
342 ///
343 /// Panics if the matrix is not square.
344 ///
345 /// # Example
346 ///
347 /// For a 3×3 identity matrix with intercept in first column (unpenalized):
348 /// ```text
349 /// add_diagonal_in_place(lambda, 1) on:
350 /// [1.0, 0.0, 0.0] [1.0, 0.0, 0.0 ]
351 /// [0.0, 1.0, 0.0] -> [0.0, 1.0+λ, 0.0 ]
352 /// [0.0, 0.0, 1.0] [0.0, 0.0, 1.0+λ]
353 /// ```
354 pub fn add_diagonal_in_place(&mut self, alpha: f64, start_index: usize) {
355 assert_eq!(self.rows, self.cols, "Matrix must be square");
356 let n = self.rows;
357 for i in start_index..n {
358 let current = self.get(i, i);
359 self.set(i, i, current + alpha);
360 }
361 }
362}
363
364// ============================================================================
365// QR Decomposition
366// ============================================================================
367
368impl Matrix {
369 /// Computes the QR decomposition using Householder reflections.
370 ///
371 /// Factorizes the matrix as `A = QR` where Q is orthogonal and R is upper triangular.
372 ///
373 /// # Requirements
374 ///
375 /// This implementation requires `rows >= cols` (tall matrix). For OLS regression,
376 /// we always have more observations than predictors, so this requirement is satisfied.
377 ///
378 /// # Returns
379 ///
380 /// A tuple `(Q, R)` where:
381 /// - `Q` is an orthogonal matrix (QᵀQ = I) of size m×m
382 /// - `R` is an upper triangular matrix of size m×n
383 #[allow(clippy::needless_range_loop)]
384 pub fn qr(&self) -> (Matrix, Matrix) {
385 let m = self.rows;
386 let n = self.cols;
387 let mut q = Matrix::identity(m);
388 let mut r = self.clone();
389
390 for k in 0..n.min(m - 1) {
391 // Create vector x = R[k:, k]
392 let mut x = vec![0.0; m - k];
393 for i in k..m {
394 x[i - k] = r.get(i, k);
395 }
396
397 // Norm of x
398 let norm_x: f64 = x.iter().map(|&v| v * v).sum::<f64>().sqrt();
399 if norm_x < QR_ZERO_TOLERANCE {
400 continue;
401 } // Already zero
402
403 // Create vector v = x + sign(x[0]) * ||x|| * e1
404 //
405 // NOTE: Numerical stability consideration (Householder sign choice)
406 // According to Overton & Yu (2023), the numerically stable choice is
407 // σ = -sgn(x₁) in the formula v = x - σ‖x‖e₁.
408 //
409 // This means: v = x - (-sgn(x₁))‖x‖e₁ = x + sgn(x₁)‖x‖e₁
410 //
411 // Equivalently: u₁ = x₁ + sgn(x₁)‖x‖
412 //
413 // Current implementation uses this formula (the "correct" choice for stability):
414 let sign = if x[0] >= 0.0 { 1.0 } else { -1.0 }; // sgn(x₀) as defined (sgn(0) = +1)
415 let u1 = x[0] + sign * norm_x;
416
417 // Normalize v to get Householder vector
418 let mut v = x; // Re-use storage
419 v[0] = u1;
420
421 let norm_v: f64 = v.iter().map(|&val| val * val).sum::<f64>().sqrt();
422 for val in &mut v {
423 *val /= norm_v;
424 }
425
426 // Apply Householder transformation to R: R = H * R = (I - 2vv^T)R = R - 2v(v^T R)
427 // Focus on submatrix R[k:, k:]
428 for j in k..n {
429 let mut dot = 0.0;
430 for i in 0..m - k {
431 dot += v[i] * r.get(k + i, j);
432 }
433
434 for i in 0..m - k {
435 let val = r.get(k + i, j) - 2.0 * v[i] * dot;
436 r.set(k + i, j, val);
437 }
438 }
439
440 // Update Q: Q = Q * H = Q(I - 2vv^T) = Q - 2(Qv)v^T
441 // Focus on Q[:, k:]
442 for i in 0..m {
443 let mut dot = 0.0;
444 for l in 0..m - k {
445 dot += q.get(i, k + l) * v[l];
446 }
447
448 for l in 0..m - k {
449 let val = q.get(i, k + l) - 2.0 * dot * v[l];
450 q.set(i, k + l, val);
451 }
452 }
453 }
454
455 (q, r)
456 }
457
458 /// Creates an identity matrix of the given size.
459 ///
460 /// # Arguments
461 ///
462 /// * `size` - Number of rows and columns (square matrix)
463 ///
464 /// # Example
465 ///
466 /// ```
467 /// # use linreg_core::linalg::Matrix;
468 /// let i = Matrix::identity(3);
469 /// assert_eq!(i.get(0, 0), 1.0);
470 /// assert_eq!(i.get(1, 1), 1.0);
471 /// assert_eq!(i.get(2, 2), 1.0);
472 /// assert_eq!(i.get(0, 1), 0.0);
473 /// ```
474 pub fn identity(size: usize) -> Self {
475 let mut data = vec![0.0; size * size];
476 for i in 0..size {
477 data[i * size + i] = 1.0;
478 }
479 Matrix::new(size, size, data)
480 }
481
482 /// Inverts an upper triangular matrix (such as R from QR decomposition).
483 ///
484 /// Uses back-substitution to compute the inverse. This is efficient for
485 /// triangular matrices compared to general matrix inversion.
486 ///
487 /// # Panics
488 ///
489 /// Panics if the matrix is not square.
490 ///
491 /// # Returns
492 ///
493 /// `None` if the matrix is singular (has a zero or near-zero diagonal element).
494 /// A matrix is considered singular if any diagonal element is below the
495 /// internal tolerance (1e-10), which indicates the matrix does not have full rank.
496 ///
497 /// # Note
498 ///
499 /// For upper triangular matrices, singularity is equivalent to having a
500 /// zero (or near-zero) diagonal element. This is much simpler to check than
501 /// for general matrices, which would require computing the condition number.
502 pub fn invert_upper_triangular(&self) -> Option<Matrix> {
503 let n = self.rows;
504 assert_eq!(n, self.cols, "Matrix must be square");
505
506 // Check for singularity using relative tolerance
507 // This scales with the magnitude of diagonal elements, handling matrices
508 // of different scales better than a fixed absolute tolerance.
509 //
510 // Previous implementation used absolute tolerance:
511 // if self.get(i, i).abs() < SINGULAR_TOLERANCE { return None; }
512 //
513 // New implementation uses relative tolerance similar to LAPACK:
514 // tolerance = max_diag * epsilon * n
515 // where epsilon is machine epsilon (~2.2e-16 for f64)
516 let max_diag: f64 = (0..n)
517 .map(|i| self.get(i, i).abs())
518 .fold(0.0_f64, |acc, val| acc.max(val));
519
520 // Use a relative tolerance based on the maximum diagonal element
521 // This is similar to LAPACK's dlamch machine epsilon approach
522 let epsilon = 2.0_f64 * f64::EPSILON; // ~4.4e-16 for f64
523 let relative_tolerance = max_diag * epsilon * n as f64;
524 let tolerance = SINGULAR_TOLERANCE.max(relative_tolerance);
525
526 for i in 0..n {
527 if self.get(i, i).abs() < tolerance {
528 return None; // Singular matrix - cannot invert
529 }
530 }
531
532 let mut inv = Matrix::zeros(n, n);
533
534 for i in 0..n {
535 inv.set(i, i, 1.0 / self.get(i, i));
536
537 for j in (0..i).rev() {
538 let mut sum = 0.0;
539 for k in j + 1..=i {
540 sum += self.get(j, k) * inv.get(k, i);
541 }
542 inv.set(j, i, -sum / self.get(j, j));
543 }
544 }
545
546 Some(inv)
547 }
548
549 /// Inverts an upper triangular matrix with a custom tolerance multiplier.
550 ///
551 /// The tolerance is computed as `max_diag * epsilon * n * tolerance_mult`.
552 /// A higher tolerance_mult allows more tolerance for near-singular matrices.
553 ///
554 /// # Arguments
555 ///
556 /// * `tolerance_mult` - Multiplier for the tolerance (1.0 = standard, higher = more tolerant)
557 pub fn invert_upper_triangular_with_tolerance(&self, tolerance_mult: f64) -> Option<Matrix> {
558 let n = self.rows;
559 assert_eq!(n, self.cols, "Matrix must be square");
560
561 // Check for singularity using relative tolerance
562 let max_diag: f64 = (0..n)
563 .map(|i| self.get(i, i).abs())
564 .fold(0.0_f64, |acc, val| acc.max(val));
565
566 // Use a relative tolerance based on the maximum diagonal element
567 let epsilon = 2.0_f64 * f64::EPSILON;
568 let relative_tolerance = max_diag * epsilon * n as f64 * tolerance_mult;
569 let tolerance = SINGULAR_TOLERANCE.max(relative_tolerance);
570
571 for i in 0..n {
572 if self.get(i, i).abs() < tolerance {
573 return None;
574 }
575 }
576
577 let mut inv = Matrix::zeros(n, n);
578
579 for i in 0..n {
580 inv.set(i, i, 1.0 / self.get(i, i));
581
582 for j in (0..i).rev() {
583 let mut sum = 0.0;
584 for k in j + 1..=i {
585 sum += self.get(j, k) * inv.get(k, i);
586 }
587 inv.set(j, i, -sum / self.get(j, j));
588 }
589 }
590
591 Some(inv)
592 }
593
594 /// Computes the inverse of a square matrix using QR decomposition.
595 ///
596 /// For an invertible matrix A, computes A⁻¹ such that A * A⁻¹ = I.
597 /// Uses QR decomposition for numerical stability.
598 ///
599 /// # Panics
600 ///
601 /// Panics if the matrix is not square (i.e., `self.rows != self.cols`).
602 /// Check dimensions before calling if the matrix shape is not guaranteed.
603 ///
604 /// # Returns
605 ///
606 /// Returns `Some(inverse)` if the matrix is invertible, or `None` if
607 /// the matrix is singular (non-invertible).
608 pub fn invert(&self) -> Option<Matrix> {
609 let n = self.rows;
610 if n != self.cols {
611 panic!("Matrix must be square for inversion");
612 }
613
614 // Use QR decomposition: A = Q * R
615 let (q, r) = self.qr();
616
617 // Compute R⁻¹ (upper triangular inverse)
618 let r_inv = r.invert_upper_triangular()?;
619
620 // A⁻¹ = R⁻¹ * Q^T
621 let q_transpose = q.transpose();
622 let mut result = Matrix::zeros(n, n);
623
624 for i in 0..n {
625 for j in 0..n {
626 let mut sum = 0.0;
627 for k in 0..n {
628 sum += r_inv.get(i, k) * q_transpose.get(k, j);
629 }
630 result.set(i, j, sum);
631 }
632 }
633
634 Some(result)
635 }
636
637 /// Computes the inverse of X'X given the QR decomposition of X (R's chol2inv).
638 ///
639 /// This is equivalent to computing `(X'X)^(-1)` using the QR decomposition of X.
640 /// R's `chol2inv` function is used for numerical stability in recursive residuals.
641 ///
642 /// # Arguments
643 ///
644 /// * `x` - Input matrix (must have rows >= cols)
645 ///
646 /// # Returns
647 ///
648 /// `Some((X'X)^(-1))` if X has full rank, `None` otherwise.
649 ///
650 /// # Algorithm
651 ///
652 /// Given QR decomposition X = QR where R is upper triangular:
653 /// 1. Extract the upper p×p portion of R (denoted R₁)
654 /// 2. Invert R₁ (upper triangular inverse)
655 /// 3. Compute (X'X)^(-1) = R₁^(-1) × R₁^(-T)
656 ///
657 /// This works because X'X = R'Q'QR = R'R, and R₁ contains the Cholesky factor.
658 pub fn chol2inv_from_qr(&self) -> Option<Matrix> {
659 self.chol2inv_from_qr_with_tolerance(1.0)
660 }
661
662 /// Computes the inverse of X'X given the QR decomposition with custom tolerance.
663 ///
664 /// Similar to `chol2inv_from_qr` but allows specifying a tolerance multiplier
665 /// for handling near-singular matrices.
666 ///
667 /// # Arguments
668 ///
669 /// * `tolerance_mult` - Multiplier for the tolerance (higher = more tolerant)
670 pub fn chol2inv_from_qr_with_tolerance(&self, tolerance_mult: f64) -> Option<Matrix> {
671 let p = self.cols;
672
673 // QR decomposition: X = QR
674 // For X (m×n, m≥n), R is m×n upper triangular
675 // The upper n×n block of R contains the meaningful values
676 let (_, r_full) = self.qr();
677
678 // Extract upper p×p portion from R
679 // For tall matrices (m > p), R has zeros below diagonal in first p rows
680 // For square matrices (m = p), R is p×p upper triangular
681 let mut r1 = Matrix::zeros(p, p);
682 for i in 0..p {
683 // Row i of R1 is row i of R_full, columns 0..p
684 // But we only copy the upper triangular part (columns i..p)
685 for j in i..p {
686 r1.set(i, j, r_full.get(i, j));
687 }
688 // Also copy diagonal if not yet copied
689 if i < p {
690 r1.set(i, i, r_full.get(i, i));
691 }
692 }
693
694 // Invert R₁ (upper triangular) with custom tolerance
695 let r1_inv = r1.invert_upper_triangular_with_tolerance(tolerance_mult)?;
696
697 // Compute (X'X)^(-1) = R₁^(-1) × R₁^(-T)
698 let mut result = Matrix::zeros(p, p);
699 for i in 0..p {
700 for j in 0..p {
701 let mut sum = 0.0;
702 // result[i,j] = sum(R1_inv[i,k] * R1_inv[j,k] for k=0..p)
703 // R1_inv is upper triangular, but we iterate full range
704 for k in 0..p {
705 sum += r1_inv.get(i, k) * r1_inv.get(j, k);
706 }
707 result.set(i, j, sum);
708 }
709 }
710
711 Some(result)
712 }
713}
714
715// ============================================================================
716// Vector Helper Functions
717// ============================================================================
718
719/// Computes the arithmetic mean of a slice of f64 values.
720///
721/// Returns 0.0 for empty slices.
722///
723/// # Examples
724///
725/// ```
726/// use linreg_core::linalg::vec_mean;
727///
728/// assert_eq!(vec_mean(&[1.0, 2.0, 3.0, 4.0, 5.0]), 3.0);
729/// assert_eq!(vec_mean(&[]), 0.0);
730/// ```
731///
732/// # Arguments
733///
734/// * `v` - Slice of values
735pub fn vec_mean(v: &[f64]) -> f64 {
736 if v.is_empty() {
737 return 0.0;
738 }
739 v.iter().sum::<f64>() / v.len() as f64
740}
741
742/// Computes element-wise subtraction of two slices: `a - b`.
743///
744/// # Examples
745///
746/// ```
747/// use linreg_core::linalg::vec_sub;
748///
749/// let a = vec![5.0, 4.0, 3.0];
750/// let b = vec![1.0, 1.0, 1.0];
751/// let result = vec_sub(&a, &b);
752/// assert_eq!(result, vec![4.0, 3.0, 2.0]);
753/// ```
754///
755/// # Arguments
756///
757/// * `a` - Minuend slice
758/// * `b` - Subtrahend slice
759///
760/// # Panics
761///
762/// Panics if slices have different lengths.
763pub fn vec_sub(a: &[f64], b: &[f64]) -> Vec<f64> {
764 assert_eq!(a.len(), b.len(), "vec_sub: slice lengths must match");
765 a.iter().zip(b.iter()).map(|(x, y)| x - y).collect()
766}
767
768/// Computes the dot product of two slices: `Σ(a[i] * b[i])`.
769///
770/// # Examples
771///
772/// ```
773/// use linreg_core::linalg::vec_dot;
774///
775/// let a = vec![1.0, 2.0, 3.0];
776/// let b = vec![4.0, 5.0, 6.0];
777/// assert_eq!(vec_dot(&a, &b), 32.0); // 1*4 + 2*5 + 3*6
778/// ```
779///
780/// # Arguments
781///
782/// * `a` - First slice
783/// * `b` - Second slice
784///
785/// # Panics
786///
787/// Panics if slices have different lengths.
788pub fn vec_dot(a: &[f64], b: &[f64]) -> f64 {
789 assert_eq!(a.len(), b.len(), "vec_dot: slice lengths must match");
790 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
791}
792
793/// Computes element-wise addition of two slices: `a + b`.
794///
795/// # Examples
796///
797/// ```
798/// use linreg_core::linalg::vec_add;
799///
800/// let a = vec![1.0, 2.0, 3.0];
801/// let b = vec![4.0, 5.0, 6.0];
802/// assert_eq!(vec_add(&a, &b), vec![5.0, 7.0, 9.0]);
803/// ```
804///
805/// # Arguments
806///
807/// * `a` - First slice
808/// * `b` - Second slice
809///
810/// # Panics
811///
812/// Panics if slices have different lengths.
813pub fn vec_add(a: &[f64], b: &[f64]) -> Vec<f64> {
814 assert_eq!(a.len(), b.len(), "vec_add: slice lengths must match");
815 a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
816}
817
818/// Computes a scaled vector addition in place: `dst += alpha * src`.
819///
820/// This is the classic BLAS AXPY operation.
821///
822/// # Arguments
823///
824/// * `dst` - Destination slice (modified in place)
825/// * `alpha` - Scaling factor for src
826/// * `src` - Source slice
827///
828/// # Panics
829///
830/// Panics if slices have different lengths.
831///
832/// # Example
833///
834/// ```
835/// use linreg_core::linalg::vec_axpy_inplace;
836///
837/// let mut dst = vec![1.0, 2.0, 3.0];
838/// let src = vec![0.5, 0.5, 0.5];
839/// vec_axpy_inplace(&mut dst, 2.0, &src);
840/// assert_eq!(dst, vec![2.0, 3.0, 4.0]); // [1+2*0.5, 2+2*0.5, 3+2*0.5]
841/// ```
842pub fn vec_axpy_inplace(dst: &mut [f64], alpha: f64, src: &[f64]) {
843 assert_eq!(
844 dst.len(),
845 src.len(),
846 "vec_axpy_inplace: slice lengths must match"
847 );
848 for (d, &s) in dst.iter_mut().zip(src.iter()) {
849 *d += alpha * s;
850 }
851}
852
853/// Scales a vector in place: `v *= alpha`.
854///
855/// # Arguments
856///
857/// * `v` - Vector to scale (modified in place)
858/// * `alpha` - Scaling factor
859///
860/// # Example
861///
862/// ```
863/// use linreg_core::linalg::vec_scale_inplace;
864///
865/// let mut v = vec![1.0, 2.0, 3.0];
866/// vec_scale_inplace(&mut v, 2.5);
867/// assert_eq!(v, vec![2.5, 5.0, 7.5]);
868/// ```
869pub fn vec_scale_inplace(v: &mut [f64], alpha: f64) {
870 for val in v.iter_mut() {
871 *val *= alpha;
872 }
873}
874
875/// Returns a scaled copy of a vector: `v * alpha`.
876///
877/// # Examples
878///
879/// ```
880/// use linreg_core::linalg::vec_scale;
881///
882/// let v = vec![1.0, 2.0, 3.0];
883/// let scaled = vec_scale(&v, 2.5);
884/// assert_eq!(scaled, vec![2.5, 5.0, 7.5]);
885/// // Original is unchanged
886/// assert_eq!(v, vec![1.0, 2.0, 3.0]);
887/// ```
888///
889/// # Arguments
890///
891/// * `v` - Vector to scale
892/// * `alpha` - Scaling factor
893pub fn vec_scale(v: &[f64], alpha: f64) -> Vec<f64> {
894 v.iter().map(|&x| x * alpha).collect()
895}
896
897/// Computes the L2 norm (Euclidean norm) of a vector: `sqrt(Σ(v[i]²))`.
898///
899/// # Examples
900///
901/// ```
902/// use linreg_core::linalg::vec_l2_norm;
903///
904/// // Pythagorean triple: 3-4-5
905/// assert_eq!(vec_l2_norm(&[3.0, 4.0]), 5.0);
906/// // Unit vector
907/// assert_eq!(vec_l2_norm(&[1.0, 0.0, 0.0]), 1.0);
908/// ```
909///
910/// # Arguments
911///
912/// * `v` - Vector slice
913pub fn vec_l2_norm(v: &[f64]) -> f64 {
914 v.iter().map(|&x| x * x).sum::<f64>().sqrt()
915}
916
917/// Computes the maximum absolute value in a vector.
918///
919/// # Arguments
920///
921/// * `v` - Vector slice
922///
923/// # Example
924///
925/// ```
926/// use linreg_core::linalg::vec_max_abs;
927///
928/// assert_eq!(vec_max_abs(&[1.0, -5.0, 3.0]), 5.0);
929/// assert_eq!(vec_max_abs(&[-2.5, -1.5]), 2.5);
930/// ```
931pub fn vec_max_abs(v: &[f64]) -> f64 {
932 v.iter().map(|&x| x.abs()).fold(0.0_f64, f64::max)
933}
934
935// ============================================================================
936// R-Compatible QR Decomposition (LINPACK dqrdc2 with Column Pivoting)
937// ============================================================================
938
939/// QR decomposition result using R's LINPACK dqrdc2 algorithm.
940///
941/// This implements the QR decomposition with column pivoting as used by R's
942/// `qr()` function with `LAPACK=FALSE`. The algorithm is a modification of
943/// LINPACK's DQRDC that:
944/// - Uses Householder transformations
945/// - Implements limited column pivoting based on 2-norms of reduced columns
946/// - Moves columns with near-zero norm to the right-hand edge
947/// - Computes the rank (number of linearly independent columns)
948///
949/// # Fields
950///
951/// * `qr` - The QR factorization (upper triangle contains R, below diagonal
952/// contains Householder vector information)
953/// * `qraux` - Auxiliary information for recovering the orthogonal part Q
954/// * `pivot` - Column permutation: `pivot\[j\]` contains the original column index
955/// now in column j
956/// * `rank` - Number of linearly independent columns (the computed rank)
957#[derive(Clone, Debug)]
958pub struct QRLinpack {
959 /// QR factorization matrix (same dimensions as input)
960 pub qr: Matrix,
961 /// Auxiliary information for Q recovery
962 pub qraux: Vec<f64>,
963 /// Column pivot vector (1-based indices like R)
964 pub pivot: Vec<usize>,
965 /// Computed rank (number of linearly independent columns)
966 pub rank: usize,
967}
968
969impl Matrix {
970 /// Computes QR decomposition using R's LINPACK dqrdc2 algorithm with column pivoting.
971 ///
972 /// This is a port of R's dqrdc2.f, which is a modification of LINPACK's DQRDC.
973 /// The algorithm:
974 /// 1. Uses Householder transformations for QR factorization
975 /// 2. Implements limited column pivoting based on column 2-norms
976 /// 3. Moves columns with near-zero norm to the right-hand edge
977 /// 4. Computes the rank (number of linearly independent columns)
978 ///
979 /// # Arguments
980 ///
981 /// * `tol` - Tolerance for determining linear independence. Default is 1e-7 (R's default).
982 /// Columns with norm < tol * original_norm are considered negligible.
983 ///
984 /// # Returns
985 ///
986 /// A [`QRLinpack`] struct containing the QR factorization, auxiliary information,
987 /// pivot vector, and computed rank.
988 ///
989 /// # Algorithm Details
990 ///
991 /// The decomposition is A * P = Q * R where:
992 /// - P is the permutation matrix coded by `pivot`
993 /// - Q is orthogonal (m × m)
994 /// - R is upper triangular in the first `rank` rows
995 ///
996 /// The `qr` matrix contains:
997 /// - Upper triangle: R matrix (if pivoting was performed, this is R of the permuted matrix)
998 /// - Below diagonal: Householder vector information
999 ///
1000 /// # Reference
1001 ///
1002 /// - R source: src/appl/dqrdc2.f
1003 /// - LINPACK documentation: <https://www.netlib.org/linpack/dqrdc.f>
1004 pub fn qr_linpack(&self, tol: Option<f64>) -> QRLinpack {
1005 let n = self.rows;
1006 let p = self.cols;
1007 let lup = n.min(p);
1008
1009 // Default tolerance matches R's qr.default: tol = 1e-07
1010 let tol = tol.unwrap_or(1e-07);
1011
1012 // Initialize working matrices
1013 let mut x = self.clone(); // Working copy that will be modified
1014 let mut qraux = vec![0.0; p];
1015 let mut pivot: Vec<usize> = (1..=p).collect(); // 1-based indices like R
1016 let mut work = vec![(0.0, 0.0); p]; // (work[j,1], work[j,2])
1017
1018 // Compute the norms of the columns of x (initialization)
1019 if n > 0 {
1020 for j in 0..p {
1021 let mut norm = 0.0;
1022 for i in 0..n {
1023 norm += x.get(i, j) * x.get(i, j);
1024 }
1025 norm = norm.sqrt();
1026 qraux[j] = norm;
1027 let original_norm = if norm == 0.0 { 1.0 } else { norm };
1028 work[j] = (norm, original_norm);
1029 }
1030 }
1031
1032 let mut k = p + 1; // Will be decremented to get the final rank
1033
1034 // Perform the Householder reduction of x
1035 for l in 0..lup {
1036 // Cycle columns from l to p until one with non-negligible norm is found
1037 // A column is negligible if its norm has fallen below tol * original_norm
1038 while l < k - 1 && qraux[l] < work[l].1 * tol {
1039 // Move column l to the end (it's negligible)
1040 let lp1 = l + 1;
1041
1042 // Shift columns in x: x(i, l..p-1) = x(i, l+1..p)
1043 for i in 0..n {
1044 let t = x.get(i, l);
1045 for j in lp1..p {
1046 x.set(i, j - 1, x.get(i, j));
1047 }
1048 x.set(i, p - 1, t);
1049 }
1050
1051 // Shift pivot, qraux, and work arrays
1052 let saved_pivot = pivot[l];
1053 let saved_qraux = qraux[l];
1054 let saved_work = work[l];
1055
1056 for j in lp1..p {
1057 pivot[j - 1] = pivot[j];
1058 qraux[j - 1] = qraux[j];
1059 work[j - 1] = work[j];
1060 }
1061
1062 pivot[p - 1] = saved_pivot;
1063 qraux[p - 1] = saved_qraux;
1064 work[p - 1] = saved_work;
1065
1066 k -= 1;
1067 }
1068
1069 if l == n - 1 {
1070 // Last row - skip transformation
1071 break;
1072 }
1073
1074 // Compute the Householder transformation for column l
1075 // nrmxl = norm of x[l:, l]
1076 let mut nrmxl = 0.0;
1077 for i in l..n {
1078 let val = x.get(i, l);
1079 nrmxl += val * val;
1080 }
1081 nrmxl = nrmxl.sqrt();
1082
1083 if nrmxl == 0.0 {
1084 // Zero column - continue to next
1085 continue;
1086 }
1087
1088 // Apply sign for numerical stability
1089 let x_ll = x.get(l, l);
1090 if x_ll != 0.0 {
1091 nrmxl = nrmxl.copysign(x_ll);
1092 }
1093
1094 // Scale the column
1095 let scale = 1.0 / nrmxl;
1096 for i in l..n {
1097 x.set(i, l, x.get(i, l) * scale);
1098 }
1099 x.set(l, l, 1.0 + x.get(l, l));
1100
1101 // Apply the transformation to remaining columns, updating the norms
1102 let lp1 = l + 1;
1103 if p > lp1 {
1104 for j in lp1..p {
1105 // Compute t = -dot(x[l:, l], x[l:, j]) / x(l, l)
1106 let mut dot = 0.0;
1107 for i in l..n {
1108 dot += x.get(i, l) * x.get(i, j);
1109 }
1110 let t = -dot / x.get(l, l);
1111
1112 // x[l:, j] = x[l:, j] + t * x[l:, l]
1113 for i in l..n {
1114 let val = x.get(i, j) + t * x.get(i, l);
1115 x.set(i, j, val);
1116 }
1117
1118 // Update the norm
1119 if qraux[j] != 0.0 {
1120 // tt = 1.0 - (x(l, j) / qraux[j])^2
1121 let x_lj = x.get(l, j).abs();
1122 let mut tt = 1.0 - (x_lj / qraux[j]).powi(2);
1123 tt = tt.max(0.0);
1124
1125 // Recompute norm if there is large reduction (BDR mod 9/99)
1126 // The tolerance here is on the squared norm
1127 if tt.abs() < 1e-6 {
1128 // Re-compute norm directly
1129 let mut new_norm = 0.0;
1130 for i in (l + 1)..n {
1131 let val = x.get(i, j);
1132 new_norm += val * val;
1133 }
1134 new_norm = new_norm.sqrt();
1135 qraux[j] = new_norm;
1136 work[j].0 = new_norm;
1137 } else {
1138 qraux[j] *= tt.sqrt();
1139 }
1140 }
1141 }
1142 }
1143
1144 // Save the transformation
1145 qraux[l] = x.get(l, l);
1146 x.set(l, l, -nrmxl);
1147 }
1148
1149 // Compute final rank
1150 let rank = k - 1;
1151 let rank = rank.min(n);
1152
1153 QRLinpack {
1154 qr: x,
1155 qraux,
1156 pivot,
1157 rank,
1158 }
1159 }
1160
1161 /// Solves a linear system using the QR decomposition with column pivoting.
1162 ///
1163 /// This implements a least squares solver using the pivoted QR decomposition.
1164 /// For rank-deficient cases, coefficients corresponding to linearly dependent
1165 /// columns are set to `f64::NAN`.
1166 ///
1167 /// # Arguments
1168 ///
1169 /// * `qr_result` - QR decomposition from [`Matrix::qr_linpack`]
1170 /// * `y` - Right-hand side vector
1171 ///
1172 /// # Returns
1173 ///
1174 /// A vector of coefficients, or `None` if the system is exactly singular.
1175 ///
1176 /// # Algorithm
1177 ///
1178 /// This solver uses the standard QR decomposition approach:
1179 /// 1. Compute the QR decomposition of the permuted matrix
1180 /// 2. Extract R matrix (upper triangular with positive diagonal)
1181 /// 3. Compute qty = Q^T * y
1182 /// 4. Solve R * coef = qty using back substitution
1183 /// 5. Apply the pivot permutation to restore original column order
1184 ///
1185 /// # Note
1186 ///
1187 /// The LINPACK QR algorithm stores R with mixed signs on the diagonal.
1188 /// This solver corrects for that by taking the absolute value of R's diagonal.
1189 #[allow(clippy::needless_range_loop)]
1190 pub fn qr_solve_linpack(&self, qr_result: &QRLinpack, y: &[f64]) -> Option<Vec<f64>> {
1191 let n = self.rows;
1192 let p = self.cols;
1193 let k = qr_result.rank;
1194
1195 if y.len() != n {
1196 return None;
1197 }
1198
1199 if k == 0 {
1200 return None;
1201 }
1202
1203 // Step 1: Compute Q^T * y using the Householder vectors directly
1204 // This is more efficient than reconstructing the full Q matrix
1205 let mut qty = y.to_vec();
1206
1207 for j in 0..k {
1208 // Check if this Householder transformation is valid
1209 let r_jj = qr_result.qr.get(j, j);
1210 if r_jj == 0.0 {
1211 continue;
1212 }
1213
1214 // Compute dot = v_j^T * qty[j:]
1215 // where v_j is the Householder vector stored in qr[j:, j]
1216 // The storage convention:
1217 // - qr[j,j] = -nrmxl (after final overwrite)
1218 // - qr[i,j] for i > j is the scaled Householder vector element
1219 // - qraux[j] = 1 + original_x[j,j]/nrmxl (the unscaled first element)
1220
1221 // Reconstruct the Householder vector v_j
1222 // After scaling by 1/nrmxl, we have:
1223 // v_scaled[j] = 1 + x[j,j]/nrmxl
1224 // v_scaled[i] = x[i,j]/nrmxl for i > j
1225 // The actual unit vector is v = v_scaled / ||v_scaled||
1226
1227 let mut v = vec![0.0; n - j];
1228 // Copy the scaled Householder vector from qr
1229 for i in j..n {
1230 v[i - j] = qr_result.qr.get(i, j);
1231 }
1232
1233 // The j-th element was modified during the QR decomposition
1234 // We need to reconstruct it from qraux
1235 let alpha = qr_result.qraux[j];
1236 if alpha != 0.0 {
1237 v[0] = alpha;
1238 }
1239
1240 // Compute the norm of v
1241 let v_norm: f64 = v.iter().map(|&x| x * x).sum::<f64>().sqrt();
1242 if v_norm < 1e-14 {
1243 continue;
1244 }
1245
1246 // Compute dot = v^T * qty[j:]
1247 let mut dot = 0.0;
1248 for i in j..n {
1249 dot += v[i - j] * qty[i];
1250 }
1251
1252 // Apply Householder transformation: qty[j:] = qty[j:] - 2 * v * (v^T * qty[j:]) / (v^T * v)
1253 // Since v is already scaled, we use: t = 2 * dot / (v_norm^2)
1254 let t = 2.0 * dot / (v_norm * v_norm);
1255
1256 for i in j..n {
1257 qty[i] -= t * v[i - j];
1258 }
1259 }
1260
1261 // Step 2: Back substitution on R (solve R * coef = qty)
1262 // The R matrix is stored in the upper triangle of qr
1263 // Note: The diagonal elements of R are negative (from -nrmxl)
1264 // We use them as-is since the signs cancel out in the computation
1265 let mut coef_permuted = vec![f64::NAN; p];
1266
1267 for row in (0..k).rev() {
1268 let r_diag = qr_result.qr.get(row, row);
1269 // Use relative tolerance for singularity check
1270 let max_abs = (0..k)
1271 .map(|i| qr_result.qr.get(i, i).abs())
1272 .fold(0.0_f64, f64::max);
1273 let tolerance = 1e-14 * max_abs.max(1.0);
1274
1275 if r_diag.abs() < tolerance {
1276 return None; // Singular
1277 }
1278
1279 let mut sum = qty[row];
1280 for col in (row + 1)..k {
1281 sum -= qr_result.qr.get(row, col) * coef_permuted[col];
1282 }
1283 coef_permuted[row] = sum / r_diag;
1284 }
1285
1286 // Step 3: Apply pivot permutation to get coefficients in original order
1287 // pivot[j] is 1-based, indicating which original column is now in position j
1288 let mut result = vec![0.0; p];
1289 for j in 0..p {
1290 let original_col = qr_result.pivot[j] - 1; // Convert to 0-based
1291 result[original_col] = coef_permuted[j];
1292 }
1293
1294 Some(result)
1295 }
1296}
1297
1298/// Performs OLS regression using R's LINPACK QR algorithm.
1299///
1300/// This function is a drop-in replacement for `fit_ols` that uses the
1301/// R-compatible QR decomposition with column pivoting. It handles
1302/// rank-deficient matrices more gracefully than the standard QR decomposition.
1303///
1304/// # Arguments
1305///
1306/// * `y` - Response variable (n observations)
1307/// * `x` - Design matrix (n rows, p columns including intercept)
1308///
1309/// # Returns
1310///
1311/// * `Some(Vec<f64>)` - OLS coefficient vector (p elements)
1312/// * `None` - If the matrix is exactly singular or dimensions don't match
1313///
1314/// # Note
1315///
1316/// For rank-deficient systems, this function uses the pivoted QR which
1317/// automatically handles multicollinearity by selecting a linearly
1318/// independent subset of columns.
1319///
1320/// # Example
1321///
1322/// ```
1323/// # use linreg_core::linalg::{fit_ols_linpack, Matrix};
1324/// let y = vec![2.0, 4.0, 6.0];
1325/// let x = Matrix::new(3, 2, vec![1.0, 1.0, 1.0, 2.0, 1.0, 3.0]);
1326///
1327/// let beta = fit_ols_linpack(&y, &x).unwrap();
1328/// assert_eq!(beta.len(), 2); // Intercept and slope
1329/// ```
1330pub fn fit_ols_linpack(y: &[f64], x: &Matrix) -> Option<Vec<f64>> {
1331 let qr_result = x.qr_linpack(None);
1332 x.qr_solve_linpack(&qr_result, y)
1333}
1334
1335/// Fits OLS and predicts using R's LINPACK QR with rank-deficient handling.
1336///
1337/// This function matches R's `lm.fit` behavior for rank-deficient cases:
1338/// coefficients for linearly dependent columns are set to NA, and predictions
1339/// are computed using only the valid (non-NA) coefficients and their corresponding
1340/// columns. This matches how R handles rank-deficient models in prediction.
1341///
1342/// # Arguments
1343///
1344/// * `y` - Response variable (n observations)
1345/// * `x` - Design matrix (n rows, p columns including intercept)
1346///
1347/// # Returns
1348///
1349/// * `Some(Vec<f64>)` - Predictions (n elements)
1350/// * `None` - If the matrix is exactly singular or dimensions don't match
1351///
1352/// # Algorithm
1353///
1354/// For rank-deficient systems (rank < p):
1355/// 1. Compute QR decomposition with column pivoting
1356/// 2. Get coefficients (rank-deficient columns will have NaN)
1357/// 3. Build a reduced design matrix with only pivoted, non-singular columns
1358/// 4. Compute predictions using only the valid columns
1359///
1360/// This matches R's behavior where `predict(lm.fit(...))` handles NA coefficients
1361/// by excluding the corresponding columns from the prediction.
1362///
1363/// # Example
1364///
1365/// ```
1366/// # use linreg_core::linalg::{fit_and_predict_linpack, Matrix};
1367/// let y = vec![2.0, 4.0, 6.0];
1368/// let x = Matrix::new(3, 2, vec![1.0, 1.0, 1.0, 2.0, 1.0, 3.0]);
1369///
1370/// let preds = fit_and_predict_linpack(&y, &x).unwrap();
1371/// assert_eq!(preds.len(), 3); // One prediction per observation
1372/// ```
1373#[allow(clippy::needless_range_loop)]
1374pub fn fit_and_predict_linpack(y: &[f64], x: &Matrix) -> Option<Vec<f64>> {
1375 let n = x.rows;
1376 let p = x.cols;
1377
1378 // Compute QR decomposition
1379 let qr_result = x.qr_linpack(None);
1380 let k = qr_result.rank;
1381
1382 // Solve for coefficients
1383 let beta_permuted = x.qr_solve_linpack(&qr_result, y)?;
1384
1385 // Check for rank deficiency
1386 if k == p {
1387 // Full rank - use standard prediction
1388 return Some(x.mul_vec(&beta_permuted));
1389 }
1390
1391 // Rank-deficient case: some columns are collinear and have NaN coefficients
1392 // We compute predictions using only columns with valid (non-NaN) coefficients
1393 // This matches R's behavior where NA coefficients exclude columns from prediction
1394
1395 let mut pred = vec![0.0; n];
1396
1397 for row in 0..n {
1398 let mut sum = 0.0;
1399 for j in 0..p {
1400 let b_val = beta_permuted[j];
1401 if b_val.is_nan() {
1402 continue; // Skip collinear columns (matches R's NA coefficient behavior)
1403 }
1404 sum += x.get(row, j) * b_val;
1405 }
1406 pred[row] = sum;
1407 }
1408
1409 Some(pred)
1410}