cubek_std/tile/compute/
elementwise.rs1use cubecl;
2use cubecl::prelude::*;
3
4use crate::tile::compute::{Mask, MaskExpand};
5use crate::tile::{Plane, Tile, TileExpand};
6
7#[cube]
12impl<E: Float> Tile<E, Plane, ReadWrite> {
13 pub fn scale_and_mask<M: Mask>(&mut self, scale: E, mask: &M) {
16 match self {
17 Tile::Unit(t) => t.scale_and_mask::<M>(scale, mask),
18 Tile::WhiteboxFragment(t) => t.scale_and_mask::<M>(scale, mask),
19 Tile::Bounce(b) => b.fragment.scale_and_mask::<M>(scale, mask),
20 Tile::Register(t) => {
21 let m = comptime!(t.config.tile_size.m());
22 let n = comptime!(t.config.tile_size.n());
23 for r in 0..m {
24 let row_offset = r * n;
25 for c in 0..n {
26 let idx = (row_offset + c) as usize;
27 t.data[idx] = t.data[idx] * scale
28 + E::cast_from(mask.should_mask((r, c))) * E::min_value();
29 }
30 }
31 }
32 _ => panic!("scale_and_mask: unsupported tile variant"),
33 }
34 }
35
36 pub fn fill_zero(&mut self) {
38 match self {
39 Tile::Register(t) => {
40 let m = comptime!(t.config.tile_size.m());
41 let n = comptime!(t.config.tile_size.n());
42 for i in 0..m * n {
43 t.data[i as usize] = E::from_int(0);
44 }
45 }
46 Tile::Unit(t) => t.zero(),
47 Tile::WhiteboxFragment(t) => t.zero(),
48 Tile::Bounce(b) => {
49 cubecl::cmma::fill(&b.cmma.matrix, E::from_int(0));
50 }
51 Tile::Cmma(t) => {
52 cubecl::cmma::fill(&t.matrix, E::from_int(0));
53 }
54 _ => panic!("fill_zero: unsupported tile variant"),
55 }
56 }
57}