cubek_std/tile/data/
whitebox_fragment.rs1use cubecl;
2use cubecl::{prelude::*, std::tensor::layout::Coords2d};
3
4use crate::tile::LOGIT_MASKED;
5use crate::tile::{
6 Tile,
7 compute::{Mask, MaskExpand},
8 data::{rowwise::RowWise, strided::StridedTile},
9 scope::{TileScope, assert_plane_scope},
10};
11
12#[derive(CubeType)]
13pub struct WhiteboxFragment<E: Numeric> {
16 pub array: Array<E>,
17 #[cube(comptime)]
18 pub layout: WhiteboxFragmentLayout,
19}
20
21#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
22pub enum InnerLayout {
23 Contiguous,
34 SplitRows,
45}
46
47#[cube]
48impl<E: Numeric> WhiteboxFragment<E> {
49 pub fn new(#[comptime] layout: WhiteboxFragmentLayout) -> WhiteboxFragment<E> {
50 let array = Array::<E>::new(comptime!(layout.unit_size.0 * layout.unit_size.1) as usize);
51
52 WhiteboxFragment::<E> { array, layout }
53 }
54
55 pub fn zero(&mut self) {
56 for i in 0..self.layout.unit_size.0 * self.layout.unit_size.1 {
57 self.array[i as usize] = E::from_int(0);
58 }
59 }
60
61 pub fn load_from_slice(&mut self, smem_slice: &Slice<E>) {
62 for r in 0..self.layout.unit_size.0 {
63 for c in 0..self.layout.unit_size.1 {
64 let (row, col) = whitebox_fragment_absolute_pos(self.layout, (r, c));
65 let index = row * self.layout.total_size.1 + col;
66
67 self.array[(r * self.layout.unit_size.1 + c) as usize] = smem_slice[index as usize];
68 }
69 }
70 }
71
72 pub fn load_from_strided_tile<E2: Numeric, N: Size>(
73 &mut self,
74 strided_tile: &StridedTile<E2, N>,
75 ) {
76 for r in 0..self.layout.unit_size.0 {
78 for c in 0..self.layout.unit_size.1 {
79 let (row, col) = whitebox_fragment_absolute_pos(self.layout, (r, c));
80 self.array[(r * self.layout.unit_size.1 + c) as usize] =
81 E::cast_from(strided_tile.get_vector(row, col))
82 }
83 }
84 }
85
86 pub fn store_to<F: Float>(&self, smem_slice: &mut SliceMut<F>) {
87 for r in 0..self.layout.unit_size.0 {
88 for c in 0..self.layout.unit_size.1 {
89 let (row, col) = whitebox_fragment_absolute_pos(self.layout, (r, c));
90 let index = row * self.layout.total_size.1 + col;
91
92 smem_slice[index as usize] =
93 F::cast_from(self.array[(r * self.layout.unit_size.1 + c) as usize]);
94 }
95 }
96 }
97
98 pub fn rowwise_scale(&mut self, scale: &RowWise<E>) {
99 for r in 0..self.layout.unit_size.0 as usize {
100 let row_offset = r as u32 * self.layout.unit_size.1;
101 for c in 0..self.layout.unit_size.1 {
102 let index = row_offset + c;
103 self.array[index as usize] = self.array[index as usize] * scale.vals[r];
104 }
105 }
106 }
107
108 pub fn rowwise_max(&self) -> RowWise<E> {
109 let num_rows = comptime!(self.layout.unit_size.0) as usize;
110 let num_cols = comptime!(self.layout.unit_size.1) as usize;
111 let mut vals = Array::new(num_rows);
112
113 for r in 0..num_rows {
114 let row_offset = r * num_cols;
115 let mut val = E::min_value();
116
117 for c in 0..num_cols {
118 let index = row_offset + c;
119 val = max(val, self.array[index]);
120 }
121
122 vals[r] = val;
123 }
124
125 RowWise::<E> { num_rows, vals }
126 }
127
128 pub fn rowwise_sum(&self) -> RowWise<E> {
129 let num_rows = comptime!(self.layout.unit_size.0) as usize;
130 let num_cols = comptime!(self.layout.unit_size.1) as usize;
131 let mut vals = Array::new(num_rows);
132
133 for r in 0..num_rows {
134 let row_offset = r * num_cols;
135 let mut val = E::from_int(0);
136
137 for c in 0..num_cols {
138 let index = row_offset + c;
139 val += self.array[index];
140 }
141
142 vals[r] = val;
143 }
144
145 RowWise::<E> { num_rows, vals }
146 }
147
148 pub fn num_units_per_row(&self) -> comptime_type!(u32) {
149 comptime!(self.layout.total_size.1 / self.layout.unit_size.1)
150 }
151
152 pub fn scale_and_mask<M: Mask>(&mut self, scale: E, mask: &M) {
153 for r in 0..self.layout.unit_size.0 {
154 let row_offset = r * self.layout.unit_size.1;
155 for c in 0..self.layout.unit_size.1 {
156 let index = row_offset + c;
157 self.array[index as usize] = self.array[index as usize] * scale
158 + E::cast_from(mask.should_mask((r, c))) * E::min_value();
159 }
160 }
161 }
162}
163
164#[cube]
165impl<E: Float> WhiteboxFragment<E> {
166 pub fn exp_diff(&mut self, rowwise: &RowWise<E>) {
167 let num_rows = comptime!(self.layout.unit_size.0) as usize;
168 let num_cols = comptime!(self.layout.unit_size.1) as usize;
169 let threshold = E::new(LOGIT_MASKED);
170
171 for r in 0..num_rows {
172 let row_offset = r * num_cols;
173
174 let val = rowwise.vals[r];
175 let safe_val = clamp_min(val, threshold);
176 let not_masked = E::cast_from(val >= threshold);
177
178 for c in 0..num_cols {
179 let index = row_offset + c;
180
181 self.array[index] = not_masked * (self.array[index] - safe_val).exp();
182 }
183 }
184 }
185}
186
187#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
188pub struct WhiteboxFragmentLayout {
189 pub total_size: Coords2d,
190 pub unit_size: Coords2d,
191 pub num_units_per_row: u32,
192 pub plane_dim: u32,
193}
194
195impl WhiteboxFragmentLayout {
196 pub const fn new(
197 total_size: Coords2d,
198 plane_dim: u32,
199 inner_layout: InnerLayout,
200 ) -> WhiteboxFragmentLayout {
201 let total_elements = total_size.0 * total_size.1;
202 let elements_per_unit = total_elements.div_ceil(plane_dim);
203
204 let (num_rows_per_unit, num_cols_per_unit) = match inner_layout {
205 InnerLayout::Contiguous => (1u32, elements_per_unit),
206 InnerLayout::SplitRows => (2u32, elements_per_unit / 2u32),
207 };
208 let unit_size = (num_rows_per_unit, num_cols_per_unit);
209 let num_units_per_row = total_size.1 / unit_size.1;
210
211 WhiteboxFragmentLayout {
212 total_size,
213 unit_size,
214 num_units_per_row,
215 plane_dim,
216 }
217 }
218
219 pub const fn num_units_per_row(&self) -> u32 {
220 self.total_size.1 / self.unit_size.1
221 }
222}
223
224#[cube]
225pub fn allocate_whitebox_fragment<E: Numeric, Sc: TileScope>(
228 #[comptime] layout: WhiteboxFragmentLayout,
229) -> Tile<E, Sc, ReadWrite> {
230 comptime!(assert_plane_scope(Sc::KIND));
231 Tile::new_WhiteboxFragment(WhiteboxFragment::<E>::new(layout))
232}
233
234#[cube]
237pub fn whitebox_fragment_absolute_pos(
238 #[comptime] layout: WhiteboxFragmentLayout,
239 local_pos: Coords2d,
240) -> Coords2d {
241 let abs_row_index = {
242 let row_0 = UNIT_POS_X / layout.num_units_per_row;
243 let row_jump = comptime!(layout.plane_dim / layout.num_units_per_row);
244 local_pos.0 * row_jump + row_0
245 };
246 let abs_col_index = layout.unit_size.1 * (UNIT_POS_X % layout.num_units_per_row) + local_pos.1;
247 (abs_row_index, abs_col_index)
248}
249
250#[cube]
252pub fn whitebox_fragment_zero_slice<E: Numeric>(
253 #[comptime] layout: WhiteboxFragmentLayout,
254 slice: &mut SliceMut<E>,
255) {
256 for r in 0..layout.unit_size.0 {
257 for c in 0..layout.unit_size.1 {
258 let (row, col) = whitebox_fragment_absolute_pos(layout, (r, c));
259 let index = row * layout.total_size.1 + col;
260
261 slice[index as usize] = E::from_int(0);
262 }
263 }
264}