cubek_std/tile/data/
unit.rs1use cubecl::{self, prelude::*};
2
3use crate::tile::{
4 LOGIT_MASKED, Tile,
5 compute::{Mask, MaskExpand},
6 data::{rowwise::RowWise, strided::StridedTile},
7 scope::TileScope,
8};
9
10#[derive(CubeType)]
11pub struct UnitTile<E: Numeric> {
12 pub data: Array<E>,
13 #[cube(comptime)]
14 pub layout: UnitTileLayout,
15}
16
17#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
18pub struct UnitTileLayout {
20 pub num_rows: u32,
21 pub num_cols: u32,
22 pub transposed_load: bool,
23}
24
25impl UnitTileLayout {
26 pub const fn new(num_rows: u32, num_cols: u32, transposed_load: bool) -> UnitTileLayout {
27 UnitTileLayout {
28 num_rows,
29 num_cols,
30 transposed_load,
31 }
32 }
33}
34
35#[cube]
36impl<E: Numeric> UnitTile<E> {
37 pub fn new(#[comptime] layout: UnitTileLayout) -> UnitTile<E> {
38 let data = Array::<E>::new(comptime!(layout.num_rows * layout.num_cols) as usize);
39 UnitTile::<E> { data, layout }
40 }
41
42 pub fn zero(&mut self) {
43 for i in 0..self.layout.num_rows * self.layout.num_cols {
44 self.data[i as usize] = E::from_int(0);
45 }
46 }
47
48 pub fn get(&self, row: u32, col: u32) -> E {
49 self.data[(row * self.layout.num_cols + col) as usize]
50 }
51
52 pub fn accumulate(&mut self, row: u32, col: u32, val: E) {
53 self.data[(row * self.layout.num_cols + col) as usize] += val;
54 }
55
56 pub fn rowwise_scale(&mut self, scale: &RowWise<E>) {
57 for r in 0..self.layout.num_rows as usize {
58 let row_offset = r as u32 * self.layout.num_cols;
59 for c in 0..self.layout.num_cols {
60 let index = row_offset + c;
61 self.data[index as usize] = self.data[index as usize] * scale.vals[r];
62 }
63 }
64 }
65
66 pub fn rowwise_max(&self) -> RowWise<E> {
67 let num_rows = comptime!(self.layout.num_rows) as usize;
68 let num_cols = comptime!(self.layout.num_cols) as usize;
69 let mut vals = Array::new(num_rows);
70
71 for r in 0..num_rows {
72 let row_offset = r * num_cols;
73 let mut val = E::min_value();
74
75 for c in 0..num_cols {
76 let index = row_offset + c;
77 val = max(val, self.data[index]);
78 }
79
80 vals[r] = val;
81 }
82
83 RowWise::<E> { num_rows, vals }
84 }
85
86 pub fn rowwise_sum(&self) -> RowWise<E> {
87 let num_rows = comptime!(self.layout.num_rows) as usize;
88 let num_cols = comptime!(self.layout.num_cols) as usize;
89 let mut vals = Array::new(num_rows);
90
91 for r in 0..num_rows {
92 let row_offset = r * num_cols;
93 let mut val = E::from_int(0);
94
95 for c in 0..num_cols {
96 let index = row_offset + c;
97 val += self.data[index];
98 }
99
100 vals[r] = val;
101 }
102
103 RowWise::<E> { num_rows, vals }
104 }
105
106 pub fn copy_from<E2: Numeric>(&mut self, other: &UnitTile<E2>) {
109 for r in 0..self.layout.num_rows as usize {
112 let row_offset = r as u32 * self.layout.num_cols;
113 for c in 0..self.layout.num_cols {
114 let index = row_offset + c;
115 self.data[index as usize] = E::cast_from(other.data[index as usize]);
116 }
117 }
118 }
119
120 pub fn load_from_strided_tile<E2: Numeric, ES: Size>(&mut self, tile: &StridedTile<E2, ES>) {
121 if comptime!(self.layout.transposed_load) {
122 strided_tile_to_transposed_unit_tile(tile, self)
123 } else {
124 strided_tile_to_unit_tile(tile, self)
125 }
126 }
127
128 pub fn scale_and_mask<M: Mask>(&mut self, scale: E, mask: &M) {
129 for r in 0..self.layout.num_rows {
130 let row_offset = r * self.layout.num_cols;
131 for c in 0..self.layout.num_cols {
132 let index = row_offset + c;
133 self.data[index as usize] = self.data[index as usize] * scale
134 + E::cast_from(mask.should_mask((r, c))) * E::min_value();
135 }
136 }
137 }
138}
139
140#[cube]
141impl<E: Float> UnitTile<E> {
142 pub fn exp_diff(&mut self, rowwise: &RowWise<E>) {
143 let num_rows = comptime!(self.layout.num_rows) as usize;
144 let num_cols = comptime!(self.layout.num_cols) as usize;
145 let threshold = E::new(LOGIT_MASKED);
146
147 for r in 0..num_rows {
148 let row_offset = r * num_cols;
149
150 let val = rowwise.vals[r];
151
152 for c in 0..num_cols {
153 let index = row_offset + c;
154
155 let safe_val = clamp_min(val, threshold);
156 let not_masked = E::cast_from(val >= threshold);
157 self.data[index] = not_masked * (self.data[index] - safe_val).exp();
158 }
159 }
160 }
161}
162
163#[cube]
164pub fn allocate_unit_tile<E: Numeric, Sc: TileScope>(
167 #[comptime] layout: UnitTileLayout,
168) -> Tile<E, Sc, ReadWrite> {
169 Tile::new_Unit(UnitTile::<E>::new(layout))
170}
171
172#[cube]
173fn strided_tile_to_unit_tile<E: Numeric, N: Size, E2: Numeric>(
174 strided_tile: &StridedTile<E, N>,
175 unit_tile: &mut UnitTile<E2>,
176) {
177 let vector_size = N::value().comptime() as u32;
178 assert!(unit_tile.layout.num_cols.is_multiple_of(vector_size));
179
180 let col_iterations = comptime!(unit_tile.layout.num_cols / vector_size);
181
182 for row in 0..unit_tile.layout.num_rows {
183 for col in 0..col_iterations {
184 let line_read = strided_tile.get_vector(row, col);
185 #[unroll]
186 for i in 0..vector_size {
187 unit_tile.data
188 [(row * unit_tile.layout.num_cols + col * vector_size + i) as usize] =
189 E2::cast_from(line_read[i as usize]);
190 }
191 }
192 }
193}
194
195#[cube]
196fn strided_tile_to_transposed_unit_tile<E: Numeric, N: Size, E2: Numeric>(
197 strided_tile: &StridedTile<E, N>,
198 unit_tile: &mut UnitTile<E2>,
199) {
200 let vector_size = N::value().comptime() as u32;
201 assert!(unit_tile.layout.num_cols.is_multiple_of(vector_size));
202
203 let input_num_rows = unit_tile.layout.num_cols.comptime();
204 let input_num_cols = unit_tile.layout.num_rows.comptime();
205 let vector_iterations = input_num_cols / vector_size;
206
207 for input_row in 0..input_num_rows {
208 for input_col_vector in 0..vector_iterations {
209 let vector_read = strided_tile.get_vector(input_row, input_col_vector);
210
211 #[unroll]
212 for i in 0..vector_size {
213 unit_tile.data[((input_col_vector + i) * input_num_rows + input_row) as usize] =
214 E2::cast_from(vector_read[i as usize]);
215 }
216 }
217 }
218}