cubecl_matmul/kernels/layered/algorithm/
ordered_double_buffering.rs1use std::marker::PhantomData;
2
3use cubecl_core::Runtime;
4use cubecl_core::client::ComputeClient;
5use cubecl_core::ir::Elem;
6
7use crate::components::batch::{PartitionedBatchMatmulFamily, RowMajorGlobalPartitionMatmul};
8use crate::components::global::load::sync_partial_cyclic::SyncPartialCyclicLoading;
9use crate::components::global::multi_stage::ordered::OrderedDoubleBufferingMatmulFamily;
10use crate::components::stage::{
11 FullReaderFamily, PartialReaderFamily, PlaneMatmulFamily, RowMajorTilingOrder,
12};
13use crate::components::{MatmulProblem, MatmulSelection};
14use crate::components::{MultiRowStrategy, tile};
15use crate::kernels::layered::Algorithm;
16use crate::kernels::layered::selector::{PlaneMatmulSelectionOptions, plane_matmul_selection};
17
18pub struct OrderedDoubleBufferingAlgorithm<TMM> {
20 pub _phantom: PhantomData<TMM>,
21}
22
23#[derive(Debug, Clone, Default)]
24pub struct OrderedSelectionArgs {
25 pub partition_k: Option<u32>,
26 pub row_count: Option<u32>,
27 pub rows_per_plane: Option<u32>,
28}
29
30impl<TMM> Algorithm for OrderedDoubleBufferingAlgorithm<TMM>
31where
32 TMM: tile::TileMatmulFamily,
33{
34 type SelectionArgs = OrderedSelectionArgs;
35 type TileMatmul = TMM;
36 type StageMatmul = PlaneMatmulFamily<Self::TileMatmul, FullReaderFamily, PartialReaderFamily>;
37 type GlobalMatmul = OrderedDoubleBufferingMatmulFamily<
38 Self::StageMatmul,
39 SyncPartialCyclicLoading<RowMajorTilingOrder>,
40 >;
41 type BatchMatmul =
42 PartitionedBatchMatmulFamily<Self::GlobalMatmul, RowMajorGlobalPartitionMatmul>;
43
44 fn selection<R: Runtime>(
45 client: &ComputeClient<R::Server, R::Channel>,
46 problem: &MatmulProblem,
47 plane_dim: u32,
48 elem_stage: Elem,
49 elem_acc: Elem,
50 args: &Self::SelectionArgs,
51 ) -> MatmulSelection {
52 plane_matmul_selection::<TMM, R>(
53 client,
54 problem,
55 plane_dim,
56 elem_stage,
57 elem_acc,
58 PlaneMatmulSelectionOptions {
59 partition_k: args.partition_k,
60 row_count: args.row_count,
61 multi_row_strategy: args
62 .rows_per_plane
63 .map(MultiRowStrategy::Always)
64 .unwrap_or_else(|| MultiRowStrategy::Adaptive {
65 minimum_stage_count: 8,
66 }),
67 ..Default::default()
68 },
69 )
70 }
71}