cubecl_linalg/matmul/components/global/
tensor_view.rs

1use crate::matmul::components::config::InputIdent;
2use crate::matmul::components::global;
3use crate::matmul::components::{Ident, MatrixLayout};
4use cubecl_core as cubecl;
5use cubecl_core::io::read_masked;
6use cubecl_core::prelude::*;
7use cubecl_std::tensor::r#virtual::{ReadWrite, VirtualTensor};
8
9#[derive(Clone, CubeType)]
10/// A view of a tensor that starts reading data from a specified offset.
11/// Ensures safe access by preventing out-of-bounds errors.
12/// Includes pre-fetched shapes and strides for optimized performance.
13pub struct TensorReader<EI: Numeric> {
14    pub tensor: VirtualTensor<EI>,
15    pub x_offset: u32,
16    pub y_offset: u32,
17    pub stride_x: u32,
18    pub stride_y: u32,
19    pub shape_x: u32,
20    pub shape_y: u32,
21    pub batch_offset: u32,
22}
23
24#[derive(CubeType)]
25/// A view of a tensor that starts reading data from a specified offset.
26/// Uses a [`TensorMap`] to actually execute the load.
27pub struct MappedTensorReader<EG: Numeric> {
28    pub tensor: TensorMap<EG>,
29    pub tile_x: u32,
30    pub tile_y: u32,
31    pub batch: u32,
32}
33
34#[derive(CubeType)]
35/// A view of a tensor that starts reading data from a specified offset.
36/// Ensures safe access by preventing out-of-bounds errors.
37/// Includes pre-fetched shapes and strides for optimized performance.
38pub struct TensorWriter<EO: Numeric> {
39    pub tensor: VirtualTensor<EO, ReadWrite>,
40    pub x_offset: u32,
41    pub y_offset: u32,
42    pub stride_x: u32,
43    pub stride_y: u32,
44    pub shape_x: u32,
45    pub shape_y: u32,
46    pub batch_offset: u32,
47}
48
49unsafe impl<EG: Numeric> Sync for TensorReader<EG> {}
50unsafe impl<EG: Numeric> Send for TensorReader<EG> {}
51unsafe impl<EG: Numeric> Sync for MappedTensorReader<EG> {}
52unsafe impl<EG: Numeric> Send for MappedTensorReader<EG> {}
53unsafe impl<EG: Numeric> Sync for TensorWriter<EG> {}
54unsafe impl<EG: Numeric> Send for TensorWriter<EG> {}
55
56#[derive(CubeType)]
57/// Contiguous slice wrapper for memcpy_async loading
58pub struct Window<EG: Numeric> {
59    /// Contiguous slice containing all and only data of window
60    pub slice: Slice<Line<EG>>,
61    /// Number of lines
62    pub size: u32,
63}
64
65#[cube]
66impl<EG: Numeric> MappedTensorReader<EG> {
67    /// Instantiate a read view over the given tensor, pre-fetching needed strides and shapes
68    pub fn new(tensor: TensorMap<EG>, tile_x: u32, tile_y: u32, batch: u32) -> Self {
69        MappedTensorReader::<EG> {
70            tensor,
71            tile_x,
72            tile_y,
73            batch,
74        }
75    }
76
77    /// Advance the view along the k dimension by a specified offset, `k_offset`.
78    pub fn update_view(&mut self, k_offset: u32, #[comptime] ident: Ident) {
79        match ident.as_input_ident() {
80            InputIdent::Lhs => {
81                self.tile_y += k_offset;
82            }
83            InputIdent::Rhs => {
84                self.tile_x += k_offset;
85            }
86        }
87    }
88}
89
90#[cube]
91impl<EG: Numeric> TensorReader<EG> {
92    /// Instantiate a read view over the given tensor, pre-fetching needed strides and shapes
93    pub fn new(tensor: VirtualTensor<EG>, x_offset: u32, y_offset: u32, batch_offset: u32) -> Self {
94        let rank = tensor.rank();
95        let stride_x = tensor.stride(rank - 2);
96        let stride_y = tensor.stride(rank - 1);
97        let shape_x = tensor.shape(rank - 2);
98        let shape_y = tensor.shape(rank - 1);
99
100        TensorReader::<EG> {
101            tensor,
102            x_offset,
103            y_offset,
104            stride_x,
105            stride_y,
106            shape_x,
107            shape_y,
108            batch_offset,
109        }
110    }
111
112    /// Advance the view along the k dimension by a specified offset, `k_offset`.
113    pub fn update_view(&mut self, k_offset: u32, #[comptime] ident: InputIdent) {
114        match ident {
115            InputIdent::Lhs => {
116                self.y_offset += k_offset;
117            }
118            InputIdent::Rhs => {
119                self.x_offset += k_offset;
120            }
121        }
122    }
123
124    /// Reads data from the tensor view as a window, i.e. a slice of global memory
125    /// Also returns the length of the slice
126    ///
127    /// The length of the slice is the width of the tile
128    ///
129    /// # Note
130    ///
131    /// If the slice would be partly out-of-bounds, it will simply be shorter.
132    /// The caller must do the padding if necessary.
133    pub fn load_window_in_tile<G: global::GlobalConfig>(
134        &self,
135        tile: (u32, u32),
136        nth_window: u32,
137        #[comptime] input_ident: InputIdent,
138        #[comptime] config: G,
139    ) -> Window<EG> {
140        let line_size = config.global_line_size(input_ident);
141        let tiling_dimensions = config.tiling_dimensions(input_ident);
142        let matrix_layout = config.matrix_layout(input_ident);
143
144        let tile_size_x = tiling_dimensions.tile_shape_row();
145        let tile_size_y = tiling_dimensions.tile_shape_col();
146
147        let num_lines_in_window = comptime! {match matrix_layout {
148            MatrixLayout::RowMajor => tile_size_y / line_size,
149            MatrixLayout::ColMajor => tile_size_x / line_size,
150        }};
151
152        self.load_window::<G>(
153            nth_window,
154            (tile.0 * tile_size_x, tile.1 * tile_size_y),
155            num_lines_in_window,
156            input_ident,
157            config,
158        )
159    }
160
161    /// Reads data from the tensor view as a window, i.e. a slice of global memory
162    ///
163    /// The length of the slice is the width of the tile
164    ///
165    /// # Note
166    ///
167    /// If the slice would be partly out-of-bounds, it will simply be shorter.
168    /// The caller must do the padding if necessary.
169    pub fn load_window_in_stage<G: global::GlobalConfig>(
170        &self,
171        nth_window: u32,
172        #[comptime] input_ident: InputIdent,
173        #[comptime] config: G,
174    ) -> Window<EG> {
175        let line_size = config.global_line_size(input_ident);
176        let tiling_dimensions = config.tiling_dimensions(input_ident);
177        let matrix_layout = config.matrix_layout(input_ident);
178
179        let num_lines_in_window = comptime! {match matrix_layout {
180            MatrixLayout::RowMajor =>
181                tiling_dimensions.total_col() / line_size
182            ,
183            MatrixLayout::ColMajor =>
184                tiling_dimensions.total_row() / line_size
185            ,
186        }};
187
188        self.load_window::<G>(
189            nth_window,
190            (0u32, 0u32).runtime(),
191            num_lines_in_window,
192            input_ident,
193            config,
194        )
195    }
196
197    fn load_window<G: global::GlobalConfig>(
198        &self,
199        nth_window: u32,
200        tile_offsets: (u32, u32),
201        #[comptime] num_lines_in_window: u32,
202        #[comptime] ident: InputIdent,
203        #[comptime] config: G,
204    ) -> Window<EG> {
205        let line_size = config.global_line_size(ident);
206        let matrix_layout = config.matrix_layout(ident);
207
208        let (load_x, load_y) = match matrix_layout {
209            MatrixLayout::RowMajor => (nth_window, 0),
210            MatrixLayout::ColMajor => (0, nth_window),
211        };
212
213        let view_tile_x = tile_offsets.0 + self.x_offset;
214        let view_tile_y = tile_offsets.1 + self.y_offset;
215
216        let view_x = view_tile_x + load_x;
217        let view_y = view_tile_y + load_y;
218
219        let read_pos =
220            (view_x * self.stride_x + view_y * self.stride_y + self.batch_offset) / line_size;
221
222        let (check_h_bounds, view_h, shape_h, check_w_bounds, view_w, shape_w) =
223            match config.matrix_layout(ident) {
224                MatrixLayout::RowMajor => (
225                    config.check_row_bounds(ident),
226                    view_x,
227                    self.shape_x,
228                    config.check_col_bounds(ident),
229                    view_y,
230                    self.shape_y,
231                ),
232                MatrixLayout::ColMajor => (
233                    config.check_col_bounds(ident),
234                    view_y,
235                    self.shape_y,
236                    config.check_row_bounds(ident),
237                    view_x,
238                    self.shape_x,
239                ),
240            };
241
242        // There are 0 lines if out-of-bounds vertically
243        let max_lines_in_window = if comptime!(check_h_bounds) {
244            num_lines_in_window * u32::cast_from(view_h < shape_h)
245        } else {
246            num_lines_in_window.runtime()
247        };
248
249        // Window is clamped if partially out-of-bounds horizontally
250        let size = if comptime!(check_w_bounds) {
251            slice_length_clamp(shape_w / line_size, view_w / line_size, max_lines_in_window)
252        } else {
253            max_lines_in_window
254        };
255
256        Window::<EG> {
257            slice: self.tensor.as_slice(read_pos, read_pos + size),
258            size,
259        }
260    }
261
262    /// Reads data from the tensor view at the specified tile coordinates (tile_x, tile_y).
263    ///
264    /// Each unit loads one line in a coalesced manner for improved efficiency.
265    /// For row-major tensors, subsequent units read lines horizontally within the tile,
266    /// while for column-major tensors, they read lines vertically.
267    ///
268    /// # Note
269    ///
270    /// Out-of-bounds reads will be translated to zeros.
271    pub fn load_coalesced_in_tile<G: global::GlobalConfig>(
272        &self,
273        tile_x: u32,
274        tile_y: u32,
275        position: u32,
276        #[comptime] input_ident: InputIdent,
277        #[comptime] config: G,
278    ) -> Line<EG> {
279        let tile_shape_x = config.tiling_dimensions(input_ident).tile_shape_row();
280        let tile_shape_y = config.tiling_dimensions(input_ident).tile_shape_col();
281
282        let view_tile_x = tile_x * tile_shape_x;
283        let view_tile_y = tile_y * tile_shape_y;
284
285        let (load_x, load_y) = match config.matrix_layout(input_ident) {
286            MatrixLayout::RowMajor => (position / tile_shape_y, position % tile_shape_y),
287            MatrixLayout::ColMajor => (position % tile_shape_x, position / tile_shape_x),
288        };
289
290        self.load_coalesced::<G>(
291            (load_x + view_tile_x, load_y + view_tile_y),
292            input_ident,
293            config,
294        )
295    }
296
297    /// Reads data from the tensor view at the specified index within the whole view,
298    /// without regards to tiles
299    ///
300    /// Each unit loads one line in a coalesced manner for improved efficiency.
301    /// For row-major tensors, subsequent units read lines horizontally within the tile,
302    /// while for column-major tensors, they read lines vertically.
303    ///
304    /// # Note
305    ///
306    /// Out-of-bounds reads will be translated to zeros.
307    pub fn load_coalesced_in_stage<G: global::GlobalConfig>(
308        &self,
309        position: u32,
310        #[comptime] input_ident: InputIdent,
311        #[comptime] config: G,
312    ) -> Line<EG> {
313        let stage_shape_x = config.tiling_dimensions(input_ident).total_row();
314        let stage_shape_y = config.tiling_dimensions(input_ident).total_col();
315
316        let load_offsets = match config.matrix_layout(input_ident) {
317            MatrixLayout::RowMajor => (position / stage_shape_y, position % stage_shape_y),
318            MatrixLayout::ColMajor => (position % stage_shape_x, position / stage_shape_x),
319        };
320
321        self.load_coalesced::<G>(load_offsets, input_ident, config)
322    }
323
324    fn load_coalesced<G: global::GlobalConfig>(
325        &self,
326        load_offsets: (u32, u32),
327        #[comptime] input_ident: InputIdent,
328        #[comptime] config: G,
329    ) -> Line<EG> {
330        let line_size = config.global_line_size(input_ident);
331
332        let view_x = load_offsets.0 + self.x_offset;
333        let view_y = load_offsets.1 + self.y_offset;
334
335        let read_pos =
336            (view_x * self.stride_x + view_y * self.stride_y + self.batch_offset) / line_size;
337
338        match comptime!((
339            config.check_row_bounds(input_ident),
340            config.check_col_bounds(input_ident)
341        )) {
342            (true, true) => read_masked::<Line<EG>>(
343                view_x < self.shape_x && view_y < self.shape_y,
344                self.tensor.as_slice(0, self.tensor.len()),
345                read_pos,
346                Line::cast_from(0),
347            ),
348            (true, false) => read_masked::<Line<EG>>(
349                view_x < self.shape_x,
350                self.tensor.as_slice(0, self.tensor.len()),
351                read_pos,
352                Line::cast_from(0),
353            ),
354            (false, true) => read_masked::<Line<EG>>(
355                view_y < self.shape_y,
356                self.tensor.as_slice(0, self.tensor.len()),
357                read_pos,
358                Line::cast_from(0),
359            ),
360            (false, false) => self.tensor.read(read_pos),
361        }
362    }
363}
364
365#[cube]
366impl<EG: Numeric> TensorWriter<EG> {
367    /// Instantiate a write view over the given tensor, pre-fetching needed strides and shapes
368    pub fn new(
369        tensor: VirtualTensor<EG, ReadWrite>,
370        x_offset: u32,
371        y_offset: u32,
372        batch_offset: u32,
373    ) -> Self {
374        let rank = tensor.rank();
375        let stride_x = tensor.stride(rank - 2);
376        let stride_y = tensor.stride(rank - 1);
377        let shape_x = tensor.shape(rank - 2);
378        let shape_y = tensor.shape(rank - 1);
379
380        TensorWriter::<EG> {
381            tensor,
382            x_offset,
383            y_offset,
384            stride_x,
385            stride_y,
386            shape_x,
387            shape_y,
388            batch_offset,
389        }
390    }
391
392    /// Writes data into the tensor view at the specified coordinates (tile_x, tile_y).
393    ///
394    /// Each unit writes one line in a coalesced manner for improved efficiency, assuming row-major layout.
395    pub fn write_coalesced<ES: Numeric, G: global::GlobalConfig>(
396        &mut self,
397        tile_x: u32,
398        tile_y: u32,
399        unit_id: u32,
400        value: Line<ES>,
401        #[comptime] config: G,
402    ) {
403        let tiling = config.tiling_dimensions(Ident::Out);
404
405        let view_x =
406            tile_x * tiling.tile_shape_row() + unit_id / tiling.tile_shape_col() + self.x_offset;
407        let view_y =
408            tile_y * tiling.tile_shape_col() + unit_id % tiling.tile_shape_col() + self.y_offset;
409
410        let write_position = (view_x * self.stride_x + view_y * self.stride_y + self.batch_offset)
411            / config.global_line_size(Ident::Out);
412
413        match comptime!((
414            config.check_row_bounds(Ident::Out),
415            config.check_col_bounds(Ident::Out)
416        )) {
417            (true, true) => {
418                if view_x < self.shape_x && view_y < self.shape_y {
419                    self.write(write_position, Line::cast_from(value));
420                }
421            }
422            (true, false) => {
423                if view_x < self.shape_x {
424                    self.write(write_position, Line::cast_from(value));
425                }
426            }
427            (false, true) => {
428                if view_y < self.shape_y {
429                    self.write(write_position, Line::cast_from(value));
430                }
431            }
432            (false, false) => {
433                self.write(write_position, Line::cast_from(value));
434            }
435        }
436    }
437
438    fn write(&mut self, position: u32, value: Line<EG>) {
439        self.tensor.write(position, value)
440    }
441}
442
443#[cube]
444/// Gives the largest slice starting at offset and not exceeding shape
445fn slice_length_clamp(shape: u32, offset: u32, max_length: u32) -> u32 {
446    Min::min(select(shape > offset, shape - offset, 0), max_length)
447}