pub fn tensordot<A, S, S2, D, E>(
lhs: &ArrayBase<S, D>,
rhs: &ArrayBase<S2, E>,
lhs_axes: &[Axis],
rhs_axes: &[Axis],
) -> ArrayD<A>Expand description
Compute tensor dot product between two tensors.
Similar to the numpy function of the same name.
Easiest to explain by showing the einsum equivalents:
let m1 = Array::range(0., (3*4*5*6) as f64, 1.)
.into_shape((3,4,5,6,))
.unwrap();
let m2 = Array::range(0., (4*5*6*7) as f64, 1.)
.into_shape((4,5,6,7))
.unwrap();
assert_eq!(
einsum(
"ijkl,jklm->im",
&[&m1, &m2]
).unwrap(),
tensordot(
&m1,
&m2,
&[Axis(1), Axis(2), Axis(3)],
&[Axis(0), Axis(1), Axis(2)]
)
);
assert_eq!(
einsum(
"abic,dief->abcdef",
&[&m1, &m2]
).unwrap(),
tensordot(
&m1,
&m2,
&[Axis(2)],
&[Axis(1)]
)
);