Skip to main content

numra_linalg/
lu.rs

1//! LU factorization and solver.
2//!
3//! Author: Moussa Leblouba
4//! Date: 8 February 2026
5//! Modified: 2 May 2026
6
7use crate::matrix::{DenseMatrix, Matrix};
8use crate::Scalar;
9use faer::linalg::solvers::PartialPivLu;
10use faer::prelude::*;
11use faer::{ComplexField, Conjugate, Entity, Mat, SimpleEntity};
12use numra_core::LinalgError;
13
14/// LU factorization of a matrix.
15///
16/// Caches the decomposition computed in `new()` and reuses it in every `solve()` call,
17/// avoiding redundant recomputation.
18pub struct LUFactorization<S: Scalar + Entity> {
19    /// Cached LU decomposition from faer
20    lu: PartialPivLu<S>,
21    n: usize,
22}
23
24impl<S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField> LUFactorization<S> {
25    /// Compute LU factorization of a matrix.
26    ///
27    /// The decomposition is cached and reused for subsequent `solve()` calls.
28    pub fn new(matrix: &DenseMatrix<S>) -> Result<Self, LinalgError> {
29        if !matrix.is_square() {
30            return Err(LinalgError::NotSquare {
31                nrows: matrix.rows(),
32                ncols: matrix.cols(),
33            });
34        }
35
36        let n = matrix.rows();
37        // Compute and cache the LU factorization
38        let lu = matrix.as_faer().partial_piv_lu();
39
40        Ok(Self { lu, n })
41    }
42
43    /// Dimension of the factorized matrix.
44    pub fn dim(&self) -> usize {
45        self.n
46    }
47
48    /// Solve Ax = b using the cached LU factorization.
49    pub fn solve(&self, b: &[S]) -> Result<Vec<S>, LinalgError> {
50        if b.len() != self.n {
51            return Err(LinalgError::DimensionMismatch {
52                expected: (self.n, 1),
53                actual: (b.len(), 1),
54            });
55        }
56
57        // Create column vector from b
58        let mut b_mat = Mat::zeros(self.n, 1);
59        for (i, &val) in b.iter().enumerate() {
60            b_mat.write(i, 0, val);
61        }
62
63        // Solve using the cached LU factorization (no recomputation)
64        let x_mat = self.lu.solve(&b_mat);
65
66        // Extract result
67        let mut x = Vec::with_capacity(self.n);
68        for i in 0..self.n {
69            x.push(x_mat.read(i, 0));
70        }
71
72        Ok(x)
73    }
74
75    /// Solve Ax = b in-place (b is overwritten with x).
76    pub fn solve_inplace(&self, b: &mut [S]) -> Result<(), LinalgError> {
77        let x = self.solve(b)?;
78        b.copy_from_slice(&x);
79        Ok(())
80    }
81
82    /// Solve multiple right-hand sides: AX = B.
83    pub fn solve_multi(&self, b: &[S], nrhs: usize) -> Result<Vec<S>, LinalgError> {
84        if b.len() != self.n * nrhs {
85            return Err(LinalgError::DimensionMismatch {
86                expected: (self.n, nrhs),
87                actual: (b.len(), 1),
88            });
89        }
90
91        // Create matrix from b (column-major)
92        let mut b_mat = Mat::zeros(self.n, nrhs);
93        for j in 0..nrhs {
94            for i in 0..self.n {
95                b_mat.write(i, j, b[j * self.n + i]);
96            }
97        }
98
99        // Solve using the cached LU factorization (no recomputation)
100        let x_mat = self.lu.solve(&b_mat);
101
102        // Extract result (column-major)
103        let mut x = vec![S::ZERO; self.n * nrhs];
104        for j in 0..nrhs {
105            for i in 0..self.n {
106                x[j * self.n + i] = x_mat.read(i, j);
107            }
108        }
109
110        Ok(x)
111    }
112}
113
114/// Trait for types that can solve linear systems.
115pub trait LUSolver<S: Scalar> {
116    /// Solve Ax = b.
117    fn lu_solve(&self, b: &[S]) -> Result<Vec<S>, LinalgError>;
118
119    /// Compute and store LU factorization.
120    fn lu_factor(&self) -> Result<LUFactorization<S>, LinalgError>
121    where
122        S: Entity + SimpleEntity + Conjugate<Canonical = S> + ComplexField;
123}
124
125impl<S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField> LUSolver<S>
126    for DenseMatrix<S>
127{
128    fn lu_solve(&self, b: &[S]) -> Result<Vec<S>, LinalgError> {
129        self.solve(b)
130    }
131
132    fn lu_factor(&self) -> Result<LUFactorization<S>, LinalgError> {
133        LUFactorization::new(self)
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140    use crate::Matrix;
141
142    #[test]
143    fn test_lu_factorization() {
144        let mut m: DenseMatrix<f64> = DenseMatrix::zeros(3, 3);
145        m.set(0, 0, 2.0);
146        m.set(1, 1, 3.0);
147        m.set(2, 2, 4.0);
148
149        let lu = LUFactorization::new(&m).unwrap();
150        assert_eq!(lu.dim(), 3);
151    }
152
153    #[test]
154    fn test_lu_solve() {
155        let mut m: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
156        m.set(0, 0, 1.0);
157        m.set(0, 1, 2.0);
158        m.set(1, 0, 3.0);
159        m.set(1, 1, 4.0);
160
161        let lu = LUFactorization::new(&m).unwrap();
162        let b = vec![5.0, 11.0];
163        let x = lu.solve(&b).unwrap();
164
165        assert!((x[0] - 1.0).abs() < 1e-10);
166        assert!((x[1] - 2.0).abs() < 1e-10);
167    }
168
169    #[test]
170    fn test_lu_solve_inplace() {
171        let mut m: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
172        m.set(0, 0, 2.0);
173        m.set(0, 1, 0.0);
174        m.set(1, 0, 0.0);
175        m.set(1, 1, 3.0);
176
177        let lu = LUFactorization::new(&m).unwrap();
178        let mut b = vec![4.0, 9.0];
179        lu.solve_inplace(&mut b).unwrap();
180
181        assert!((b[0] - 2.0).abs() < 1e-10);
182        assert!((b[1] - 3.0).abs() < 1e-10);
183    }
184
185    #[test]
186    fn test_lu_solve_multi() {
187        let mut m: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
188        m.set(0, 0, 2.0);
189        m.set(1, 1, 3.0);
190
191        let lu = LUFactorization::new(&m).unwrap();
192
193        // Two right-hand sides (column-major): [2, 6] and [4, 9]
194        let b = vec![2.0, 6.0, 4.0, 9.0];
195        let x = lu.solve_multi(&b, 2).unwrap();
196
197        // First solution: [1, 2]
198        assert!((x[0] - 1.0).abs() < 1e-10);
199        assert!((x[1] - 2.0).abs() < 1e-10);
200        // Second solution: [2, 3]
201        assert!((x[2] - 2.0).abs() < 1e-10);
202        assert!((x[3] - 3.0).abs() < 1e-10);
203    }
204
205    #[test]
206    fn test_lu_solver_trait() {
207        let mut m: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
208        m.set(0, 0, 1.0);
209        m.set(0, 1, 2.0);
210        m.set(1, 0, 3.0);
211        m.set(1, 1, 4.0);
212
213        let b = vec![5.0, 11.0];
214        let x = m.lu_solve(&b).unwrap();
215
216        assert!((x[0] - 1.0).abs() < 1e-10);
217        assert!((x[1] - 2.0).abs() < 1e-10);
218    }
219
220    #[test]
221    fn test_lu_repeated_solve() {
222        // Verify that two different solve() calls on the same factorization
223        // both produce correct results, confirming the cached decomposition works.
224        let mut m: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
225        m.set(0, 0, 1.0);
226        m.set(0, 1, 2.0);
227        m.set(1, 0, 3.0);
228        m.set(1, 1, 4.0);
229
230        let lu = LUFactorization::new(&m).unwrap();
231
232        // First solve: A * [1, 2]^T = [5, 11]^T
233        let b1 = vec![5.0, 11.0];
234        let x1 = lu.solve(&b1).unwrap();
235        assert!((x1[0] - 1.0).abs() < 1e-10);
236        assert!((x1[1] - 2.0).abs() < 1e-10);
237
238        // Second solve with different RHS: A * [3, 1]^T = [5, 13]^T
239        let b2 = vec![5.0, 13.0];
240        let x2 = lu.solve(&b2).unwrap();
241        assert!((x2[0] - 3.0).abs() < 1e-10);
242        assert!((x2[1] - 1.0).abs() < 1e-10);
243    }
244}