differential_equations/linalg/matrix/
linear.rs1use crate::{
4 error::Error,
5 traits::{Real, State},
6};
7
8use super::base::{Matrix, MatrixStorage};
9
10impl<T: Real> Matrix<T> {
11 pub fn lin_solve<Y>(&self, b: Y) -> Result<Y, Error<T, Y>>
13 where
14 Y: State<T>,
15 {
16 let n = self.n;
17 if b.len() != n {
18 return Err(Error::BadInput {
19 msg: "Incompatible vector length".into(),
20 });
21 }
22
23 let mut a = vec![T::zero(); n * n];
25 match &self.storage {
26 MatrixStorage::Identity => {
27 for i in 0..n {
28 a[i * n + i] = T::one();
29 }
30 }
31 MatrixStorage::Full => {
32 a.copy_from_slice(&self.data[0..n * n]);
33 }
34 MatrixStorage::Banded { ml, mu, .. } => {
35 let rows = *ml + *mu + 1;
36 for j in 0..self.m {
37 for r in 0..rows {
38 let k = r as isize - *mu as isize; let i_signed = j as isize + k;
40 if i_signed >= 0 && (i_signed as usize) < self.n {
41 let i = i_signed as usize;
42 a[i * self.m + j] += self.data[r * self.m + j];
43 }
44 }
45 }
46 }
47 }
48
49 let mut x = b;
51
52 let mut piv: Vec<usize> = (0..n).collect();
54 let eps = T::from_f64(1e-14).unwrap(); let mut swapper;
57 for k in 0..n {
58 let mut pivot_row = k;
60 let mut pivot_val = a[k * n + k].abs();
61 for i in (k + 1)..n {
62 let val = a[i * n + k].abs();
63 if val > pivot_val {
64 pivot_val = val;
65 pivot_row = i;
66 }
67 }
68
69 if pivot_val <= eps {
71 return Err(Error::LinearAlgebra {
73 msg: "Singular matrix encountered".into(),
74 });
75 }
76
77 if pivot_row != k {
78 for j in 0..n {
80 a.swap(k * n + j, pivot_row * n + j);
81 }
82 swapper = x.get(k);
84 x.set(k, x.get(pivot_row));
85 x.set(pivot_row, swapper);
86 piv.swap(k, pivot_row);
87 }
88
89 let akk = a[k * n + k];
91 for i in (k + 1)..n {
92 let factor = a[i * n + k] / akk;
93 a[i * n + k] = factor; for j in (k + 1)..n {
95 a[i * n + j] = a[i * n + j] - factor * a[k * n + j];
96 }
97 }
98 }
99
100 for i in 0..n {
102 let mut sum = x.get(i);
103 for k in 0..i {
104 sum -= a[i * n + k] * x.get(k);
105 }
106 x.set(i, sum); }
108
109 for i in (0..n).rev() {
111 let mut sum = x.get(i);
112 for k in (i + 1)..n {
113 sum -= a[i * n + k] * x.get(k);
114 }
115 x.set(i, sum / a[i * n + i]);
116 }
117
118 let mut out = Y::zeros();
120 for i in 0..n {
121 out.set(i, x.get(i));
122 }
123 Ok(out)
124 }
125
126 pub fn lin_solve_mut(&self, b: &mut [T]) {
128 let n = self.n;
129 assert_eq!(
130 b.len(),
131 n,
132 "dimension mismatch in solve: A is {}x{}, b has length {}",
133 n,
134 n,
135 b.len()
136 );
137
138 let mut a = vec![T::zero(); n * n];
140 match &self.storage {
141 MatrixStorage::Identity => {
142 for i in 0..n {
143 a[i * n + i] = T::one();
144 }
145 }
146 MatrixStorage::Full => {
147 a.copy_from_slice(&self.data[0..n * n]);
148 }
149 MatrixStorage::Banded { ml, mu, .. } => {
150 let rows = *ml + *mu + 1;
151 for j in 0..self.m {
152 for r in 0..rows {
153 let k = r as isize - *mu as isize;
154 let i_signed = j as isize + k;
155 if i_signed >= 0 && (i_signed as usize) < self.n {
156 let i = i_signed as usize;
157 a[i * self.m + j] += self.data[r * self.m + j];
158 }
159 }
160 }
161 }
162 }
163
164 for k in 0..n {
166 let mut pivot_row = k;
168 let mut pivot_val = a[k * n + k].abs();
169 for i in (k + 1)..n {
170 let val = a[i * n + k].abs();
171 if val > pivot_val {
172 pivot_val = val;
173 pivot_row = i;
174 }
175 }
176 if pivot_val == T::zero() {
177 panic!("singular matrix in solve");
178 }
179 if pivot_row != k {
180 for j in 0..n {
181 a.swap(k * n + j, pivot_row * n + j);
182 }
183 b.swap(k, pivot_row);
184 }
185 let akk = a[k * n + k];
187 for i in (k + 1)..n {
188 let factor = a[i * n + k] / akk;
189 a[i * n + k] = factor;
190 for j in (k + 1)..n {
191 a[i * n + j] = a[i * n + j] - factor * a[k * n + j];
192 }
193 }
194 }
195
196 for i in 0..n {
198 let mut sum = b[i];
199 for k in 0..i {
200 sum -= a[i * n + k] * b[k];
201 }
202 b[i] = sum;
203 }
204 for i in (0..n).rev() {
206 let mut sum = b[i];
207 for k in (i + 1)..n {
208 sum -= a[i * n + k] * b[k];
209 }
210 b[i] = sum / a[i * n + i];
211 }
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use crate::linalg::matrix::Matrix;
218 use nalgebra::Vector2;
219
220 #[test]
221 fn solve_full_2x2() {
222 let mut a: Matrix<f64> = Matrix::full(2, 2);
224 a[(0, 0)] = 3.0;
225 a[(0, 1)] = 2.0;
226 a[(1, 0)] = 1.0;
227 a[(1, 1)] = 4.0;
228 let b = Vector2::new(5.0, 6.0);
229 let x = a.lin_solve(b).unwrap();
230 assert!((x.x - 0.8).abs() < 1e-12);
232 assert!((x.y - 1.3).abs() < 1e-12);
233 }
234}