Skip to main content

cubek_std/tile/data/
whitebox_fragment.rs

1use cubecl;
2use cubecl::{prelude::*, std::tensor::layout::Coords2d};
3
4use crate::tile::LOGIT_MASKED;
5use crate::tile::{
6    Tile,
7    compute::{Mask, MaskExpand},
8    data::{rowwise::RowWise, strided::StridedTile},
9    scope::{TileScope, assert_plane_scope},
10};
11
12#[derive(CubeType)]
13/// Assumes:
14/// - unit_size * plane_dim = total_size (not dim wise but in total count)
15pub struct WhiteboxFragment<E: Numeric> {
16    pub array: Array<E>,
17    #[cube(comptime)]
18    pub layout: WhiteboxFragmentLayout,
19}
20
21#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
22pub enum InnerLayout {
23    /// Each unit has all its elements contiguous inside the same row
24    ///
25    ///  0,  0,  1,  1,  2,  2,  3,  3,
26    ///  4,  4,  5,  5,  6,  6,  7,  7,
27    ///  8,  8,  9,  9, 10, 10, 11, 11,
28    /// 12, 12, 13, 13, 14, 14, 15, 15,
29    /// 16, 16, 17, 17, 18, 18, 19, 19,
30    /// 20, 20, 21, 21, 22, 22, 23, 23,
31    /// 24, 24, 25, 25, 26, 26, 27, 27,
32    /// 28, 28, 29, 29, 30, 30, 31, 31,
33    Contiguous,
34    /// Each unit spreads its elements along two rows
35    ///
36    ///  0,  1,  2,  3,  4,  5,  6,  7,
37    ///  8,  9, 10, 11, 12, 13, 14, 15,
38    /// 16, 17, 18, 19, 20, 21, 22, 23,
39    /// 24, 25, 26, 27, 28, 29, 30, 31,
40    ///  0,  1,  2,  3,  4,  5,  6,  7,
41    ///  8,  9, 10, 11, 12, 13, 14, 15,
42    /// 16, 17, 18, 19, 20, 21, 22, 23,
43    /// 24, 25, 26, 27, 28, 29, 30, 31,
44    SplitRows,
45}
46
47#[cube]
48impl<E: Numeric> WhiteboxFragment<E> {
49    pub fn new(#[comptime] layout: WhiteboxFragmentLayout) -> WhiteboxFragment<E> {
50        let array = Array::<E>::new(comptime!(layout.unit_size.0 * layout.unit_size.1) as usize);
51
52        WhiteboxFragment::<E> { array, layout }
53    }
54
55    pub fn zero(&mut self) {
56        for i in 0..self.layout.unit_size.0 * self.layout.unit_size.1 {
57            self.array[i as usize] = E::from_int(0);
58        }
59    }
60
61    pub fn load_from_slice(&mut self, smem_slice: &Slice<E>) {
62        for r in 0..self.layout.unit_size.0 {
63            for c in 0..self.layout.unit_size.1 {
64                let (row, col) = whitebox_fragment_absolute_pos(self.layout, (r, c));
65                let index = row * self.layout.total_size.1 + col;
66
67                self.array[(r * self.layout.unit_size.1 + c) as usize] = smem_slice[index as usize];
68            }
69        }
70    }
71
72    pub fn load_from_strided_tile<E2: Numeric, N: Size>(
73        &mut self,
74        strided_tile: &StridedTile<E2, N>,
75    ) {
76        // Assumes vector size == 1
77        for r in 0..self.layout.unit_size.0 {
78            for c in 0..self.layout.unit_size.1 {
79                let (row, col) = whitebox_fragment_absolute_pos(self.layout, (r, c));
80                self.array[(r * self.layout.unit_size.1 + c) as usize] =
81                    E::cast_from(strided_tile.get_vector(row, col))
82            }
83        }
84    }
85
86    pub fn store_to<F: Float>(&self, smem_slice: &mut SliceMut<F>) {
87        for r in 0..self.layout.unit_size.0 {
88            for c in 0..self.layout.unit_size.1 {
89                let (row, col) = whitebox_fragment_absolute_pos(self.layout, (r, c));
90                let index = row * self.layout.total_size.1 + col;
91
92                smem_slice[index as usize] =
93                    F::cast_from(self.array[(r * self.layout.unit_size.1 + c) as usize]);
94            }
95        }
96    }
97
98    pub fn rowwise_scale(&mut self, scale: &RowWise<E>) {
99        for r in 0..self.layout.unit_size.0 as usize {
100            let row_offset = r as u32 * self.layout.unit_size.1;
101            for c in 0..self.layout.unit_size.1 {
102                let index = row_offset + c;
103                self.array[index as usize] = self.array[index as usize] * scale.vals[r];
104            }
105        }
106    }
107
108    pub fn rowwise_max(&self) -> RowWise<E> {
109        let num_rows = comptime!(self.layout.unit_size.0) as usize;
110        let num_cols = comptime!(self.layout.unit_size.1) as usize;
111        let mut vals = Array::new(num_rows);
112
113        for r in 0..num_rows {
114            let row_offset = r * num_cols;
115            let mut val = E::min_value();
116
117            for c in 0..num_cols {
118                let index = row_offset + c;
119                val = max(val, self.array[index]);
120            }
121
122            vals[r] = val;
123        }
124
125        RowWise::<E> { num_rows, vals }
126    }
127
128    pub fn rowwise_sum(&self) -> RowWise<E> {
129        let num_rows = comptime!(self.layout.unit_size.0) as usize;
130        let num_cols = comptime!(self.layout.unit_size.1) as usize;
131        let mut vals = Array::new(num_rows);
132
133        for r in 0..num_rows {
134            let row_offset = r * num_cols;
135            let mut val = E::from_int(0);
136
137            for c in 0..num_cols {
138                let index = row_offset + c;
139                val += self.array[index];
140            }
141
142            vals[r] = val;
143        }
144
145        RowWise::<E> { num_rows, vals }
146    }
147
148    pub fn num_units_per_row(&self) -> comptime_type!(u32) {
149        comptime!(self.layout.total_size.1 / self.layout.unit_size.1)
150    }
151
152    pub fn scale_and_mask<M: Mask>(&mut self, scale: E, mask: &M) {
153        for r in 0..self.layout.unit_size.0 {
154            let row_offset = r * self.layout.unit_size.1;
155            for c in 0..self.layout.unit_size.1 {
156                let index = row_offset + c;
157                self.array[index as usize] = self.array[index as usize] * scale
158                    + E::cast_from(mask.should_mask((r, c))) * E::min_value();
159            }
160        }
161    }
162}
163
164#[cube]
165impl<E: Float> WhiteboxFragment<E> {
166    pub fn exp_diff(&mut self, rowwise: &RowWise<E>) {
167        let num_rows = comptime!(self.layout.unit_size.0) as usize;
168        let num_cols = comptime!(self.layout.unit_size.1) as usize;
169        let threshold = E::new(LOGIT_MASKED);
170
171        for r in 0..num_rows {
172            let row_offset = r * num_cols;
173
174            let val = rowwise.vals[r];
175            let safe_val = clamp_min(val, threshold);
176            let not_masked = E::cast_from(val >= threshold);
177
178            for c in 0..num_cols {
179                let index = row_offset + c;
180
181                self.array[index] = not_masked * (self.array[index] - safe_val).exp();
182            }
183        }
184    }
185}
186
187#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
188pub struct WhiteboxFragmentLayout {
189    pub total_size: Coords2d,
190    pub unit_size: Coords2d,
191    pub num_units_per_row: u32,
192    pub plane_dim: u32,
193}
194
195impl WhiteboxFragmentLayout {
196    pub const fn new(
197        total_size: Coords2d,
198        plane_dim: u32,
199        inner_layout: InnerLayout,
200    ) -> WhiteboxFragmentLayout {
201        let total_elements = total_size.0 * total_size.1;
202        let elements_per_unit = total_elements.div_ceil(plane_dim);
203
204        let (num_rows_per_unit, num_cols_per_unit) = match inner_layout {
205            InnerLayout::Contiguous => (1u32, elements_per_unit),
206            InnerLayout::SplitRows => (2u32, elements_per_unit / 2u32),
207        };
208        let unit_size = (num_rows_per_unit, num_cols_per_unit);
209        let num_units_per_row = total_size.1 / unit_size.1;
210
211        WhiteboxFragmentLayout {
212            total_size,
213            unit_size,
214            num_units_per_row,
215            plane_dim,
216        }
217    }
218
219    pub const fn num_units_per_row(&self) -> u32 {
220        self.total_size.1 / self.unit_size.1
221    }
222}
223
224#[cube]
225/// Allocates a `Tile::WhiteboxFragment` for the given scope. Panics at expansion
226/// time unless `Sc = Plane`.
227pub fn allocate_whitebox_fragment<E: Numeric, Sc: TileScope>(
228    #[comptime] layout: WhiteboxFragmentLayout,
229) -> Tile<E, Sc, ReadWrite> {
230    comptime!(assert_plane_scope(Sc::KIND));
231    Tile::new_WhiteboxFragment(WhiteboxFragment::<E>::new(layout))
232}
233
234/// Maps a per-unit `(row, col)` to its absolute position within the tile
235/// described by `layout`.
236#[cube]
237pub fn whitebox_fragment_absolute_pos(
238    #[comptime] layout: WhiteboxFragmentLayout,
239    local_pos: Coords2d,
240) -> Coords2d {
241    let abs_row_index = {
242        let row_0 = UNIT_POS_X / layout.num_units_per_row;
243        let row_jump = comptime!(layout.plane_dim / layout.num_units_per_row);
244        local_pos.0 * row_jump + row_0
245    };
246    let abs_col_index = layout.unit_size.1 * (UNIT_POS_X % layout.num_units_per_row) + local_pos.1;
247    (abs_row_index, abs_col_index)
248}
249
250/// Zeroes a slice giving responsibility to units following `layout`.
251#[cube]
252pub fn whitebox_fragment_zero_slice<E: Numeric>(
253    #[comptime] layout: WhiteboxFragmentLayout,
254    slice: &mut SliceMut<E>,
255) {
256    for r in 0..layout.unit_size.0 {
257        for c in 0..layout.unit_size.1 {
258            let (row, col) = whitebox_fragment_absolute_pos(layout, (r, c));
259            let index = row * layout.total_size.1 + col;
260
261            slice[index as usize] = E::from_int(0);
262        }
263    }
264}