1use cubecl::prelude::*;
2use cubecl_common::quant::scheme::{QuantLevel, QuantScheme};
3use cubecl_core::{self as cubecl};
4use cubecl_std::{
5 FastDivmod, FastDivmodArgs,
6 tensor::layout::{
7 Coords1d, Coords2d, Coords3d, Layout, LayoutExpand, VirtualLayout, VirtualLayoutLaunch,
8 },
9};
10
11use crate::components::{MatmulProblem, MatrixLayout, global::memory::GlobalMemoryConfig};
12
13#[derive(CubeType, CubeLaunch, Clone, Copy)]
15pub struct SimpleTmaGlobalLayout {
16 #[cube(comptime)]
17 transposed: bool,
18 shape: Coords3d,
19}
20
21#[cube]
22impl SimpleTmaGlobalLayout {
23 pub fn new(shape: Coords3d, #[comptime] layout: MatrixLayout) -> Self {
25 let transposed = comptime![matches!(layout, MatrixLayout::ColMajor)];
26 SimpleTmaGlobalLayout { shape, transposed }
27 }
28}
29
30#[cube]
31impl Layout for SimpleTmaGlobalLayout {
32 type Coordinates = Coords3d;
33 type SourceCoordinates = Coords3d;
34
35 fn to_source_pos(&self, coords: Self::Coordinates) -> Coords3d {
36 let (batch, row, col) = coords;
37 if comptime![self.transposed] {
40 (batch, col, row)
41 } else {
42 (batch, row, col)
43 }
44 }
45
46 fn to_source_pos_checked(&self, coords: Self::Coordinates) -> (Coords3d, bool) {
47 (self.to_source_pos(coords), self.is_in_bounds(coords))
48 }
49
50 fn shape(&self) -> Self::Coordinates {
51 self.shape
52 }
53
54 fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool {
55 true.runtime()
57 }
58}
59
60#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, Default)]
61pub struct GlobalLayoutConfig {
62 pub matrix_layout: MatrixLayout,
63 pub check_row_bounds: bool,
64 pub check_col_bounds: bool,
65}
66
67impl From<GlobalMemoryConfig> for GlobalLayoutConfig {
68 fn from(gmem_config: GlobalMemoryConfig) -> Self {
69 GlobalLayoutConfig {
70 matrix_layout: gmem_config.matrix_layout,
71 check_row_bounds: gmem_config.check_row_bounds,
72 check_col_bounds: gmem_config.check_col_bounds,
73 }
74 }
75}
76
77#[derive(CubeType, CubeLaunch, Clone)]
79pub struct GlobalLayout {
80 batch_layout: VirtualLayout<Coords1d, Coords1d>,
81 rows: u32,
82 cols: u32,
83
84 stride_row: u32,
85 stride_col: u32,
86
87 #[cube(comptime)]
88 line_size: u32,
89 #[cube(comptime)]
90 packing: u32,
91 #[cube(comptime)]
92 config: GlobalLayoutConfig,
93}
94
95#[cube]
96impl GlobalLayout {
97 #[allow(clippy::too_many_arguments)]
99 pub fn new(
100 batch_layout: VirtualLayout<Coords1d, Coords1d>,
101 shape_row: u32,
102 shape_col: u32,
103 stride_row: u32,
104 stride_col: u32,
105 #[comptime] line_size: u32,
106 #[comptime] packing: u32,
107 #[comptime] config: GlobalLayoutConfig,
108 ) -> Self {
109 GlobalLayout {
110 batch_layout,
111 rows: shape_row,
112 cols: shape_col,
113 stride_row,
114 stride_col,
115 line_size,
116 packing,
117 config,
118 }
119 }
120}
121
122#[cube]
123impl Layout for GlobalLayout {
124 type Coordinates = Coords3d;
125 type SourceCoordinates = Coords1d;
126
127 fn to_source_pos(&self, coords: Self::Coordinates) -> u32 {
128 let line_size = comptime![self.line_size];
129 let (batch, row, col) = coords;
130 let batch_offs = self.batch_layout.to_source_pos(batch);
131
132 let (row, col) = match comptime![self.config.matrix_layout] {
133 MatrixLayout::RowMajor => (row, col / self.packing),
134 MatrixLayout::ColMajor => (row / self.packing, col),
135 };
136
137 let idx = batch_offs + row * self.stride_row + col * self.stride_col;
138
139 idx / line_size
140 }
141
142 fn to_source_pos_checked(&self, coords: Self::Coordinates) -> (u32, bool) {
143 (self.to_source_pos(coords), self.is_in_bounds(coords))
144 }
145
146 fn shape(&self) -> Self::Coordinates {
147 (u32::MAX.runtime(), self.rows, self.cols)
148 }
149
150 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
151 let (_, row, col) = pos;
152
153 match comptime!((self.config.check_row_bounds, self.config.check_col_bounds)) {
154 (true, true) => row < self.rows && col < self.cols,
155 (true, false) => row < self.rows,
156 (false, true) => col < self.cols,
157 (false, false) => true,
158 }
159 }
160}
161
162impl<'a, R: Runtime> GlobalLayoutLaunch<'a, R> {
163 pub fn from_handle(
164 handle: &TensorHandleRef<'a, R>,
165 line_size: u8,
166 config: GlobalLayoutConfig,
167 ) -> Self {
168 let rank = handle.shape.len();
169 let rows = handle.shape[rank - 2];
170 let cols = handle.shape[rank - 1];
171 let stride_row = handle.strides[rank - 2];
172 let stride_col = handle.strides[rank - 1];
173
174 GlobalLayoutLaunch::new(
175 VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
176 ScalarArg::new(rows as u32),
177 ScalarArg::new(cols as u32),
178 ScalarArg::new(stride_row as u32),
179 ScalarArg::new(stride_col as u32),
180 line_size as u32,
181 1,
182 config,
183 )
184 }
185
186 pub fn from_handle_batched(
187 client: &ComputeClient<R::Server>,
188 handle: &TensorHandleRef<'a, R>,
189 problem: &MatmulProblem,
190 line_size: u8,
191 config: GlobalLayoutConfig,
192 ) -> Self {
193 let rank = handle.shape.len();
194 let rows = handle.shape[rank - 2];
195 let cols = handle.shape[rank - 1];
196 let stride_row = handle.strides[rank - 2];
197 let stride_col = handle.strides[rank - 1];
198
199 let batch_layout = BatchLayoutLaunch::from_handle(client, handle, problem);
200
201 GlobalLayoutLaunch::new(
202 VirtualLayoutLaunch::new::<BatchLayout>(batch_layout),
203 ScalarArg::new(rows as u32),
204 ScalarArg::new(cols as u32),
205 ScalarArg::new(stride_row as u32),
206 ScalarArg::new(stride_col as u32),
207 line_size as u32,
208 1,
209 config,
210 )
211 }
212
213 #[allow(clippy::too_many_arguments)]
214 pub fn from_quantized_handle(
215 client: &ComputeClient<R::Server>,
216 values: &TensorHandleRef<'a, R>,
217 scales: &TensorHandleRef<'a, R>,
218 shape: &'a [usize],
219 problem: &MatmulProblem,
220 scheme: QuantScheme,
221 line_size: u8,
222 config: GlobalLayoutConfig,
223 ) -> (GlobalLayoutLaunch<'a, R>, GlobalScaleLayoutArgs<'a, R>) {
224 let rank = values.shape.len();
225 let (rows, cols) = (shape[rank - 2], shape[rank - 1]);
226 let values_layout = {
227 let (stride_row, stride_col) = (values.strides[rank - 2], values.strides[rank - 1]);
228
229 let batch_layout = BatchLayoutLaunch::from_handle(client, values, problem);
230
231 GlobalLayoutLaunch::new(
232 VirtualLayoutLaunch::new::<BatchLayout>(batch_layout),
233 ScalarArg::new(rows as u32),
234 ScalarArg::new(cols as u32),
235 ScalarArg::new(stride_row as u32),
236 ScalarArg::new(stride_col as u32),
237 line_size as u32,
238 scheme.num_quants() as u32,
239 config,
240 )
241 };
242
243 let scales_layout = {
244 let shape = (ScalarArg::new(rows as u32), ScalarArg::new(cols as u32));
245
246 match scheme.level {
247 QuantLevel::Tensor => GlobalScaleLayoutArgs::PerTensor { shape },
248 QuantLevel::Block(block_size) => {
249 let [block_row, block_col] = block_size.as_dim();
250 let scales_layout =
252 GlobalLayoutLaunch::from_handle_batched(client, scales, problem, 1, config);
253 GlobalScaleLayoutArgs::BlockScaled(BlockScaledLayoutLaunch::new(
254 shape,
255 scales_layout,
256 (block_row as u32, block_col as u32),
257 ))
258 }
259 }
260 };
261
262 (values_layout, scales_layout)
263 }
264}
265
266#[derive(CubeType, CubeLaunch)]
267pub struct BatchLayout {
268 batch_shape: Sequence<FastDivmod>,
269 batch_strides: Sequence<u32>,
270}
271
272#[cube]
273impl BatchLayout {
274 pub fn new(batch_strides: Sequence<u32>, batch_shape: Sequence<FastDivmod>) -> Self {
275 BatchLayout {
276 batch_shape,
277 batch_strides,
278 }
279 }
280}
281
282#[cube]
283impl Layout for BatchLayout {
284 type Coordinates = Coords1d;
285 type SourceCoordinates = Coords1d;
286
287 fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
288 let mut batch = pos;
289 let mut batch_offs = 0;
290 let batch_shape = self.batch_shape.rev();
291 let batch_strides = self.batch_strides.rev();
292
293 #[unroll]
294 for i in 0..batch_shape.len() {
295 let (rem, local_pos) = batch_shape.index(i).div_mod(batch);
296 batch = rem;
297 batch_offs += local_pos * *batch_strides.index(i);
298 }
299
300 batch_offs
301 }
302
303 fn shape(&self) -> Self::Coordinates {
304 u32::MAX.runtime()
305 }
306
307 fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool {
308 true.runtime()
309 }
310
311 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
312 (self.to_source_pos(pos), self.is_in_bounds(pos))
313 }
314}
315
316#[derive(CubeType, CubeLaunch)]
318pub struct NoopLayout {}
319
320#[cube]
321impl NoopLayout {
322 #[allow(clippy::new_without_default)]
323 pub fn new() -> Self {
324 NoopLayout {}
325 }
326}
327
328#[cube]
329impl Layout for NoopLayout {
330 type Coordinates = Coords1d;
331 type SourceCoordinates = Coords1d;
332
333 fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
334 pos
335 }
336
337 fn shape(&self) -> Self::Coordinates {
338 u32::MAX.runtime()
339 }
340
341 fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool {
342 true.runtime()
343 }
344
345 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
346 (self.to_source_pos(pos), self.is_in_bounds(pos))
347 }
348}
349
350impl<'a, R: Runtime> BatchLayoutLaunch<'a, R> {
351 pub fn from_handle(
352 client: &ComputeClient<R::Server>,
353 handle: &TensorHandleRef<'a, R>,
354 problem: &MatmulProblem,
355 ) -> Self {
356 let rank = handle.shape.len();
357 let batch_shape = problem
358 .out_batches
359 .iter()
360 .map(|shape| FastDivmodArgs::new(client, *shape as u32))
361 .collect();
362 let batch_strides = handle.strides[..rank - 2]
363 .iter()
364 .zip(&handle.shape[..rank - 2])
365 .map(|(stride, shape)| if *shape == 1 { 0 } else { *stride })
366 .map(|stride| ScalarArg::new(stride as u32))
367 .collect();
368 BatchLayoutLaunch::new(batch_shape, batch_strides)
369 }
370}
371
372#[derive(CubeType, CubeLaunch)]
373pub enum GlobalScaleLayout {
374 PerTensor { shape: Coords2d },
375 BlockScaled(BlockScaledLayout),
376}
377
378#[derive(CubeType, CubeLaunch)]
380pub struct BlockScaledLayout {
381 shape: Coords2d,
382 scales_layout: GlobalLayout,
383 #[cube(comptime)]
384 block_size: Coords2d,
385}
386
387#[cube]
388impl BlockScaledLayout {
389 pub fn new(
390 shape: Coords2d,
391 scales_layout: GlobalLayout,
392 #[comptime] block_size: Coords2d,
393 ) -> Self {
394 BlockScaledLayout {
395 shape,
396 scales_layout,
397 block_size,
398 }
399 }
400}
401
402#[cube]
403impl Layout for GlobalScaleLayout {
404 type Coordinates = Coords3d;
405 type SourceCoordinates = Coords1d;
406
407 fn to_source_pos(&self, coords: Self::Coordinates) -> u32 {
408 match self {
409 GlobalScaleLayout::PerTensor { .. } => 0u32.runtime(),
410 GlobalScaleLayout::BlockScaled(layout) => {
411 let BlockScaledLayout {
412 scales_layout,
413 block_size,
414 ..
415 } = layout;
416
417 let (batch, row, col) = coords;
418 let (block_row, block_col) = block_size;
419 let (row, col) = (row / block_row, col / block_col);
420 scales_layout.to_source_pos((batch, row, col))
421 }
422 }
423 }
424
425 fn to_source_pos_checked(&self, coords: Self::Coordinates) -> (u32, bool) {
426 (self.to_source_pos(coords), self.is_in_bounds(coords))
427 }
428
429 fn shape(&self) -> Self::Coordinates {
430 match self {
431 GlobalScaleLayout::PerTensor { shape } => (u32::MAX.runtime(), shape.0, shape.1),
432 GlobalScaleLayout::BlockScaled(layout) => {
433 let (row, col) = layout.shape;
434 (u32::MAX.runtime(), row, col)
435 }
436 }
437 }
438
439 fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
440 match self {
441 GlobalScaleLayout::PerTensor { .. } => true.runtime(),
442 GlobalScaleLayout::BlockScaled(layout) => {
443 let (_, row, col) = pos;
444 let l = &layout.scales_layout;
445 let (rows, cols) = layout.shape;
446
447 match comptime!((l.config.check_row_bounds, l.config.check_col_bounds)) {
448 (true, true) => row < rows && col < cols,
449 (true, false) => row < rows,
450 (false, true) => col < cols,
451 (false, false) => true,
452 }
453 }
454 }
455 }
456}