cubek_std/tile/compute/
mask.rs1use cubecl;
2use cubecl::{prelude::*, std::tensor::layout::Coords2d};
3
4use crate::tile::{
5 Tile, TileExpand, TileScope,
6 data::{InnerLayout, StridedTile, UnitTileLayout, WhiteboxFragmentLayout},
7};
8
9#[cube]
10pub trait Mask: CubeType {
14 fn should_mask(&self, local_pos: Coords2d) -> bool;
15}
16
17#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
20pub enum MaskLayout {
21 Unit(UnitTileLayout),
23 WhiteboxFragment(WhiteboxFragmentLayout),
26}
27
28impl MaskLayout {
29 pub const fn unit(num_rows: u32, num_cols: u32) -> MaskLayout {
30 MaskLayout::Unit(UnitTileLayout {
31 num_rows,
32 num_cols,
33 transposed_load: false,
34 })
35 }
36
37 pub const fn whitebox_fragment(
38 tile_shape: Coords2d,
39 plane_dim: u32,
40 inner_layout: InnerLayout,
41 ) -> MaskLayout {
42 let total_elements = tile_shape.0 * tile_shape.1;
43 let elements_per_unit = total_elements.div_ceil(plane_dim);
44 let (num_rows_per_unit, num_cols_per_unit) = match inner_layout {
45 InnerLayout::Contiguous => (1u32, elements_per_unit),
46 InnerLayout::SplitRows => (2u32, elements_per_unit / 2u32),
47 };
48 let unit_size = (num_rows_per_unit, num_cols_per_unit);
49 let num_units_per_row = tile_shape.1 / unit_size.1;
50
51 MaskLayout::WhiteboxFragment(WhiteboxFragmentLayout {
52 total_size: tile_shape,
53 unit_size,
54 num_units_per_row,
55 plane_dim,
56 })
57 }
58}
59
60#[cube]
61pub fn mask_layout_num_units_per_row(#[comptime] layout: MaskLayout) -> comptime_type!(u32) {
63 match layout {
64 MaskLayout::Unit(_) => 1u32,
65 MaskLayout::WhiteboxFragment(l) => comptime!(l.total_size.1 / l.unit_size.1),
66 }
67}
68
69#[cube]
70pub fn mask_layout_absolute_pos(#[comptime] layout: MaskLayout, local_pos: Coords2d) -> Coords2d {
72 match layout {
73 MaskLayout::Unit(_) => local_pos,
74 MaskLayout::WhiteboxFragment(l) => {
75 let abs_row_index = {
76 let row_0 = UNIT_POS_X / l.num_units_per_row;
77 let row_jump = comptime!(l.plane_dim / l.num_units_per_row);
78 local_pos.0 * row_jump + row_0
79 };
80 let abs_col_index = l.unit_size.1 * (UNIT_POS_X % l.num_units_per_row) + local_pos.1;
81 (abs_row_index, abs_col_index)
82 }
83 }
84}
85
86#[cube]
87impl<E: Numeric, Sc: TileScope, IO: SliceVisibility> Mask for Tile<E, Sc, IO> {
88 fn should_mask(&self, local_pos: Coords2d) -> bool {
89 match self {
90 Tile::Unit(t) => {
91 bool::cast_from(t.data[(local_pos.0 * t.layout.num_cols + local_pos.1) as usize])
92 }
93 Tile::WhiteboxFragment(t) => bool::cast_from(
94 t.array[(local_pos.0 * t.layout.unit_size.1 + local_pos.1) as usize],
95 ),
96 _ => panic!(
97 "Mask::should_mask is only defined for Tile::Unit and Tile::WhiteboxFragment"
98 ),
99 }
100 }
101}
102
103#[cube]
104impl<N: Numeric, Sc: TileScope> Tile<N, Sc, ReadWrite> {
105 pub fn load_mask_from_strided_tile<E: Numeric, ES: Size>(&mut self, tile: &StridedTile<E, ES>) {
108 match self {
109 Tile::Unit(t) => t.load_from_strided_tile::<E, ES>(tile),
110 Tile::WhiteboxFragment(t) => t.load_from_strided_tile::<E, ES>(tile),
111 _ => panic!("load_mask_from_strided_tile: unsupported tile variant"),
112 }
113 }
114}