cubecl_matmul/components/global/memory/
window.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::tensor::{View, layout::Coords2d};
4
5use crate::components::{MatrixLayout, global::memory::GlobalMemoryConfig};
6
7/// Reads data from the tensor view as a window, i.e. a slice of global memory
8/// Also returns the length of the slice
9///
10/// The length of the slice is the width of the tile
11///
12/// # Note
13///
14/// If the slice would be partly out-of-bounds, it will simply be shorter.
15/// The caller must do the padding if necessary.
16#[cube]
17pub fn load_window_in_tile<EG: Numeric>(
18    view: &View<Line<EG>, Coords2d>,
19    tile: Coords2d,
20    nth_window: u32,
21    #[comptime] config: GlobalMemoryConfig,
22) -> Slice<Line<EG>> {
23    let (tile_row, tile_col) = tile;
24    let tile_size_row = config.elements_in_tile_row;
25    let tile_size_col = config.elements_in_tile_col;
26
27    let size = match config.matrix_layout {
28        MatrixLayout::RowMajor => (1u32, tile_size_col).runtime(),
29        MatrixLayout::ColMajor => (tile_size_row, 1u32).runtime(),
30    };
31
32    let offset = (tile_row * tile_size_row, tile_col * tile_size_col);
33    let tile_size = (tile_size_row, tile_size_col).runtime();
34
35    load_window(&view.slice(offset, tile_size), nth_window, size, config)
36}
37
38/// Reads data from the tensor view as a window, i.e. a slice of global memory
39///
40/// The length of the slice is the width of the tile
41///
42/// # Note
43///
44/// If the slice would be partly out-of-bounds, it will simply be shorter.
45/// The caller must do the padding if necessary.
46#[cube]
47pub fn load_window_in_stage<EG: Numeric>(
48    view: &View<Line<EG>, Coords2d>,
49    nth_window: u32,
50    #[comptime] config: GlobalMemoryConfig,
51) -> Slice<Line<EG>> {
52    let size = match config.matrix_layout {
53        MatrixLayout::RowMajor => (1u32, config.elements_in_stage_col).runtime(),
54        MatrixLayout::ColMajor => (config.elements_in_stage_row, 1u32).runtime(),
55    };
56
57    load_window(view, nth_window, size, config)
58}
59
60#[cube]
61fn load_window<EG: Numeric>(
62    view: &View<Line<EG>, Coords2d>,
63    nth_window: u32,
64    size: Coords2d,
65    #[comptime] config: GlobalMemoryConfig,
66) -> Slice<Line<EG>> {
67    let offset = match config.matrix_layout {
68        MatrixLayout::RowMajor => (nth_window, 0),
69        MatrixLayout::ColMajor => (0, nth_window),
70    };
71
72    if comptime![config.check_row_bounds || config.check_col_bounds] {
73        view.slice(offset, size).to_linear_slice()
74    } else {
75        view.slice_unchecked(offset, size).to_linear_slice()
76    }
77}