cubecl_matmul/kernels/layered/algorithm/
base.rs1use crate::components::batch::BatchMatmulFamily;
2use crate::components::global::GlobalMatmulFamily;
3use crate::components::stage::StageMatmulFamily;
4use crate::components::tile::TileMatmulFamily;
5use crate::components::{
6 AvailableLineSizes, MatmulElems, MatmulLineSizes, MatmulPrecision, MatmulProblem,
7 MatmulSelection, MatmulSetupError,
8};
9use cubecl_core::prelude::*;
10
11pub trait Algorithm {
13 type SelectionArgs: Default + Clone;
14 type TileMatmul: TileMatmulFamily;
15 type StageMatmul: StageMatmulFamily;
16 type GlobalMatmul: GlobalMatmulFamily;
17 type BatchMatmul: BatchMatmulFamily;
18
19 fn setup<MP: MatmulPrecision, R: Runtime>(
20 client: &ComputeClient<R::Server>,
21 problem: &MatmulProblem,
22 selection: &MatmulSelection,
23 line_sizes: &MatmulLineSizes,
24 ) -> Result<<Self::BatchMatmul as BatchMatmulFamily>::Config, MatmulSetupError> {
25 Self::BatchMatmul::setup::<MP, R>(client, problem, selection, line_sizes)
26 }
27
28 fn selection<R: Runtime>(
29 client: &ComputeClient<R::Server>,
30 problem: &MatmulProblem,
31 plane_dim: u32,
32 line_sizes: &MatmulLineSizes,
33 elems: MatmulElems,
34 args: &Self::SelectionArgs,
35 ) -> Result<MatmulSelection, MatmulSetupError>;
36
37 fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes {
38 Self::BatchMatmul::filter_line_sizes(Self::GlobalMatmul::filter_line_sizes(
39 Self::StageMatmul::filter_line_sizes(Self::TileMatmul::filter_line_sizes(
40 available_line_sizes,
41 )),
42 ))
43 }
44
45 fn select_plane_dim<R: Runtime>(client: &ComputeClient<R::Server>) -> u32 {
46 client.properties().hardware.plane_size_max
47 }
48}