mdarray_linalg/testing/tensordot/
mod.rs

1use mdarray::tensor;
2use crate::matmul::{MatMul, ContractBuilder};
3
4// --- Basic functionality ---
5
6pub fn tensordot_all_axes_impl(backend: &impl MatMul<f64>) {
7    // np.tensordot(a, b, axes=2) -> [[70.0]]
8    let a = tensor![[1., 2.], [3., 4.]].into_dyn();
9    let b = tensor![[5., 6.], [7., 8.]].into_dyn();
10    let expected = tensor![[70.0]].into_dyn();
11    let result = backend.contract_all(&a, &b).eval();
12    assert_eq!(result, expected);
13}
14
15pub fn tensordot_contract_k_2_should_match_all_axes_impl(backend: &impl MatMul<f64>) {
16    // contract_k(2) is equivalent to All for 2D tensors
17    let a = tensor![[1., 2.], [3., 4.]].into_dyn();
18    let b = tensor![[5., 6.], [7., 8.]].into_dyn();
19    let expected = tensor![[70.0]].into_dyn();
20    let result = backend.contract_n(&a, &b, 2).eval();
21    assert_eq!(result, expected);
22}
23
24pub fn tensordot_specific_axes_matrix_multiplication_impl(backend: &impl MatMul<f64>) {
25    // tensordot(a, b, axes=([1], [0])) -> matrix product
26    let a = tensor![[1., 2.], [3., 4.]].into_dyn();
27    let b = tensor![[5., 6.], [7., 8.]].into_dyn();
28    let expected = tensor![[19., 22.], [43., 50.]].into_dyn();
29    let result = backend.contract(&a, &b, vec![1], vec![0]).eval();
30    assert_eq!(result, expected);
31}
32
33pub fn tensordot_specific_empty_axes_should_outer_product_impl(backend: &impl MatMul<f64>) {
34    // tensordot(a, b, axes=0) -> outer product
35    let a = tensor![[1., 2.], [3., 4.]].into_dyn();
36    let b = tensor![[5., 6.], [7., 8.]].into_dyn();
37    let expected = tensor![
38        [[[5.0, 6.0], [7.0, 8.0]], [[10.0, 12.0], [14.0, 16.0]]],
39        [[[15.0, 18.0], [21.0, 24.0]], [[20.0, 24.0], [28.0, 32.0]]]
40    ]
41    .into_dyn();
42    let result = backend.contract_n(&a, &b, 0).eval();
43    assert_eq!(result, expected);
44}
45
46// --- Edge cases ---
47
48pub fn tensordot_scalar_inputs_should_multiply_impl(backend: &impl MatMul<f64>) {
49    let a = tensor![3.].into_dyn();
50    let b = tensor![5.].into_dyn();
51    let expected = tensor![[15.0]].into_dyn();
52    let result = backend.contract_all(&a, &b).eval();
53    assert_eq!(result, expected);
54}
55
56pub fn tensordot_increase_deep_impl(backend: &impl MatMul<f64>) {
57    let r = tensor![[[1.]]].into_dyn();
58    let mps = tensor![[[1.], [0.]]].into_dyn();
59    let expected = tensor![[[[1.0], [0.]]]].into_dyn();
60    let result = backend.contract(&r, &mps, vec![1], vec![0]).eval();
61    assert_eq!(result, expected);
62}
63
64pub fn tensordot_vector_dot_product_impl(backend: &impl MatMul<f64>) {
65    // tensordot(a, b, axes=1) -> scalar inner product
66    let a = tensor![1., 2., 3.].into_dyn();
67    let b = tensor![4., 5., 6.].into_dyn();
68    let expected = tensor![[32.0]].into_dyn(); // 1*4 + 2*5 + 3*6
69    let result = backend.contract_all(&a, &b).eval();
70    assert_eq!(result, expected);
71}
72
73pub fn tensordot_mismatched_dimensions_should_panic_impl(
74    backend: &(impl MatMul<f64> + std::panic::RefUnwindSafe),
75) {
76    // Should panic when dimensions are not aligned
77    let a = tensor![[1., 2.], [3., 4.]].into_dyn();
78    let b = tensor![[1., 2., 3.]].into_dyn(); // shape mismatch
79    let result = std::panic::catch_unwind(|| backend.contract_all(&a, &b).eval());
80    assert!(result.is_err());
81}
82
83// --- Structural and mathematical properties ---
84
85pub fn tensordot_outer_should_match_manual_kronecker_impl(backend: &impl MatMul<f64>) {
86    // The outer product should be equal to np.kron(a,b)
87    let a = tensor![1., 2.].into_dyn();
88    let b = tensor![3., 4.].into_dyn();
89    let expected = tensor![[3., 4.], [6., 8.]].into_dyn();
90    let result = backend.contract_n(&a, &b, 0).eval();
91    assert_eq!(result, expected);
92}
93
94// --- Test overwrite functionality ---
95
96// fn tensordot_overwrite_impl(backend: &impl MatMul<f64>) {
97//     let a = tensor![[1., 2.], [3., 4.]].into_dyn();
98//     let b = tensor![[5., 6.], [7., 8.]].into_dyn();
99//     let expected = tensor![[19., 22.], [43., 50.]].into_dyn();
100
101//     let mut c = tensor![[0., 0.], [0., 0.]].into_dyn();
102//     backend
103//         .contract(&a, &b, vec![1], vec![0])
104//         .overwrite(&mut c);
105
106//     assert_eq!(c, expected);
107// }
108
109// #[test]
110// fn tensordot_overwrite() {
111//     tensordot_overwrite_impl(&Naive);
112//     tensordot_overwrite_impl(&Blas);
113// }
114
115// fn tensordot_overwrite_all_axes_impl(backend: &impl MatMul<f64>) {
116//     let a = tensor![[1., 2.], [3., 4.]].into_dyn();
117//     let b = tensor![[5., 6.], [7., 8.]].into_dyn();
118//     let expected = tensor![[70.0]].into_dyn();
119
120//     let mut c = tensor![[0.0]].into_dyn();
121//     backend.contract_all(&a, &b).overwrite(&mut c);
122
123//     assert_eq!(c, expected);
124// }
125
126// #[test]
127// fn tensordot_overwrite_all_axes() {
128//     tensordot_overwrite_all_axes_impl(&Naive);
129//     tensordot_overwrite_all_axes_impl(&Blas);
130// }