Trait BatchMatmul

Source
pub trait BatchMatmul<MP: MatmulPrecision>:
    'static
    + Send
    + Sync {
    type Config: BatchConfig;

    // Required methods
    fn execute(
        lhs: VirtualTensor<MP::EI>,
        rhs: VirtualTensor<MP::EI>,
        out: VirtualTensor<MP::EO, ReadWrite>,
        quantization: CubeOption<Quantization<MP>>,
        cube_count_args: CubeCountInput,
        config: Self::Config,
    );
    fn __expand_execute(
        scope: &mut Scope,
        lhs: <VirtualTensor<MP::EI> as CubeType>::ExpandType,
        rhs: <VirtualTensor<MP::EI> as CubeType>::ExpandType,
        out: <VirtualTensor<MP::EO, ReadWrite> as CubeType>::ExpandType,
        quantization: <CubeOption<Quantization<MP>> as CubeType>::ExpandType,
        cube_count_args: <CubeCountInput as CubeType>::ExpandType,
        config: Self::Config,
    ) -> <() as CubeType>::ExpandType;
}
Expand description

Provides matrix multiplication operations at the batch level.

At the batch level,

  • Inputs are whole tensors in global memory.
  • All Cubes are used to solve the problem
  • Dimensions M, N and K can be arbitrary large, as well as the number of batches.

§Assumptions

  • Line sizes of the inputs evenly divide the dimension they are aligned with.

§Safety

  • It is not assumed that the matmul’s dimensions match its inputs dimensions perfectly. It is therefore important to use an underlying global matmul that performs check bounds,
  • It is accepted to launch more Cube than necessary, providing a CubeCountInput that states the max cube position

Required Associated Types§

Required Methods§

Source

fn execute( lhs: VirtualTensor<MP::EI>, rhs: VirtualTensor<MP::EI>, out: VirtualTensor<MP::EO, ReadWrite>, quantization: CubeOption<Quantization<MP>>, cube_count_args: CubeCountInput, config: Self::Config, )

Performs batchwise matrix multiplication over tensors.

Source

fn __expand_execute( scope: &mut Scope, lhs: <VirtualTensor<MP::EI> as CubeType>::ExpandType, rhs: <VirtualTensor<MP::EI> as CubeType>::ExpandType, out: <VirtualTensor<MP::EO, ReadWrite> as CubeType>::ExpandType, quantization: <CubeOption<Quantization<MP>> as CubeType>::ExpandType, cube_count_args: <CubeCountInput as CubeType>::ExpandType, config: Self::Config, ) -> <() as CubeType>::ExpandType

Dyn Compatibility§

This trait is not dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.

Implementors§