opensrdk_linear_algebra/matrix/gt/
mod.rs

1use crate::number::Number;
2use crate::{ge::Matrix, matrix::*};
3use rayon::prelude::*;
4
5pub mod trf;
6pub mod trs;
7
8#[derive(Clone, Debug, Default, PartialEq, Hash)]
9pub struct TridiagonalMatrix<T = f64>
10where
11    T: Number,
12{
13    dl: Vec<T>,
14    d: Vec<T>,
15    du: Vec<T>,
16}
17
18impl<T> TridiagonalMatrix<T>
19where
20    T: Number,
21{
22    pub fn new(dim: usize) -> Self {
23        let e = vec![T::default(); dim.max(1) - 1];
24        Self {
25            dl: e.clone(),
26            d: vec![T::default(); dim],
27            du: e,
28        }
29    }
30
31    /// - `dl`: Lower diagonal elements. The length must be `dimension - 1`.
32    /// - `d`: Diagonal elements. The length must be `dimension`.
33    /// - `du`: Upper diagonal elements. The length must be `dimension - 1`.
34    pub fn from(dl: Vec<T>, d: Vec<T>, du: Vec<T>) -> Result<Self, MatrixError> {
35        let n_1 = d.len().max(1) - 1;
36        if n_1 != dl.len() || n_1 != du.len() {
37            return Err(MatrixError::DimensionMismatch);
38        }
39
40        Ok(Self { dl, d, du })
41    }
42
43    /// Dimension.
44    pub fn dim(&self) -> usize {
45        self.d.len()
46    }
47
48    /// Lower diagonal elements.
49    pub fn dl(&self) -> &[T] {
50        &self.dl
51    }
52
53    /// Diagonal elements.
54    pub fn d(&self) -> &[T] {
55        &self.d
56    }
57
58    /// Lower diagonal elements.
59    pub fn du(&self) -> &[T] {
60        &self.du
61    }
62
63    /// Returns `(self.dl, self.d, self.du)`
64    pub fn eject(self) -> (Vec<T>, Vec<T>, Vec<T>) {
65        (self.dl, self.d, self.du)
66    }
67
68    pub fn mat(&self) -> Matrix<T> {
69        let n = self.d.len();
70        let mut mat = Matrix::new(n, n);
71
72        // for i in 0..n {
73        //     mat[i][i] = self.d[i];
74        // }
75
76        // for i in 0..n - 1 {
77        //     mat[i][i + 1] = self.du[i];
78        //     mat[i + 1][i] = self.dl[i];
79        // }
80
81        mat.elems_mut()
82            .par_iter_mut()
83            .enumerate()
84            .map(|(k, elem)| ((k / n, k % n), elem))
85            .for_each(|((i, j), elem)| {
86                if i == j {
87                    *elem = self.d[i];
88                } else if i + 1 == j {
89                    *elem = self.du[i];
90                } else if i == j + 1 {
91                    *elem = self.dl[j];
92                }
93            });
94
95        mat
96    }
97}