mdarray_linalg/testing/common/
mod.rs1use num_complex::ComplexFloat;
4use num_traits::Zero;
5
6use mdarray::{DSlice, DTensor, expr, tensor};
7
8use rand::Rng;
9
10pub fn example_matrix(
11 shape: [usize; 2],
12) -> expr::FromFn<(usize, usize), impl FnMut(&[usize]) -> f64> {
13 expr::from_fn(shape, move |i| (shape[1] * i[0] + i[1] + 1) as f64)
14}
15
16#[macro_export]
17macro_rules! assert_matrix_eq {
18 ($a:expr, $b:expr) => {
19 assert_matrix_eq!($a, $b, 1e-8f64)
20 };
21 ($a:expr, $b:expr, $epsilon:expr) => {
22 assert_eq!($a.shape(), $b.shape(), "Matrix shapes don't match");
23 let shape = $a.shape();
24 for i in 0..shape.0 {
25 for j in 0..shape.1 {
26 assert_relative_eq!($a[[i, j]], $b[[i, j]], epsilon = $epsilon);
27 }
28 }
29 };
30}
31
32#[macro_export]
33macro_rules! assert_complex_matrix_eq {
34 ($a:expr, $b:expr) => {
35 assert_complex_matrix_eq!($a, $b, 1e-8)
36 };
37 ($a:expr, $b:expr, $epsilon:expr) => {
38 assert_eq!($a.shape(), $b.shape(), "Matrix shapes don't match");
39 let shape = $a.shape();
40 for i in 0..shape.0 {
41 for j in 0..shape.1 {
42 assert_relative_eq!(
43 Complex::norm($a[[i, j]]),
44 Complex::norm($b[[i, j]]),
45 epsilon = $epsilon
46 );
47 }
48 }
49 };
50}
51
52pub fn random_matrix(m: usize, n: usize) -> DTensor<f64, 2> {
54 let mut rng = rand::rng();
55 DTensor::<f64, 2>::from_fn([m, n], |_| rng.random_range(0.0..1.0))
56}
57
58pub fn rank_k_matrix(m: usize, n: usize, k: usize) -> DTensor<f64, 2> {
60 assert!(k <= n.min(m));
61
62 let a = random_matrix(m, k);
63 let b = random_matrix(k, n);
64
65 naive_matmul(&a, &b)
66}
67
68pub fn naive_matmul<T: ComplexFloat + Zero>(a: &DSlice<T, 2>, b: &DSlice<T, 2>) -> DTensor<T, 2> {
71 let (ma, na) = *a.shape();
72 let (mb, nb) = *b.shape();
73
74 if na != mb {
75 panic!("Shapes don't match");
76 }
77
78 let mut c = tensor![[T::zero();nb];ma];
79
80 for (mut ci, ai) in c.rows_mut().into_iter().zip(a.rows()) {
81 for (aik, bk) in ai.expr().into_iter().zip(b.rows()) {
82 for (cij, bkj) in ci.expr_mut().into_iter().zip(bk) {
83 *cij = (*aik) * (*bkj) + *cij;
84 }
85 }
86 }
87 c
88}