pub trait GlobalMatmul<MP: MatmulPrecision>:
'static
+ Send
+ Sync {
type Config: GlobalConfig;
type LhsGlobalReader: CubeType;
type RhsGlobalReader: CubeType;
type AccGlobalReader: CubeType;
type GlobalWriter: CubeType;
type Accumulators: CubeType;
// Required methods
fn execute(
lhs_reader: Self::LhsGlobalReader,
rhs_reader: Self::RhsGlobalReader,
acc_reader: Self::AccGlobalReader,
writer: Self::GlobalWriter,
acc: &mut Self::Accumulators,
k_range: (u32, u32),
config: Self::Config,
);
fn init_lhs_global_reader(
lhs: View<Line<LhsG<MP>>, Coords2d>,
config: Self::Config,
) -> Self::LhsGlobalReader;
fn init_rhs_global_reader(
rhs: View<Line<RhsG<MP>>, Coords2d>,
config: Self::Config,
) -> Self::RhsGlobalReader;
fn init_acc_global_reader(
acc: CubeOption<View<Line<AccG<MP>>, Coords2d>>,
config: Self::Config,
) -> Self::AccGlobalReader;
fn init_accumulators(config: Self::Config) -> Self::Accumulators;
fn init_global_writer(
out: View<Line<AccG<MP>>, Coords2d, ReadWrite>,
config: Self::Config,
) -> Self::GlobalWriter;
fn __expand_execute(
scope: &mut Scope,
lhs_reader: <Self::LhsGlobalReader as CubeType>::ExpandType,
rhs_reader: <Self::RhsGlobalReader as CubeType>::ExpandType,
acc_reader: <Self::AccGlobalReader as CubeType>::ExpandType,
writer: <Self::GlobalWriter as CubeType>::ExpandType,
acc: <Self::Accumulators as CubeType>::ExpandType,
k_range: <(u32, u32) as CubeType>::ExpandType,
config: Self::Config,
) -> <() as CubeType>::ExpandType;
fn __expand_init_lhs_global_reader(
scope: &mut Scope,
lhs: <View<Line<LhsG<MP>>, Coords2d> as CubeType>::ExpandType,
config: Self::Config,
) -> <Self::LhsGlobalReader as CubeType>::ExpandType;
fn __expand_init_rhs_global_reader(
scope: &mut Scope,
rhs: <View<Line<RhsG<MP>>, Coords2d> as CubeType>::ExpandType,
config: Self::Config,
) -> <Self::RhsGlobalReader as CubeType>::ExpandType;
fn __expand_init_acc_global_reader(
scope: &mut Scope,
acc: <CubeOption<View<Line<AccG<MP>>, Coords2d>> as CubeType>::ExpandType,
config: Self::Config,
) -> <Self::AccGlobalReader as CubeType>::ExpandType;
fn __expand_init_accumulators(
scope: &mut Scope,
config: Self::Config,
) -> <Self::Accumulators as CubeType>::ExpandType;
fn __expand_init_global_writer(
scope: &mut Scope,
out: <View<Line<AccG<MP>>, Coords2d, ReadWrite> as CubeType>::ExpandType,
config: Self::Config,
) -> <Self::GlobalWriter as CubeType>::ExpandType;
}Expand description
Provides matrix multiplication operations at the global level.
At the global level,
- Inputs are views over global memory, meaning access is given to only parts of the global memory inputs at once.
- All planes within a Cube are used to solve the problem
- Dimensions M and N are fixed to an integer, but K is arbitrary large. The matrix multiplication works only for size (M, ) · (, N) = (M, N). M and N should match the underlying Stage matmul’s M and N.
§Assumptions
- Line sizes of the inputs evenly divide the dimension they are aligned with.
§Safety
It is not assumed that the matmul’s dimensions match its inputs dimensions perfectly. It is therefore important that Readers and Writers perform checks to avoid out-of-bounds before reading data.
Required Associated Types§
type Config: GlobalConfig
Sourcetype LhsGlobalReader: CubeType
type LhsGlobalReader: CubeType
Global reader for matrix A (Lhs)
Sourcetype RhsGlobalReader: CubeType
type RhsGlobalReader: CubeType
Global reader for matrix B (Rhs)
Sourcetype AccGlobalReader: CubeType
type AccGlobalReader: CubeType
Global reader for matrix C (Accumulator/Bias)
Sourcetype GlobalWriter: CubeType
type GlobalWriter: CubeType
Writer to store the output stage into global memory
Sourcetype Accumulators: CubeType
type Accumulators: CubeType
The accumulator type for the tile matmul
Required Methods§
Sourcefn execute(
lhs_reader: Self::LhsGlobalReader,
rhs_reader: Self::RhsGlobalReader,
acc_reader: Self::AccGlobalReader,
writer: Self::GlobalWriter,
acc: &mut Self::Accumulators,
k_range: (u32, u32),
config: Self::Config,
)
fn execute( lhs_reader: Self::LhsGlobalReader, rhs_reader: Self::RhsGlobalReader, acc_reader: Self::AccGlobalReader, writer: Self::GlobalWriter, acc: &mut Self::Accumulators, k_range: (u32, u32), config: Self::Config, )
Performs the matrix multiplication over data loaded by the Lhs and Rhs readers, over the range given for K, and stores with using the output writer.
To compute the whole range of k values, use k_range=(0, K) where K is the K dimension of Lhs and Rhs.
Sourcefn init_lhs_global_reader(
lhs: View<Line<LhsG<MP>>, Coords2d>,
config: Self::Config,
) -> Self::LhsGlobalReader
fn init_lhs_global_reader( lhs: View<Line<LhsG<MP>>, Coords2d>, config: Self::Config, ) -> Self::LhsGlobalReader
Initialize the global reader for Lhs, starting at row m and column k
Sourcefn init_rhs_global_reader(
rhs: View<Line<RhsG<MP>>, Coords2d>,
config: Self::Config,
) -> Self::RhsGlobalReader
fn init_rhs_global_reader( rhs: View<Line<RhsG<MP>>, Coords2d>, config: Self::Config, ) -> Self::RhsGlobalReader
Initialize the global reader for Rhs, starting at row k and column n
Sourcefn init_acc_global_reader(
acc: CubeOption<View<Line<AccG<MP>>, Coords2d>>,
config: Self::Config,
) -> Self::AccGlobalReader
fn init_acc_global_reader( acc: CubeOption<View<Line<AccG<MP>>, Coords2d>>, config: Self::Config, ) -> Self::AccGlobalReader
Initialize the global reader for Rhs, starting at row k and column n
Sourcefn init_accumulators(config: Self::Config) -> Self::Accumulators
fn init_accumulators(config: Self::Config) -> Self::Accumulators
Initialize the accumulator without data
Sourcefn init_global_writer(
out: View<Line<AccG<MP>>, Coords2d, ReadWrite>,
config: Self::Config,
) -> Self::GlobalWriter
fn init_global_writer( out: View<Line<AccG<MP>>, Coords2d, ReadWrite>, config: Self::Config, ) -> Self::GlobalWriter
Initialize the global writer at row m and column n
fn __expand_execute( scope: &mut Scope, lhs_reader: <Self::LhsGlobalReader as CubeType>::ExpandType, rhs_reader: <Self::RhsGlobalReader as CubeType>::ExpandType, acc_reader: <Self::AccGlobalReader as CubeType>::ExpandType, writer: <Self::GlobalWriter as CubeType>::ExpandType, acc: <Self::Accumulators as CubeType>::ExpandType, k_range: <(u32, u32) as CubeType>::ExpandType, config: Self::Config, ) -> <() as CubeType>::ExpandType
fn __expand_init_lhs_global_reader( scope: &mut Scope, lhs: <View<Line<LhsG<MP>>, Coords2d> as CubeType>::ExpandType, config: Self::Config, ) -> <Self::LhsGlobalReader as CubeType>::ExpandType
fn __expand_init_rhs_global_reader( scope: &mut Scope, rhs: <View<Line<RhsG<MP>>, Coords2d> as CubeType>::ExpandType, config: Self::Config, ) -> <Self::RhsGlobalReader as CubeType>::ExpandType
fn __expand_init_acc_global_reader( scope: &mut Scope, acc: <CubeOption<View<Line<AccG<MP>>, Coords2d>> as CubeType>::ExpandType, config: Self::Config, ) -> <Self::AccGlobalReader as CubeType>::ExpandType
fn __expand_init_accumulators( scope: &mut Scope, config: Self::Config, ) -> <Self::Accumulators as CubeType>::ExpandType
fn __expand_init_global_writer( scope: &mut Scope, out: <View<Line<AccG<MP>>, Coords2d, ReadWrite> as CubeType>::ExpandType, config: Self::Config, ) -> <Self::GlobalWriter as CubeType>::ExpandType
Dyn Compatibility§
This trait is not dyn compatible.
In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.