differential_equations/linalg/matrix/
base.rs1use crate::traits::Real;
4
5#[derive(Clone, Debug)]
7pub enum MatrixStorage<T: Real> {
8 Identity,
10 Full,
12 Banded { ml: usize, mu: usize, zero: T },
16}
17
18#[derive(Clone, Debug)]
20pub struct Matrix<T: Real> {
21 pub nrows: usize,
22 pub ncols: usize,
23 pub data: Vec<T>,
24 pub storage: MatrixStorage<T>,
25}
26
27impl<T: Real> Matrix<T> {
28 pub fn identity(n: usize) -> Self {
30 Matrix {
31 nrows: n,
32 ncols: n,
33 data: vec![T::one(), T::zero()],
35 storage: MatrixStorage::Identity,
36 }
37 }
38
39 pub fn full(n: usize, data: Vec<T>) -> Self {
41 assert_eq!(data.len(), n * n, "Matrix::full expects data of length n*n");
42 Matrix {
43 nrows: n,
44 ncols: n,
45 data,
46 storage: MatrixStorage::Full,
47 }
48 }
49
50 pub fn zeros(n: usize) -> Self {
52 Matrix {
53 nrows: n,
54 ncols: n,
55 data: vec![T::zero(); n * n],
56 storage: MatrixStorage::Full,
57 }
58 }
59
60 pub fn banded(n: usize, ml: usize, mu: usize) -> Self {
63 let rows = ml + mu + 1;
64 let data = vec![T::zero(); rows * n];
65 Matrix {
66 nrows: n,
67 ncols: n,
68 data,
69 storage: MatrixStorage::Banded {
70 ml,
71 mu,
72 zero: T::zero(),
73 },
74 }
75 }
76
77 pub fn diagonal(diag: Vec<T>) -> Self {
79 let n = diag.len();
80 Matrix {
82 nrows: n,
83 ncols: n,
84 data: diag,
85 storage: MatrixStorage::Banded {
86 ml: 0,
87 mu: 0,
88 zero: T::zero(),
89 },
90 }
91 }
92
93 pub fn lower_triangular(n: usize) -> Self {
95 Matrix::banded(n, n.saturating_sub(1), 0)
96 }
97
98 pub fn upper_triangular(n: usize) -> Self {
100 Matrix::banded(n, 0, n.saturating_sub(1))
101 }
102
103 pub fn dims(&self) -> (usize, usize) {
105 (self.nrows, self.ncols)
106 }
107
108 pub fn n(&self) -> usize {
110 self.dims().0
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use super::Matrix;
117
118 #[test]
119 fn diagonal_constructor_sets_diagonal() {
120 let m = Matrix::diagonal(vec![1.0f64, 2.0, 3.0]);
121 assert_eq!(m[(0, 0)], 1.0);
122 assert_eq!(m[(1, 1)], 2.0);
123 assert_eq!(m[(2, 2)], 3.0);
124 assert_eq!(m[(0, 1)], 0.0);
125 assert_eq!(m[(2, 0)], 0.0);
126 }
127
128 #[test]
129 fn triangular_constructors_shape() {
130 let l: Matrix<f64> = Matrix::lower_triangular(4);
131 assert_eq!(l[(0, 3)], 0.0);
133 let u: Matrix<f64> = Matrix::upper_triangular(4);
134 assert_eq!(u[(3, 0)], 0.0);
136 }
137}