1use cubecl::prelude::*;
2use cubecl_common::quant::scheme::{QuantLevel, QuantScheme};
3use cubecl_core::{self as cubecl};
4use cubecl_std::{
5 FastDivmod, FastDivmodArgs,
6 tensor::layout::{Coords1d, Coords2d, Coords3d, Layout, LayoutExpand},
7};
8
9use crate::components::{MatmulProblem, MatrixLayout, global::memory::GlobalMemoryConfig};
10
11#[derive(CubeType, CubeLaunch, Clone, Copy)]
13pub struct SimpleTmaGlobalLayout {
14 #[cube(comptime)]
15 transposed: bool,
16 shape: Coords3d,
17}
18
19#[cube]
20impl SimpleTmaGlobalLayout {
21 pub fn new(shape: Coords3d, #[comptime] layout: MatrixLayout) -> Self {
23 let transposed = comptime![matches!(layout, MatrixLayout::ColMajor)];
24 SimpleTmaGlobalLayout { shape, transposed }
25 }
26}
27
28#[cube]
29impl Layout for SimpleTmaGlobalLayout {
30 type Coordinates = Coords3d;
31 type SourceCoordinates = Coords3d;
32
33 fn to_source_pos(&self, coords: Self::Coordinates) -> Coords3d {
34 let (batch, row, col) = coords;
35 if comptime![self.transposed] {
38 (batch, col, row)
39 } else {
40 (batch, row, col)
41 }
42 }
43
44 fn to_source_pos_checked(&self, coords: Self::Coordinates) -> (Coords3d, bool) {
45 (self.to_source_pos(coords), self.is_in_bounds(coords))
46 }
47
48 fn shape(&self) -> Self::Coordinates {
49 self.shape
50 }
51
52 fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool {
53 true.runtime()
55 }
56}
57
58#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, Default)]
59pub struct GlobalLayoutConfig {
60 pub matrix_layout: MatrixLayout,
61 pub check_row_bounds: bool,
62 pub check_col_bounds: bool,
63}
64
65impl From<GlobalMemoryConfig> for GlobalLayoutConfig {
66 fn from(value: GlobalMemoryConfig) -> Self {
67 GlobalLayoutConfig {
68 matrix_layout: value.matrix_layout,
69 check_row_bounds: value.check_row_bounds,
70 check_col_bounds: value.check_col_bounds,
71 }
72 }
73}
74
75#[derive(CubeType, CubeLaunch, Clone)]
77pub struct BatchedGlobalLayout {
78 batch_shape: Sequence<FastDivmod>,
79 batch_strides: Sequence<u32>,
80
81 rows: u32,
82 cols: u32,
83
84 stride_row: u32,
85 stride_col: u32,
86
87 #[cube(comptime)]
88 line_size: u32,
89 #[cube(comptime)]
90 packing: u32,
91 #[cube(comptime)]
92 config: GlobalLayoutConfig,
93}
94
95#[cube]
96impl BatchedGlobalLayout {
97 #[allow(clippy::too_many_arguments)]
99 pub fn new(
100 batch_strides: Sequence<u32>,
101 batch_shape: Sequence<FastDivmod>,
102 shape_row: u32,
103 shape_col: u32,
104 stride_row: u32,
105 stride_col: u32,
106 #[comptime] line_size: u32,
107 #[comptime] packing: u32,
108 #[comptime] config: GlobalLayoutConfig,
109 ) -> Self {
110 BatchedGlobalLayout {
111 batch_shape,
112 batch_strides,
113 rows: shape_row,
114 cols: shape_col,
115 stride_row,
116 stride_col,
117 line_size,
118 packing,
119 config,
120 }
121 }
122}
123
124#[cube]
125impl Layout for BatchedGlobalLayout {
126 type Coordinates = Coords3d;
127 type SourceCoordinates = Coords1d;
128
129 fn to_source_pos(&self, coords: Self::Coordinates) -> u32 {
130 let line_size = comptime![self.line_size];
131 let (mut batch, row, col) = coords;
132
133 let (row, col) = match comptime![self.config.matrix_layout] {
134 MatrixLayout::RowMajor => (row, col / self.packing),
135 MatrixLayout::ColMajor => (row / self.packing, col),
136 };
137
138 let mut batch_offs = 0;
143 let batch_shape = self.batch_shape.rev();
144 let batch_strides = self.batch_strides.rev();
145
146 #[unroll]
147 for i in 0..batch_shape.len() {
148 let (rem, local_pos) = batch_shape.index(i).div_mod(batch);
149 batch = rem;
150 batch_offs += local_pos * *batch_strides.index(i);
151 }
152
153 let idx = batch_offs + row * self.stride_row + col * self.stride_col;
154
155 idx / line_size
156 }
157
158 fn to_source_pos_checked(&self, coords: Self::Coordinates) -> (u32, bool) {
159 (self.to_source_pos(coords), self.is_in_bounds(coords))
160 }
161
162 fn shape(&self) -> Self::Coordinates {
163 (u32::MAX.runtime(), self.rows, self.cols)
164 }
165
166 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
167 let (_, row, col) = pos;
168
169 match comptime!((self.config.check_row_bounds, self.config.check_col_bounds)) {
170 (true, true) => row < self.rows && col < self.cols,
171 (true, false) => row < self.rows,
172 (false, true) => col < self.cols,
173 (false, false) => true,
174 }
175 }
176}
177
178impl<'a, R: Runtime> BatchedGlobalLayoutLaunch<'a, R> {
179 pub fn from_handle(
180 client: &ComputeClient<R::Server>,
181 handle: &TensorHandleRef<'a, R>,
182 problem: &MatmulProblem,
183 line_size: u8,
184 config: GlobalLayoutConfig,
185 ) -> Self {
186 let rank = handle.shape.len();
187 let rows = handle.shape[rank - 2];
188 let cols = handle.shape[rank - 1];
189 let stride_row = handle.strides[rank - 2];
190 let stride_col = handle.strides[rank - 1];
191
192 let batch_shape = problem
193 .out_batches
194 .iter()
195 .map(|shape| FastDivmodArgs::new(client, *shape as u32))
196 .collect();
197 let batch_strides = handle.strides[..rank - 2]
198 .iter()
199 .zip(&handle.shape[..rank - 2])
200 .map(|(stride, shape)| if *shape == 1 { 0 } else { *stride })
201 .map(|stride| ScalarArg::new(stride as u32))
202 .collect();
203
204 BatchedGlobalLayoutLaunch::new(
205 batch_shape,
206 batch_strides,
207 ScalarArg::new(rows as u32),
208 ScalarArg::new(cols as u32),
209 ScalarArg::new(stride_row as u32),
210 ScalarArg::new(stride_col as u32),
211 line_size as u32,
212 1,
213 config,
214 )
215 }
216
217 #[allow(clippy::too_many_arguments)]
218 pub fn from_quantized_handle(
219 client: &ComputeClient<R::Server>,
220 values: &TensorHandleRef<'a, R>,
221 scales: &TensorHandleRef<'a, R>,
222 shape: &'a [usize],
223 problem: &MatmulProblem,
224 scheme: QuantScheme,
225 line_size: u8,
226 config: GlobalLayoutConfig,
227 ) -> (
228 BatchedGlobalLayoutLaunch<'a, R>,
229 BatchedGlobalScaleLayoutArgs<'a, R>,
230 ) {
231 let rank = values.shape.len();
232 let (rows, cols) = (shape[rank - 2], shape[rank - 1]);
233 let values_layout = {
234 let (stride_row, stride_col) = (values.strides[rank - 2], values.strides[rank - 1]);
235
236 let batch_shape = problem
237 .out_batches
238 .iter()
239 .map(|shape| FastDivmodArgs::new(client, *shape as u32))
240 .collect();
241 let batch_strides = values.strides[..rank - 2]
242 .iter()
243 .zip(&values.shape[..rank - 2])
244 .map(|(stride, shape)| if *shape == 1 { 0 } else { *stride })
245 .map(|stride| ScalarArg::new(stride as u32))
246 .collect();
247
248 BatchedGlobalLayoutLaunch::new(
249 batch_shape,
250 batch_strides,
251 ScalarArg::new(rows as u32),
252 ScalarArg::new(cols as u32),
253 ScalarArg::new(stride_row as u32),
254 ScalarArg::new(stride_col as u32),
255 line_size as u32,
256 scheme.num_quants() as u32,
257 config,
258 )
259 };
260
261 let scales_layout = {
262 let shape = (ScalarArg::new(rows as u32), ScalarArg::new(cols as u32));
263
264 match scheme.level {
265 QuantLevel::Tensor => BatchedGlobalScaleLayoutArgs::PerTensor { shape },
266 QuantLevel::Block(block_size) => {
267 let [block_row, block_col] = block_size.as_dim();
268 let scales_layout =
270 BatchedGlobalLayoutLaunch::from_handle(client, scales, problem, 1, config);
271 BatchedGlobalScaleLayoutArgs::BlockScaled(BlockScaledLayoutLaunch::new(
272 shape,
273 scales_layout,
274 (block_row as u32, block_col as u32),
275 ))
276 }
277 }
278 };
279
280 (values_layout, scales_layout)
281 }
282}
283
284#[derive(CubeType, CubeLaunch)]
285pub enum BatchedGlobalScaleLayout {
286 PerTensor { shape: Coords2d },
287 BlockScaled(BlockScaledLayout),
288}
289
290#[derive(CubeType, CubeLaunch)]
292pub struct BlockScaledLayout {
293 shape: Coords2d,
294 scales_layout: BatchedGlobalLayout,
295 #[cube(comptime)]
296 block_size: Coords2d,
297}
298
299#[cube]
300impl BlockScaledLayout {
301 pub fn new(
302 shape: Coords2d,
303 scales_layout: BatchedGlobalLayout,
304 #[comptime] block_size: Coords2d,
305 ) -> Self {
306 BlockScaledLayout {
307 shape,
308 scales_layout,
309 block_size,
310 }
311 }
312}
313
314#[cube]
315impl Layout for BatchedGlobalScaleLayout {
316 type Coordinates = Coords3d;
317 type SourceCoordinates = Coords1d;
318
319 fn to_source_pos(&self, coords: Self::Coordinates) -> u32 {
320 match self {
321 BatchedGlobalScaleLayout::PerTensor { .. } => 0u32.runtime(),
322 BatchedGlobalScaleLayout::BlockScaled(layout) => {
323 let BlockScaledLayout {
324 scales_layout,
325 block_size,
326 ..
327 } = layout;
328
329 let (batch, row, col) = coords;
330 let (block_row, block_col) = block_size;
331 let (row, col) = (row / block_row, col / block_col);
332 scales_layout.to_source_pos((batch, row, col))
333 }
334 }
335 }
336
337 fn to_source_pos_checked(&self, coords: Self::Coordinates) -> (u32, bool) {
338 (self.to_source_pos(coords), self.is_in_bounds(coords))
339 }
340
341 fn shape(&self) -> Self::Coordinates {
342 match self {
343 BatchedGlobalScaleLayout::PerTensor { shape } => (u32::MAX.runtime(), shape.0, shape.1),
344 BatchedGlobalScaleLayout::BlockScaled(layout) => {
345 let (row, col) = layout.shape;
346 (u32::MAX.runtime(), row, col)
347 }
348 }
349 }
350
351 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
352 match self {
353 BatchedGlobalScaleLayout::PerTensor { .. } => true.runtime(),
354 BatchedGlobalScaleLayout::BlockScaled(layout) => {
355 let (_, row, col) = pos;
356 let l = &layout.scales_layout;
357 let (rows, cols) = layout.shape;
358
359 match comptime!((l.config.check_row_bounds, l.config.check_col_bounds)) {
360 (true, true) => row < rows && col < cols,
361 (true, false) => row < rows,
362 (false, true) => col < cols,
363 (false, false) => true,
364 }
365 }
366 }
367 }
368}