use cubecl;
use cubecl::{prelude::*, std::tensor::layout::Coords2d};
use crate::tile::{
Tile, TileExpand, TileScope,
data::{InnerLayout, StridedTile, UnitTileLayout, WhiteboxFragmentLayout},
};
#[cube]
pub trait Mask: CubeType {
fn should_mask(&self, local_pos: Coords2d) -> bool;
}
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
pub enum MaskLayout {
Unit(UnitTileLayout),
WhiteboxFragment(WhiteboxFragmentLayout),
}
impl MaskLayout {
pub const fn unit(num_rows: u32, num_cols: u32) -> MaskLayout {
MaskLayout::Unit(UnitTileLayout {
num_rows,
num_cols,
transposed_load: false,
})
}
pub const fn whitebox_fragment(
tile_shape: Coords2d,
plane_dim: u32,
inner_layout: InnerLayout,
) -> MaskLayout {
let total_elements = tile_shape.0 * tile_shape.1;
let elements_per_unit = total_elements.div_ceil(plane_dim);
let (num_rows_per_unit, num_cols_per_unit) = match inner_layout {
InnerLayout::Contiguous => (1u32, elements_per_unit),
InnerLayout::SplitRows => (2u32, elements_per_unit / 2u32),
};
let unit_size = (num_rows_per_unit, num_cols_per_unit);
let num_units_per_row = tile_shape.1 / unit_size.1;
MaskLayout::WhiteboxFragment(WhiteboxFragmentLayout {
total_size: tile_shape,
unit_size,
num_units_per_row,
plane_dim,
})
}
}
#[cube]
pub fn mask_layout_num_units_per_row(#[comptime] layout: MaskLayout) -> comptime_type!(u32) {
match layout {
MaskLayout::Unit(_) => 1u32,
MaskLayout::WhiteboxFragment(l) => comptime!(l.total_size.1 / l.unit_size.1),
}
}
#[cube]
pub fn mask_layout_absolute_pos(#[comptime] layout: MaskLayout, local_pos: Coords2d) -> Coords2d {
match layout {
MaskLayout::Unit(_) => local_pos,
MaskLayout::WhiteboxFragment(l) => {
let abs_row_index = {
let row_0 = UNIT_POS_X / l.num_units_per_row;
let row_jump = comptime!(l.plane_dim / l.num_units_per_row);
local_pos.0 * row_jump + row_0
};
let abs_col_index = l.unit_size.1 * (UNIT_POS_X % l.num_units_per_row) + local_pos.1;
(abs_row_index, abs_col_index)
}
}
}
#[cube]
impl<E: Numeric, Sc: TileScope, IO: SliceVisibility> Mask for Tile<E, Sc, IO> {
fn should_mask(&self, local_pos: Coords2d) -> bool {
match self {
Tile::Unit(t) => {
bool::cast_from(t.data[(local_pos.0 * t.layout.num_cols + local_pos.1) as usize])
}
Tile::WhiteboxFragment(t) => bool::cast_from(
t.array[(local_pos.0 * t.layout.unit_size.1 + local_pos.1) as usize],
),
_ => panic!(
"Mask::should_mask is only defined for Tile::Unit and Tile::WhiteboxFragment"
),
}
}
}
#[cube]
impl<N: Numeric, Sc: TileScope> Tile<N, Sc, ReadWrite> {
pub fn load_mask_from_strided_tile<E: Numeric, ES: Size>(&mut self, tile: &StridedTile<E, ES>) {
match self {
Tile::Unit(t) => t.load_from_strided_tile::<E, ES>(tile),
Tile::WhiteboxFragment(t) => t.load_from_strided_tile::<E, ES>(tile),
_ => panic!("load_mask_from_strided_tile: unsupported tile variant"),
}
}
}