cubecl_matmul/kernels/layered/algorithm/
double_unit.rs1use cubecl_core::{Runtime, client::ComputeClient};
2
3use crate::{
4 components::{
5 MatmulElems, MatmulLineSizes, MatmulProblem, MatmulSelection, MatmulSetupError,
6 batch::{PartitionedBatchMatmulFamily, RowMajorGlobalPartitionMatmul},
7 global::{
8 UnitWriterFamily, multi_stage::double_buffering::DoubleBufferingMatmulFamily,
9 read::sync_partial_cyclic::SyncPartialCyclicLoading,
10 },
11 stage::{FilledStageFamily, RowMajorTilingOrder, StridedStageFamily, UnitMatmulFamily},
12 tile::{io::Filled, register::RegisterMatmul},
13 },
14 kernels::layered::{
15 Algorithm,
16 selector::{TileSizeSelection, UnitMatmulSelectionOptions, unit_matmul_selection},
17 },
18};
19
20pub struct DoubleUnitAlgorithm {}
22
23#[derive(Default, Clone, Debug)]
24pub struct DoubleUnitSelectionArgs {
25 pub tile_size: TileSizeSelection,
26}
27
28impl Algorithm for DoubleUnitAlgorithm {
29 type SelectionArgs = DoubleUnitSelectionArgs;
30 type TileMatmul = RegisterMatmul<Filled>;
31 type StageMatmul = UnitMatmulFamily<Self::TileMatmul, StridedStageFamily, FilledStageFamily>;
32 type GlobalMatmul = DoubleBufferingMatmulFamily<
33 Self::StageMatmul,
34 SyncPartialCyclicLoading<RowMajorTilingOrder>,
35 SyncPartialCyclicLoading<RowMajorTilingOrder>,
36 UnitWriterFamily,
37 >;
38 type BatchMatmul =
39 PartitionedBatchMatmulFamily<Self::GlobalMatmul, RowMajorGlobalPartitionMatmul>;
40
41 fn selection<R: Runtime>(
42 client: &ComputeClient<R::Server>,
43 problem: &MatmulProblem,
44 plane_dim: u32,
45 line_sizes: &MatmulLineSizes,
46 _elems: MatmulElems,
47 args: &Self::SelectionArgs,
48 ) -> Result<MatmulSelection, MatmulSetupError> {
49 Ok(unit_matmul_selection::<R>(
50 client,
51 problem,
52 plane_dim,
53 true,
54 line_sizes,
55 UnitMatmulSelectionOptions {
56 tile: args.tile_size,
57 ..Default::default()
58 },
59 ))
60 }
61
62 fn select_plane_dim<R: Runtime>(client: &ComputeClient<R::Server>) -> u32 {
63 client.properties().hardware.plane_size_min
64 }
65}