cubecl_matmul/components/global/memory/
layout.rs

1use cubecl::prelude::*;
2use cubecl_common::quant::scheme::{QuantLevel, QuantScheme};
3use cubecl_core::{self as cubecl};
4use cubecl_std::{
5    FastDivmod, FastDivmodArgs,
6    tensor::layout::{
7        Coords1d, Coords2d, Coords3d, Layout, LayoutExpand, VirtualLayout, VirtualLayoutLaunch,
8    },
9};
10
11use crate::components::{MatmulProblem, MatrixLayout, global::memory::GlobalMemoryConfig};
12
13/// Global layout that uses the last two dimensions and ignores all others.
14#[derive(CubeType, CubeLaunch, Clone, Copy)]
15pub struct SimpleTmaGlobalLayout {
16    #[cube(comptime)]
17    transposed: bool,
18    shape: Coords3d,
19}
20
21#[cube]
22impl SimpleTmaGlobalLayout {
23    /// Creates a new 2D layout with the batch set to `nth_batch`.
24    pub fn new(shape: Coords3d, #[comptime] layout: MatrixLayout) -> Self {
25        let transposed = comptime![matches!(layout, MatrixLayout::ColMajor)];
26        SimpleTmaGlobalLayout { shape, transposed }
27    }
28}
29
30#[cube]
31impl Layout for SimpleTmaGlobalLayout {
32    type Coordinates = Coords3d;
33    type SourceCoordinates = Coords3d;
34
35    fn to_source_pos(&self, coords: Self::Coordinates) -> Coords3d {
36        let (batch, row, col) = coords;
37        // Tensor maps are required to have a stride of 1 on the last dim, so their shape is
38        // transposed for col-major matrices. Need to compensate by swapping the coordinates.
39        if comptime![self.transposed] {
40            (batch, col, row)
41        } else {
42            (batch, row, col)
43        }
44    }
45
46    fn to_source_pos_checked(&self, coords: Self::Coordinates) -> (Coords3d, bool) {
47        (self.to_source_pos(coords), self.is_in_bounds(coords))
48    }
49
50    fn shape(&self) -> Self::Coordinates {
51        self.shape
52    }
53
54    fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool {
55        // No need to bounds check TMA loads
56        true.runtime()
57    }
58}
59
60#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, Default)]
61pub struct GlobalLayoutConfig {
62    pub matrix_layout: MatrixLayout,
63    pub check_row_bounds: bool,
64    pub check_col_bounds: bool,
65}
66
67impl From<GlobalMemoryConfig> for GlobalLayoutConfig {
68    fn from(gmem_config: GlobalMemoryConfig) -> Self {
69        GlobalLayoutConfig {
70            matrix_layout: gmem_config.matrix_layout,
71            check_row_bounds: gmem_config.check_row_bounds,
72            check_col_bounds: gmem_config.check_col_bounds,
73        }
74    }
75}
76
77/// Global layout that uses the last two dimensions and ignores all others.
78#[derive(CubeType, CubeLaunch, Clone)]
79pub struct GlobalLayout {
80    batch_layout: VirtualLayout<Coords1d, Coords1d>,
81    rows: u32,
82    cols: u32,
83
84    stride_row: u32,
85    stride_col: u32,
86
87    #[cube(comptime)]
88    line_size: u32,
89    #[cube(comptime)]
90    packing: u32,
91    #[cube(comptime)]
92    config: GlobalLayoutConfig,
93}
94
95#[cube]
96impl GlobalLayout {
97    /// Create a new batched global layout. `batch_shape` should be based on the output shape.
98    #[allow(clippy::too_many_arguments)]
99    pub fn new(
100        batch_layout: VirtualLayout<Coords1d, Coords1d>,
101        shape_row: u32,
102        shape_col: u32,
103        stride_row: u32,
104        stride_col: u32,
105        #[comptime] line_size: u32,
106        #[comptime] packing: u32,
107        #[comptime] config: GlobalLayoutConfig,
108    ) -> Self {
109        GlobalLayout {
110            batch_layout,
111            rows: shape_row,
112            cols: shape_col,
113            stride_row,
114            stride_col,
115            line_size,
116            packing,
117            config,
118        }
119    }
120}
121
122#[cube]
123impl Layout for GlobalLayout {
124    type Coordinates = Coords3d;
125    type SourceCoordinates = Coords1d;
126
127    fn to_source_pos(&self, coords: Self::Coordinates) -> u32 {
128        let line_size = comptime![self.line_size];
129        let (batch, row, col) = coords;
130        let batch_offs = self.batch_layout.to_source_pos(batch);
131
132        let (row, col) = match comptime![self.config.matrix_layout] {
133            MatrixLayout::RowMajor => (row, col / self.packing),
134            MatrixLayout::ColMajor => (row / self.packing, col),
135        };
136
137        let idx = batch_offs + row * self.stride_row + col * self.stride_col;
138
139        idx / line_size
140    }
141
142    fn to_source_pos_checked(&self, coords: Self::Coordinates) -> (u32, bool) {
143        (self.to_source_pos(coords), self.is_in_bounds(coords))
144    }
145
146    fn shape(&self) -> Self::Coordinates {
147        (u32::MAX.runtime(), self.rows, self.cols)
148    }
149
150    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
151        let (_, row, col) = pos;
152
153        match comptime!((self.config.check_row_bounds, self.config.check_col_bounds)) {
154            (true, true) => row < self.rows && col < self.cols,
155            (true, false) => row < self.rows,
156            (false, true) => col < self.cols,
157            (false, false) => true,
158        }
159    }
160}
161
162impl<'a, R: Runtime> GlobalLayoutLaunch<'a, R> {
163    pub fn from_handle(
164        handle: &TensorHandleRef<'a, R>,
165        line_size: u8,
166        config: GlobalLayoutConfig,
167    ) -> Self {
168        let rank = handle.shape.len();
169        let rows = handle.shape[rank - 2];
170        let cols = handle.shape[rank - 1];
171        let stride_row = handle.strides[rank - 2];
172        let stride_col = handle.strides[rank - 1];
173
174        GlobalLayoutLaunch::new(
175            VirtualLayoutLaunch::new::<NoopLayout>(NoopLayoutLaunch::new()),
176            ScalarArg::new(rows as u32),
177            ScalarArg::new(cols as u32),
178            ScalarArg::new(stride_row as u32),
179            ScalarArg::new(stride_col as u32),
180            line_size as u32,
181            1,
182            config,
183        )
184    }
185
186    pub fn from_handle_batched(
187        client: &ComputeClient<R::Server>,
188        handle: &TensorHandleRef<'a, R>,
189        problem: &MatmulProblem,
190        line_size: u8,
191        config: GlobalLayoutConfig,
192    ) -> Self {
193        let rank = handle.shape.len();
194        let rows = handle.shape[rank - 2];
195        let cols = handle.shape[rank - 1];
196        let stride_row = handle.strides[rank - 2];
197        let stride_col = handle.strides[rank - 1];
198
199        let batch_layout = BatchLayoutLaunch::from_handle(client, handle, problem);
200
201        GlobalLayoutLaunch::new(
202            VirtualLayoutLaunch::new::<BatchLayout>(batch_layout),
203            ScalarArg::new(rows as u32),
204            ScalarArg::new(cols as u32),
205            ScalarArg::new(stride_row as u32),
206            ScalarArg::new(stride_col as u32),
207            line_size as u32,
208            1,
209            config,
210        )
211    }
212
213    #[allow(clippy::too_many_arguments)]
214    pub fn from_quantized_handle(
215        client: &ComputeClient<R::Server>,
216        values: &TensorHandleRef<'a, R>,
217        scales: &TensorHandleRef<'a, R>,
218        shape: &'a [usize],
219        problem: &MatmulProblem,
220        scheme: QuantScheme,
221        line_size: u8,
222        config: GlobalLayoutConfig,
223    ) -> (GlobalLayoutLaunch<'a, R>, GlobalScaleLayoutArgs<'a, R>) {
224        let rank = values.shape.len();
225        let (rows, cols) = (shape[rank - 2], shape[rank - 1]);
226        let values_layout = {
227            let (stride_row, stride_col) = (values.strides[rank - 2], values.strides[rank - 1]);
228
229            let batch_layout = BatchLayoutLaunch::from_handle(client, values, problem);
230
231            GlobalLayoutLaunch::new(
232                VirtualLayoutLaunch::new::<BatchLayout>(batch_layout),
233                ScalarArg::new(rows as u32),
234                ScalarArg::new(cols as u32),
235                ScalarArg::new(stride_row as u32),
236                ScalarArg::new(stride_col as u32),
237                line_size as u32,
238                scheme.num_quants() as u32,
239                config,
240            )
241        };
242
243        let scales_layout = {
244            let shape = (ScalarArg::new(rows as u32), ScalarArg::new(cols as u32));
245
246            match scheme.level {
247                QuantLevel::Tensor => GlobalScaleLayoutArgs::PerTensor { shape },
248                QuantLevel::Block(block_size) => {
249                    let [block_row, block_col] = block_size.as_dim();
250                    // Scales are never vectorized because we require that `block_size >= line_size * num_quants`.
251                    let scales_layout =
252                        GlobalLayoutLaunch::from_handle_batched(client, scales, problem, 1, config);
253                    GlobalScaleLayoutArgs::BlockScaled(BlockScaledLayoutLaunch::new(
254                        shape,
255                        scales_layout,
256                        (block_row as u32, block_col as u32),
257                    ))
258                }
259            }
260        };
261
262        (values_layout, scales_layout)
263    }
264}
265
266#[derive(CubeType, CubeLaunch)]
267pub struct BatchLayout {
268    batch_shape: Sequence<FastDivmod>,
269    batch_strides: Sequence<u32>,
270}
271
272#[cube]
273impl BatchLayout {
274    pub fn new(batch_strides: Sequence<u32>, batch_shape: Sequence<FastDivmod>) -> Self {
275        BatchLayout {
276            batch_shape,
277            batch_strides,
278        }
279    }
280}
281
282#[cube]
283impl Layout for BatchLayout {
284    type Coordinates = Coords1d;
285    type SourceCoordinates = Coords1d;
286
287    fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
288        let mut batch = pos;
289        let mut batch_offs = 0;
290        let batch_shape = self.batch_shape.rev();
291        let batch_strides = self.batch_strides.rev();
292
293        #[unroll]
294        for i in 0..batch_shape.len() {
295            let (rem, local_pos) = batch_shape.index(i).div_mod(batch);
296            batch = rem;
297            batch_offs += local_pos * *batch_strides.index(i);
298        }
299
300        batch_offs
301    }
302
303    fn shape(&self) -> Self::Coordinates {
304        u32::MAX.runtime()
305    }
306
307    fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool {
308        true.runtime()
309    }
310
311    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
312        (self.to_source_pos(pos), self.is_in_bounds(pos))
313    }
314}
315
316/// Layout that passed through the coordinates with no checks or modification.
317#[derive(CubeType, CubeLaunch)]
318pub struct NoopLayout {}
319
320#[cube]
321impl NoopLayout {
322    #[allow(clippy::new_without_default)]
323    pub fn new() -> Self {
324        NoopLayout {}
325    }
326}
327
328#[cube]
329impl Layout for NoopLayout {
330    type Coordinates = Coords1d;
331    type SourceCoordinates = Coords1d;
332
333    fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
334        pos
335    }
336
337    fn shape(&self) -> Self::Coordinates {
338        u32::MAX.runtime()
339    }
340
341    fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool {
342        true.runtime()
343    }
344
345    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
346        (self.to_source_pos(pos), self.is_in_bounds(pos))
347    }
348}
349
350impl<'a, R: Runtime> BatchLayoutLaunch<'a, R> {
351    pub fn from_handle(
352        client: &ComputeClient<R::Server>,
353        handle: &TensorHandleRef<'a, R>,
354        problem: &MatmulProblem,
355    ) -> Self {
356        let rank = handle.shape.len();
357        let batch_shape = problem
358            .out_batches
359            .iter()
360            .map(|shape| FastDivmodArgs::new(client, *shape as u32))
361            .collect();
362        let batch_strides = handle.strides[..rank - 2]
363            .iter()
364            .zip(&handle.shape[..rank - 2])
365            .map(|(stride, shape)| if *shape == 1 { 0 } else { *stride })
366            .map(|stride| ScalarArg::new(stride as u32))
367            .collect();
368        BatchLayoutLaunch::new(batch_shape, batch_strides)
369    }
370}
371
372#[derive(CubeType, CubeLaunch)]
373pub enum GlobalScaleLayout {
374    PerTensor { shape: Coords2d },
375    BlockScaled(BlockScaledLayout),
376}
377
378/// Workaround for enums not supporting `comptime`, should fix that in the future
379#[derive(CubeType, CubeLaunch)]
380pub struct BlockScaledLayout {
381    shape: Coords2d,
382    scales_layout: GlobalLayout,
383    #[cube(comptime)]
384    block_size: Coords2d,
385}
386
387#[cube]
388impl BlockScaledLayout {
389    pub fn new(
390        shape: Coords2d,
391        scales_layout: GlobalLayout,
392        #[comptime] block_size: Coords2d,
393    ) -> Self {
394        BlockScaledLayout {
395            shape,
396            scales_layout,
397            block_size,
398        }
399    }
400}
401
402#[cube]
403impl Layout for GlobalScaleLayout {
404    type Coordinates = Coords3d;
405    type SourceCoordinates = Coords1d;
406
407    fn to_source_pos(&self, coords: Self::Coordinates) -> u32 {
408        match self {
409            GlobalScaleLayout::PerTensor { .. } => 0u32.runtime(),
410            GlobalScaleLayout::BlockScaled(layout) => {
411                let BlockScaledLayout {
412                    scales_layout,
413                    block_size,
414                    ..
415                } = layout;
416
417                let (batch, row, col) = coords;
418                let (block_row, block_col) = block_size;
419                let (row, col) = (row / block_row, col / block_col);
420                scales_layout.to_source_pos((batch, row, col))
421            }
422        }
423    }
424
425    fn to_source_pos_checked(&self, coords: Self::Coordinates) -> (u32, bool) {
426        (self.to_source_pos(coords), self.is_in_bounds(coords))
427    }
428
429    fn shape(&self) -> Self::Coordinates {
430        match self {
431            GlobalScaleLayout::PerTensor { shape } => (u32::MAX.runtime(), shape.0, shape.1),
432            GlobalScaleLayout::BlockScaled(layout) => {
433                let (row, col) = layout.shape;
434                (u32::MAX.runtime(), row, col)
435            }
436        }
437    }
438
439    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
440        match self {
441            GlobalScaleLayout::PerTensor { .. } => true.runtime(),
442            GlobalScaleLayout::BlockScaled(layout) => {
443                let (_, row, col) = pos;
444                let l = &layout.scales_layout;
445                let (rows, cols) = layout.shape;
446
447                match comptime!((l.config.check_row_bounds, l.config.check_col_bounds)) {
448                    (true, true) => row < rows && col < cols,
449                    (true, false) => row < rows,
450                    (false, true) => col < cols,
451                    (false, false) => true,
452                }
453            }
454        }
455    }
456}