mathhook_core/matrices/decomposition/
decomposition_tests.rs1#[cfg(test)]
7mod tests {
8 use crate::core::Expression;
9 use crate::matrices::{CoreMatrixOps, Matrix, MatrixDecomposition};
10
11 #[test]
13 fn test_lu_decomposition_correctness() {
14 let matrix = Matrix::dense(vec![
16 vec![Expression::integer(2), Expression::integer(1)],
17 vec![Expression::integer(4), Expression::integer(3)],
18 ]);
19
20 let lu = matrix.lu_decomposition().unwrap();
21
22 let lu_product = lu.l.multiply(&lu.u).expect("L * U should succeed");
24
25 if let Some(p) = &lu.p {
26 let pa = p.multiply(&matrix).expect("P * A should succeed");
27 assert_eq!(pa.dimensions(), lu_product.dimensions());
29 } else {
30 assert_eq!(matrix.dimensions(), lu_product.dimensions());
32 }
33
34 let (l_rows, l_cols) = lu.l.dimensions();
36 for i in 0..l_rows {
37 for j in 0..l_cols {
38 let elem = lu.l.get_element(i, j);
39 if i == j {
40 assert_eq!(elem, Expression::integer(1)); } else if i < j {
42 assert!(elem.is_zero()); }
44 }
45 }
46
47 let (u_rows, u_cols) = lu.u.dimensions();
49 for i in 0..u_rows {
50 for j in 0..u_cols {
51 if i > j {
52 let elem = lu.u.get_element(i, j);
53 assert!(elem.is_zero()); }
55 }
56 }
57 }
58
59 #[test]
61 fn test_lu_decomposition_special_cases() {
62 let identity = Matrix::identity(3);
64 let lu = identity.lu_decomposition().unwrap();
65 assert!(matches!(lu.l, Matrix::Identity(_)));
66 assert!(matches!(lu.u, Matrix::Identity(_)));
67
68 let diagonal = Matrix::diagonal(vec![
70 Expression::integer(2),
71 Expression::integer(3),
72 Expression::integer(4),
73 ]);
74 let lu = diagonal.lu_decomposition().unwrap();
75 assert!(matches!(lu.l, Matrix::Identity(_)));
76 assert_eq!(lu.u, diagonal);
77 }
78
79 #[test]
81 fn test_qr_decomposition_correctness() {
82 let matrix = Matrix::dense(vec![
84 vec![Expression::integer(1), Expression::integer(1)],
85 vec![Expression::integer(0), Expression::integer(1)],
86 ]);
87
88 let qr = matrix.qr_decomposition().unwrap();
89
90 let qr_product = qr.q.multiply(&qr.r).unwrap();
92 assert_eq!(matrix.dimensions(), qr_product.dimensions());
93
94 let (r_rows, r_cols) = qr.r.dimensions();
96 for i in 0..r_rows {
97 for j in 0..r_cols {
98 if i > j {
99 let elem = qr.r.get_element(i, j);
100 assert!(elem.is_zero()); }
102 }
103 }
104
105 let (q_rows, q_cols) = qr.q.dimensions();
107 assert_eq!(q_rows, matrix.dimensions().0);
108 assert_eq!(q_cols, matrix.dimensions().1);
109 }
110
111 #[test]
113 fn test_qr_decomposition_special_cases() {
114 let identity = Matrix::identity(2);
116 let qr = identity.qr_decomposition().unwrap();
117 assert!(matches!(qr.q, Matrix::Identity(_)));
118 assert!(matches!(qr.r, Matrix::Identity(_)));
119
120 let zero = Matrix::zero(2, 2);
122 let qr = zero.qr_decomposition().unwrap();
123 assert!(matches!(qr.q, Matrix::Identity(_)));
124 assert!(matches!(qr.r, Matrix::Zero(_)));
125 }
126
127 #[test]
129 fn test_cholesky_decomposition_correctness() {
130 let matrix = Matrix::dense(vec![
132 vec![Expression::integer(4), Expression::integer(2)],
133 vec![Expression::integer(2), Expression::integer(3)],
134 ]);
135
136 if let Some(chol) = matrix.cholesky_decomposition() {
137 let l_transpose = chol.l.transpose();
139 let llt_product = chol.l.multiply(&l_transpose).unwrap();
140 assert_eq!(matrix.dimensions(), llt_product.dimensions());
141
142 let (l_rows, l_cols) = chol.l.dimensions();
144 for i in 0..l_rows {
145 for j in 0..l_cols {
146 if i < j {
147 let elem = chol.l.get_element(i, j);
148 assert!(elem.is_zero()); }
150 }
151 }
152 }
153 }
154
155 #[test]
157 fn test_cholesky_decomposition_special_cases() {
158 let identity = Matrix::identity(3);
160 let chol = identity.cholesky_decomposition().unwrap();
161 assert!(matches!(chol.l, Matrix::Identity(_)));
162
163 let scalar = Matrix::scalar(2, Expression::integer(4));
165 let chol = scalar.cholesky_decomposition().unwrap();
166 assert!(matches!(chol.l, Matrix::Scalar(_)));
168
169 let diagonal = Matrix::diagonal(vec![Expression::integer(4), Expression::integer(9)]);
171 let chol = diagonal.cholesky_decomposition().unwrap();
172 assert!(matches!(chol.l, Matrix::Diagonal(_)));
173 }
174
175 #[test]
177 fn test_svd_decomposition_correctness() {
178 let matrix = Matrix::dense(vec![
180 vec![Expression::integer(1), Expression::integer(2)],
181 vec![Expression::integer(3), Expression::integer(4)],
182 ]);
183
184 let svd = matrix.svd_decomposition().unwrap();
185
186 let sigma_vt = svd.sigma.multiply(&svd.vt).unwrap();
188 let usvt_product = svd.u.multiply(&sigma_vt).unwrap();
189 assert_eq!(matrix.dimensions(), usvt_product.dimensions());
190
191 let (sigma_rows, sigma_cols) = svd.sigma.dimensions();
193 for i in 0..sigma_rows {
194 for j in 0..sigma_cols {
195 if i != j {
196 let elem = svd.sigma.get_element(i, j);
197 assert!(elem.is_zero()); }
199 }
200 }
201 }
202
203 #[test]
205 fn test_svd_special_cases() {
206 let identity = Matrix::identity(2);
208 let svd = identity.svd_decomposition().unwrap();
209 assert!(matches!(svd.u, Matrix::Identity(_)));
210 assert!(matches!(svd.sigma, Matrix::Identity(_)));
211 assert!(matches!(svd.vt, Matrix::Identity(_)));
212
213 let zero = Matrix::zero(2, 2);
215 let svd = zero.svd_decomposition().unwrap();
216 assert!(matches!(svd.sigma, Matrix::Zero(_)));
217
218 let diagonal = Matrix::diagonal(vec![Expression::integer(3), Expression::integer(4)]);
220 let svd = diagonal.svd_decomposition().unwrap();
221 assert!(matches!(svd.sigma, Matrix::Diagonal(_)));
222 }
223
224 #[test]
226 fn test_matrix_rank() {
227 let identity = Matrix::identity(3);
229 assert_eq!(identity.rank(), 3);
230
231 let zero = Matrix::zero(3, 3);
233 assert_eq!(zero.rank(), 0);
234
235 let diagonal = Matrix::diagonal(vec![
237 Expression::integer(1),
238 Expression::integer(0),
239 Expression::integer(3),
240 ]);
241 assert_eq!(diagonal.rank(), 2);
242 }
243
244 #[test]
246 fn test_positive_definite_check() {
247 let identity = Matrix::identity(2);
249 assert!(identity.is_positive_definite());
250
251 let pos_scalar = Matrix::scalar(2, Expression::integer(5));
253 assert!(pos_scalar.is_positive_definite());
254
255 let pos_diagonal = Matrix::diagonal(vec![
257 Expression::integer(1),
258 Expression::integer(2),
259 Expression::integer(3),
260 ]);
261 assert!(pos_diagonal.is_positive_definite());
262 }
263
264 #[test]
266 fn test_condition_number() {
267 let identity = Matrix::identity(2);
269 let cond = identity.condition_number();
270 assert_eq!(cond, Expression::integer(1));
271
272 let diagonal = Matrix::diagonal(vec![Expression::integer(2), Expression::integer(2)]);
274 let cond = diagonal.condition_number();
275 assert_eq!(cond, Expression::integer(1)); }
277}