cubecl_matmul/kernels/layered/algorithm/
ordered_double_buffering.rs

1use std::marker::PhantomData;
2
3use cubecl_core::Runtime;
4use cubecl_core::client::ComputeClient;
5use cubecl_core::ir::Elem;
6
7use crate::components::batch::{PartitionedBatchMatmulFamily, RowMajorGlobalPartitionMatmul};
8use crate::components::global::load::sync_partial_cyclic::SyncPartialCyclicLoading;
9use crate::components::global::multi_stage::ordered::OrderedDoubleBufferingMatmulFamily;
10use crate::components::stage::{
11    FullReaderFamily, PartialReaderFamily, PlaneMatmulFamily, RowMajorTilingOrder,
12};
13use crate::components::{MatmulProblem, MatmulSelection};
14use crate::components::{MultiRowStrategy, tile};
15use crate::kernels::layered::Algorithm;
16use crate::kernels::layered::selector::{PlaneMatmulSelectionOptions, plane_matmul_selection};
17
18/// Plane accelerated double buffered matmul ordered on Lhs with cyclic loader on Rhs
19pub struct OrderedDoubleBufferingAlgorithm<TMM> {
20    pub _phantom: PhantomData<TMM>,
21}
22
23#[derive(Debug, Clone, Default)]
24pub struct OrderedSelectionArgs {
25    pub partition_k: Option<u32>,
26    pub row_count: Option<u32>,
27    pub rows_per_plane: Option<u32>,
28}
29
30impl<TMM> Algorithm for OrderedDoubleBufferingAlgorithm<TMM>
31where
32    TMM: tile::TileMatmulFamily,
33{
34    type SelectionArgs = OrderedSelectionArgs;
35    type TileMatmul = TMM;
36    type StageMatmul = PlaneMatmulFamily<Self::TileMatmul, FullReaderFamily, PartialReaderFamily>;
37    type GlobalMatmul = OrderedDoubleBufferingMatmulFamily<
38        Self::StageMatmul,
39        SyncPartialCyclicLoading<RowMajorTilingOrder>,
40    >;
41    type BatchMatmul =
42        PartitionedBatchMatmulFamily<Self::GlobalMatmul, RowMajorGlobalPartitionMatmul>;
43
44    fn selection<R: Runtime>(
45        client: &ComputeClient<R::Server, R::Channel>,
46        problem: &MatmulProblem,
47        plane_dim: u32,
48        elem_stage: Elem,
49        elem_acc: Elem,
50        args: &Self::SelectionArgs,
51    ) -> MatmulSelection {
52        plane_matmul_selection::<TMM, R>(
53            client,
54            problem,
55            plane_dim,
56            elem_stage,
57            elem_acc,
58            PlaneMatmulSelectionOptions {
59                partition_k: args.partition_k,
60                row_count: args.row_count,
61                multi_row_strategy: args
62                    .rows_per_plane
63                    .map(MultiRowStrategy::Always)
64                    .unwrap_or_else(|| MultiRowStrategy::Adaptive {
65                        minimum_stage_count: 8,
66                    }),
67                ..Default::default()
68            },
69        )
70    }
71}