Function dfdx::tensor_ops::matmul_transpose
source · [−]pub fn matmul_transpose<A, B, C>(a: A, b: &B) -> <A as MatMulTrTyping<B>>::C where
A: Tensor<Dtype = f32> + MatMulTrTyping<B, C = C>,
B: 'static + Tensor<Dtype = f32> + Clone,
C: Tensor<Dtype = f32, Tape = A::Tape>,
A::Array: Transpose,
B::Array: Transpose,
C::Array: Transpose,
A::Device: MatMulOp<A::Array, <B::Array as Transpose>::T, C::Array>,
Expand description
Matrix multiplication with the transpose of rhs
. Equivalent to matmul(lhs, transpose(rhs))
.
This supports the same variants as matmul (broadcasted, batched, etc).
Examples
- Normal matmul
let x: Tensor2D<3, 2> = Tensor2D::zeros();
let y: Tensor2D<4, 2> = Tensor2D::zeros();
let result: Tensor2D<3, 4> = matmul_transpose(x, &y);
- Batched matmul
let x: Tensor3D<10, 3, 2> = Tensor3D::zeros();
let y: Tensor3D<10, 4, 2> = Tensor3D::zeros();
let result: Tensor3D<10, 3, 4> = matmul_transpose(x, &y);
- Broadcasted matmul
let x: Tensor3D<10, 3, 2> = Tensor3D::zeros();
let y: Tensor2D<4, 2> = Tensor2D::zeros();
let result: Tensor3D<10, 3, 4> = matmul_transpose(x, &y);