Skip to main content

cubek_std/tile/variants/
unit_tile.rs

1use cubecl::{
2    prelude::*,
3    {self},
4};
5
6use crate::tile::ops::{LOGIT_MASKED, Mask, MaskExpand, RowWise};
7use crate::tile::scope::Scope;
8use crate::tile::{StridedTile, Tile};
9
10#[derive(CubeType)]
11pub struct UnitTile<E: Numeric> {
12    pub data: Array<E>,
13    #[cube(comptime)]
14    pub layout: UnitTileLayout,
15}
16
17#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
18// Assumes row-major. If loading from a col-major source, use transposed_load=true
19pub struct UnitTileLayout {
20    pub num_rows: u32,
21    pub num_cols: u32,
22    pub transposed_load: bool,
23}
24
25impl UnitTileLayout {
26    pub const fn new(num_rows: u32, num_cols: u32, transposed_load: bool) -> UnitTileLayout {
27        UnitTileLayout {
28            num_rows,
29            num_cols,
30            transposed_load,
31        }
32    }
33}
34
35#[cube]
36impl<E: Numeric> UnitTile<E> {
37    pub fn new(#[comptime] layout: UnitTileLayout) -> UnitTile<E> {
38        let data = Array::<E>::new(comptime!(layout.num_rows * layout.num_cols) as usize);
39        UnitTile::<E> { data, layout }
40    }
41
42    pub fn zero(&mut self) {
43        for i in 0..self.layout.num_rows * self.layout.num_cols {
44            self.data[i as usize] = E::from_int(0);
45        }
46    }
47
48    pub fn get(&self, row: u32, col: u32) -> E {
49        self.data[(row * self.layout.num_cols + col) as usize]
50    }
51
52    pub fn accumulate(&mut self, row: u32, col: u32, val: E) {
53        self.data[(row * self.layout.num_cols + col) as usize] += val;
54    }
55
56    pub fn rowwise_scale(&mut self, scale: &RowWise<E>) {
57        for r in 0..self.layout.num_rows as usize {
58            let row_offset = r as u32 * self.layout.num_cols;
59            for c in 0..self.layout.num_cols {
60                let index = row_offset + c;
61                self.data[index as usize] = self.data[index as usize] * scale.vals[r];
62            }
63        }
64    }
65
66    pub fn rowwise_max(&self) -> RowWise<E> {
67        let num_rows = comptime!(self.layout.num_rows) as usize;
68        let num_cols = comptime!(self.layout.num_cols) as usize;
69        let mut vals = Array::new(num_rows);
70
71        for r in 0..num_rows {
72            let row_offset = r * num_cols;
73            let mut val = E::min_value();
74
75            for c in 0..num_cols {
76                let index = row_offset + c;
77                val = max(val, self.data[index]);
78            }
79
80            vals[r] = val;
81        }
82
83        RowWise::<E> { num_rows, vals }
84    }
85
86    pub fn rowwise_sum(&self) -> RowWise<E> {
87        let num_rows = comptime!(self.layout.num_rows) as usize;
88        let num_cols = comptime!(self.layout.num_cols) as usize;
89        let mut vals = Array::new(num_rows);
90
91        for r in 0..num_rows {
92            let row_offset = r * num_cols;
93            let mut val = E::from_int(0);
94
95            for c in 0..num_cols {
96                let index = row_offset + c;
97                val += self.data[index];
98            }
99
100            vals[r] = val;
101        }
102
103        RowWise::<E> { num_rows, vals }
104    }
105
106    // TODO find a way to have this not necessary if E == E2
107    // TODO even if E != E2 it could be written as output to UnitTile::exp_diff rather than exp_diff being inplace
108    pub fn copy_from<E2: Numeric>(&mut self, other: &UnitTile<E2>) {
109        // Assume layouts are the same
110
111        for r in 0..self.layout.num_rows as usize {
112            let row_offset = r as u32 * self.layout.num_cols;
113            for c in 0..self.layout.num_cols {
114                let index = row_offset + c;
115                self.data[index as usize] = E::cast_from(other.data[index as usize]);
116            }
117        }
118    }
119
120    pub fn load_from_strided_tile<E2: Numeric, ES: Size>(&mut self, tile: &StridedTile<E2, ES>) {
121        if comptime!(self.layout.transposed_load) {
122            strided_tile_to_transposed_unit_tile(tile, self)
123        } else {
124            strided_tile_to_unit_tile(tile, self)
125        }
126    }
127
128    pub fn scale_and_mask<M: Mask>(&mut self, scale: E, mask: &M) {
129        for r in 0..self.layout.num_rows {
130            let row_offset = r * self.layout.num_cols;
131            for c in 0..self.layout.num_cols {
132                let index = row_offset + c;
133                self.data[index as usize] = self.data[index as usize] * scale
134                    + E::cast_from(mask.should_mask((r, c))) * E::min_value();
135            }
136        }
137    }
138}
139
140#[cube]
141impl<E: Float> UnitTile<E> {
142    pub fn exp_diff(&mut self, rowwise: &RowWise<E>) {
143        let num_rows = comptime!(self.layout.num_rows) as usize;
144        let num_cols = comptime!(self.layout.num_cols) as usize;
145        let threshold = E::new(LOGIT_MASKED);
146
147        for r in 0..num_rows {
148            let row_offset = r * num_cols;
149
150            let val = rowwise.vals[r];
151
152            for c in 0..num_cols {
153                let index = row_offset + c;
154
155                let safe_val = clamp_min(val, threshold);
156                let not_masked = E::cast_from(val >= threshold);
157                self.data[index] = not_masked * (self.data[index] - safe_val).exp();
158            }
159        }
160    }
161}
162
163#[cube]
164/// Allocates a `Tile::Unit`. The variant is valid in any scope — each unit
165/// just holds its own row-major copy of the tile.
166pub fn allocate_unit_tile<E: Numeric, Sc: Scope>(
167    #[comptime] layout: UnitTileLayout,
168) -> Tile<E, Sc, ReadWrite> {
169    Tile::new_Unit(UnitTile::<E>::new(layout))
170}
171
172#[cube]
173fn strided_tile_to_unit_tile<E: Numeric, N: Size, E2: Numeric>(
174    strided_tile: &StridedTile<E, N>,
175    unit_tile: &mut UnitTile<E2>,
176) {
177    let vector_size = N::value().comptime() as u32;
178    assert!(unit_tile.layout.num_cols.is_multiple_of(vector_size));
179
180    let col_iterations = comptime!(unit_tile.layout.num_cols / vector_size);
181
182    for row in 0..unit_tile.layout.num_rows {
183        for col in 0..col_iterations {
184            let line_read = strided_tile.get_vector(row, col);
185            #[unroll]
186            for i in 0..vector_size {
187                unit_tile.data
188                    [(row * unit_tile.layout.num_cols + col * vector_size + i) as usize] =
189                    E2::cast_from(line_read[i as usize]);
190            }
191        }
192    }
193}
194
195#[cube]
196fn strided_tile_to_transposed_unit_tile<E: Numeric, N: Size, E2: Numeric>(
197    strided_tile: &StridedTile<E, N>,
198    unit_tile: &mut UnitTile<E2>,
199) {
200    let vector_size = N::value().comptime() as u32;
201    assert!(unit_tile.layout.num_cols.is_multiple_of(vector_size));
202
203    let input_num_rows = unit_tile.layout.num_cols.comptime();
204    let input_num_cols = unit_tile.layout.num_rows.comptime();
205    let vector_iterations = input_num_cols / vector_size;
206
207    for input_row in 0..input_num_rows {
208        for input_col_vector in 0..vector_iterations {
209            let vector_read = strided_tile.get_vector(input_row, input_col_vector);
210
211            #[unroll]
212            for i in 0..vector_size {
213                unit_tile.data[((input_col_vector + i) * input_num_rows + input_row) as usize] =
214                    E2::cast_from(vector_read[i as usize]);
215            }
216        }
217    }
218}