cubecl_matmul/kernels/layered/algorithm/
double_unit.rs1use cubecl_core::{Runtime, client::ComputeClient, ir::Elem};
2
3use crate::{
4 components::{
5 MatmulProblem, MatmulSelection,
6 batch::{PartitionedBatchMatmulFamily, RowMajorGlobalPartitionMatmul},
7 global::{
8 load::sync_partial_cyclic::SyncPartialCyclicLoading,
9 multi_stage::double_buffering::DoubleBufferingMatmulFamily,
10 },
11 stage::{PartialReaderFamily, RowMajorTilingOrder, UnitMatmulFamily},
12 tile::register::RegisterMatmul,
13 },
14 kernels::layered::{
15 Algorithm,
16 selector::{TileSizeSelection, UnitMatmulSelectionOptions, unit_matmul_selection},
17 },
18};
19
20pub struct DoubleUnitAlgorithm {}
22
23#[derive(Default, Clone, Debug)]
24pub struct DoubleUnitSelectionArgs {
25 pub tile_size: TileSizeSelection,
26}
27
28impl Algorithm for DoubleUnitAlgorithm {
29 type SelectionArgs = DoubleUnitSelectionArgs;
30 type TileMatmul = RegisterMatmul;
31 type StageMatmul = UnitMatmulFamily<Self::TileMatmul, PartialReaderFamily>;
32 type GlobalMatmul = DoubleBufferingMatmulFamily<
33 Self::StageMatmul,
34 SyncPartialCyclicLoading<RowMajorTilingOrder>,
35 SyncPartialCyclicLoading<RowMajorTilingOrder>,
36 >;
37 type BatchMatmul =
38 PartitionedBatchMatmulFamily<Self::GlobalMatmul, RowMajorGlobalPartitionMatmul>;
39
40 fn selection<R: Runtime>(
41 client: &ComputeClient<R::Server, R::Channel>,
42 problem: &MatmulProblem,
43 plane_dim: u32,
44 _elem_stage: Elem,
45 _elem_acc: Elem,
46 args: &Self::SelectionArgs,
47 ) -> MatmulSelection {
48 unit_matmul_selection::<R>(
49 client,
50 problem,
51 plane_dim,
52 true,
53 UnitMatmulSelectionOptions {
54 tile: args.tile_size,
55 ..Default::default()
56 },
57 )
58 }
59
60 fn select_plane_dim<R: Runtime>(client: &ComputeClient<R::Server, R::Channel>) -> u32 {
61 client.properties().hardware.plane_size_min
62 }
63}