cubecl_matmul/kernels/layered/algorithm/
simple_unit.rs1use cubecl_core::{Runtime, client::ComputeClient, ir::Elem};
2
3use std::marker::PhantomData;
4
5use crate::{
6 components::{
7 MatmulProblem, MatmulSelection,
8 batch::{PartitionedBatchMatmulFamily, RowMajorGlobalPartitionMatmul},
9 global::{
10 load::{SyncFullLoadingStrategy, sync_full_cyclic::SyncFullCyclicLoading},
11 single_stage::simple::SimpleMatmulFamily,
12 },
13 stage::{ColMajorTilingOrder, FullReaderFamily, RowMajorTilingOrder, UnitMatmulFamily},
14 tile::register::RegisterMatmul,
15 },
16 kernels::layered::{
17 TileSizeSelection,
18 selector::{
19 PartitionScaling, StageScaling, UnitMatmulSelectionOptions, unit_matmul_selection,
20 },
21 },
22};
23
24use super::Algorithm;
25
26pub struct SimpleUnitAlgorithm<
28 LL = SyncFullCyclicLoading<ColMajorTilingOrder>,
29 RL = SyncFullCyclicLoading<RowMajorTilingOrder>,
30> {
31 pub _ll: PhantomData<LL>,
32 pub _rl: PhantomData<RL>,
33}
34
35#[derive(Default, Clone, Debug)]
36pub struct SimpleUnitSelectionArgs {
37 pub tile_size: TileSizeSelection,
38}
39
40impl<LL, RL> Algorithm for SimpleUnitAlgorithm<LL, RL>
41where
42 LL: SyncFullLoadingStrategy,
43 RL: SyncFullLoadingStrategy,
44{
45 type SelectionArgs = SimpleUnitSelectionArgs;
46 type TileMatmul = RegisterMatmul;
47 type StageMatmul = UnitMatmulFamily<Self::TileMatmul, FullReaderFamily>;
48 type GlobalMatmul = SimpleMatmulFamily<Self::StageMatmul, LL, RL>;
49
50 type BatchMatmul =
51 PartitionedBatchMatmulFamily<Self::GlobalMatmul, RowMajorGlobalPartitionMatmul>;
52
53 fn selection<R: Runtime>(
54 client: &ComputeClient<R::Server, R::Channel>,
55 problem: &MatmulProblem,
56 plane_dim: u32,
57 _elem_stage: Elem,
58 _elem_acc: Elem,
59 args: &Self::SelectionArgs,
60 ) -> MatmulSelection {
61 unit_matmul_selection::<R>(
62 client,
63 problem,
64 plane_dim,
65 false,
66 UnitMatmulSelectionOptions {
67 tile: args.tile_size,
68 stage: match args.tile_size {
69 TileSizeSelection::MinTileSize => StageScaling::Enabled(2),
70 TileSizeSelection::MaxTileSize => StageScaling::Disabled,
71 },
72 partition: match args.tile_size {
73 TileSizeSelection::MinTileSize => PartitionScaling::Disabled,
74 TileSizeSelection::MaxTileSize => PartitionScaling::Enabled,
75 },
76 },
77 )
78 }
79
80 fn select_plane_dim<R: Runtime>(client: &ComputeClient<R::Server, R::Channel>) -> u32 {
81 client.properties().hardware.plane_size_min
82 }
83}