cubecl_matmul/components/global/memory/
window.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::tensor::{View, layout::Coords2d};
4
5use crate::components::{MatrixLayout, global::memory::GlobalMemoryConfig};
6
7#[cube]
17pub fn load_window_in_tile<EG: Numeric>(
18 view: &View<Line<EG>, Coords2d>,
19 tile: Coords2d,
20 nth_window: u32,
21 #[comptime] config: GlobalMemoryConfig,
22) -> Slice<Line<EG>> {
23 let (tile_row, tile_col) = tile;
24 let tile_size_row = config.elements_in_tile_row;
25 let tile_size_col = config.elements_in_tile_col;
26
27 let size = match config.matrix_layout {
28 MatrixLayout::RowMajor => (1u32, tile_size_col).runtime(),
29 MatrixLayout::ColMajor => (tile_size_row, 1u32).runtime(),
30 };
31
32 let offset = (tile_row * tile_size_row, tile_col * tile_size_col);
33 let tile_size = (tile_size_row, tile_size_col).runtime();
34
35 load_window(&view.slice(offset, tile_size), nth_window, size, config)
36}
37
38#[cube]
47pub fn load_window_in_stage<EG: Numeric>(
48 view: &View<Line<EG>, Coords2d>,
49 nth_window: u32,
50 #[comptime] config: GlobalMemoryConfig,
51) -> Slice<Line<EG>> {
52 let size = match config.matrix_layout {
53 MatrixLayout::RowMajor => (1u32, config.elements_in_stage_col).runtime(),
54 MatrixLayout::ColMajor => (config.elements_in_stage_row, 1u32).runtime(),
55 };
56
57 load_window(view, nth_window, size, config)
58}
59
60#[cube]
61fn load_window<EG: Numeric>(
62 view: &View<Line<EG>, Coords2d>,
63 nth_window: u32,
64 size: Coords2d,
65 #[comptime] config: GlobalMemoryConfig,
66) -> Slice<Line<EG>> {
67 let offset = match config.matrix_layout {
68 MatrixLayout::RowMajor => (nth_window, 0),
69 MatrixLayout::ColMajor => (0, nth_window),
70 };
71
72 if comptime![config.check_row_bounds || config.check_col_bounds] {
73 view.slice(offset, size).to_linear_slice()
74 } else {
75 view.slice_unchecked(offset, size).to_linear_slice()
76 }
77}