Skip to main content

cubek_std/tile/variants/
local_tile.rs

1use cubecl;
2use cubecl::{prelude::*, std::tensor::layout::Coords2d};
3
4use crate::tile::ops::{LOGIT_MASKED, Mask, MaskExpand, RowWise};
5use crate::tile::scope::{Scope, assert_plane_scope};
6use crate::tile::{StridedTile, Tile};
7
8#[derive(CubeType)]
9/// Assumes:
10/// - unit_size * plane_dim = total_size (not dim wise but in total count)
11pub struct LocalTile<E: Numeric> {
12    pub array: Array<E>,
13    #[cube(comptime)]
14    pub layout: LocalTileLayout,
15}
16
17#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
18pub enum InnerLayout {
19    /// Each unit has all its elements contiguous inside the same row
20    ///
21    ///  0,  0,  1,  1,  2,  2,  3,  3,
22    ///  4,  4,  5,  5,  6,  6,  7,  7,
23    ///  8,  8,  9,  9, 10, 10, 11, 11,
24    /// 12, 12, 13, 13, 14, 14, 15, 15,
25    /// 16, 16, 17, 17, 18, 18, 19, 19,
26    /// 20, 20, 21, 21, 22, 22, 23, 23,
27    /// 24, 24, 25, 25, 26, 26, 27, 27,
28    /// 28, 28, 29, 29, 30, 30, 31, 31,
29    Contiguous,
30    /// Each unit spreads its elements along two rows
31    ///
32    ///  0,  1,  2,  3,  4,  5,  6,  7,
33    ///  8,  9, 10, 11, 12, 13, 14, 15,
34    /// 16, 17, 18, 19, 20, 21, 22, 23,
35    /// 24, 25, 26, 27, 28, 29, 30, 31,
36    ///  0,  1,  2,  3,  4,  5,  6,  7,
37    ///  8,  9, 10, 11, 12, 13, 14, 15,
38    /// 16, 17, 18, 19, 20, 21, 22, 23,
39    /// 24, 25, 26, 27, 28, 29, 30, 31,
40    SplitRows,
41}
42
43#[cube]
44impl<E: Numeric> LocalTile<E> {
45    pub fn new(#[comptime] layout: LocalTileLayout) -> LocalTile<E> {
46        let array = Array::<E>::new(comptime!(layout.unit_size.0 * layout.unit_size.1) as usize);
47
48        LocalTile::<E> { array, layout }
49    }
50
51    pub fn zero(&mut self) {
52        for i in 0..self.layout.unit_size.0 * self.layout.unit_size.1 {
53            self.array[i as usize] = E::from_int(0);
54        }
55    }
56
57    pub fn load_from_slice(&mut self, smem_slice: &Slice<E>) {
58        for r in 0..self.layout.unit_size.0 {
59            for c in 0..self.layout.unit_size.1 {
60                let (row, col) = local_layout_absolute_pos(self.layout, (r, c));
61                let index = row * self.layout.total_size.1 + col;
62
63                self.array[(r * self.layout.unit_size.1 + c) as usize] = smem_slice[index as usize];
64            }
65        }
66    }
67
68    pub fn load_from_strided_tile<E2: Numeric, N: Size>(
69        &mut self,
70        strided_tile: &StridedTile<E2, N>,
71    ) {
72        // Assumes vector size == 1
73        for r in 0..self.layout.unit_size.0 {
74            for c in 0..self.layout.unit_size.1 {
75                let (row, col) = local_layout_absolute_pos(self.layout, (r, c));
76                self.array[(r * self.layout.unit_size.1 + c) as usize] =
77                    E::cast_from(strided_tile.get_vector(row, col))
78            }
79        }
80    }
81
82    pub fn store_to<F: Float>(&self, smem_slice: &mut SliceMut<F>) {
83        for r in 0..self.layout.unit_size.0 {
84            for c in 0..self.layout.unit_size.1 {
85                let (row, col) = local_layout_absolute_pos(self.layout, (r, c));
86                let index = row * self.layout.total_size.1 + col;
87
88                smem_slice[index as usize] =
89                    F::cast_from(self.array[(r * self.layout.unit_size.1 + c) as usize]);
90            }
91        }
92    }
93
94    pub fn rowwise_scale(&mut self, scale: &RowWise<E>) {
95        for r in 0..self.layout.unit_size.0 as usize {
96            let row_offset = r as u32 * self.layout.unit_size.1;
97            for c in 0..self.layout.unit_size.1 {
98                let index = row_offset + c;
99                self.array[index as usize] = self.array[index as usize] * scale.vals[r];
100            }
101        }
102    }
103
104    pub fn rowwise_max(&self) -> RowWise<E> {
105        let num_rows = comptime!(self.layout.unit_size.0) as usize;
106        let num_cols = comptime!(self.layout.unit_size.1) as usize;
107        let mut vals = Array::new(num_rows);
108
109        for r in 0..num_rows {
110            let row_offset = r * num_cols;
111            let mut val = E::min_value();
112
113            for c in 0..num_cols {
114                let index = row_offset + c;
115                val = max(val, self.array[index]);
116            }
117
118            vals[r] = val;
119        }
120
121        RowWise::<E> { num_rows, vals }
122    }
123
124    pub fn rowwise_sum(&self) -> RowWise<E> {
125        let num_rows = comptime!(self.layout.unit_size.0) as usize;
126        let num_cols = comptime!(self.layout.unit_size.1) as usize;
127        let mut vals = Array::new(num_rows);
128
129        for r in 0..num_rows {
130            let row_offset = r * num_cols;
131            let mut val = E::from_int(0);
132
133            for c in 0..num_cols {
134                let index = row_offset + c;
135                val += self.array[index];
136            }
137
138            vals[r] = val;
139        }
140
141        RowWise::<E> { num_rows, vals }
142    }
143
144    pub fn num_units_per_row(&self) -> comptime_type!(u32) {
145        comptime!(self.layout.total_size.1 / self.layout.unit_size.1)
146    }
147
148    pub fn scale_and_mask<M: Mask>(&mut self, scale: E, mask: &M) {
149        for r in 0..self.layout.unit_size.0 {
150            let row_offset = r * self.layout.unit_size.1;
151            for c in 0..self.layout.unit_size.1 {
152                let index = row_offset + c;
153                self.array[index as usize] = self.array[index as usize] * scale
154                    + E::cast_from(mask.should_mask((r, c))) * E::min_value();
155            }
156        }
157    }
158}
159
160#[cube]
161impl<E: Float> LocalTile<E> {
162    pub fn exp_diff(&mut self, rowwise: &RowWise<E>) {
163        let num_rows = comptime!(self.layout.unit_size.0) as usize;
164        let num_cols = comptime!(self.layout.unit_size.1) as usize;
165        let threshold = E::new(LOGIT_MASKED);
166
167        for r in 0..num_rows {
168            let row_offset = r * num_cols;
169
170            let val = rowwise.vals[r];
171            let safe_val = clamp_min(val, threshold);
172            let not_masked = E::cast_from(val >= threshold);
173
174            for c in 0..num_cols {
175                let index = row_offset + c;
176
177                self.array[index] = not_masked * (self.array[index] - safe_val).exp();
178            }
179        }
180    }
181}
182
183#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
184pub struct LocalTileLayout {
185    pub total_size: Coords2d,
186    pub unit_size: Coords2d,
187    pub num_units_per_row: u32,
188    pub plane_dim: u32,
189}
190
191impl LocalTileLayout {
192    pub const fn new(
193        total_size: Coords2d,
194        plane_dim: u32,
195        inner_layout: InnerLayout,
196    ) -> LocalTileLayout {
197        let total_elements = total_size.0 * total_size.1;
198        let elements_per_unit = total_elements.div_ceil(plane_dim);
199
200        let (num_rows_per_unit, num_cols_per_unit) = match inner_layout {
201            InnerLayout::Contiguous => (1u32, elements_per_unit),
202            InnerLayout::SplitRows => (2u32, elements_per_unit / 2u32),
203        };
204        let unit_size = (num_rows_per_unit, num_cols_per_unit);
205        let num_units_per_row = total_size.1 / unit_size.1;
206
207        LocalTileLayout {
208            total_size,
209            unit_size,
210            num_units_per_row,
211            plane_dim,
212        }
213    }
214
215    pub const fn num_units_per_row(&self) -> u32 {
216        self.total_size.1 / self.unit_size.1
217    }
218}
219
220#[cube]
221/// Allocates a `Tile::Local` for the given scope. Panics at expansion time
222/// unless `Sc = Plane`.
223pub fn allocate_local_tile<E: Numeric, Sc: Scope>(
224    #[comptime] layout: LocalTileLayout,
225) -> Tile<E, Sc, ReadWrite> {
226    comptime!(assert_plane_scope(Sc::KIND));
227    Tile::new_Local(LocalTile::<E>::new(layout))
228}
229
230/// Maps a per-unit `(row, col)` to its absolute position within the tile
231/// described by `layout`.
232#[cube]
233pub fn local_layout_absolute_pos(
234    #[comptime] layout: LocalTileLayout,
235    local_pos: Coords2d,
236) -> Coords2d {
237    let abs_row_index = {
238        let row_0 = UNIT_POS_X / layout.num_units_per_row;
239        let row_jump = comptime!(layout.plane_dim / layout.num_units_per_row);
240        local_pos.0 * row_jump + row_0
241    };
242    let abs_col_index = layout.unit_size.1 * (UNIT_POS_X % layout.num_units_per_row) + local_pos.1;
243    (abs_row_index, abs_col_index)
244}
245
246/// Zeroes a slice giving responsibility to units following `layout`.
247#[cube]
248pub fn local_layout_zero_slice<E: Numeric>(
249    #[comptime] layout: LocalTileLayout,
250    slice: &mut SliceMut<E>,
251) {
252    for r in 0..layout.unit_size.0 {
253        for c in 0..layout.unit_size.1 {
254            let (row, col) = local_layout_absolute_pos(layout, (r, c));
255            let index = row * layout.total_size.1 + col;
256
257            slice[index as usize] = E::from_int(0);
258        }
259    }
260}