differential_equations/linalg/matrix/
linear.rs

1//! Linear solves: A x = b via LU with partial pivoting on a dense copy.
2
3use crate::traits::{Real, State};
4
5use super::base::{Matrix, MatrixStorage};
6
7impl<T: Real> Matrix<T> {
8    /// Solve A x = b (returns same `State<T>` type as `b`). Densifies `A` as needed.
9    pub fn lin_solve<V>(&self, b: V) -> V
10    where
11        V: State<T>,
12    {
13        let n = self.n();
14        assert_eq!(
15            b.len(),
16            n,
17            "dimension mismatch in solve: A is {}x{}, b has length {}",
18            n,
19            n,
20            b.len()
21        );
22
23        // 1) Densify A into a Vec<T> of size n*n (row-major)
24        let mut a = vec![T::zero(); n * n];
25        match &self.storage {
26            MatrixStorage::Identity => {
27                for i in 0..n {
28                    a[i * n + i] = T::one();
29                }
30            }
31            MatrixStorage::Full => {
32                a.copy_from_slice(&self.data[0..n * n]);
33            }
34            MatrixStorage::Banded { ml, mu, .. } => {
35                let rows = *ml + *mu + 1;
36                for j in 0..self.ncols {
37                    for r in 0..rows {
38                        let k = r as isize - *mu as isize; // i - j
39                        let i_signed = j as isize + k;
40                        if i_signed >= 0 && (i_signed as usize) < self.nrows {
41                            let i = i_signed as usize;
42                            a[i * self.ncols + j] =
43                                a[i * self.ncols + j] + self.data[r * self.ncols + j];
44                        }
45                    }
46                }
47            }
48        }
49
50        // 2) Copy b into a dense vector x and perform in-place solve
51        let mut x = vec![T::zero(); n];
52        for i in 0..n {
53            x[i] = b.get(i);
54        }
55
56        // 3) LU factorization with partial pivoting (Doolittle-style)
57        let mut piv: Vec<usize> = (0..n).collect();
58        for k in 0..n {
59            // Find pivot row
60            let mut pivot_row = k;
61            let mut pivot_val = a[k * n + k].abs();
62            for i in (k + 1)..n {
63                let val = a[i * n + k].abs();
64                if val > pivot_val {
65                    pivot_val = val;
66                    pivot_row = i;
67                }
68            }
69            if pivot_val == T::zero() {
70                panic!("singular matrix in solve");
71            }
72            if pivot_row != k {
73                // swap rows in A
74                for j in 0..n {
75                    a.swap(k * n + j, pivot_row * n + j);
76                }
77                // swap entries in x (we'll apply permutation to RHS)
78                x.swap(k, pivot_row);
79                piv.swap(k, pivot_row);
80            }
81            // Eliminate below the pivot
82            let akk = a[k * n + k];
83            for i in (k + 1)..n {
84                let factor = a[i * n + k] / akk;
85                a[i * n + k] = factor; // store L(i,k)
86                for j in (k + 1)..n {
87                    a[i * n + j] = a[i * n + j] - factor * a[k * n + j];
88                }
89            }
90        }
91
92        // Forward solve Ly = Pb (x currently holds permuted b)
93        for i in 0..n {
94            let mut sum = x[i];
95            for k in 0..i {
96                sum = sum - a[i * n + k] * x[k];
97            }
98            x[i] = sum; // since L has ones on diagonal
99        }
100
101        // Backward solve Ux = y
102        for i in (0..n).rev() {
103            let mut sum = x[i];
104            for k in (i + 1)..n {
105                sum = sum - a[i * n + k] * x[k];
106            }
107            x[i] = sum / a[i * n + i];
108        }
109
110        // 6) Build output State from x
111        let mut out = V::zeros();
112        for i in 0..n {
113            out.set(i, x[i]);
114        }
115        out
116    }
117
118    /// In-place solve: overwrites `b` with `x`.
119    pub fn lin_solve_mut(&self, b: &mut [T]) {
120        let n = self.n();
121        assert_eq!(
122            b.len(),
123            n,
124            "dimension mismatch in solve: A is {}x{}, b has length {}",
125            n,
126            n,
127            b.len()
128        );
129
130        // Densify A into row-major Vec<T>
131        let mut a = vec![T::zero(); n * n];
132        match &self.storage {
133            MatrixStorage::Identity => {
134                for i in 0..n {
135                    a[i * n + i] = T::one();
136                }
137            }
138            MatrixStorage::Full => {
139                a.copy_from_slice(&self.data[0..n * n]);
140            }
141            MatrixStorage::Banded { ml, mu, .. } => {
142                let rows = *ml + *mu + 1;
143                for j in 0..self.ncols {
144                    for r in 0..rows {
145                        let k = r as isize - *mu as isize;
146                        let i_signed = j as isize + k;
147                        if i_signed >= 0 && (i_signed as usize) < self.nrows {
148                            let i = i_signed as usize;
149                            a[i * self.ncols + j] =
150                                a[i * self.ncols + j] + self.data[r * self.ncols + j];
151                        }
152                    }
153                }
154            }
155        }
156
157        // LU with partial pivoting, applying permutations to b
158        for k in 0..n {
159            // pivot
160            let mut pivot_row = k;
161            let mut pivot_val = a[k * n + k].abs();
162            for i in (k + 1)..n {
163                let val = a[i * n + k].abs();
164                if val > pivot_val {
165                    pivot_val = val;
166                    pivot_row = i;
167                }
168            }
169            if pivot_val == T::zero() {
170                panic!("singular matrix in solve");
171            }
172            if pivot_row != k {
173                for j in 0..n {
174                    a.swap(k * n + j, pivot_row * n + j);
175                }
176                b.swap(k, pivot_row);
177            }
178            // Eliminate below the pivot
179            let akk = a[k * n + k];
180            for i in (k + 1)..n {
181                let factor = a[i * n + k] / akk;
182                a[i * n + k] = factor;
183                for j in (k + 1)..n {
184                    a[i * n + j] = a[i * n + j] - factor * a[k * n + j];
185                }
186            }
187        }
188
189        // Forward solve Ly = Pb (b is permuted)
190        for i in 0..n {
191            let mut sum = b[i];
192            for k in 0..i {
193                sum = sum - a[i * n + k] * b[k];
194            }
195            b[i] = sum;
196        }
197        // Backward solve Ux = y
198        for i in (0..n).rev() {
199            let mut sum = b[i];
200            for k in (i + 1)..n {
201                sum = sum - a[i * n + k] * b[k];
202            }
203            b[i] = sum / a[i * n + i];
204        }
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use crate::linalg::matrix::Matrix;
211    use nalgebra::Vector2;
212
213    #[test]
214    fn solve_full_2x2() {
215        // A = [[3, 2],[1, 4]], b = [5, 6] -> x = [0.8, 1.3]
216        let a: Matrix<f64> = Matrix::full(2, vec![3.0, 2.0, 1.0, 4.0]);
217        let b = Vector2::new(5.0, 6.0);
218        let x = a.lin_solve(b);
219        // Solve manually: [[3,2],[1,4]] x = [5,6] => x = [ (20-12)/10, (15-5)/10 ] = [0.8, 1.3]
220        assert!((x.x - 0.8).abs() < 1e-12);
221        assert!((x.y - 1.3).abs() < 1e-12);
222    }
223}