TensorDot

Trait TensorDot 

Source
pub trait TensorDot<RHS = Self> {
    type Output;

    // Required method
    fn tensordot<const N: usize>(
        &self,
        rhs: &RHS,
        axes: ([i64; N], [i64; N]),
    ) -> Result<Self::Output, TensorError>;
}
Expand description

A trait for tensor dot operations on tensors.

Required Associated Types§

Source

type Output

The output tensor type.

Required Methods§

Source

fn tensordot<const N: usize>( &self, rhs: &RHS, axes: ([i64; N], [i64; N]), ) -> Result<Self::Output, TensorError>

Compute tensor dot product along specified axes. This is a generalization of matrix multiplication to higher dimensions.

§Parameters:

rhs: The right-hand side tensor.

axes: A tuple of two arrays specifying the axes to contract over:

  • First array contains axes from the first tensor
  • Second array contains axes from the second tensor
  • Arrays must have same length N
§Example:
// Matrix multiplication (2D tensordot)
let a = Tensor::new(&[[1., 2.], [3., 4.]]);
let b = Tensor::new(&[[5., 6.], [7., 8.]]);
let c = a.tensordot(&b, ([1], [0]))?; // Contract last axis of a with first axis of b
println!("Matrix multiplication:\n{}", c);

// Higher dimensional example
let d = Tensor::<f32>::ones(&[2, 3, 4])?;
let e = Tensor::<f32>::ones(&[4, 3, 2])?;
let f = d.tensordot(&e, ([1, 2], [1, 0]))?; // Contract axes 1,2 of d with axes 1,0 of e
println!("Higher dimensional result:\n{}", f);

Dyn Compatibility§

This trait is not dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.

Implementors§

Source§

impl<A, B> TensorDot<Tensor<B>> for Tensor<A>
where _Tensor<A>: TensorDot<_Tensor<B>, Output = _Tensor<<A as NormalOut<B>>::Output>>, A: CommonBounds + NormalOut<B>, B: CommonBounds, <A as NormalOut<B>>::Output: CommonBounds,