Function dfdx::tensor_ops::matmul

source ·
pub fn matmul<Lhs, Rhs>(lhs: Lhs, rhs: Rhs) -> Lhs::Outputwhere
    Lhs: TryMatMul<Rhs>,
Expand description

Matrix * Matrix, Vector * Matrix, Vector * Vector, and broadcasted/batched versions.

Examples

  1. Matrix & Matrix
let x: Tensor<Rank2<3, 10>, f32, _> = dev.zeros();
let y: Tensor<Rank2<10, 5>, f32, _> = dev.zeros();
let _: Tensor<Rank2<3, 5>, f32, _> = x.matmul(y);
  1. Vector x Matrix
let x: Tensor<Rank1<2>, f32, _> = dev.zeros();
let y: Tensor<Rank2<2, 4>, f32, _> = dev.zeros();
let _: Tensor<Rank1<4>, f32, _> = x.matmul(y);
  1. Vector x Vector
let x: Tensor<Rank1<2>, f32, _> = dev.zeros();
let y: Tensor<Rank1<4>, f32, _> = dev.zeros();
let _: Tensor<Rank2<2, 4>, f32, _> = x.matmul(y);
  1. Batched matmul
let x: Tensor<Rank3<10, 3, 2>, f32, _> = dev.zeros();
let y: Tensor<Rank3<10, 2, 4>, f32, _> = dev.zeros();
let _: Tensor<Rank3<10, 3, 4>, f32, _> = x.matmul(y);
  1. Broadcasted matmul
let x: Tensor<Rank3<10, 3, 2>, f32, _> = dev.zeros();
let y: Tensor<Rank2<2, 4>, f32, _> = dev.zeros();
let _: Tensor<Rank3<10, 3, 4>, f32, _> = x.matmul(y);