cubecl_matmul/kernels/layered/algorithm/
simple_unit.rs

1use cubecl_core::{Runtime, client::ComputeClient, ir::Elem};
2
3use std::marker::PhantomData;
4
5use crate::{
6    components::{
7        MatmulProblem, MatmulSelection,
8        batch::{PartitionedBatchMatmulFamily, RowMajorGlobalPartitionMatmul},
9        global::{
10            load::{SyncFullLoadingStrategy, sync_full_cyclic::SyncFullCyclicLoading},
11            single_stage::simple::SimpleMatmulFamily,
12        },
13        stage::{ColMajorTilingOrder, FullReaderFamily, RowMajorTilingOrder, UnitMatmulFamily},
14        tile::register::RegisterMatmul,
15    },
16    kernels::layered::{
17        TileSizeSelection,
18        selector::{
19            PartitionScaling, StageScaling, UnitMatmulSelectionOptions, unit_matmul_selection,
20        },
21    },
22};
23
24use super::Algorithm;
25
26/// Unit single stage matmul with configurable loaders (default to cyclic)
27pub struct SimpleUnitAlgorithm<
28    LL = SyncFullCyclicLoading<ColMajorTilingOrder>,
29    RL = SyncFullCyclicLoading<RowMajorTilingOrder>,
30> {
31    pub _ll: PhantomData<LL>,
32    pub _rl: PhantomData<RL>,
33}
34
35#[derive(Default, Clone, Debug)]
36pub struct SimpleUnitSelectionArgs {
37    pub tile_size: TileSizeSelection,
38}
39
40impl<LL, RL> Algorithm for SimpleUnitAlgorithm<LL, RL>
41where
42    LL: SyncFullLoadingStrategy,
43    RL: SyncFullLoadingStrategy,
44{
45    type SelectionArgs = SimpleUnitSelectionArgs;
46    type TileMatmul = RegisterMatmul;
47    type StageMatmul = UnitMatmulFamily<Self::TileMatmul, FullReaderFamily>;
48    type GlobalMatmul = SimpleMatmulFamily<Self::StageMatmul, LL, RL>;
49
50    type BatchMatmul =
51        PartitionedBatchMatmulFamily<Self::GlobalMatmul, RowMajorGlobalPartitionMatmul>;
52
53    fn selection<R: Runtime>(
54        client: &ComputeClient<R::Server, R::Channel>,
55        problem: &MatmulProblem,
56        plane_dim: u32,
57        _elem_stage: Elem,
58        _elem_acc: Elem,
59        args: &Self::SelectionArgs,
60    ) -> MatmulSelection {
61        unit_matmul_selection::<R>(
62            client,
63            problem,
64            plane_dim,
65            false,
66            UnitMatmulSelectionOptions {
67                tile: args.tile_size,
68                stage: match args.tile_size {
69                    TileSizeSelection::MinTileSize => StageScaling::Enabled(2),
70                    TileSizeSelection::MaxTileSize => StageScaling::Disabled,
71                },
72                partition: match args.tile_size {
73                    TileSizeSelection::MinTileSize => PartitionScaling::Disabled,
74                    TileSizeSelection::MaxTileSize => PartitionScaling::Enabled,
75                },
76            },
77        )
78    }
79
80    fn select_plane_dim<R: Runtime>(client: &ComputeClient<R::Server, R::Channel>) -> u32 {
81        client.properties().hardware.plane_size_min
82    }
83}