cubecl_matmul/kernels/layered/algorithm/
double_unit.rs

1use cubecl_core::{Runtime, client::ComputeClient, ir::Elem};
2
3use crate::{
4    components::{
5        MatmulProblem, MatmulSelection,
6        batch::{PartitionedBatchMatmulFamily, RowMajorGlobalPartitionMatmul},
7        global::{
8            load::sync_partial_cyclic::SyncPartialCyclicLoading,
9            multi_stage::double_buffering::DoubleBufferingMatmulFamily,
10        },
11        stage::{PartialReaderFamily, RowMajorTilingOrder, UnitMatmulFamily},
12        tile::register::RegisterMatmul,
13    },
14    kernels::layered::{
15        Algorithm,
16        selector::{TileSizeSelection, UnitMatmulSelectionOptions, unit_matmul_selection},
17    },
18};
19
20/// Unit double buffered matmul with cyclic loaders
21pub struct DoubleUnitAlgorithm {}
22
23#[derive(Default, Clone, Debug)]
24pub struct DoubleUnitSelectionArgs {
25    pub tile_size: TileSizeSelection,
26}
27
28impl Algorithm for DoubleUnitAlgorithm {
29    type SelectionArgs = DoubleUnitSelectionArgs;
30    type TileMatmul = RegisterMatmul;
31    type StageMatmul = UnitMatmulFamily<Self::TileMatmul, PartialReaderFamily>;
32    type GlobalMatmul = DoubleBufferingMatmulFamily<
33        Self::StageMatmul,
34        SyncPartialCyclicLoading<RowMajorTilingOrder>,
35        SyncPartialCyclicLoading<RowMajorTilingOrder>,
36    >;
37    type BatchMatmul =
38        PartitionedBatchMatmulFamily<Self::GlobalMatmul, RowMajorGlobalPartitionMatmul>;
39
40    fn selection<R: Runtime>(
41        client: &ComputeClient<R::Server, R::Channel>,
42        problem: &MatmulProblem,
43        plane_dim: u32,
44        _elem_stage: Elem,
45        _elem_acc: Elem,
46        args: &Self::SelectionArgs,
47    ) -> MatmulSelection {
48        unit_matmul_selection::<R>(
49            client,
50            problem,
51            plane_dim,
52            true,
53            UnitMatmulSelectionOptions {
54                tile: args.tile_size,
55                ..Default::default()
56            },
57        )
58    }
59
60    fn select_plane_dim<R: Runtime>(client: &ComputeClient<R::Server, R::Channel>) -> u32 {
61        client.properties().hardware.plane_size_min
62    }
63}