Function dfdx::tensor_ops::matmul
source · pub fn matmul<Lhs, Rhs>(lhs: Lhs, rhs: Rhs) -> Lhs::Outputwhere
Lhs: TryMatMul<Rhs>,
Expand description
Matrix * Matrix, Vector * Matrix, Vector * Vector, and broadcasted/batched versions.
Examples
- Matrix & Matrix
let x: Tensor<Rank2<3, 10>, f32, _> = dev.zeros();
let y: Tensor<Rank2<10, 5>, f32, _> = dev.zeros();
let _: Tensor<Rank2<3, 5>, f32, _> = x.matmul(y);
- Vector x Matrix
let x: Tensor<Rank1<2>, f32, _> = dev.zeros();
let y: Tensor<Rank2<2, 4>, f32, _> = dev.zeros();
let _: Tensor<Rank1<4>, f32, _> = x.matmul(y);
- Vector x Vector
let x: Tensor<Rank1<2>, f32, _> = dev.zeros();
let y: Tensor<Rank1<4>, f32, _> = dev.zeros();
let _: Tensor<Rank2<2, 4>, f32, _> = x.matmul(y);
- Batched matmul
let x: Tensor<Rank3<10, 3, 2>, f32, _> = dev.zeros();
let y: Tensor<Rank3<10, 2, 4>, f32, _> = dev.zeros();
let _: Tensor<Rank3<10, 3, 4>, f32, _> = x.matmul(y);
- Broadcasted matmul
let x: Tensor<Rank3<10, 3, 2>, f32, _> = dev.zeros();
let y: Tensor<Rank2<2, 4>, f32, _> = dev.zeros();
let _: Tensor<Rank3<10, 3, 4>, f32, _> = x.matmul(y);