pub trait BatchMatmul<MP: MatmulPrecision>:
'static
+ Send
+ Sync {
type Config: BatchConfig;
// Required methods
fn execute(
a: View<Line<LhsG<MP>>, Coords3d>,
b: View<Line<RhsG<MP>>, Coords3d>,
c: CubeOption<View<Line<AccG<MP>>, Coords3d>>,
out: View<Line<AccG<MP>>, Coords3d, ReadWrite>,
cube_count_args: CubeCountInput,
config: Self::Config,
);
fn __expand_execute(
scope: &mut Scope,
a: <View<Line<LhsG<MP>>, Coords3d> as CubeType>::ExpandType,
b: <View<Line<RhsG<MP>>, Coords3d> as CubeType>::ExpandType,
c: <CubeOption<View<Line<AccG<MP>>, Coords3d>> as CubeType>::ExpandType,
out: <View<Line<AccG<MP>>, Coords3d, ReadWrite> as CubeType>::ExpandType,
cube_count_args: <CubeCountInput as CubeType>::ExpandType,
config: Self::Config,
) -> <() as CubeType>::ExpandType;
}Expand description
Provides matrix multiplication operations at the batch level.
At the batch level,
- Inputs are whole tensors in global memory.
- All Cubes are used to solve the problem
- Dimensions M, N and K can be arbitrary large, as well as the number of batches.
§Assumptions
- Line sizes of the inputs evenly divide the dimension they are aligned with.
§Safety
- It is not assumed that the matmul’s dimensions match its inputs dimensions perfectly. It is therefore important to use an underlying global matmul that performs check bounds,
- It is accepted to launch more Cube than necessary, providing a CubeCountInput that states the max cube position
Required Associated Types§
type Config: BatchConfig
Required Methods§
Sourcefn execute(
a: View<Line<LhsG<MP>>, Coords3d>,
b: View<Line<RhsG<MP>>, Coords3d>,
c: CubeOption<View<Line<AccG<MP>>, Coords3d>>,
out: View<Line<AccG<MP>>, Coords3d, ReadWrite>,
cube_count_args: CubeCountInput,
config: Self::Config,
)
fn execute( a: View<Line<LhsG<MP>>, Coords3d>, b: View<Line<RhsG<MP>>, Coords3d>, c: CubeOption<View<Line<AccG<MP>>, Coords3d>>, out: View<Line<AccG<MP>>, Coords3d, ReadWrite>, cube_count_args: CubeCountInput, config: Self::Config, )
Performs batchwise matrix multiplication over tensors.
fn __expand_execute( scope: &mut Scope, a: <View<Line<LhsG<MP>>, Coords3d> as CubeType>::ExpandType, b: <View<Line<RhsG<MP>>, Coords3d> as CubeType>::ExpandType, c: <CubeOption<View<Line<AccG<MP>>, Coords3d>> as CubeType>::ExpandType, out: <View<Line<AccG<MP>>, Coords3d, ReadWrite> as CubeType>::ExpandType, cube_count_args: <CubeCountInput as CubeType>::ExpandType, config: Self::Config, ) -> <() as CubeType>::ExpandType
Dyn Compatibility§
This trait is not dyn compatible.
In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.