cubecl_matmul/kernels/layered/algorithm/
ordered_double_buffering.rs

1use std::marker::PhantomData;
2
3use cubecl_core::Runtime;
4use cubecl_core::client::ComputeClient;
5
6use crate::components::stage::{PlaneMatmulFamily, RowMajorTilingOrder};
7use crate::components::{
8    MatmulElems, MatmulLineSizes, MatmulProblem, MatmulSelection, MatmulSetupError,
9    global::PlaneWriterFamily,
10};
11use crate::components::{MultiRowStrategy, tile};
12use crate::components::{
13    batch::{PartitionedBatchMatmulFamily, RowMajorGlobalPartitionMatmul},
14    stage::{FilledStageFamily, StridedStageFamily},
15};
16use crate::components::{
17    global::multi_stage::ordered::OrderedDoubleBufferingMatmulFamily, tile::io::Filled,
18};
19use crate::components::{
20    global::read::sync_partial_cyclic::SyncPartialCyclicLoading, tile::io::Strided,
21};
22use crate::kernels::layered::Algorithm;
23use crate::kernels::layered::selector::{PlaneMatmulSelectionOptions, plane_matmul_selection};
24
25/// Plane accelerated double buffered matmul ordered on Lhs with cyclic reader on Rhs
26pub struct OrderedDoubleBufferingAlgorithm<TMM> {
27    pub _phantom: PhantomData<TMM>,
28}
29
30#[derive(Debug, Clone, Default)]
31pub struct OrderedSelectionArgs {
32    pub partition_k: Option<u32>,
33    pub row_count: Option<u32>,
34    pub rows_per_plane: Option<u32>,
35}
36
37impl<TMM> Algorithm for OrderedDoubleBufferingAlgorithm<TMM>
38where
39    TMM: tile::TileMatmulFamily<
40            LhsTile = Strided,
41            RhsTile = Strided,
42            AccTile = Filled,
43            OutTile = Strided,
44        >,
45{
46    type SelectionArgs = OrderedSelectionArgs;
47    type TileMatmul = TMM;
48    type StageMatmul = PlaneMatmulFamily<
49        Self::TileMatmul,
50        StridedStageFamily,
51        StridedStageFamily,
52        FilledStageFamily,
53    >;
54    type GlobalMatmul = OrderedDoubleBufferingMatmulFamily<
55        Self::StageMatmul,
56        SyncPartialCyclicLoading<RowMajorTilingOrder>,
57        PlaneWriterFamily,
58    >;
59    type BatchMatmul =
60        PartitionedBatchMatmulFamily<Self::GlobalMatmul, RowMajorGlobalPartitionMatmul>;
61
62    fn selection<R: Runtime>(
63        client: &ComputeClient<R::Server>,
64        problem: &MatmulProblem,
65        plane_dim: u32,
66        _line_sizes: &MatmulLineSizes,
67        elems: MatmulElems,
68        args: &Self::SelectionArgs,
69    ) -> Result<MatmulSelection, MatmulSetupError> {
70        plane_matmul_selection::<TMM, R>(
71            client,
72            problem,
73            plane_dim,
74            elems,
75            PlaneMatmulSelectionOptions {
76                partition_k: args.partition_k,
77                row_count: args.row_count,
78                multi_row_strategy: args
79                    .rows_per_plane
80                    .map(MultiRowStrategy::Always)
81                    .unwrap_or_else(|| MultiRowStrategy::Adaptive {
82                        minimum_stage_count: 8,
83                    }),
84                ..Default::default()
85            },
86        )
87    }
88}