cubecl_matmul/kernels/layered/algorithm/
simple_unit.rs1use cubecl_core::{Runtime, client::ComputeClient};
2
3use std::marker::PhantomData;
4
5use crate::{
6 components::{
7 MatmulElems, MatmulLineSizes, MatmulProblem, MatmulSelection, MatmulSetupError,
8 batch::{PartitionedBatchMatmulFamily, RowMajorGlobalPartitionMatmul},
9 global::{
10 UnitWriterFamily,
11 read::{SyncFullLoadingStrategy, sync_full_cyclic::SyncFullCyclicLoading},
12 single_stage::simple::SimpleMatmulFamily,
13 },
14 stage::{
15 ColMajorTilingOrder, FilledStageFamily, RowMajorTilingOrder, StridedStageFamily,
16 UnitMatmulFamily,
17 },
18 tile::{io::Filled, register::RegisterMatmul},
19 },
20 kernels::layered::{
21 TileSizeSelection,
22 selector::{
23 PartitionScaling, StageScaling, UnitMatmulSelectionOptions, unit_matmul_selection,
24 },
25 },
26};
27
28use super::Algorithm;
29
30pub struct SimpleUnitAlgorithm<
32 LL = SyncFullCyclicLoading<ColMajorTilingOrder>,
33 RL = SyncFullCyclicLoading<RowMajorTilingOrder>,
34> {
35 pub _ll: PhantomData<LL>,
36 pub _rl: PhantomData<RL>,
37}
38
39#[derive(Default, Clone, Debug)]
40pub struct SimpleUnitSelectionArgs {
41 pub tile_size: TileSizeSelection,
42}
43
44impl<LL, RL> Algorithm for SimpleUnitAlgorithm<LL, RL>
45where
46 LL: SyncFullLoadingStrategy,
47 RL: SyncFullLoadingStrategy,
48{
49 type SelectionArgs = SimpleUnitSelectionArgs;
50 type TileMatmul = RegisterMatmul<Filled>;
51 type StageMatmul = UnitMatmulFamily<Self::TileMatmul, StridedStageFamily, FilledStageFamily>;
52 type GlobalMatmul = SimpleMatmulFamily<Self::StageMatmul, LL, RL, UnitWriterFamily>;
53
54 type BatchMatmul =
55 PartitionedBatchMatmulFamily<Self::GlobalMatmul, RowMajorGlobalPartitionMatmul>;
56
57 fn selection<R: Runtime>(
58 client: &ComputeClient<R::Server>,
59 problem: &MatmulProblem,
60 plane_dim: u32,
61 line_sizes: &MatmulLineSizes,
62 _elems: MatmulElems,
63 args: &Self::SelectionArgs,
64 ) -> Result<MatmulSelection, MatmulSetupError> {
65 Ok(unit_matmul_selection::<R>(
66 client,
67 problem,
68 plane_dim,
69 false,
70 line_sizes,
71 UnitMatmulSelectionOptions {
72 tile: args.tile_size,
73 stage: match args.tile_size {
74 TileSizeSelection::MinTileSize => StageScaling::Enabled(2),
75 TileSizeSelection::MaxTileSize => StageScaling::Disabled,
76 },
77 partition: match args.tile_size {
78 TileSizeSelection::MinTileSize => PartitionScaling::Disabled,
79 TileSizeSelection::MaxTileSize => PartitionScaling::Enabled,
80 },
81 },
82 ))
83 }
84
85 fn select_plane_dim<R: Runtime>(client: &ComputeClient<R::Server>) -> u32 {
86 client.properties().hardware.plane_size_min
87 }
88}