Skip to main content

cubek_std/tile/compute/
mask.rs

1use cubecl;
2use cubecl::{prelude::*, std::tensor::layout::Coords2d};
3
4use crate::tile::{
5    Tile, TileExpand, TileScope,
6    data::{InnerLayout, StridedTile, UnitTileLayout, WhiteboxFragmentLayout},
7};
8
9#[cube]
10/// Minimal mask abstraction used by row-wise tile operations.
11/// Returns `true` when the element at `local_pos` should be treated as masked
12/// (i.e. driven to -inf by `Tile::scale_and_mask`).
13pub trait Mask: CubeType {
14    fn should_mask(&self, local_pos: Coords2d) -> bool;
15}
16
17/// Layout of an attention-style mask fragment across the units of a plane.
18/// Purely comptime — all variants carry only comptime data.
19#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
20pub enum MaskLayout {
21    /// Each unit owns a full row-major copy of the tile.
22    Unit(UnitTileLayout),
23    /// The tile is fragmented across plane units, with the layout described by
24    /// [`WhiteboxFragmentLayout`].
25    WhiteboxFragment(WhiteboxFragmentLayout),
26}
27
28impl MaskLayout {
29    pub const fn unit(num_rows: u32, num_cols: u32) -> MaskLayout {
30        MaskLayout::Unit(UnitTileLayout {
31            num_rows,
32            num_cols,
33            transposed_load: false,
34        })
35    }
36
37    pub const fn whitebox_fragment(
38        tile_shape: Coords2d,
39        plane_dim: u32,
40        inner_layout: InnerLayout,
41    ) -> MaskLayout {
42        let total_elements = tile_shape.0 * tile_shape.1;
43        let elements_per_unit = total_elements.div_ceil(plane_dim);
44        let (num_rows_per_unit, num_cols_per_unit) = match inner_layout {
45            InnerLayout::Contiguous => (1u32, elements_per_unit),
46            InnerLayout::SplitRows => (2u32, elements_per_unit / 2u32),
47        };
48        let unit_size = (num_rows_per_unit, num_cols_per_unit);
49        let num_units_per_row = tile_shape.1 / unit_size.1;
50
51        MaskLayout::WhiteboxFragment(WhiteboxFragmentLayout {
52            total_size: tile_shape,
53            unit_size,
54            num_units_per_row,
55            plane_dim,
56        })
57    }
58}
59
60#[cube]
61/// Returns how many units in a plane participate in the same row.
62pub fn mask_layout_num_units_per_row(#[comptime] layout: MaskLayout) -> comptime_type!(u32) {
63    match layout {
64        MaskLayout::Unit(_) => 1u32,
65        MaskLayout::WhiteboxFragment(l) => comptime!(l.total_size.1 / l.unit_size.1),
66    }
67}
68
69#[cube]
70/// Maps a per-unit `(row, col)` to its absolute position within the tile.
71pub fn mask_layout_absolute_pos(#[comptime] layout: MaskLayout, local_pos: Coords2d) -> Coords2d {
72    match layout {
73        MaskLayout::Unit(_) => local_pos,
74        MaskLayout::WhiteboxFragment(l) => {
75            let abs_row_index = {
76                let row_0 = UNIT_POS_X / l.num_units_per_row;
77                let row_jump = comptime!(l.plane_dim / l.num_units_per_row);
78                local_pos.0 * row_jump + row_0
79            };
80            let abs_col_index = l.unit_size.1 * (UNIT_POS_X % l.num_units_per_row) + local_pos.1;
81            (abs_row_index, abs_col_index)
82        }
83    }
84}
85
86#[cube]
87impl<E: Numeric, Sc: TileScope, IO: SliceVisibility> Mask for Tile<E, Sc, IO> {
88    fn should_mask(&self, local_pos: Coords2d) -> bool {
89        match self {
90            Tile::Unit(t) => {
91                bool::cast_from(t.data[(local_pos.0 * t.layout.num_cols + local_pos.1) as usize])
92            }
93            Tile::WhiteboxFragment(t) => bool::cast_from(
94                t.array[(local_pos.0 * t.layout.unit_size.1 + local_pos.1) as usize],
95            ),
96            _ => panic!(
97                "Mask::should_mask is only defined for Tile::Unit and Tile::WhiteboxFragment"
98            ),
99        }
100    }
101}
102
103#[cube]
104impl<N: Numeric, Sc: TileScope> Tile<N, Sc, ReadWrite> {
105    /// Loads the data from an external strided tile into the inner storage of a
106    /// `Tile::Unit` or `Tile::WhiteboxFragment`. Used to materialize a mask fragment.
107    pub fn load_mask_from_strided_tile<E: Numeric, ES: Size>(&mut self, tile: &StridedTile<E, ES>) {
108        match self {
109            Tile::Unit(t) => t.load_from_strided_tile::<E, ES>(tile),
110            Tile::WhiteboxFragment(t) => t.load_from_strided_tile::<E, ES>(tile),
111            _ => panic!("load_mask_from_strided_tile: unsupported tile variant"),
112        }
113    }
114}