cubecl_matmul/kernels/layered/algorithm/
simple_unit.rs

1use cubecl_core::{Runtime, client::ComputeClient};
2
3use std::marker::PhantomData;
4
5use crate::{
6    components::{
7        MatmulElems, MatmulLineSizes, MatmulProblem, MatmulSelection, MatmulSetupError,
8        batch::{PartitionedBatchMatmulFamily, RowMajorGlobalPartitionMatmul},
9        global::{
10            UnitWriterFamily,
11            read::{SyncFullLoadingStrategy, sync_full_cyclic::SyncFullCyclicLoading},
12            single_stage::simple::SimpleMatmulFamily,
13        },
14        stage::{
15            ColMajorTilingOrder, FilledStageFamily, RowMajorTilingOrder, StridedStageFamily,
16            UnitMatmulFamily,
17        },
18        tile::{io::Filled, register::RegisterMatmul},
19    },
20    kernels::layered::{
21        TileSizeSelection,
22        selector::{
23            PartitionScaling, StageScaling, UnitMatmulSelectionOptions, unit_matmul_selection,
24        },
25    },
26};
27
28use super::Algorithm;
29
30/// Unit single stage matmul with configurable readers (default to cyclic)
31pub struct SimpleUnitAlgorithm<
32    LL = SyncFullCyclicLoading<ColMajorTilingOrder>,
33    RL = SyncFullCyclicLoading<RowMajorTilingOrder>,
34> {
35    pub _ll: PhantomData<LL>,
36    pub _rl: PhantomData<RL>,
37}
38
39#[derive(Default, Clone, Debug)]
40pub struct SimpleUnitSelectionArgs {
41    pub tile_size: TileSizeSelection,
42}
43
44impl<LL, RL> Algorithm for SimpleUnitAlgorithm<LL, RL>
45where
46    LL: SyncFullLoadingStrategy,
47    RL: SyncFullLoadingStrategy,
48{
49    type SelectionArgs = SimpleUnitSelectionArgs;
50    type TileMatmul = RegisterMatmul<Filled>;
51    type StageMatmul = UnitMatmulFamily<Self::TileMatmul, StridedStageFamily, FilledStageFamily>;
52    type GlobalMatmul = SimpleMatmulFamily<Self::StageMatmul, LL, RL, UnitWriterFamily>;
53
54    type BatchMatmul =
55        PartitionedBatchMatmulFamily<Self::GlobalMatmul, RowMajorGlobalPartitionMatmul>;
56
57    fn selection<R: Runtime>(
58        client: &ComputeClient<R::Server>,
59        problem: &MatmulProblem,
60        plane_dim: u32,
61        line_sizes: &MatmulLineSizes,
62        _elems: MatmulElems,
63        args: &Self::SelectionArgs,
64    ) -> Result<MatmulSelection, MatmulSetupError> {
65        Ok(unit_matmul_selection::<R>(
66            client,
67            problem,
68            plane_dim,
69            false,
70            line_sizes,
71            UnitMatmulSelectionOptions {
72                tile: args.tile_size,
73                stage: match args.tile_size {
74                    TileSizeSelection::MinTileSize => StageScaling::Enabled(2),
75                    TileSizeSelection::MaxTileSize => StageScaling::Disabled,
76                },
77                partition: match args.tile_size {
78                    TileSizeSelection::MinTileSize => PartitionScaling::Disabled,
79                    TileSizeSelection::MaxTileSize => PartitionScaling::Enabled,
80                },
81            },
82        ))
83    }
84
85    fn select_plane_dim<R: Runtime>(client: &ComputeClient<R::Server>) -> u32 {
86        client.properties().hardware.plane_size_min
87    }
88}