differential_equations/linalg/matrix/
linear.rs

1//! Linear solves: A x = b via LU with partial pivoting on a dense copy.
2
3use crate::{
4    error::Error,
5    traits::{Real, State},
6};
7
8use super::base::{Matrix, MatrixStorage};
9
10impl<T: Real> Matrix<T> {
11    /// Solve A x = b, Returns Err if the matrix is singular or dimensions are incompatible
12    pub fn lin_solve<Y>(&self, b: Y) -> Result<Y, Error<T, Y>>
13    where
14        Y: State<T>,
15    {
16        let n = self.n;
17        if b.len() != n {
18            return Err(Error::BadInput {
19                msg: "Incompatible vector length".into(),
20            });
21        }
22
23        // 1) Densify A into a Vec<T> of size n*n (row-major)
24        let mut a = vec![T::zero(); n * n];
25        match &self.storage {
26            MatrixStorage::Identity => {
27                for i in 0..n {
28                    a[i * n + i] = T::one();
29                }
30            }
31            MatrixStorage::Full => {
32                a.copy_from_slice(&self.data[0..n * n]);
33            }
34            MatrixStorage::Banded { ml, mu, .. } => {
35                let rows = *ml + *mu + 1;
36                for j in 0..self.m {
37                    for r in 0..rows {
38                        let k = r as isize - *mu as isize; // i - j
39                        let i_signed = j as isize + k;
40                        if i_signed >= 0 && (i_signed as usize) < self.n {
41                            let i = i_signed as usize;
42                            a[i * self.m + j] += self.data[r * self.m + j];
43                        }
44                    }
45                }
46            }
47        }
48
49        // 2) Copy b into a dense vector x and perform solve
50        let mut x = b;
51
52        // 3) LU factorization with partial pivoting and singularity checking
53        let mut piv: Vec<usize> = (0..n).collect();
54        let eps = T::from_f64(1e-14).unwrap(); // Singularity threshold
55
56        let mut swapper;
57        for k in 0..n {
58            // Find pivot row
59            let mut pivot_row = k;
60            let mut pivot_val = a[k * n + k].abs();
61            for i in (k + 1)..n {
62                let val = a[i * n + k].abs();
63                if val > pivot_val {
64                    pivot_val = val;
65                    pivot_row = i;
66                }
67            }
68
69            // Check for singularity
70            if pivot_val <= eps {
71                // Note the t, y are not known here and should be updated by caller before returning to user
72                return Err(Error::LinearAlgebra {
73                    msg: "Singular matrix encountered".into(),
74                });
75            }
76
77            if pivot_row != k {
78                // swap rows in A
79                for j in 0..n {
80                    a.swap(k * n + j, pivot_row * n + j);
81                }
82                // swap entries in x
83                swapper = x.get(k);
84                x.set(k, x.get(pivot_row));
85                x.set(pivot_row, swapper);
86                piv.swap(k, pivot_row);
87            }
88
89            // Eliminate below the pivot
90            let akk = a[k * n + k];
91            for i in (k + 1)..n {
92                let factor = a[i * n + k] / akk;
93                a[i * n + k] = factor; // store L(i,k)
94                for j in (k + 1)..n {
95                    a[i * n + j] = a[i * n + j] - factor * a[k * n + j];
96                }
97            }
98        }
99
100        // Forward solve Ly = Pb (x currently holds permuted b)
101        for i in 0..n {
102            let mut sum = x.get(i);
103            for k in 0..i {
104                sum -= a[i * n + k] * x.get(k);
105            }
106            x.set(i, sum); // since L has ones on diagonal
107        }
108
109        // Backward solve Ux = y
110        for i in (0..n).rev() {
111            let mut sum = x.get(i);
112            for k in (i + 1)..n {
113                sum -= a[i * n + k] * x.get(k);
114            }
115            x.set(i, sum / a[i * n + i]);
116        }
117
118        // Build output State from x
119        let mut out = Y::zeros();
120        for i in 0..n {
121            out.set(i, x.get(i));
122        }
123        Ok(out)
124    }
125
126    /// In-place solve: overwrites `b` with `x`.
127    pub fn lin_solve_mut(&self, b: &mut [T]) {
128        let n = self.n;
129        assert_eq!(
130            b.len(),
131            n,
132            "dimension mismatch in solve: A is {}x{}, b has length {}",
133            n,
134            n,
135            b.len()
136        );
137
138        // Densify A into row-major Vec<T>
139        let mut a = vec![T::zero(); n * n];
140        match &self.storage {
141            MatrixStorage::Identity => {
142                for i in 0..n {
143                    a[i * n + i] = T::one();
144                }
145            }
146            MatrixStorage::Full => {
147                a.copy_from_slice(&self.data[0..n * n]);
148            }
149            MatrixStorage::Banded { ml, mu, .. } => {
150                let rows = *ml + *mu + 1;
151                for j in 0..self.m {
152                    for r in 0..rows {
153                        let k = r as isize - *mu as isize;
154                        let i_signed = j as isize + k;
155                        if i_signed >= 0 && (i_signed as usize) < self.n {
156                            let i = i_signed as usize;
157                            a[i * self.m + j] += self.data[r * self.m + j];
158                        }
159                    }
160                }
161            }
162        }
163
164        // LU with partial pivoting, applying permutations to b
165        for k in 0..n {
166            // pivot
167            let mut pivot_row = k;
168            let mut pivot_val = a[k * n + k].abs();
169            for i in (k + 1)..n {
170                let val = a[i * n + k].abs();
171                if val > pivot_val {
172                    pivot_val = val;
173                    pivot_row = i;
174                }
175            }
176            if pivot_val == T::zero() {
177                panic!("singular matrix in solve");
178            }
179            if pivot_row != k {
180                for j in 0..n {
181                    a.swap(k * n + j, pivot_row * n + j);
182                }
183                b.swap(k, pivot_row);
184            }
185            // Eliminate below the pivot
186            let akk = a[k * n + k];
187            for i in (k + 1)..n {
188                let factor = a[i * n + k] / akk;
189                a[i * n + k] = factor;
190                for j in (k + 1)..n {
191                    a[i * n + j] = a[i * n + j] - factor * a[k * n + j];
192                }
193            }
194        }
195
196        // Forward solve Ly = Pb (b is permuted)
197        for i in 0..n {
198            let mut sum = b[i];
199            for k in 0..i {
200                sum -= a[i * n + k] * b[k];
201            }
202            b[i] = sum;
203        }
204        // Backward solve Ux = y
205        for i in (0..n).rev() {
206            let mut sum = b[i];
207            for k in (i + 1)..n {
208                sum -= a[i * n + k] * b[k];
209            }
210            b[i] = sum / a[i * n + i];
211        }
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use crate::linalg::matrix::Matrix;
218    use nalgebra::Vector2;
219
220    #[test]
221    fn solve_full_2x2() {
222        // A = [[3, 2],[1, 4]], b = [5, 6] -> x = [0.8, 1.3]
223        let mut a: Matrix<f64> = Matrix::full(2, 2);
224        a[(0, 0)] = 3.0;
225        a[(0, 1)] = 2.0;
226        a[(1, 0)] = 1.0;
227        a[(1, 1)] = 4.0;
228        let b = Vector2::new(5.0, 6.0);
229        let x = a.lin_solve(b).unwrap();
230        // Solve manually: [[3,2],[1,4]] x = [5,6] => x = [ (20-12)/10, (15-5)/10 ] = [0.8, 1.3]
231        assert!((x.x - 0.8).abs() < 1e-12);
232        assert!((x.y - 1.3).abs() < 1e-12);
233    }
234}