mathhook_core/matrices/decomposition/
decomposition_tests.rs

1//! Tests for matrix decomposition algorithms
2//!
3//! This module tests the mathematical correctness of LU, QR, Cholesky, and SVD
4//! decompositions, verifying that they satisfy fundamental properties.
5
6#[cfg(test)]
7mod tests {
8    use crate::core::Expression;
9    use crate::matrices::{CoreMatrixOps, Matrix, MatrixDecomposition};
10
11    /// Test LU decomposition correctness
12    #[test]
13    fn test_lu_decomposition_correctness() {
14        // Test 2x2 matrix
15        let matrix = Matrix::dense(vec![
16            vec![Expression::integer(2), Expression::integer(1)],
17            vec![Expression::integer(4), Expression::integer(3)],
18        ]);
19
20        let lu = matrix.lu_decomposition().unwrap();
21
22        // Verify PA = LU (or A = LU if P is identity)
23        let lu_product = lu.l.multiply(&lu.u).expect("L * U should succeed");
24
25        if let Some(p) = &lu.p {
26            let pa = p.multiply(&matrix).expect("P * A should succeed");
27            // Check dimensions match
28            assert_eq!(pa.dimensions(), lu_product.dimensions());
29        } else {
30            // Direct A = LU check
31            assert_eq!(matrix.dimensions(), lu_product.dimensions());
32        }
33
34        // Verify L is lower triangular with 1s on diagonal
35        let (l_rows, l_cols) = lu.l.dimensions();
36        for i in 0..l_rows {
37            for j in 0..l_cols {
38                let elem = lu.l.get_element(i, j);
39                if i == j {
40                    assert_eq!(elem, Expression::integer(1)); // Diagonal should be 1
41                } else if i < j {
42                    assert!(elem.is_zero()); // Upper part should be 0
43                }
44            }
45        }
46
47        // Verify U is upper triangular
48        let (u_rows, u_cols) = lu.u.dimensions();
49        for i in 0..u_rows {
50            for j in 0..u_cols {
51                if i > j {
52                    let elem = lu.u.get_element(i, j);
53                    assert!(elem.is_zero()); // Lower part should be 0
54                }
55            }
56        }
57    }
58
59    /// Test LU decomposition for special matrices
60    #[test]
61    fn test_lu_decomposition_special_cases() {
62        // Identity matrix
63        let identity = Matrix::identity(3);
64        let lu = identity.lu_decomposition().unwrap();
65        assert!(matches!(lu.l, Matrix::Identity(_)));
66        assert!(matches!(lu.u, Matrix::Identity(_)));
67
68        // Diagonal matrix
69        let diagonal = Matrix::diagonal(vec![
70            Expression::integer(2),
71            Expression::integer(3),
72            Expression::integer(4),
73        ]);
74        let lu = diagonal.lu_decomposition().unwrap();
75        assert!(matches!(lu.l, Matrix::Identity(_)));
76        assert_eq!(lu.u, diagonal);
77    }
78
79    /// Test QR decomposition correctness
80    #[test]
81    fn test_qr_decomposition_correctness() {
82        // Test with simple 2x2 matrix
83        let matrix = Matrix::dense(vec![
84            vec![Expression::integer(1), Expression::integer(1)],
85            vec![Expression::integer(0), Expression::integer(1)],
86        ]);
87
88        let qr = matrix.qr_decomposition().unwrap();
89
90        // Verify A = QR
91        let qr_product = qr.q.multiply(&qr.r).unwrap();
92        assert_eq!(matrix.dimensions(), qr_product.dimensions());
93
94        // Verify R is upper triangular
95        let (r_rows, r_cols) = qr.r.dimensions();
96        for i in 0..r_rows {
97            for j in 0..r_cols {
98                if i > j {
99                    let elem = qr.r.get_element(i, j);
100                    assert!(elem.is_zero()); // Lower part should be 0
101                }
102            }
103        }
104
105        // Q should be orthogonal (Q^T * Q = I), but we'll test dimensions for now
106        let (q_rows, q_cols) = qr.q.dimensions();
107        assert_eq!(q_rows, matrix.dimensions().0);
108        assert_eq!(q_cols, matrix.dimensions().1);
109    }
110
111    /// Test QR decomposition for special matrices
112    #[test]
113    fn test_qr_decomposition_special_cases() {
114        // Identity matrix
115        let identity = Matrix::identity(2);
116        let qr = identity.qr_decomposition().unwrap();
117        assert!(matches!(qr.q, Matrix::Identity(_)));
118        assert!(matches!(qr.r, Matrix::Identity(_)));
119
120        // Zero matrix
121        let zero = Matrix::zero(2, 2);
122        let qr = zero.qr_decomposition().unwrap();
123        assert!(matches!(qr.q, Matrix::Identity(_)));
124        assert!(matches!(qr.r, Matrix::Zero(_)));
125    }
126
127    /// Test Cholesky decomposition correctness
128    #[test]
129    fn test_cholesky_decomposition_correctness() {
130        // Test positive definite matrix
131        let matrix = Matrix::dense(vec![
132            vec![Expression::integer(4), Expression::integer(2)],
133            vec![Expression::integer(2), Expression::integer(3)],
134        ]);
135
136        if let Some(chol) = matrix.cholesky_decomposition() {
137            // Verify A = LL^T
138            let l_transpose = chol.l.transpose();
139            let llt_product = chol.l.multiply(&l_transpose).unwrap();
140            assert_eq!(matrix.dimensions(), llt_product.dimensions());
141
142            // Verify L is lower triangular
143            let (l_rows, l_cols) = chol.l.dimensions();
144            for i in 0..l_rows {
145                for j in 0..l_cols {
146                    if i < j {
147                        let elem = chol.l.get_element(i, j);
148                        assert!(elem.is_zero()); // Upper part should be 0
149                    }
150                }
151            }
152        }
153    }
154
155    /// Test Cholesky decomposition for special matrices
156    #[test]
157    fn test_cholesky_decomposition_special_cases() {
158        // Identity matrix
159        let identity = Matrix::identity(3);
160        let chol = identity.cholesky_decomposition().unwrap();
161        assert!(matches!(chol.l, Matrix::Identity(_)));
162
163        // Scalar matrix
164        let scalar = Matrix::scalar(2, Expression::integer(4));
165        let chol = scalar.cholesky_decomposition().unwrap();
166        // Should be sqrt(4) = 2 on diagonal
167        assert!(matches!(chol.l, Matrix::Scalar(_)));
168
169        // Diagonal matrix
170        let diagonal = Matrix::diagonal(vec![Expression::integer(4), Expression::integer(9)]);
171        let chol = diagonal.cholesky_decomposition().unwrap();
172        assert!(matches!(chol.l, Matrix::Diagonal(_)));
173    }
174
175    /// Test SVD decomposition correctness
176    #[test]
177    fn test_svd_decomposition_correctness() {
178        // Test 2x2 matrix
179        let matrix = Matrix::dense(vec![
180            vec![Expression::integer(1), Expression::integer(2)],
181            vec![Expression::integer(3), Expression::integer(4)],
182        ]);
183
184        let svd = matrix.svd_decomposition().unwrap();
185
186        // Verify A = UΣV^T
187        let sigma_vt = svd.sigma.multiply(&svd.vt).unwrap();
188        let usvt_product = svd.u.multiply(&sigma_vt).unwrap();
189        assert_eq!(matrix.dimensions(), usvt_product.dimensions());
190
191        // Verify Σ is diagonal with non-negative entries
192        let (sigma_rows, sigma_cols) = svd.sigma.dimensions();
193        for i in 0..sigma_rows {
194            for j in 0..sigma_cols {
195                if i != j {
196                    let elem = svd.sigma.get_element(i, j);
197                    assert!(elem.is_zero()); // Off-diagonal should be 0
198                }
199            }
200        }
201    }
202
203    /// Test SVD for special matrices
204    #[test]
205    fn test_svd_special_cases() {
206        // Identity matrix
207        let identity = Matrix::identity(2);
208        let svd = identity.svd_decomposition().unwrap();
209        assert!(matches!(svd.u, Matrix::Identity(_)));
210        assert!(matches!(svd.sigma, Matrix::Identity(_)));
211        assert!(matches!(svd.vt, Matrix::Identity(_)));
212
213        // Zero matrix
214        let zero = Matrix::zero(2, 2);
215        let svd = zero.svd_decomposition().unwrap();
216        assert!(matches!(svd.sigma, Matrix::Zero(_)));
217
218        // Diagonal matrix
219        let diagonal = Matrix::diagonal(vec![Expression::integer(3), Expression::integer(4)]);
220        let svd = diagonal.svd_decomposition().unwrap();
221        assert!(matches!(svd.sigma, Matrix::Diagonal(_)));
222    }
223
224    /// Test matrix rank computation
225    #[test]
226    fn test_matrix_rank() {
227        // Identity matrix has full rank
228        let identity = Matrix::identity(3);
229        assert_eq!(identity.rank(), 3);
230
231        // Zero matrix has rank 0
232        let zero = Matrix::zero(3, 3);
233        assert_eq!(zero.rank(), 0);
234
235        // Diagonal matrix rank equals number of non-zero diagonal elements
236        let diagonal = Matrix::diagonal(vec![
237            Expression::integer(1),
238            Expression::integer(0),
239            Expression::integer(3),
240        ]);
241        assert_eq!(diagonal.rank(), 2);
242    }
243
244    /// Test positive definite check
245    #[test]
246    fn test_positive_definite_check() {
247        // Identity is positive definite
248        let identity = Matrix::identity(2);
249        assert!(identity.is_positive_definite());
250
251        // Positive scalar matrix is positive definite
252        let pos_scalar = Matrix::scalar(2, Expression::integer(5));
253        assert!(pos_scalar.is_positive_definite());
254
255        // Diagonal with positive elements is positive definite
256        let pos_diagonal = Matrix::diagonal(vec![
257            Expression::integer(1),
258            Expression::integer(2),
259            Expression::integer(3),
260        ]);
261        assert!(pos_diagonal.is_positive_definite());
262    }
263
264    /// Test condition number computation
265    #[test]
266    fn test_condition_number() {
267        // Identity matrix has condition number 1
268        let identity = Matrix::identity(2);
269        let cond = identity.condition_number();
270        assert_eq!(cond, Expression::integer(1));
271
272        // Well-conditioned diagonal matrix
273        let diagonal = Matrix::diagonal(vec![Expression::integer(2), Expression::integer(2)]);
274        let cond = diagonal.condition_number();
275        assert_eq!(cond, Expression::integer(1)); // 2/2 = 1
276    }
277}