mdarray_linalg/testing/matmul/
mod.rs1use mdarray::{DTensor, Tensor, expr, expr::Expression as _};
2use num_complex::Complex64;
3
4use super::common::*;
5use crate::prelude::*;
6
7pub fn create_test_matrix_f64(
8 shape: [usize; 2],
9) -> expr::FromFn<(usize, usize), impl FnMut(&[usize]) -> f64> {
10 expr::from_fn(shape, move |i| (shape[1] * i[0] + i[1] + 1) as f64)
11}
12
13pub fn create_test_matrix_complex(
14 shape: [usize; 2],
15) -> expr::FromFn<(usize, usize), impl FnMut(&[usize]) -> Complex64> {
16 expr::from_fn(shape, move |i| {
17 let val = (shape[1] * i[0] + i[1] + 1) as f64;
18 Complex64::new(val, val * 0.5)
19 })
20}
21
22pub fn test_matmul_complex_with_scaling_impl(backend: &impl MatMul<Complex64>) {
23 let a = create_test_matrix_complex([2, 3]).eval();
24 let b = create_test_matrix_complex([3, 2]).eval();
25 let scale_factor = Complex64::new(2.0, 1.5);
26
27 let result = backend.matmul(&a, &b).scale(scale_factor).eval();
28
29 let expected = naive_matmul(&a, &b);
30 let expected = (expr::fill(scale_factor) * &expected).eval();
31
32 assert_eq!(result, expected);
33}
34
35pub fn create_symmetric_matrix_f64(size: usize) -> DTensor<f64, 2> {
36 let mut matrix = Tensor::from_elem([size, size], 0.0);
37 for i in 0..size {
38 for j in 0..size {
39 let value = ((i + 1) * (j + 1)) as f64;
40 matrix[[i, j]] = value;
41 matrix[[j, i]] = value; }
43 }
44 matrix
45}
46
47pub fn create_upper_triangular_f64(size: usize) -> DTensor<f64, 2> {
48 let mut matrix = Tensor::from_elem([size, size], 0.0);
49 for i in 0..size {
50 for j in i..size {
51 matrix[[i, j]] = ((i + 1) * (j + 1)) as f64;
52 }
53 }
54 matrix
55}
56
57pub fn create_lower_triangular_f64(size: usize) -> DTensor<f64, 2> {
58 let mut matrix = Tensor::from_elem([size, size], 0.0);
59 for i in 0..size {
60 for j in 0..=i {
61 matrix[[i, j]] = ((i + 1) * (j + 1)) as f64;
62 }
63 }
64 matrix
65}
66
67pub fn create_hermitian_matrix_complex(size: usize) -> DTensor<Complex64, 2> {
68 let mut matrix = Tensor::from_elem([size, size], Complex64::new(0.0, 0.0));
69 for i in 0..size {
70 for j in 0..size {
71 if i == j {
72 matrix[[i, j]] = Complex64::new((i + 1) as f64, 0.0);
73 } else if i < j {
74 let real = ((i + 1) * (j + 1)) as f64;
75 let imag = (i + j + 1) as f64;
76 matrix[[i, j]] = Complex64::new(real, imag);
77 matrix[[j, i]] = Complex64::new(real, -imag);
78 }
79 }
80 }
81 matrix
82}