cubecl_matmul/kernels/layered/algorithm/
double_buffering.rs

1use std::marker::PhantomData;
2
3use cubecl_core::Runtime;
4use cubecl_core::client::ComputeClient;
5
6use crate::components::global::read::sync_partial_cyclic::SyncPartialCyclicLoading;
7use crate::components::global::{
8    PlaneWriterFamily, read::sync_partial_tilewise::SyncPartialTilewiseLoading,
9};
10use crate::components::stage::{ColMajorTilingOrder, PlaneMatmulFamily, RowMajorTilingOrder};
11use crate::components::{MatmulElems, MatmulLineSizes, MatmulSelection, MatmulSetupError};
12use crate::components::{MatmulProblem, MultiRowStrategy, tile};
13use crate::components::{
14    batch::{PartitionedBatchMatmulFamily, RowMajorGlobalPartitionMatmul},
15    tile::io::{Filled, Strided},
16};
17use crate::components::{
18    global::multi_stage::double_buffering::DoubleBufferingMatmulFamily,
19    stage::{FilledStageFamily, StridedStageFamily},
20};
21use crate::kernels::layered::Algorithm;
22use crate::kernels::layered::algorithm::base;
23use crate::kernels::layered::selector::{PlaneMatmulSelectionOptions, plane_matmul_selection};
24
25/// Plane accelerated double buffered matmul with cyclic readers
26pub struct CyclicDoubleBufferingAlgorithm<TMM> {
27    pub _phantom: PhantomData<TMM>,
28}
29
30/// Plane accelerated double buffered matmul with tilewise readers
31pub struct TilewiseDoubleBufferingAlgorithm<TMM> {
32    pub _phantom: PhantomData<TMM>,
33}
34
35/// Plane accelerated double buffered matmul with tilewise reader on Lhs and cyclic on Rhs
36pub struct HybridDoubleBufferingAlgorithm<TMM> {
37    pub _phantom: PhantomData<TMM>,
38}
39
40#[derive(Default, Debug, Clone, Copy)]
41pub struct DoubleBufferingArgs {
42    pub specialized: bool,
43}
44
45impl<TMM> base::Algorithm for CyclicDoubleBufferingAlgorithm<TMM>
46where
47    TMM: tile::TileMatmulFamily<
48            LhsTile = Strided,
49            RhsTile = Strided,
50            AccTile = Filled,
51            OutTile = Strided,
52        >,
53{
54    type SelectionArgs = DoubleBufferingArgs;
55    type TileMatmul = TMM;
56    type StageMatmul = PlaneMatmulFamily<
57        Self::TileMatmul,
58        StridedStageFamily,
59        StridedStageFamily,
60        FilledStageFamily,
61    >;
62    type GlobalMatmul = DoubleBufferingMatmulFamily<
63        Self::StageMatmul,
64        SyncPartialCyclicLoading<RowMajorTilingOrder>,
65        SyncPartialCyclicLoading<RowMajorTilingOrder>,
66        PlaneWriterFamily,
67    >;
68    type BatchMatmul =
69        PartitionedBatchMatmulFamily<Self::GlobalMatmul, RowMajorGlobalPartitionMatmul>;
70
71    fn selection<R: Runtime>(
72        client: &ComputeClient<R::Server>,
73        problem: &MatmulProblem,
74        plane_dim: u32,
75        _line_sizes: &MatmulLineSizes,
76        elems: MatmulElems,
77        args: &Self::SelectionArgs,
78    ) -> Result<MatmulSelection, MatmulSetupError> {
79        plane_matmul_selection::<TMM, R>(
80            client,
81            problem,
82            plane_dim,
83            elems,
84            PlaneMatmulSelectionOptions {
85                specialized: args.specialized,
86                multi_row_strategy: MultiRowStrategy::Adaptive {
87                    minimum_stage_count: 8,
88                },
89                ..Default::default()
90            },
91        )
92    }
93}
94
95impl<TMM> Algorithm for TilewiseDoubleBufferingAlgorithm<TMM>
96where
97    TMM: tile::TileMatmulFamily<
98            LhsTile = Strided,
99            RhsTile = Strided,
100            AccTile = Filled,
101            OutTile = Strided,
102        >,
103{
104    type SelectionArgs = DoubleBufferingArgs;
105    type TileMatmul = TMM;
106    type StageMatmul = PlaneMatmulFamily<
107        Self::TileMatmul,
108        StridedStageFamily,
109        StridedStageFamily,
110        FilledStageFamily,
111    >;
112    type GlobalMatmul = DoubleBufferingMatmulFamily<
113        Self::StageMatmul,
114        // Other tiling orders are not supported
115        SyncPartialTilewiseLoading<RowMajorTilingOrder>,
116        SyncPartialTilewiseLoading<ColMajorTilingOrder>,
117        PlaneWriterFamily,
118    >;
119    type BatchMatmul =
120        PartitionedBatchMatmulFamily<Self::GlobalMatmul, RowMajorGlobalPartitionMatmul>;
121
122    fn selection<R: Runtime>(
123        client: &ComputeClient<R::Server>,
124        problem: &MatmulProblem,
125        plane_dim: u32,
126        _line_sizes: &MatmulLineSizes,
127        elems: MatmulElems,
128        args: &Self::SelectionArgs,
129    ) -> Result<MatmulSelection, MatmulSetupError> {
130        plane_matmul_selection::<TMM, R>(
131            client,
132            problem,
133            plane_dim,
134            elems,
135            PlaneMatmulSelectionOptions {
136                specialized: args.specialized,
137                multi_row_strategy: MultiRowStrategy::Adaptive {
138                    minimum_stage_count: 8,
139                },
140                ..Default::default()
141            },
142        )
143    }
144}
145
146impl<TMM> base::Algorithm for HybridDoubleBufferingAlgorithm<TMM>
147where
148    TMM: tile::TileMatmulFamily<
149            LhsTile = Strided,
150            RhsTile = Strided,
151            AccTile = Filled,
152            OutTile = Strided,
153        >,
154{
155    type SelectionArgs = DoubleBufferingArgs;
156    type TileMatmul = TMM;
157    type StageMatmul = PlaneMatmulFamily<
158        Self::TileMatmul,
159        StridedStageFamily,
160        StridedStageFamily,
161        FilledStageFamily,
162    >;
163    type GlobalMatmul = DoubleBufferingMatmulFamily<
164        Self::StageMatmul,
165        SyncPartialTilewiseLoading<RowMajorTilingOrder>,
166        SyncPartialCyclicLoading<RowMajorTilingOrder>,
167        PlaneWriterFamily,
168    >;
169    type BatchMatmul =
170        PartitionedBatchMatmulFamily<Self::GlobalMatmul, RowMajorGlobalPartitionMatmul>;
171
172    fn selection<R: Runtime>(
173        client: &ComputeClient<R::Server>,
174        problem: &MatmulProblem,
175        plane_dim: u32,
176        _line_sizes: &MatmulLineSizes,
177        elems: MatmulElems,
178        args: &Self::SelectionArgs,
179    ) -> Result<MatmulSelection, MatmulSetupError> {
180        plane_matmul_selection::<TMM, R>(
181            client,
182            problem,
183            plane_dim,
184            elems,
185            PlaneMatmulSelectionOptions {
186                specialized: args.specialized,
187                multi_row_strategy: MultiRowStrategy::Adaptive {
188                    minimum_stage_count: 8,
189                },
190                ..Default::default()
191            },
192        )
193    }
194}