mdarray_linalg/testing/tensordot/
mod.rs1use mdarray::tensor;
2use crate::matmul::{MatMul, ContractBuilder};
3
4pub fn tensordot_all_axes_impl(backend: &impl MatMul<f64>) {
7 let a = tensor![[1., 2.], [3., 4.]].into_dyn();
9 let b = tensor![[5., 6.], [7., 8.]].into_dyn();
10 let expected = tensor![[70.0]].into_dyn();
11 let result = backend.contract_all(&a, &b).eval();
12 assert_eq!(result, expected);
13}
14
15pub fn tensordot_contract_k_2_should_match_all_axes_impl(backend: &impl MatMul<f64>) {
16 let a = tensor![[1., 2.], [3., 4.]].into_dyn();
18 let b = tensor![[5., 6.], [7., 8.]].into_dyn();
19 let expected = tensor![[70.0]].into_dyn();
20 let result = backend.contract_n(&a, &b, 2).eval();
21 assert_eq!(result, expected);
22}
23
24pub fn tensordot_specific_axes_matrix_multiplication_impl(backend: &impl MatMul<f64>) {
25 let a = tensor![[1., 2.], [3., 4.]].into_dyn();
27 let b = tensor![[5., 6.], [7., 8.]].into_dyn();
28 let expected = tensor![[19., 22.], [43., 50.]].into_dyn();
29 let result = backend.contract(&a, &b, vec![1], vec![0]).eval();
30 assert_eq!(result, expected);
31}
32
33pub fn tensordot_specific_empty_axes_should_outer_product_impl(backend: &impl MatMul<f64>) {
34 let a = tensor![[1., 2.], [3., 4.]].into_dyn();
36 let b = tensor![[5., 6.], [7., 8.]].into_dyn();
37 let expected = tensor![
38 [[[5.0, 6.0], [7.0, 8.0]], [[10.0, 12.0], [14.0, 16.0]]],
39 [[[15.0, 18.0], [21.0, 24.0]], [[20.0, 24.0], [28.0, 32.0]]]
40 ]
41 .into_dyn();
42 let result = backend.contract_n(&a, &b, 0).eval();
43 assert_eq!(result, expected);
44}
45
46pub fn tensordot_scalar_inputs_should_multiply_impl(backend: &impl MatMul<f64>) {
49 let a = tensor![3.].into_dyn();
50 let b = tensor![5.].into_dyn();
51 let expected = tensor![[15.0]].into_dyn();
52 let result = backend.contract_all(&a, &b).eval();
53 assert_eq!(result, expected);
54}
55
56pub fn tensordot_increase_deep_impl(backend: &impl MatMul<f64>) {
57 let r = tensor![[[1.]]].into_dyn();
58 let mps = tensor![[[1.], [0.]]].into_dyn();
59 let expected = tensor![[[[1.0], [0.]]]].into_dyn();
60 let result = backend.contract(&r, &mps, vec![1], vec![0]).eval();
61 assert_eq!(result, expected);
62}
63
64pub fn tensordot_vector_dot_product_impl(backend: &impl MatMul<f64>) {
65 let a = tensor![1., 2., 3.].into_dyn();
67 let b = tensor![4., 5., 6.].into_dyn();
68 let expected = tensor![[32.0]].into_dyn(); let result = backend.contract_all(&a, &b).eval();
70 assert_eq!(result, expected);
71}
72
73pub fn tensordot_mismatched_dimensions_should_panic_impl(
74 backend: &(impl MatMul<f64> + std::panic::RefUnwindSafe),
75) {
76 let a = tensor![[1., 2.], [3., 4.]].into_dyn();
78 let b = tensor![[1., 2., 3.]].into_dyn(); let result = std::panic::catch_unwind(|| backend.contract_all(&a, &b).eval());
80 assert!(result.is_err());
81}
82
83pub fn tensordot_outer_should_match_manual_kronecker_impl(backend: &impl MatMul<f64>) {
86 let a = tensor![1., 2.].into_dyn();
88 let b = tensor![3., 4.].into_dyn();
89 let expected = tensor![[3., 4.], [6., 8.]].into_dyn();
90 let result = backend.contract_n(&a, &b, 0).eval();
91 assert_eq!(result, expected);
92}
93
94