te1d/
linalg.rs

1//! Solves linear systems and handles matrices.
2//! * The utilities in this module is not at all fast. For massive computations, use other specialized linear algebra packages.
3//! * By matrices, we mean [`ndarray::Array2`].
4
5use std::fmt;
6
7use ndarray::prelude::*;
8
9use crate::calc::ColVec;
10
11
12/// Error type for linear algebraic computations
13pub struct LinalgError {
14    repr: LinalgErrorKind,
15}
16
17impl fmt::Display for LinalgErrorKind {
18    /// Shows a human-readable description of the [`LinalgErrorKind`].
19    ///
20    /// This is similar to `impl Display for Error`, but doesn't require first converting to Error.
21    ///
22    /// # Examples
23    /// ```
24    /// use te1d::linalg::LinalgErrorKind;
25    /// assert_eq!("singular matrix", LinalgErrorKind::SingularMatrixError.to_string());
26    /// ```
27    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
28        fmt.write_str(self.as_str())
29    }
30}
31
32impl fmt::Debug for LinalgError {
33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34        fmt::Debug::fmt(&self.repr, f)
35    }
36}
37
38impl LinalgError {
39    pub fn new(error_kind: LinalgErrorKind) -> Self {
40        LinalgError { repr: error_kind }
41    }
42
43    #[inline]
44    pub fn kind(&self) -> LinalgErrorKind {
45        self.repr
46    }
47}
48
49/// A list specifying general categories of errors in linear algebraic computation
50#[derive(Clone, Copy, Debug, PartialEq)]
51pub enum LinalgErrorKind {
52    /// the matrix is singular, that is, not invertible
53    SingularMatrixError,
54}
55
56impl LinalgErrorKind {
57    pub(crate) fn as_str(&self) -> &'static str {
58        use LinalgErrorKind::*;
59        // Strictly alphabetical, please.  (Sadly rustfmt cannot do this yet.)
60        match *self {
61            SingularMatrixError => "singular matrix",
62        }
63    }
64}
65
66
67/// Find the solution `x` of the linear system `mat * x = b` using Gaussian elimination.
68/// 
69/// # Arguments
70/// * `square_mat`: a square matrix.
71/// 
72/// # Examples
73/// ```
74/// use ndarray::{Array2, array};
75/// use te1d::prelude::*;
76/// use te1d::linalg::{solve_gauss, LinalgErrorKind};
77/// 
78/// let mat: Array2<f64> = array![[1.0, 0.0, 1.0], [0.0, -3.0, 1.0], [2.0, 1.0, 3.0]];
79/// let b: ColVec = array![6.0, 7.0, 15.0];
80/// let x = solve_gauss(&mat, &b).unwrap();
81/// let x_exact: ColVec = array![2.0, -1.0, 4.0];
82/// 
83/// assert!(all_close(&x, &x_exact, 1e-05, 1e-08));
84/// 
85/// // Error for singular matrix
86/// let mat: Array2<f64> = array![[0.0, 0.0], [1.0, 2.0]];
87/// let b: ColVec = array![1.0, 2.0];
88/// let msg = match solve_gauss(&mat, &b) {
89///     Ok(b) => String::from("no error"),
90///     Err(err) => err.kind().to_string(),
91/// };
92/// assert_eq!(msg, LinalgErrorKind::SingularMatrixError.to_string());
93/// ```
94/// 
95/// # Panics
96/// * when `mat` is not square.
97/// * when the number of rows of `mat` and the length of `b` are not the same.
98pub fn solve_gauss(square_mat: &Array2<f64>, b: &ColVec) -> Result<ColVec, LinalgError> {
99    if !square_mat.is_square() {
100        panic!("`mat` must be a square matrix!");
101    }
102    let n = square_mat.shape()[0];
103    if b.len() != n {
104        panic!("the number of rows of `mat` and the length of `b` must be the same!");
105    }
106
107    let mut mat = square_mat.to_owned();
108    let mut b = b.to_owned();
109    let mut pivot_row: usize;
110    let mut pivot_value: f64;
111    let mut target_row_pivot_value: f64;
112
113    // perform upper triangulation
114    for row in 0..n {
115        // set the pivot row
116        pivot_row = row_of_abs_max(&mat, row, row, n-1);
117        if row != pivot_row {
118            swap_rows(&mut mat, row, pivot_row);
119            b.swap(row, pivot_row);
120        }
121        // normalize the pivot row and `b`
122        pivot_value = mat[[row, row]];
123        if pivot_value.abs() <= f64::EPSILON {
124            return Err(LinalgError::new(LinalgErrorKind::SingularMatrixError));
125        }
126        mat[[row, row]] = 1.0;
127        mat.row_mut(row).slice_mut(s![row+1..]).mapv_inplace(|e| { e/pivot_value });
128        b[row] /= pivot_value;
129        // make upper triangular
130        for target_row in row+1..n {
131            target_row_pivot_value = mat[[target_row, row]];
132            mat[[target_row, row]] = 0.0;
133            for col in row+1..n {
134                mat[[target_row, col]] -= target_row_pivot_value * mat[[row, col]];
135            }
136            b[target_row] -= target_row_pivot_value * b[row];
137        }
138    }
139
140    // perform diagonalization
141    for row in (0..n).rev() {
142        for target_row in 0..row {
143            target_row_pivot_value = mat[[target_row, row]];
144            mat[[target_row, row]] = 0.0;
145            for col in row+1..n {
146                mat[[target_row, col]] -= target_row_pivot_value * mat[[row, col]];
147            }
148            b[target_row] -= target_row_pivot_value * b[row];
149        }
150    }
151
152    Ok(b)
153}
154
155
156/// Return the index of the element having the largest absolute value among a given column.
157fn row_of_abs_max(mat: &Array2<f64>, col: usize, start_row: usize, end_row: usize) -> usize {
158    let mut max_value: f64 = mat[[start_row, col]].abs();
159    let mut result: usize = start_row;
160    for i in start_row+1..end_row+1 {
161        let cur_value = mat[[i, col]].abs();
162        if cur_value > max_value {
163            max_value = cur_value;
164            result = i;
165        }
166    }
167
168    result
169}
170
171
172/// Swap two rows of a matrix.
173/// 
174/// # Examples
175/// ```
176/// use ndarray::prelude::*;
177/// use te1d::prelude::*;
178/// use te1d::linalg::swap_rows;
179/// 
180/// let mut mat: Array2<f64> = array![[0.0, 1.0], [1.0, 2.0], [2.0, 3.0]];
181/// let result: Array2<f64> = array![[2.0, 3.0], [1.0, 2.0], [0.0, 1.0]];
182/// 
183/// swap_rows(&mut mat, 0, 2);
184/// assert_eq!(mat, result);
185/// ```
186#[inline]
187pub fn swap_rows(mat: &mut Array2<f64>, row1: usize, row2: usize) {
188    for i in 0..mat.shape()[1] {
189        mat.swap([row1, i], [row2, i]);
190    }
191}
192
193
194/// Swap two columns of a matrix.
195/// 
196/// # Examples
197/// ```
198/// use ndarray::prelude::*;
199/// use te1d::prelude::*;
200/// use te1d::linalg::swap_cols;
201/// 
202/// let mut mat: Array2<f64> = array![[0.0, 1.0, 2.0], [1.0, 2.0, 3.0]];
203/// let result: Array2<f64> = array![[2.0, 1.0, 0.0], [3.0, 2.0, 1.0]];
204/// 
205/// swap_cols(&mut mat, 0, 2);
206/// assert_eq!(mat, result);
207/// ```
208#[inline]
209pub fn swap_cols(mat: &mut Array2<f64>, col1: usize, col2: usize) {
210    for i in 0..mat.shape()[0] {
211        mat.swap([i, col1], [i, col2]);
212    }
213}
214
215
216/// Compute the outer product of two [`ColVec`]s.
217/// For two column vectors `u` and `v`, the result is `u*v^T`.
218/// 
219/// # Examples
220/// ```
221/// use ndarray::prelude::*;
222/// use te1d::prelude::*;
223/// use te1d::linalg::outer_product;
224/// 
225/// let u: ColVec = array![1.0, 2.0, 3.0];
226/// let v: ColVec = array![4.0, 5.0];
227/// let result: Array2<f64> = outer_product(&u, &v);
228/// let sol: Array2<f64> = array![[4.0, 5.0], [8.0, 10.0], [12.0, 15.0]];
229/// 
230/// assert_eq!(&result, &sol);
231/// ```
232pub fn outer_product(u: &ColVec, v: &ColVec) -> Array2<f64> {
233    let m: usize = u.len();
234    let n: usize = v.len();
235    let mut result: Array2<f64> = Array::default((m, n));
236    for i in 0..m {
237        for j in 0..n {
238            result[[i, j]] = u[i] * v[j];
239        }
240    }
241
242    result
243}
244
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249
250    #[test]
251    fn test_row_of_abs_max() {
252        let mat: Array2<f64> = array![[1.0, 2.0], [-5.0, -2.0], [2.0, -7.0], [3.0, 9.0]];
253        let result = row_of_abs_max(&mat, 1, 1, 2);
254        assert_eq!(result, 2 as usize);
255    }
256
257    #[test]
258    #[should_panic]
259    fn test_panic_solve_gauss_when_not_square() {
260        let mat: Array2<f64> = array![[1.0, 2.0], [-5.0, -2.0], [2.0, -7.0], [3.0, 9.0]];
261        let b = array![0.0, 1.0, 2.0, 3.0];
262        solve_gauss(&mat, &b).unwrap();   // must panic: not a square matrix
263
264    }
265
266    #[test]
267    #[should_panic]
268    fn test_panic_solve_gauss_when_size_mistmatch() {
269        let mat: Array2<f64> = array![[1.0, 2.0], [3.0, 4.0]];
270        let b = array![0.0, 1.0, 2.0, 3.0];
271        solve_gauss(&mat, &b).unwrap();   // must panic: size does not match
272    }
273}