1use crate::matrix::{DenseMatrix, Matrix};
8use crate::Scalar;
9use faer::linalg::solvers::PartialPivLu;
10use faer::prelude::*;
11use faer::{ComplexField, Conjugate, Entity, Mat, SimpleEntity};
12use numra_core::LinalgError;
13
14pub struct LUFactorization<S: Scalar + Entity> {
19 lu: PartialPivLu<S>,
21 n: usize,
22}
23
24impl<S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField> LUFactorization<S> {
25 pub fn new(matrix: &DenseMatrix<S>) -> Result<Self, LinalgError> {
29 if !matrix.is_square() {
30 return Err(LinalgError::NotSquare {
31 nrows: matrix.rows(),
32 ncols: matrix.cols(),
33 });
34 }
35
36 let n = matrix.rows();
37 let lu = matrix.as_faer().partial_piv_lu();
39
40 Ok(Self { lu, n })
41 }
42
43 pub fn dim(&self) -> usize {
45 self.n
46 }
47
48 pub fn solve(&self, b: &[S]) -> Result<Vec<S>, LinalgError> {
50 if b.len() != self.n {
51 return Err(LinalgError::DimensionMismatch {
52 expected: (self.n, 1),
53 actual: (b.len(), 1),
54 });
55 }
56
57 let mut b_mat = Mat::zeros(self.n, 1);
59 for (i, &val) in b.iter().enumerate() {
60 b_mat.write(i, 0, val);
61 }
62
63 let x_mat = self.lu.solve(&b_mat);
65
66 let mut x = Vec::with_capacity(self.n);
68 for i in 0..self.n {
69 x.push(x_mat.read(i, 0));
70 }
71
72 Ok(x)
73 }
74
75 pub fn solve_inplace(&self, b: &mut [S]) -> Result<(), LinalgError> {
77 let x = self.solve(b)?;
78 b.copy_from_slice(&x);
79 Ok(())
80 }
81
82 pub fn solve_multi(&self, b: &[S], nrhs: usize) -> Result<Vec<S>, LinalgError> {
84 if b.len() != self.n * nrhs {
85 return Err(LinalgError::DimensionMismatch {
86 expected: (self.n, nrhs),
87 actual: (b.len(), 1),
88 });
89 }
90
91 let mut b_mat = Mat::zeros(self.n, nrhs);
93 for j in 0..nrhs {
94 for i in 0..self.n {
95 b_mat.write(i, j, b[j * self.n + i]);
96 }
97 }
98
99 let x_mat = self.lu.solve(&b_mat);
101
102 let mut x = vec![S::ZERO; self.n * nrhs];
104 for j in 0..nrhs {
105 for i in 0..self.n {
106 x[j * self.n + i] = x_mat.read(i, j);
107 }
108 }
109
110 Ok(x)
111 }
112}
113
114pub trait LUSolver<S: Scalar> {
116 fn lu_solve(&self, b: &[S]) -> Result<Vec<S>, LinalgError>;
118
119 fn lu_factor(&self) -> Result<LUFactorization<S>, LinalgError>
121 where
122 S: Entity + SimpleEntity + Conjugate<Canonical = S> + ComplexField;
123}
124
125impl<S: Scalar + SimpleEntity + Conjugate<Canonical = S> + ComplexField> LUSolver<S>
126 for DenseMatrix<S>
127{
128 fn lu_solve(&self, b: &[S]) -> Result<Vec<S>, LinalgError> {
129 self.solve(b)
130 }
131
132 fn lu_factor(&self) -> Result<LUFactorization<S>, LinalgError> {
133 LUFactorization::new(self)
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140 use crate::Matrix;
141
142 #[test]
143 fn test_lu_factorization() {
144 let mut m: DenseMatrix<f64> = DenseMatrix::zeros(3, 3);
145 m.set(0, 0, 2.0);
146 m.set(1, 1, 3.0);
147 m.set(2, 2, 4.0);
148
149 let lu = LUFactorization::new(&m).unwrap();
150 assert_eq!(lu.dim(), 3);
151 }
152
153 #[test]
154 fn test_lu_solve() {
155 let mut m: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
156 m.set(0, 0, 1.0);
157 m.set(0, 1, 2.0);
158 m.set(1, 0, 3.0);
159 m.set(1, 1, 4.0);
160
161 let lu = LUFactorization::new(&m).unwrap();
162 let b = vec![5.0, 11.0];
163 let x = lu.solve(&b).unwrap();
164
165 assert!((x[0] - 1.0).abs() < 1e-10);
166 assert!((x[1] - 2.0).abs() < 1e-10);
167 }
168
169 #[test]
170 fn test_lu_solve_inplace() {
171 let mut m: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
172 m.set(0, 0, 2.0);
173 m.set(0, 1, 0.0);
174 m.set(1, 0, 0.0);
175 m.set(1, 1, 3.0);
176
177 let lu = LUFactorization::new(&m).unwrap();
178 let mut b = vec![4.0, 9.0];
179 lu.solve_inplace(&mut b).unwrap();
180
181 assert!((b[0] - 2.0).abs() < 1e-10);
182 assert!((b[1] - 3.0).abs() < 1e-10);
183 }
184
185 #[test]
186 fn test_lu_solve_multi() {
187 let mut m: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
188 m.set(0, 0, 2.0);
189 m.set(1, 1, 3.0);
190
191 let lu = LUFactorization::new(&m).unwrap();
192
193 let b = vec![2.0, 6.0, 4.0, 9.0];
195 let x = lu.solve_multi(&b, 2).unwrap();
196
197 assert!((x[0] - 1.0).abs() < 1e-10);
199 assert!((x[1] - 2.0).abs() < 1e-10);
200 assert!((x[2] - 2.0).abs() < 1e-10);
202 assert!((x[3] - 3.0).abs() < 1e-10);
203 }
204
205 #[test]
206 fn test_lu_solver_trait() {
207 let mut m: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
208 m.set(0, 0, 1.0);
209 m.set(0, 1, 2.0);
210 m.set(1, 0, 3.0);
211 m.set(1, 1, 4.0);
212
213 let b = vec![5.0, 11.0];
214 let x = m.lu_solve(&b).unwrap();
215
216 assert!((x[0] - 1.0).abs() < 1e-10);
217 assert!((x[1] - 2.0).abs() < 1e-10);
218 }
219
220 #[test]
221 fn test_lu_repeated_solve() {
222 let mut m: DenseMatrix<f64> = DenseMatrix::zeros(2, 2);
225 m.set(0, 0, 1.0);
226 m.set(0, 1, 2.0);
227 m.set(1, 0, 3.0);
228 m.set(1, 1, 4.0);
229
230 let lu = LUFactorization::new(&m).unwrap();
231
232 let b1 = vec![5.0, 11.0];
234 let x1 = lu.solve(&b1).unwrap();
235 assert!((x1[0] - 1.0).abs() < 1e-10);
236 assert!((x1[1] - 2.0).abs() < 1e-10);
237
238 let b2 = vec![5.0, 13.0];
240 let x2 = lu.solve(&b2).unwrap();
241 assert!((x2[0] - 3.0).abs() < 1e-10);
242 assert!((x2[1] - 1.0).abs() < 1e-10);
243 }
244}