cubecl_matmul/kernels/layered/algorithm/
base.rs

1use crate::components::batch::BatchMatmulFamily;
2use crate::components::global::GlobalMatmulFamily;
3use crate::components::stage::StageMatmulFamily;
4use crate::components::tile::TileMatmulFamily;
5use crate::components::{
6    AvailableLineSizes, MatmulElems, MatmulLineSizes, MatmulPrecision, MatmulProblem,
7    MatmulSelection, MatmulSetupError,
8};
9use cubecl_core::prelude::*;
10
11/// Specifications for a matmul algorithm
12pub trait Algorithm {
13    type SelectionArgs: Default + Clone;
14    type TileMatmul: TileMatmulFamily;
15    type StageMatmul: StageMatmulFamily;
16    type GlobalMatmul: GlobalMatmulFamily;
17    type BatchMatmul: BatchMatmulFamily;
18
19    fn setup<MP: MatmulPrecision, R: Runtime>(
20        client: &ComputeClient<R::Server>,
21        problem: &MatmulProblem,
22        selection: &MatmulSelection,
23        line_sizes: &MatmulLineSizes,
24    ) -> Result<<Self::BatchMatmul as BatchMatmulFamily>::Config, MatmulSetupError> {
25        Self::BatchMatmul::setup::<MP, R>(client, problem, selection, line_sizes)
26    }
27
28    fn selection<R: Runtime>(
29        client: &ComputeClient<R::Server>,
30        problem: &MatmulProblem,
31        plane_dim: u32,
32        line_sizes: &MatmulLineSizes,
33        elems: MatmulElems,
34        args: &Self::SelectionArgs,
35    ) -> Result<MatmulSelection, MatmulSetupError>;
36
37    fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes {
38        Self::BatchMatmul::filter_line_sizes(Self::GlobalMatmul::filter_line_sizes(
39            Self::StageMatmul::filter_line_sizes(Self::TileMatmul::filter_line_sizes(
40                available_line_sizes,
41            )),
42        ))
43    }
44
45    fn select_plane_dim<R: Runtime>(client: &ComputeClient<R::Server>) -> u32 {
46        client.properties().hardware.plane_size_max
47    }
48}