Skip to main content

cubek_std/tile/compute/
elementwise.rs

1use cubecl;
2use cubecl::prelude::*;
3
4use crate::tile::compute::{Mask, MaskExpand};
5use crate::tile::{Plane, Tile, TileExpand};
6
7/// Element-wise tile operations on `Tile<E, Plane, ReadWrite>`. Unlike the
8/// row-wise primitives in [`crate::tile::compute::rowwise`], these touch every
9/// element with no row-axis structure: a uniform scalar scale, a per-element
10/// mask bool, or a whole-tile fill.
11#[cube]
12impl<E: Float> Tile<E, Plane, ReadWrite> {
13    /// Multiplies each element by `scale` and adds `-inf` at masked positions.
14    /// `scale` is a scalar; `mask.should_mask((r, c))` is element-wise.
15    pub fn scale_and_mask<M: Mask>(&mut self, scale: E, mask: &M) {
16        match self {
17            Tile::Unit(t) => t.scale_and_mask::<M>(scale, mask),
18            Tile::WhiteboxFragment(t) => t.scale_and_mask::<M>(scale, mask),
19            Tile::Bounce(b) => b.fragment.scale_and_mask::<M>(scale, mask),
20            Tile::Register(t) => {
21                let m = comptime!(t.config.tile_size.m());
22                let n = comptime!(t.config.tile_size.n());
23                for r in 0..m {
24                    let row_offset = r * n;
25                    for c in 0..n {
26                        let idx = (row_offset + c) as usize;
27                        t.data[idx] = t.data[idx] * scale
28                            + E::cast_from(mask.should_mask((r, c))) * E::min_value();
29                    }
30                }
31            }
32            _ => panic!("scale_and_mask: unsupported tile variant"),
33        }
34    }
35
36    /// Zeros every element in the tile.
37    pub fn fill_zero(&mut self) {
38        match self {
39            Tile::Register(t) => {
40                let m = comptime!(t.config.tile_size.m());
41                let n = comptime!(t.config.tile_size.n());
42                for i in 0..m * n {
43                    t.data[i as usize] = E::from_int(0);
44                }
45            }
46            Tile::Unit(t) => t.zero(),
47            Tile::WhiteboxFragment(t) => t.zero(),
48            Tile::Bounce(b) => {
49                cubecl::cmma::fill(&b.cmma.matrix, E::from_int(0));
50            }
51            Tile::Cmma(t) => {
52                cubecl::cmma::fill(&t.matrix, E::from_int(0));
53            }
54            _ => panic!("fill_zero: unsupported tile variant"),
55        }
56    }
57}