mdarray_linalg/testing/common/
mod.rs

1// Helper module with common code for integration tests.
2// See https://doc.rust-lang.org/rust-by-example/testing/integration_testing.html
3use num_complex::ComplexFloat;
4use num_traits::Zero;
5
6use mdarray::{DSlice, DTensor, expr, tensor};
7
8use rand::Rng;
9
10pub fn example_matrix(
11    shape: [usize; 2],
12) -> expr::FromFn<(usize, usize), impl FnMut(&[usize]) -> f64> {
13    expr::from_fn(shape, move |i| (shape[1] * i[0] + i[1] + 1) as f64)
14}
15
16#[macro_export]
17macro_rules! assert_matrix_eq {
18    ($a:expr, $b:expr) => {
19        assert_matrix_eq!($a, $b, 1e-8f64)
20    };
21    ($a:expr, $b:expr, $epsilon:expr) => {
22        assert_eq!($a.shape(), $b.shape(), "Matrix shapes don't match");
23        let shape = $a.shape();
24        for i in 0..shape.0 {
25            for j in 0..shape.1 {
26                assert_relative_eq!($a[[i, j]], $b[[i, j]], epsilon = $epsilon);
27            }
28        }
29    };
30}
31
32#[macro_export]
33macro_rules! assert_complex_matrix_eq {
34    ($a:expr, $b:expr) => {
35        assert_complex_matrix_eq!($a, $b, 1e-8)
36    };
37    ($a:expr, $b:expr, $epsilon:expr) => {
38        assert_eq!($a.shape(), $b.shape(), "Matrix shapes don't match");
39        let shape = $a.shape();
40        for i in 0..shape.0 {
41            for j in 0..shape.1 {
42                assert_relative_eq!(
43                    Complex::norm($a[[i, j]]),
44                    Complex::norm($b[[i, j]]),
45                    epsilon = $epsilon
46                );
47            }
48        }
49    };
50}
51
52/// Generate a random matrix of size m x n
53pub fn random_matrix(m: usize, n: usize) -> DTensor<f64, 2> {
54    let mut rng = rand::rng();
55    DTensor::<f64, 2>::from_fn([m, n], |_| rng.random_range(0.0..1.0))
56}
57
58/// Generate a rank-k matrix by multiplying m×k and k×n matrices
59pub fn rank_k_matrix(m: usize, n: usize, k: usize) -> DTensor<f64, 2> {
60    assert!(k <= n.min(m));
61
62    let a = random_matrix(m, k);
63    let b = random_matrix(k, n);
64
65    naive_matmul(&a, &b)
66}
67
68/// Textbook implementation of matrix multiplication, in order for
69/// this crate to be independant of any backend.
70pub fn naive_matmul<T: ComplexFloat + Zero>(a: &DSlice<T, 2>, b: &DSlice<T, 2>) -> DTensor<T, 2> {
71    let (ma, na) = *a.shape();
72    let (mb, nb) = *b.shape();
73
74    if na != mb {
75        panic!("Shapes don't match");
76    }
77
78    let mut c = tensor![[T::zero();nb];ma];
79
80    for (mut ci, ai) in c.rows_mut().into_iter().zip(a.rows()) {
81        for (aik, bk) in ai.expr().into_iter().zip(b.rows()) {
82            for (cij, bkj) in ci.expr_mut().into_iter().zip(bk) {
83                *cij = (*aik) * (*bkj) + *cij;
84            }
85        }
86    }
87    c
88}