cubek_std/tile/ops/
mask.rs1use cubecl;
2use cubecl::{prelude::*, std::tensor::layout::Coords2d};
3
4use crate::tile::scope::Scope;
5use crate::tile::variants::{InnerLayout, LocalTileLayout, UnitTileLayout};
6use crate::tile::{Tile, TileExpand};
7
8#[cube]
9pub trait Mask: CubeType {
13 fn should_mask(&self, local_pos: Coords2d) -> bool;
14}
15
16#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
19pub enum MaskLayout {
20 Unit(UnitTileLayout),
22 Local(LocalTileLayout),
25}
26
27impl MaskLayout {
28 pub const fn unit(num_rows: u32, num_cols: u32) -> MaskLayout {
29 MaskLayout::Unit(UnitTileLayout {
30 num_rows,
31 num_cols,
32 transposed_load: false,
33 })
34 }
35
36 pub const fn local(
37 tile_shape: Coords2d,
38 plane_dim: u32,
39 inner_layout: InnerLayout,
40 ) -> MaskLayout {
41 let total_elements = tile_shape.0 * tile_shape.1;
42 let elements_per_unit = total_elements.div_ceil(plane_dim);
43 let (num_rows_per_unit, num_cols_per_unit) = match inner_layout {
44 InnerLayout::Contiguous => (1u32, elements_per_unit),
45 InnerLayout::SplitRows => (2u32, elements_per_unit / 2u32),
46 };
47 let unit_size = (num_rows_per_unit, num_cols_per_unit);
48 let num_units_per_row = tile_shape.1 / unit_size.1;
49
50 MaskLayout::Local(LocalTileLayout {
51 total_size: tile_shape,
52 unit_size,
53 num_units_per_row,
54 plane_dim,
55 })
56 }
57}
58
59#[cube]
60pub fn mask_layout_num_units_per_row(#[comptime] layout: MaskLayout) -> comptime_type!(u32) {
62 match layout {
63 MaskLayout::Unit(_) => 1u32,
64 MaskLayout::Local(l) => comptime!(l.total_size.1 / l.unit_size.1),
65 }
66}
67
68#[cube]
69pub fn mask_layout_absolute_pos(#[comptime] layout: MaskLayout, local_pos: Coords2d) -> Coords2d {
71 match layout {
72 MaskLayout::Unit(_) => local_pos,
73 MaskLayout::Local(l) => {
74 let abs_row_index = {
75 let row_0 = UNIT_POS_X / l.num_units_per_row;
76 let row_jump = comptime!(l.plane_dim / l.num_units_per_row);
77 local_pos.0 * row_jump + row_0
78 };
79 let abs_col_index = l.unit_size.1 * (UNIT_POS_X % l.num_units_per_row) + local_pos.1;
80 (abs_row_index, abs_col_index)
81 }
82 }
83}
84
85#[cube]
86impl<E: Numeric, Sc: Scope, IO: SliceVisibility> Mask for Tile<E, Sc, IO> {
87 fn should_mask(&self, local_pos: Coords2d) -> bool {
88 match self {
89 Tile::Unit(t) => {
90 bool::cast_from(t.data[(local_pos.0 * t.layout.num_cols + local_pos.1) as usize])
91 }
92 Tile::Local(t) => bool::cast_from(
93 t.array[(local_pos.0 * t.layout.unit_size.1 + local_pos.1) as usize],
94 ),
95 _ => panic!("Mask::should_mask is only defined for Tile::Unit and Tile::Local"),
96 }
97 }
98}
99
100#[cube]
101impl<N: Numeric, Sc: Scope> Tile<N, Sc, ReadWrite> {
102 pub fn load_mask_from_strided_tile<E: Numeric, ES: Size>(
105 &mut self,
106 tile: &crate::tile::StridedTile<E, ES>,
107 ) {
108 match self {
109 Tile::Unit(t) => t.load_from_strided_tile::<E, ES>(tile),
110 Tile::Local(t) => t.load_from_strided_tile::<E, ES>(tile),
111 _ => panic!("load_mask_from_strided_tile: unsupported tile variant"),
112 }
113 }
114}