pub struct TensordotFixedPosition {
len_uncontracted_lhs: usize,
len_uncontracted_rhs: usize,
len_contracted_axes: usize,
output_shape: Vec<usize>,
}Expand description
Performs tensor dot product for two tensors where no permutation needs to be performed,
e.g. ijk,jkl->il or ijk,klm->ijlm.
The axes to be contracted must be the last axes of the LHS tensor and the first axes of the RHS tensor, and the axis order for the output tensor must be all the uncontracted axes of the LHS tensor followed by all the uncontracted axes of the RHS tensor, in the orders those originally appear in the LHS and RHS tensors.
The contraction is performed by reshaping the LHS into a matrix (2-D tensor) of shape [len_uncontracted_lhs, len_contracted_axes], reshaping the RHS into shape [len_contracted_axes, len_contracted_rhs], matrix-multiplying the two reshaped tensor, and then reshaping the result into […self.output_shape].
Fields§
§len_uncontracted_lhs: usizeThe product of the lengths of all the uncontracted axes in the LHS (or 1 if all of the LHS axes are contracted)
len_uncontracted_rhs: usizeThe product of the lengths of all the uncontracted axes in the RHS (or 1 if all of the RHS axes are contracted)
len_contracted_axes: usizeThe product of the lengths of all the contracted axes (or 1 if no axes are contracted, i.e. the outer product is computed)
output_shape: Vec<usize>The shape that the tensor dot product will be recast to
Implementations§
Source§impl TensordotFixedPosition
impl TensordotFixedPosition
pub fn new(sc: &SizedContraction) -> Self
Sourcepub fn from_shapes_and_number_of_contracted_axes(
lhs_shape: &[usize],
rhs_shape: &[usize],
num_contracted_axes: usize,
) -> Self
pub fn from_shapes_and_number_of_contracted_axes( lhs_shape: &[usize], rhs_shape: &[usize], num_contracted_axes: usize, ) -> Self
Compute the uncontracted and contracted axis lengths and the output shape based on the input shapes and how many axes should be contracted from each tensor.
TODO: The assert_eq! here could be tightened up by verifying that the
last num_contracted_axes of the LHS match the first num_contracted_axes of the
RHS axis-by-axis (as opposed to only checking the product as is done here.)
Trait Implementations§
Source§impl Clone for TensordotFixedPosition
impl Clone for TensordotFixedPosition
Source§fn clone(&self) -> TensordotFixedPosition
fn clone(&self) -> TensordotFixedPosition
1.0.0 · Source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
source. Read more