cubecl_matmul/kernels/layered/algorithm/
vecmat.rs1use cubecl_core::{Runtime, client::ComputeClient};
2
3use crate::{
4 components::{
5 MatmulElems, MatmulLineSizes, MatmulProblem, MatmulSelection, MatmulSetupError,
6 PartitionSize, TileSize, TilingScheme,
7 batch::{
8 CubeCountPlanSelection, GlobalOrderSelection, HypercubeSelection,
9 PartitionedBatchMatmulFamily, RowMajorGlobalPartitionMatmul, SmAllocation,
10 },
11 global::{
12 PlaneWriterFamily,
13 multi_stage::double_buffering::DoubleBufferingMatmulFamily,
14 read::{
15 sync_full_cyclic::SyncFullCyclicLoading,
16 sync_partial_cyclic::SyncPartialCyclicLoading,
17 },
18 single_stage::simple::SimpleMatmulFamily,
19 },
20 stage::{
21 ColMajorTilingOrder, FilledStageFamily, PartitionBuffering, PlaneMatmulFamily,
22 RowMajorTilingOrder, StridedStageFamily,
23 },
24 tile::{io::Filled, plane_vec_mat_inner_product::PlaneVecMatInnerProduct},
25 },
26 kernels::layered::Algorithm,
27};
28
29pub struct SimpleVecMatAlgorithm {}
30
31impl Algorithm for SimpleVecMatAlgorithm {
32 type SelectionArgs = ();
33 type TileMatmul = PlaneVecMatInnerProduct<Filled>;
34 type StageMatmul = PlaneMatmulFamily<
35 Self::TileMatmul,
36 StridedStageFamily,
37 StridedStageFamily,
38 FilledStageFamily,
39 >;
40 type GlobalMatmul = SimpleMatmulFamily<
41 Self::StageMatmul,
42 SyncFullCyclicLoading<RowMajorTilingOrder>,
43 SyncFullCyclicLoading<ColMajorTilingOrder>,
44 PlaneWriterFamily,
45 >;
46 type BatchMatmul =
47 PartitionedBatchMatmulFamily<Self::GlobalMatmul, RowMajorGlobalPartitionMatmul>;
48
49 fn selection<R: Runtime>(
50 client: &ComputeClient<R::Server>,
51 problem: &MatmulProblem,
52 plane_dim: u32,
53 line_sizes: &MatmulLineSizes,
54 _elems: MatmulElems,
55 _args: &Self::SelectionArgs,
56 ) -> Result<MatmulSelection, MatmulSetupError> {
57 Ok(selection_vecmat::<R>(
58 client,
59 problem,
60 (1, line_sizes.out as u32, plane_dim * line_sizes.lhs as u32).into(),
61 plane_dim,
62 ))
63 }
64}
65
66pub struct DoubleVecMatAlgorithm {}
67
68impl Algorithm for DoubleVecMatAlgorithm {
69 type SelectionArgs = ();
70 type TileMatmul = PlaneVecMatInnerProduct<Filled>;
71 type StageMatmul = PlaneMatmulFamily<
72 Self::TileMatmul,
73 StridedStageFamily,
74 StridedStageFamily,
75 FilledStageFamily,
76 >;
77 type GlobalMatmul = DoubleBufferingMatmulFamily<
78 Self::StageMatmul,
79 SyncPartialCyclicLoading<RowMajorTilingOrder>,
80 SyncPartialCyclicLoading<ColMajorTilingOrder>,
81 PlaneWriterFamily,
82 >;
83 type BatchMatmul =
84 PartitionedBatchMatmulFamily<Self::GlobalMatmul, RowMajorGlobalPartitionMatmul>;
85
86 fn selection<R: Runtime>(
87 client: &ComputeClient<R::Server>,
88 problem: &MatmulProblem,
89 plane_dim: u32,
90 line_sizes: &MatmulLineSizes,
91 _elems: MatmulElems,
92 _args: &Self::SelectionArgs,
93 ) -> Result<MatmulSelection, MatmulSetupError> {
94 Ok(selection_vecmat::<R>(
95 client,
96 problem,
97 (1, line_sizes.out as u32, plane_dim * line_sizes.lhs as u32).into(),
98 plane_dim,
99 ))
100 }
101}
102
103fn selection_vecmat<R: Runtime>(
104 client: &ComputeClient<R::Server>,
105 problem: &MatmulProblem,
106 tile_size: TileSize,
107 plane_dim: u32,
108) -> MatmulSelection {
109 let tiling_scheme = TilingScheme::builder()
110 .with_tile_size(tile_size)
111 .with_partition_size(PartitionSize::new(1, 1, 1))
112 .with_stage_size((1, 1, 1).into())
113 .build()
114 .unwrap();
115 let cube_count_plan = match client.properties().hardware.num_streaming_multiprocessors {
116 Some(num_sms) => CubeCountPlanSelection::Sm {
117 num_sms,
118 sm_usage: SmAllocation::Exact,
119 cubes_first: true,
120 },
121 None => CubeCountPlanSelection::FromProblem,
122 };
123
124 let hypercube = HypercubeSelection::builder(&tiling_scheme)
125 .global_order(GlobalOrderSelection::SwizzleRow {
126 m: problem.m as u32,
127 w: 2,
128 })
129 .cube_count_plan(cube_count_plan)
130 .build();
131
132 MatmulSelection::builder(tiling_scheme, plane_dim)
133 .partition_buffering(PartitionBuffering::Single)
134 .hypercube_config(hypercube)
135 .build()
136}