differential_equations/linalg/matrix/
linear.rs1use crate::traits::{Real, State};
4
5use super::base::{Matrix, MatrixStorage};
6
7impl<T: Real> Matrix<T> {
8 pub fn lin_solve<V>(&self, b: V) -> V
10 where
11 V: State<T>,
12 {
13 let n = self.n();
14 assert_eq!(
15 b.len(),
16 n,
17 "dimension mismatch in solve: A is {}x{}, b has length {}",
18 n,
19 n,
20 b.len()
21 );
22
23 let mut a = vec![T::zero(); n * n];
25 match &self.storage {
26 MatrixStorage::Identity => {
27 for i in 0..n {
28 a[i * n + i] = T::one();
29 }
30 }
31 MatrixStorage::Full => {
32 a.copy_from_slice(&self.data[0..n * n]);
33 }
34 MatrixStorage::Banded { ml, mu, .. } => {
35 let rows = *ml + *mu + 1;
36 for j in 0..self.ncols {
37 for r in 0..rows {
38 let k = r as isize - *mu as isize; let i_signed = j as isize + k;
40 if i_signed >= 0 && (i_signed as usize) < self.nrows {
41 let i = i_signed as usize;
42 a[i * self.ncols + j] =
43 a[i * self.ncols + j] + self.data[r * self.ncols + j];
44 }
45 }
46 }
47 }
48 }
49
50 let mut x = vec![T::zero(); n];
52 for i in 0..n {
53 x[i] = b.get(i);
54 }
55
56 let mut piv: Vec<usize> = (0..n).collect();
58 for k in 0..n {
59 let mut pivot_row = k;
61 let mut pivot_val = a[k * n + k].abs();
62 for i in (k + 1)..n {
63 let val = a[i * n + k].abs();
64 if val > pivot_val {
65 pivot_val = val;
66 pivot_row = i;
67 }
68 }
69 if pivot_val == T::zero() {
70 panic!("singular matrix in solve");
71 }
72 if pivot_row != k {
73 for j in 0..n {
75 a.swap(k * n + j, pivot_row * n + j);
76 }
77 x.swap(k, pivot_row);
79 piv.swap(k, pivot_row);
80 }
81 let akk = a[k * n + k];
83 for i in (k + 1)..n {
84 let factor = a[i * n + k] / akk;
85 a[i * n + k] = factor; for j in (k + 1)..n {
87 a[i * n + j] = a[i * n + j] - factor * a[k * n + j];
88 }
89 }
90 }
91
92 for i in 0..n {
94 let mut sum = x[i];
95 for k in 0..i {
96 sum = sum - a[i * n + k] * x[k];
97 }
98 x[i] = sum; }
100
101 for i in (0..n).rev() {
103 let mut sum = x[i];
104 for k in (i + 1)..n {
105 sum = sum - a[i * n + k] * x[k];
106 }
107 x[i] = sum / a[i * n + i];
108 }
109
110 let mut out = V::zeros();
112 for i in 0..n {
113 out.set(i, x[i]);
114 }
115 out
116 }
117
118 pub fn lin_solve_mut(&self, b: &mut [T]) {
120 let n = self.n();
121 assert_eq!(
122 b.len(),
123 n,
124 "dimension mismatch in solve: A is {}x{}, b has length {}",
125 n,
126 n,
127 b.len()
128 );
129
130 let mut a = vec![T::zero(); n * n];
132 match &self.storage {
133 MatrixStorage::Identity => {
134 for i in 0..n {
135 a[i * n + i] = T::one();
136 }
137 }
138 MatrixStorage::Full => {
139 a.copy_from_slice(&self.data[0..n * n]);
140 }
141 MatrixStorage::Banded { ml, mu, .. } => {
142 let rows = *ml + *mu + 1;
143 for j in 0..self.ncols {
144 for r in 0..rows {
145 let k = r as isize - *mu as isize;
146 let i_signed = j as isize + k;
147 if i_signed >= 0 && (i_signed as usize) < self.nrows {
148 let i = i_signed as usize;
149 a[i * self.ncols + j] =
150 a[i * self.ncols + j] + self.data[r * self.ncols + j];
151 }
152 }
153 }
154 }
155 }
156
157 for k in 0..n {
159 let mut pivot_row = k;
161 let mut pivot_val = a[k * n + k].abs();
162 for i in (k + 1)..n {
163 let val = a[i * n + k].abs();
164 if val > pivot_val {
165 pivot_val = val;
166 pivot_row = i;
167 }
168 }
169 if pivot_val == T::zero() {
170 panic!("singular matrix in solve");
171 }
172 if pivot_row != k {
173 for j in 0..n {
174 a.swap(k * n + j, pivot_row * n + j);
175 }
176 b.swap(k, pivot_row);
177 }
178 let akk = a[k * n + k];
180 for i in (k + 1)..n {
181 let factor = a[i * n + k] / akk;
182 a[i * n + k] = factor;
183 for j in (k + 1)..n {
184 a[i * n + j] = a[i * n + j] - factor * a[k * n + j];
185 }
186 }
187 }
188
189 for i in 0..n {
191 let mut sum = b[i];
192 for k in 0..i {
193 sum = sum - a[i * n + k] * b[k];
194 }
195 b[i] = sum;
196 }
197 for i in (0..n).rev() {
199 let mut sum = b[i];
200 for k in (i + 1)..n {
201 sum = sum - a[i * n + k] * b[k];
202 }
203 b[i] = sum / a[i * n + i];
204 }
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use crate::linalg::matrix::Matrix;
211 use nalgebra::Vector2;
212
213 #[test]
214 fn solve_full_2x2() {
215 let a: Matrix<f64> = Matrix::full(2, vec![3.0, 2.0, 1.0, 4.0]);
217 let b = Vector2::new(5.0, 6.0);
218 let x = a.lin_solve(b);
219 assert!((x.x - 0.8).abs() < 1e-12);
221 assert!((x.y - 1.3).abs() < 1e-12);
222 }
223}