use cubecl::{
prelude::*,
std::tensor::{View, layout::Coords2d},
};
use cubek_std::{MatrixLayout, stage::StageMemoryConfig};
use crate::components::global::memory::GlobalMemoryConfig;
#[cube]
pub fn load_window_in_tile<EG: Numeric>(
view: &View<EG, Coords2d>,
tile: Coords2d,
nth_window: u32,
#[comptime] smem_config: StageMemoryConfig,
#[comptime] gmem_config: GlobalMemoryConfig,
) -> Slice<EG> {
let (tile_row, tile_col) = tile;
let tile_size_row = smem_config.elements_per_tile_along_row;
let tile_size_col = smem_config.elements_per_tile_along_col;
let size = match smem_config.matrix_layout {
MatrixLayout::RowMajor => (1u32, tile_size_col).runtime(),
MatrixLayout::ColMajor => (tile_size_row, 1u32).runtime(),
};
let offset = (tile_row * tile_size_row, tile_col * tile_size_col);
let tile_size = (tile_size_row, tile_size_col).runtime();
load_window(
&view.slice(offset, tile_size),
nth_window,
size,
gmem_config,
)
}
#[cube]
pub fn load_window_in_stage<EG: CubePrimitive>(
view: &View<EG, Coords2d>,
nth_window: u32,
#[comptime] smem_config: StageMemoryConfig,
#[comptime] gmem_config: GlobalMemoryConfig,
) -> Slice<EG> {
let size = match smem_config.matrix_layout {
MatrixLayout::RowMajor => (1u32, smem_config.elements_per_stage_along_col()).runtime(),
MatrixLayout::ColMajor => (smem_config.elements_per_stage_along_row(), 1u32).runtime(),
};
load_window(view, nth_window, size, gmem_config)
}
#[cube]
fn load_window<EG: CubePrimitive>(
view: &View<EG, Coords2d>,
nth_window: u32,
size: Coords2d,
#[comptime] gmem_config: GlobalMemoryConfig,
) -> Slice<EG> {
let offset = match gmem_config.matrix_layout {
MatrixLayout::RowMajor => (nth_window, 0),
MatrixLayout::ColMajor => (0, nth_window),
};
if gmem_config.check_row_bounds || gmem_config.check_col_bounds {
view.slice(offset, size).to_linear_slice()
} else {
view.slice_unchecked(offset, size).to_linear_slice()
}
}