Function dfdx::tensor_ops::matmul
source · [−]pub fn matmul<A, B, C>(a: A, b: B) -> <A as MatMulTyping<B>>::Cwhere
A: Tensor<Dtype = f32> + MatMulTyping<B, C = C>,
B: Tensor<Dtype = f32>,
C: Tensor<Dtype = f32, Tape = A::Tape>,
A::Tape: Merge<B::Tape>,
A::Array: Transpose,
B::Array: Transpose,
C::Array: Transpose,
A::Device: MatMulOp<A::Array, B::Array, C::Array>,
Expand description
Matrix multiplication. This also supports batched matrix multiplication, and broadcasted matrix multiplication.
Examples
- Normal matmul
let x: Tensor2D<3, 2> = TensorCreator::zeros();
let y: Tensor2D<2, 4> = TensorCreator::zeros();
let result: Tensor2D<3, 4> = matmul(x, y);
- Batched matmul
let x: Tensor3D<10, 3, 2> = TensorCreator::zeros();
let y: Tensor3D<10, 2, 4> = TensorCreator::zeros();
let result: Tensor3D<10, 3, 4> = matmul(x, y);
- Broadcasted matmul
let x: Tensor3D<10, 3, 2> = TensorCreator::zeros();
let y: Tensor2D<2, 4> = TensorCreator::zeros();
let result: Tensor3D<10, 3, 4> = matmul(x, y);