cubecl_matmul/components/batch/partitioned_matmul/partition/
matmul.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use crate::components::{
5 AccG, LhsG, MatmulPrecision, RhsG,
6 batch::SliceIndex,
7 global::{self, GlobalConfig},
8};
9use cubecl_std::{
10 CubeOption, CubeOptionExpand,
11 tensor::{View, layout::Coords3d},
12};
13
14#[derive(CubeType)]
15pub struct PartitionRanges {
17 row: PartitionRangeDim,
18 col: PartitionRangeDim,
19 batch: PartitionRangeDim,
20}
21
22#[derive(CubeType)]
23pub struct PartitionRangeDim {
24 start: u32,
25 #[cube(comptime)]
26 step: u32,
27 #[cube(comptime)]
28 num_steps: u32,
29}
30
31#[cube]
32pub trait GlobalPartitionMatmul: 'static + Send + Sync {
34 fn execute<MP: MatmulPrecision, GMM: global::GlobalMatmul<MP>>(
35 a: View<Line<LhsG<MP>>, Coords3d>,
36 b: View<Line<RhsG<MP>>, Coords3d>,
37 c: CubeOption<View<Line<AccG<MP>>, Coords3d>>,
38 out: View<Line<AccG<MP>>, Coords3d, ReadWrite>,
39 partition_ranges: PartitionRanges,
40 acc: GMM::Accumulators,
41 k_range: (u32, u32),
42 #[comptime] config: GMM::Config,
43 );
44}
45
46#[derive(CubeType)]
47pub struct RowMajorGlobalPartitionMatmul {}
49
50#[derive(CubeType)]
51pub struct ColMajorGlobalPartitionMatmul {}
53
54#[cube]
55impl PartitionRanges {
56 pub fn new(
58 row: PartitionRangeDim,
59 col: PartitionRangeDim,
60 batch: PartitionRangeDim,
61 ) -> PartitionRanges {
62 PartitionRanges { row, col, batch }
63 }
64}
65
66#[cube]
67impl PartitionRangeDim {
68 pub fn new(
70 cube_pos: u32,
71 #[comptime] stage_dim: u32,
72 #[comptime] global_partition_dim: u32,
73 ) -> PartitionRangeDim {
74 let start = cube_pos * global_partition_dim;
75 PartitionRangeDim {
76 start,
77 step: stage_dim,
78 num_steps: global_partition_dim.div_ceil(stage_dim),
79 }
80 }
81}
82
83#[cube]
84impl GlobalPartitionMatmul for RowMajorGlobalPartitionMatmul {
85 fn execute<MP: MatmulPrecision, GMM: global::GlobalMatmul<MP>>(
86 a: View<Line<LhsG<MP>>, Coords3d>,
87 b: View<Line<RhsG<MP>>, Coords3d>,
88 c: CubeOption<View<Line<AccG<MP>>, Coords3d>>,
89 out: View<Line<AccG<MP>>, Coords3d, ReadWrite>,
90 ranges: PartitionRanges,
91 mut acc: GMM::Accumulators,
92 k_range: (u32, u32),
93 #[comptime] config: GMM::Config,
94 ) {
95 let num_steps_batch = comptime!(ranges.batch.num_steps);
97 let num_steps_row = comptime!(ranges.row.num_steps);
98 let num_steps_col = comptime!(ranges.col.num_steps);
99
100 #[unroll(num_steps_batch == 1)]
101 for batch in 0..num_steps_batch {
102 let batch_iter = ranges.batch.start + batch * ranges.batch.step;
103
104 #[unroll(num_steps_row == 1)]
105 for row in 0..num_steps_row {
106 let row_offset = ranges.row.start + row * ranges.row.step;
107
108 #[unroll(num_steps_col == 1)]
109 for col in 0..num_steps_col {
110 let col_offset = ranges.col.start + col * ranges.col.step;
111
112 execute_global_matmul::<MP, GMM>(
113 a, b, c, out, batch_iter, row_offset, col_offset, &mut acc, k_range, config,
114 );
115 }
116 }
117 }
118 }
119}
120
121#[cube]
122impl GlobalPartitionMatmul for ColMajorGlobalPartitionMatmul {
123 fn execute<MP: MatmulPrecision, GMM: global::GlobalMatmul<MP>>(
124 a: View<Line<LhsG<MP>>, Coords3d>,
125 b: View<Line<RhsG<MP>>, Coords3d>,
126 c: CubeOption<View<Line<AccG<MP>>, Coords3d>>,
127 out: View<Line<AccG<MP>>, Coords3d, ReadWrite>,
128 ranges: PartitionRanges,
129 mut acc: GMM::Accumulators,
130 k_range: (u32, u32),
131 #[comptime] config: GMM::Config,
132 ) {
133 let num_steps_batch = comptime!(ranges.batch.num_steps);
135 let num_steps_row = comptime!(ranges.row.num_steps);
136 let num_steps_col = comptime!(ranges.col.num_steps);
137
138 #[unroll(num_steps_batch == 1)]
139 for batch in 0..num_steps_batch {
140 let batch_iter = ranges.batch.start + batch * ranges.batch.step;
141
142 #[unroll(num_steps_col == 1)]
143 for col in 0..num_steps_col {
144 let col_offset = ranges.col.start + col * ranges.col.step;
145
146 #[unroll(num_steps_row == 1)]
147 for row in 0..num_steps_row {
148 let row_offset = ranges.row.start + row * ranges.row.step;
149
150 execute_global_matmul::<MP, GMM>(
151 a, b, c, out, batch_iter, row_offset, col_offset, &mut acc, k_range, config,
152 );
153 }
154 }
155 }
156 }
157}
158
159#[cube]
160pub(crate) fn execute_global_matmul<MP: MatmulPrecision, GMM: global::GlobalMatmul<MP>>(
163 a: View<Line<LhsG<MP>>, Coords3d>,
164 b: View<Line<RhsG<MP>>, Coords3d>,
165 c: CubeOption<View<Line<AccG<MP>>, Coords3d>>,
166 out: View<Line<AccG<MP>>, Coords3d, ReadWrite>,
167 nth_batch: u32,
168 m_offset: u32,
169 n_offset: u32,
170 acc: &mut GMM::Accumulators,
171 k_range: (u32, u32),
172 #[comptime] config: GMM::Config,
173) {
174 let tiling = config.tiling_scheme();
175 let stage_m = tiling.elements_in_stage_m().runtime();
176 let stage_n = tiling.elements_in_stage_n().runtime();
177 let k_size = k_range.1 - k_range.0;
178
179 let a = a.view(SliceIndex::new(nth_batch, a.shape()));
180 let b = b.view(SliceIndex::new(nth_batch, b.shape()));
181 let c = match c {
182 CubeOption::Some(c) => {
183 let c = c.view(SliceIndex::new(nth_batch, c.shape()));
184 CubeOption::new_Some(c.slice_unchecked((m_offset, n_offset), (stage_m, stage_n)))
185 }
186 CubeOption::None => CubeOption::new_None(),
187 };
188 let out = out.view_mut(SliceIndex::new(nth_batch, out.shape()));
189
190 GMM::execute(
191 GMM::init_lhs_global_reader(
192 a.slice_unchecked((m_offset, k_range.0), (stage_m, k_size)),
193 config,
194 ),
195 GMM::init_rhs_global_reader(
196 b.slice_unchecked((k_range.0, n_offset), (k_size, stage_n)),
197 config,
198 ),
199 GMM::init_acc_global_reader(c, config),
200 GMM::init_global_writer(
201 out.slice_mut_unchecked((m_offset, n_offset), (stage_m, stage_n)),
202 config,
203 ),
204 acc,
205 k_range,
206 config,
207 );
208}