te1d/linalg.rs
1//! Solves linear systems and handles matrices.
2//! * The utilities in this module is not at all fast. For massive computations, use other specialized linear algebra packages.
3//! * By matrices, we mean [`ndarray::Array2`].
4
5use std::fmt;
6
7use ndarray::prelude::*;
8
9use crate::calc::ColVec;
10
11
12/// Error type for linear algebraic computations
13pub struct LinalgError {
14 repr: LinalgErrorKind,
15}
16
17impl fmt::Display for LinalgErrorKind {
18 /// Shows a human-readable description of the [`LinalgErrorKind`].
19 ///
20 /// This is similar to `impl Display for Error`, but doesn't require first converting to Error.
21 ///
22 /// # Examples
23 /// ```
24 /// use te1d::linalg::LinalgErrorKind;
25 /// assert_eq!("singular matrix", LinalgErrorKind::SingularMatrixError.to_string());
26 /// ```
27 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
28 fmt.write_str(self.as_str())
29 }
30}
31
32impl fmt::Debug for LinalgError {
33 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34 fmt::Debug::fmt(&self.repr, f)
35 }
36}
37
38impl LinalgError {
39 pub fn new(error_kind: LinalgErrorKind) -> Self {
40 LinalgError { repr: error_kind }
41 }
42
43 #[inline]
44 pub fn kind(&self) -> LinalgErrorKind {
45 self.repr
46 }
47}
48
49/// A list specifying general categories of errors in linear algebraic computation
50#[derive(Clone, Copy, Debug, PartialEq)]
51pub enum LinalgErrorKind {
52 /// the matrix is singular, that is, not invertible
53 SingularMatrixError,
54}
55
56impl LinalgErrorKind {
57 pub(crate) fn as_str(&self) -> &'static str {
58 use LinalgErrorKind::*;
59 // Strictly alphabetical, please. (Sadly rustfmt cannot do this yet.)
60 match *self {
61 SingularMatrixError => "singular matrix",
62 }
63 }
64}
65
66
67/// Find the solution `x` of the linear system `mat * x = b` using Gaussian elimination.
68///
69/// # Arguments
70/// * `square_mat`: a square matrix.
71///
72/// # Examples
73/// ```
74/// use ndarray::{Array2, array};
75/// use te1d::prelude::*;
76/// use te1d::linalg::{solve_gauss, LinalgErrorKind};
77///
78/// let mat: Array2<f64> = array![[1.0, 0.0, 1.0], [0.0, -3.0, 1.0], [2.0, 1.0, 3.0]];
79/// let b: ColVec = array![6.0, 7.0, 15.0];
80/// let x = solve_gauss(&mat, &b).unwrap();
81/// let x_exact: ColVec = array![2.0, -1.0, 4.0];
82///
83/// assert!(all_close(&x, &x_exact, 1e-05, 1e-08));
84///
85/// // Error for singular matrix
86/// let mat: Array2<f64> = array![[0.0, 0.0], [1.0, 2.0]];
87/// let b: ColVec = array![1.0, 2.0];
88/// let msg = match solve_gauss(&mat, &b) {
89/// Ok(b) => String::from("no error"),
90/// Err(err) => err.kind().to_string(),
91/// };
92/// assert_eq!(msg, LinalgErrorKind::SingularMatrixError.to_string());
93/// ```
94///
95/// # Panics
96/// * when `mat` is not square.
97/// * when the number of rows of `mat` and the length of `b` are not the same.
98pub fn solve_gauss(square_mat: &Array2<f64>, b: &ColVec) -> Result<ColVec, LinalgError> {
99 if !square_mat.is_square() {
100 panic!("`mat` must be a square matrix!");
101 }
102 let n = square_mat.shape()[0];
103 if b.len() != n {
104 panic!("the number of rows of `mat` and the length of `b` must be the same!");
105 }
106
107 let mut mat = square_mat.to_owned();
108 let mut b = b.to_owned();
109 let mut pivot_row: usize;
110 let mut pivot_value: f64;
111 let mut target_row_pivot_value: f64;
112
113 // perform upper triangulation
114 for row in 0..n {
115 // set the pivot row
116 pivot_row = row_of_abs_max(&mat, row, row, n-1);
117 if row != pivot_row {
118 swap_rows(&mut mat, row, pivot_row);
119 b.swap(row, pivot_row);
120 }
121 // normalize the pivot row and `b`
122 pivot_value = mat[[row, row]];
123 if pivot_value.abs() <= f64::EPSILON {
124 return Err(LinalgError::new(LinalgErrorKind::SingularMatrixError));
125 }
126 mat[[row, row]] = 1.0;
127 mat.row_mut(row).slice_mut(s![row+1..]).mapv_inplace(|e| { e/pivot_value });
128 b[row] /= pivot_value;
129 // make upper triangular
130 for target_row in row+1..n {
131 target_row_pivot_value = mat[[target_row, row]];
132 mat[[target_row, row]] = 0.0;
133 for col in row+1..n {
134 mat[[target_row, col]] -= target_row_pivot_value * mat[[row, col]];
135 }
136 b[target_row] -= target_row_pivot_value * b[row];
137 }
138 }
139
140 // perform diagonalization
141 for row in (0..n).rev() {
142 for target_row in 0..row {
143 target_row_pivot_value = mat[[target_row, row]];
144 mat[[target_row, row]] = 0.0;
145 for col in row+1..n {
146 mat[[target_row, col]] -= target_row_pivot_value * mat[[row, col]];
147 }
148 b[target_row] -= target_row_pivot_value * b[row];
149 }
150 }
151
152 Ok(b)
153}
154
155
156/// Return the index of the element having the largest absolute value among a given column.
157fn row_of_abs_max(mat: &Array2<f64>, col: usize, start_row: usize, end_row: usize) -> usize {
158 let mut max_value: f64 = mat[[start_row, col]].abs();
159 let mut result: usize = start_row;
160 for i in start_row+1..end_row+1 {
161 let cur_value = mat[[i, col]].abs();
162 if cur_value > max_value {
163 max_value = cur_value;
164 result = i;
165 }
166 }
167
168 result
169}
170
171
172/// Swap two rows of a matrix.
173///
174/// # Examples
175/// ```
176/// use ndarray::prelude::*;
177/// use te1d::prelude::*;
178/// use te1d::linalg::swap_rows;
179///
180/// let mut mat: Array2<f64> = array![[0.0, 1.0], [1.0, 2.0], [2.0, 3.0]];
181/// let result: Array2<f64> = array![[2.0, 3.0], [1.0, 2.0], [0.0, 1.0]];
182///
183/// swap_rows(&mut mat, 0, 2);
184/// assert_eq!(mat, result);
185/// ```
186#[inline]
187pub fn swap_rows(mat: &mut Array2<f64>, row1: usize, row2: usize) {
188 for i in 0..mat.shape()[1] {
189 mat.swap([row1, i], [row2, i]);
190 }
191}
192
193
194/// Swap two columns of a matrix.
195///
196/// # Examples
197/// ```
198/// use ndarray::prelude::*;
199/// use te1d::prelude::*;
200/// use te1d::linalg::swap_cols;
201///
202/// let mut mat: Array2<f64> = array![[0.0, 1.0, 2.0], [1.0, 2.0, 3.0]];
203/// let result: Array2<f64> = array![[2.0, 1.0, 0.0], [3.0, 2.0, 1.0]];
204///
205/// swap_cols(&mut mat, 0, 2);
206/// assert_eq!(mat, result);
207/// ```
208#[inline]
209pub fn swap_cols(mat: &mut Array2<f64>, col1: usize, col2: usize) {
210 for i in 0..mat.shape()[0] {
211 mat.swap([i, col1], [i, col2]);
212 }
213}
214
215
216/// Compute the outer product of two [`ColVec`]s.
217/// For two column vectors `u` and `v`, the result is `u*v^T`.
218///
219/// # Examples
220/// ```
221/// use ndarray::prelude::*;
222/// use te1d::prelude::*;
223/// use te1d::linalg::outer_product;
224///
225/// let u: ColVec = array![1.0, 2.0, 3.0];
226/// let v: ColVec = array![4.0, 5.0];
227/// let result: Array2<f64> = outer_product(&u, &v);
228/// let sol: Array2<f64> = array![[4.0, 5.0], [8.0, 10.0], [12.0, 15.0]];
229///
230/// assert_eq!(&result, &sol);
231/// ```
232pub fn outer_product(u: &ColVec, v: &ColVec) -> Array2<f64> {
233 let m: usize = u.len();
234 let n: usize = v.len();
235 let mut result: Array2<f64> = Array::default((m, n));
236 for i in 0..m {
237 for j in 0..n {
238 result[[i, j]] = u[i] * v[j];
239 }
240 }
241
242 result
243}
244
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249
250 #[test]
251 fn test_row_of_abs_max() {
252 let mat: Array2<f64> = array![[1.0, 2.0], [-5.0, -2.0], [2.0, -7.0], [3.0, 9.0]];
253 let result = row_of_abs_max(&mat, 1, 1, 2);
254 assert_eq!(result, 2 as usize);
255 }
256
257 #[test]
258 #[should_panic]
259 fn test_panic_solve_gauss_when_not_square() {
260 let mat: Array2<f64> = array![[1.0, 2.0], [-5.0, -2.0], [2.0, -7.0], [3.0, 9.0]];
261 let b = array![0.0, 1.0, 2.0, 3.0];
262 solve_gauss(&mat, &b).unwrap(); // must panic: not a square matrix
263
264 }
265
266 #[test]
267 #[should_panic]
268 fn test_panic_solve_gauss_when_size_mistmatch() {
269 let mat: Array2<f64> = array![[1.0, 2.0], [3.0, 4.0]];
270 let b = array![0.0, 1.0, 2.0, 3.0];
271 solve_gauss(&mat, &b).unwrap(); // must panic: size does not match
272 }
273}