BatchMatmul

Trait BatchMatmul 

Source
pub trait BatchMatmul<MP: MatmulPrecision>:
    'static
    + Send
    + Sync {
    type Config: BatchConfig;

    // Required methods
    fn execute(
        a: View<Line<LhsG<MP>>, Coords3d>,
        b: View<Line<RhsG<MP>>, Coords3d>,
        c: CubeOption<View<Line<AccG<MP>>, Coords3d>>,
        out: View<Line<AccG<MP>>, Coords3d, ReadWrite>,
        cube_count_args: CubeCountInput,
        config: Self::Config,
    );
    fn __expand_execute(
        scope: &mut Scope,
        a: <View<Line<LhsG<MP>>, Coords3d> as CubeType>::ExpandType,
        b: <View<Line<RhsG<MP>>, Coords3d> as CubeType>::ExpandType,
        c: <CubeOption<View<Line<AccG<MP>>, Coords3d>> as CubeType>::ExpandType,
        out: <View<Line<AccG<MP>>, Coords3d, ReadWrite> as CubeType>::ExpandType,
        cube_count_args: <CubeCountInput as CubeType>::ExpandType,
        config: Self::Config,
    ) -> <() as CubeType>::ExpandType;
}
Expand description

Provides matrix multiplication operations at the batch level.

At the batch level,

  • Inputs are whole tensors in global memory.
  • All Cubes are used to solve the problem
  • Dimensions M, N and K can be arbitrary large, as well as the number of batches.

§Assumptions

  • Line sizes of the inputs evenly divide the dimension they are aligned with.

§Safety

  • It is not assumed that the matmul’s dimensions match its inputs dimensions perfectly. It is therefore important to use an underlying global matmul that performs check bounds,
  • It is accepted to launch more Cube than necessary, providing a CubeCountInput that states the max cube position

Required Associated Types§

Required Methods§

Source

fn execute( a: View<Line<LhsG<MP>>, Coords3d>, b: View<Line<RhsG<MP>>, Coords3d>, c: CubeOption<View<Line<AccG<MP>>, Coords3d>>, out: View<Line<AccG<MP>>, Coords3d, ReadWrite>, cube_count_args: CubeCountInput, config: Self::Config, )

Performs batchwise matrix multiplication over tensors.

Source

fn __expand_execute( scope: &mut Scope, a: <View<Line<LhsG<MP>>, Coords3d> as CubeType>::ExpandType, b: <View<Line<RhsG<MP>>, Coords3d> as CubeType>::ExpandType, c: <CubeOption<View<Line<AccG<MP>>, Coords3d>> as CubeType>::ExpandType, out: <View<Line<AccG<MP>>, Coords3d, ReadWrite> as CubeType>::ExpandType, cube_count_args: <CubeCountInput as CubeType>::ExpandType, config: Self::Config, ) -> <() as CubeType>::ExpandType

Dyn Compatibility§

This trait is not dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.

Implementors§