cubecl_matmul/components/global/memory/
layout.rs

1use cubecl::prelude::*;
2use cubecl_common::quant::scheme::{QuantLevel, QuantScheme};
3use cubecl_core::{self as cubecl};
4use cubecl_std::{
5    FastDivmod, FastDivmodArgs,
6    tensor::layout::{Coords1d, Coords2d, Coords3d, Layout, LayoutExpand},
7};
8
9use crate::components::{MatmulProblem, MatrixLayout, global::memory::GlobalMemoryConfig};
10
11/// Global layout that uses the last two dimensions and ignores all others.
12#[derive(CubeType, CubeLaunch, Clone, Copy)]
13pub struct SimpleTmaGlobalLayout {
14    #[cube(comptime)]
15    transposed: bool,
16    shape: Coords3d,
17}
18
19#[cube]
20impl SimpleTmaGlobalLayout {
21    /// Creates a new 2D layout with the batch set to `nth_batch`.
22    pub fn new(shape: Coords3d, #[comptime] layout: MatrixLayout) -> Self {
23        let transposed = comptime![matches!(layout, MatrixLayout::ColMajor)];
24        SimpleTmaGlobalLayout { shape, transposed }
25    }
26}
27
28#[cube]
29impl Layout for SimpleTmaGlobalLayout {
30    type Coordinates = Coords3d;
31    type SourceCoordinates = Coords3d;
32
33    fn to_source_pos(&self, coords: Self::Coordinates) -> Coords3d {
34        let (batch, row, col) = coords;
35        // Tensor maps are required to have a stride of 1 on the last dim, so their shape is
36        // transposed for col-major matrices. Need to compensate by swapping the coordinates.
37        if comptime![self.transposed] {
38            (batch, col, row)
39        } else {
40            (batch, row, col)
41        }
42    }
43
44    fn to_source_pos_checked(&self, coords: Self::Coordinates) -> (Coords3d, bool) {
45        (self.to_source_pos(coords), self.is_in_bounds(coords))
46    }
47
48    fn shape(&self) -> Self::Coordinates {
49        self.shape
50    }
51
52    fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool {
53        // No need to bounds check TMA loads
54        true.runtime()
55    }
56}
57
58#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, Default)]
59pub struct GlobalLayoutConfig {
60    pub matrix_layout: MatrixLayout,
61    pub check_row_bounds: bool,
62    pub check_col_bounds: bool,
63}
64
65impl From<GlobalMemoryConfig> for GlobalLayoutConfig {
66    fn from(value: GlobalMemoryConfig) -> Self {
67        GlobalLayoutConfig {
68            matrix_layout: value.matrix_layout,
69            check_row_bounds: value.check_row_bounds,
70            check_col_bounds: value.check_col_bounds,
71        }
72    }
73}
74
75/// Global layout that uses the last two dimensions and ignores all others.
76#[derive(CubeType, CubeLaunch, Clone)]
77pub struct BatchedGlobalLayout {
78    batch_shape: Sequence<FastDivmod>,
79    batch_strides: Sequence<u32>,
80
81    rows: u32,
82    cols: u32,
83
84    stride_row: u32,
85    stride_col: u32,
86
87    #[cube(comptime)]
88    line_size: u32,
89    #[cube(comptime)]
90    packing: u32,
91    #[cube(comptime)]
92    config: GlobalLayoutConfig,
93}
94
95#[cube]
96impl BatchedGlobalLayout {
97    /// Create a new batched global layout. `batch_shape` should be based on the output shape.
98    #[allow(clippy::too_many_arguments)]
99    pub fn new(
100        batch_strides: Sequence<u32>,
101        batch_shape: Sequence<FastDivmod>,
102        shape_row: u32,
103        shape_col: u32,
104        stride_row: u32,
105        stride_col: u32,
106        #[comptime] line_size: u32,
107        #[comptime] packing: u32,
108        #[comptime] config: GlobalLayoutConfig,
109    ) -> Self {
110        BatchedGlobalLayout {
111            batch_shape,
112            batch_strides,
113            rows: shape_row,
114            cols: shape_col,
115            stride_row,
116            stride_col,
117            line_size,
118            packing,
119            config,
120        }
121    }
122}
123
124#[cube]
125impl Layout for BatchedGlobalLayout {
126    type Coordinates = Coords3d;
127    type SourceCoordinates = Coords1d;
128
129    fn to_source_pos(&self, coords: Self::Coordinates) -> u32 {
130        let line_size = comptime![self.line_size];
131        let (mut batch, row, col) = coords;
132
133        let (row, col) = match comptime![self.config.matrix_layout] {
134            MatrixLayout::RowMajor => (row, col / self.packing),
135            MatrixLayout::ColMajor => (row / self.packing, col),
136        };
137
138        // This looks expensive to calculate each time, but the batch is constant across all loop
139        // iterations, so it'll get pulled out by the compiler and only calculated once. It will
140        // generate more code for unrolled loops, but should be fine.
141        // TODO: VALIDATE WITH PROFILER
142        let mut batch_offs = 0;
143        let batch_shape = self.batch_shape.rev();
144        let batch_strides = self.batch_strides.rev();
145
146        #[unroll]
147        for i in 0..batch_shape.len() {
148            let (rem, local_pos) = batch_shape.index(i).div_mod(batch);
149            batch = rem;
150            batch_offs += local_pos * *batch_strides.index(i);
151        }
152
153        let idx = batch_offs + row * self.stride_row + col * self.stride_col;
154
155        idx / line_size
156    }
157
158    fn to_source_pos_checked(&self, coords: Self::Coordinates) -> (u32, bool) {
159        (self.to_source_pos(coords), self.is_in_bounds(coords))
160    }
161
162    fn shape(&self) -> Self::Coordinates {
163        (u32::MAX.runtime(), self.rows, self.cols)
164    }
165
166    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
167        let (_, row, col) = pos;
168
169        match comptime!((self.config.check_row_bounds, self.config.check_col_bounds)) {
170            (true, true) => row < self.rows && col < self.cols,
171            (true, false) => row < self.rows,
172            (false, true) => col < self.cols,
173            (false, false) => true,
174        }
175    }
176}
177
178impl<'a, R: Runtime> BatchedGlobalLayoutLaunch<'a, R> {
179    pub fn from_handle(
180        client: &ComputeClient<R::Server>,
181        handle: &TensorHandleRef<'a, R>,
182        problem: &MatmulProblem,
183        line_size: u8,
184        config: GlobalLayoutConfig,
185    ) -> Self {
186        let rank = handle.shape.len();
187        let rows = handle.shape[rank - 2];
188        let cols = handle.shape[rank - 1];
189        let stride_row = handle.strides[rank - 2];
190        let stride_col = handle.strides[rank - 1];
191
192        let batch_shape = problem
193            .out_batches
194            .iter()
195            .map(|shape| FastDivmodArgs::new(client, *shape as u32))
196            .collect();
197        let batch_strides = handle.strides[..rank - 2]
198            .iter()
199            .zip(&handle.shape[..rank - 2])
200            .map(|(stride, shape)| if *shape == 1 { 0 } else { *stride })
201            .map(|stride| ScalarArg::new(stride as u32))
202            .collect();
203
204        BatchedGlobalLayoutLaunch::new(
205            batch_shape,
206            batch_strides,
207            ScalarArg::new(rows as u32),
208            ScalarArg::new(cols as u32),
209            ScalarArg::new(stride_row as u32),
210            ScalarArg::new(stride_col as u32),
211            line_size as u32,
212            1,
213            config,
214        )
215    }
216
217    #[allow(clippy::too_many_arguments)]
218    pub fn from_quantized_handle(
219        client: &ComputeClient<R::Server>,
220        values: &TensorHandleRef<'a, R>,
221        scales: &TensorHandleRef<'a, R>,
222        shape: &'a [usize],
223        problem: &MatmulProblem,
224        scheme: QuantScheme,
225        line_size: u8,
226        config: GlobalLayoutConfig,
227    ) -> (
228        BatchedGlobalLayoutLaunch<'a, R>,
229        BatchedGlobalScaleLayoutArgs<'a, R>,
230    ) {
231        let rank = values.shape.len();
232        let (rows, cols) = (shape[rank - 2], shape[rank - 1]);
233        let values_layout = {
234            let (stride_row, stride_col) = (values.strides[rank - 2], values.strides[rank - 1]);
235
236            let batch_shape = problem
237                .out_batches
238                .iter()
239                .map(|shape| FastDivmodArgs::new(client, *shape as u32))
240                .collect();
241            let batch_strides = values.strides[..rank - 2]
242                .iter()
243                .zip(&values.shape[..rank - 2])
244                .map(|(stride, shape)| if *shape == 1 { 0 } else { *stride })
245                .map(|stride| ScalarArg::new(stride as u32))
246                .collect();
247
248            BatchedGlobalLayoutLaunch::new(
249                batch_shape,
250                batch_strides,
251                ScalarArg::new(rows as u32),
252                ScalarArg::new(cols as u32),
253                ScalarArg::new(stride_row as u32),
254                ScalarArg::new(stride_col as u32),
255                line_size as u32,
256                scheme.num_quants() as u32,
257                config,
258            )
259        };
260
261        let scales_layout = {
262            let shape = (ScalarArg::new(rows as u32), ScalarArg::new(cols as u32));
263
264            match scheme.level {
265                QuantLevel::Tensor => BatchedGlobalScaleLayoutArgs::PerTensor { shape },
266                QuantLevel::Block(block_size) => {
267                    let [block_row, block_col] = block_size.as_dim();
268                    // Scales are never vectorized because we require that `block_size >= line_size * num_quants`.
269                    let scales_layout =
270                        BatchedGlobalLayoutLaunch::from_handle(client, scales, problem, 1, config);
271                    BatchedGlobalScaleLayoutArgs::BlockScaled(BlockScaledLayoutLaunch::new(
272                        shape,
273                        scales_layout,
274                        (block_row as u32, block_col as u32),
275                    ))
276                }
277            }
278        };
279
280        (values_layout, scales_layout)
281    }
282}
283
284#[derive(CubeType, CubeLaunch)]
285pub enum BatchedGlobalScaleLayout {
286    PerTensor { shape: Coords2d },
287    BlockScaled(BlockScaledLayout),
288}
289
290/// Workaround for enums not supporting `comptime`, should fix that in the future
291#[derive(CubeType, CubeLaunch)]
292pub struct BlockScaledLayout {
293    shape: Coords2d,
294    scales_layout: BatchedGlobalLayout,
295    #[cube(comptime)]
296    block_size: Coords2d,
297}
298
299#[cube]
300impl BlockScaledLayout {
301    pub fn new(
302        shape: Coords2d,
303        scales_layout: BatchedGlobalLayout,
304        #[comptime] block_size: Coords2d,
305    ) -> Self {
306        BlockScaledLayout {
307            shape,
308            scales_layout,
309            block_size,
310        }
311    }
312}
313
314#[cube]
315impl Layout for BatchedGlobalScaleLayout {
316    type Coordinates = Coords3d;
317    type SourceCoordinates = Coords1d;
318
319    fn to_source_pos(&self, coords: Self::Coordinates) -> u32 {
320        match self {
321            BatchedGlobalScaleLayout::PerTensor { .. } => 0u32.runtime(),
322            BatchedGlobalScaleLayout::BlockScaled(layout) => {
323                let BlockScaledLayout {
324                    scales_layout,
325                    block_size,
326                    ..
327                } = layout;
328
329                let (batch, row, col) = coords;
330                let (block_row, block_col) = block_size;
331                let (row, col) = (row / block_row, col / block_col);
332                scales_layout.to_source_pos((batch, row, col))
333            }
334        }
335    }
336
337    fn to_source_pos_checked(&self, coords: Self::Coordinates) -> (u32, bool) {
338        (self.to_source_pos(coords), self.is_in_bounds(coords))
339    }
340
341    fn shape(&self) -> Self::Coordinates {
342        match self {
343            BatchedGlobalScaleLayout::PerTensor { shape } => (u32::MAX.runtime(), shape.0, shape.1),
344            BatchedGlobalScaleLayout::BlockScaled(layout) => {
345                let (row, col) = layout.shape;
346                (u32::MAX.runtime(), row, col)
347            }
348        }
349    }
350
351    fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
352        match self {
353            BatchedGlobalScaleLayout::PerTensor { .. } => true.runtime(),
354            BatchedGlobalScaleLayout::BlockScaled(layout) => {
355                let (_, row, col) = pos;
356                let l = &layout.scales_layout;
357                let (rows, cols) = layout.shape;
358
359                match comptime!((l.config.check_row_bounds, l.config.check_col_bounds)) {
360                    (true, true) => row < rows && col < cols,
361                    (true, false) => row < rows,
362                    (false, true) => col < cols,
363                    (false, false) => true,
364                }
365            }
366        }
367    }
368}