cubecl_matmul/kernels/layered/algorithm/
double_unit.rs

1use cubecl_core::{Runtime, client::ComputeClient};
2
3use crate::{
4    components::{
5        MatmulElems, MatmulLineSizes, MatmulProblem, MatmulSelection, MatmulSetupError,
6        batch::{PartitionedBatchMatmulFamily, RowMajorGlobalPartitionMatmul},
7        global::{
8            UnitWriterFamily, multi_stage::double_buffering::DoubleBufferingMatmulFamily,
9            read::sync_partial_cyclic::SyncPartialCyclicLoading,
10        },
11        stage::{FilledStageFamily, RowMajorTilingOrder, StridedStageFamily, UnitMatmulFamily},
12        tile::{io::Filled, register::RegisterMatmul},
13    },
14    kernels::layered::{
15        Algorithm,
16        selector::{TileSizeSelection, UnitMatmulSelectionOptions, unit_matmul_selection},
17    },
18};
19
20/// Unit double buffered matmul with cyclic readers
21pub struct DoubleUnitAlgorithm {}
22
23#[derive(Default, Clone, Debug)]
24pub struct DoubleUnitSelectionArgs {
25    pub tile_size: TileSizeSelection,
26}
27
28impl Algorithm for DoubleUnitAlgorithm {
29    type SelectionArgs = DoubleUnitSelectionArgs;
30    type TileMatmul = RegisterMatmul<Filled>;
31    type StageMatmul = UnitMatmulFamily<Self::TileMatmul, StridedStageFamily, FilledStageFamily>;
32    type GlobalMatmul = DoubleBufferingMatmulFamily<
33        Self::StageMatmul,
34        SyncPartialCyclicLoading<RowMajorTilingOrder>,
35        SyncPartialCyclicLoading<RowMajorTilingOrder>,
36        UnitWriterFamily,
37    >;
38    type BatchMatmul =
39        PartitionedBatchMatmulFamily<Self::GlobalMatmul, RowMajorGlobalPartitionMatmul>;
40
41    fn selection<R: Runtime>(
42        client: &ComputeClient<R::Server>,
43        problem: &MatmulProblem,
44        plane_dim: u32,
45        line_sizes: &MatmulLineSizes,
46        _elems: MatmulElems,
47        args: &Self::SelectionArgs,
48    ) -> Result<MatmulSelection, MatmulSetupError> {
49        Ok(unit_matmul_selection::<R>(
50            client,
51            problem,
52            plane_dim,
53            true,
54            line_sizes,
55            UnitMatmulSelectionOptions {
56                tile: args.tile_size,
57                ..Default::default()
58            },
59        ))
60    }
61
62    fn select_plane_dim<R: Runtime>(client: &ComputeClient<R::Server>) -> u32 {
63        client.properties().hardware.plane_size_min
64    }
65}