cubecl_matmul/kernels/layered/algorithm/
double_buffering.rs

1use std::marker::PhantomData;
2
3use cubecl_core::Runtime;
4use cubecl_core::client::ComputeClient;
5use cubecl_core::ir::Elem;
6
7use crate::components::MatmulSelection;
8use crate::components::batch::{PartitionedBatchMatmulFamily, RowMajorGlobalPartitionMatmul};
9use crate::components::global::load::sync_partial_cyclic::SyncPartialCyclicLoading;
10use crate::components::global::load::sync_partial_tilewise::SyncPartialTilewiseLoading;
11use crate::components::global::multi_stage::double_buffering::DoubleBufferingMatmulFamily;
12use crate::components::stage::{
13    ColMajorTilingOrder, PartialReaderFamily, PlaneMatmulFamily, RowMajorTilingOrder,
14};
15use crate::components::{MatmulProblem, MultiRowStrategy, tile};
16use crate::kernels::layered::Algorithm;
17use crate::kernels::layered::algorithm::base;
18use crate::kernels::layered::selector::{PlaneMatmulSelectionOptions, plane_matmul_selection};
19
20/// Plane accelerated double buffered matmul with cyclic loaders
21pub struct CyclicDoubleBufferingAlgorithm<TMM> {
22    pub _phantom: PhantomData<TMM>,
23}
24
25/// Plane accelerated double buffered matmul with tilewise loaders
26pub struct TilewiseDoubleBufferingAlgorithm<TMM> {
27    pub _phantom: PhantomData<TMM>,
28}
29
30/// Plane accelerated double buffered matmul with tilewise loader on Lhs and cyclic on Rhs
31pub struct HybridDoubleBufferingAlgorithm<TMM> {
32    pub _phantom: PhantomData<TMM>,
33}
34
35#[derive(Default, Debug, Clone, Copy)]
36pub struct DoubleBufferingArgs {
37    pub specialized: bool,
38}
39
40impl<TMM> base::Algorithm for CyclicDoubleBufferingAlgorithm<TMM>
41where
42    TMM: tile::TileMatmulFamily,
43{
44    type SelectionArgs = DoubleBufferingArgs;
45    type TileMatmul = TMM;
46    type StageMatmul =
47        PlaneMatmulFamily<Self::TileMatmul, PartialReaderFamily, PartialReaderFamily>;
48    type GlobalMatmul = DoubleBufferingMatmulFamily<
49        Self::StageMatmul,
50        SyncPartialCyclicLoading<RowMajorTilingOrder>,
51        SyncPartialCyclicLoading<RowMajorTilingOrder>,
52    >;
53    type BatchMatmul =
54        PartitionedBatchMatmulFamily<Self::GlobalMatmul, RowMajorGlobalPartitionMatmul>;
55
56    fn selection<R: Runtime>(
57        client: &ComputeClient<R::Server, R::Channel>,
58        problem: &MatmulProblem,
59        plane_dim: u32,
60        elem_stage: Elem,
61        elem_acc: Elem,
62        args: &Self::SelectionArgs,
63    ) -> MatmulSelection {
64        plane_matmul_selection::<TMM, R>(
65            client,
66            problem,
67            plane_dim,
68            elem_stage,
69            elem_acc,
70            PlaneMatmulSelectionOptions {
71                specialized: args.specialized,
72                multi_row_strategy: MultiRowStrategy::Adaptive {
73                    minimum_stage_count: 8,
74                },
75                ..Default::default()
76            },
77        )
78    }
79}
80
81impl<TMM> Algorithm for TilewiseDoubleBufferingAlgorithm<TMM>
82where
83    TMM: tile::TileMatmulFamily,
84{
85    type SelectionArgs = DoubleBufferingArgs;
86    type TileMatmul = TMM;
87    type StageMatmul =
88        PlaneMatmulFamily<Self::TileMatmul, PartialReaderFamily, PartialReaderFamily>;
89    type GlobalMatmul = DoubleBufferingMatmulFamily<
90        Self::StageMatmul,
91        // Other tiling orders are not supported
92        SyncPartialTilewiseLoading<RowMajorTilingOrder>,
93        SyncPartialTilewiseLoading<ColMajorTilingOrder>,
94    >;
95    type BatchMatmul =
96        PartitionedBatchMatmulFamily<Self::GlobalMatmul, RowMajorGlobalPartitionMatmul>;
97
98    fn selection<R: Runtime>(
99        client: &ComputeClient<R::Server, R::Channel>,
100        problem: &MatmulProblem,
101        plane_dim: u32,
102        elem_stage: Elem,
103        elem_acc: Elem,
104        args: &Self::SelectionArgs,
105    ) -> MatmulSelection {
106        plane_matmul_selection::<TMM, R>(
107            client,
108            problem,
109            plane_dim,
110            elem_stage,
111            elem_acc,
112            PlaneMatmulSelectionOptions {
113                specialized: args.specialized,
114                multi_row_strategy: MultiRowStrategy::Adaptive {
115                    minimum_stage_count: 8,
116                },
117                ..Default::default()
118            },
119        )
120    }
121}
122
123impl<TMM> base::Algorithm for HybridDoubleBufferingAlgorithm<TMM>
124where
125    TMM: tile::TileMatmulFamily,
126{
127    type SelectionArgs = DoubleBufferingArgs;
128    type TileMatmul = TMM;
129    type StageMatmul =
130        PlaneMatmulFamily<Self::TileMatmul, PartialReaderFamily, PartialReaderFamily>;
131    type GlobalMatmul = DoubleBufferingMatmulFamily<
132        Self::StageMatmul,
133        SyncPartialTilewiseLoading<RowMajorTilingOrder>,
134        SyncPartialCyclicLoading<RowMajorTilingOrder>,
135    >;
136    type BatchMatmul =
137        PartitionedBatchMatmulFamily<Self::GlobalMatmul, RowMajorGlobalPartitionMatmul>;
138
139    fn selection<R: Runtime>(
140        client: &ComputeClient<R::Server, R::Channel>,
141        problem: &MatmulProblem,
142        plane_dim: u32,
143        elem_stage: Elem,
144        elem_acc: Elem,
145        args: &Self::SelectionArgs,
146    ) -> MatmulSelection {
147        plane_matmul_selection::<TMM, R>(
148            client,
149            problem,
150            plane_dim,
151            elem_stage,
152            elem_acc,
153            PlaneMatmulSelectionOptions {
154                specialized: args.specialized,
155                multi_row_strategy: MultiRowStrategy::Adaptive {
156                    minimum_stage_count: 8,
157                },
158                ..Default::default()
159            },
160        )
161    }
162}