cubek_std/tile/variants/
local_tile.rs1use cubecl;
2use cubecl::{prelude::*, std::tensor::layout::Coords2d};
3
4use crate::tile::ops::{LOGIT_MASKED, Mask, MaskExpand, RowWise};
5use crate::tile::scope::{Scope, assert_plane_scope};
6use crate::tile::{StridedTile, Tile};
7
8#[derive(CubeType)]
9pub struct LocalTile<E: Numeric> {
12 pub array: Array<E>,
13 #[cube(comptime)]
14 pub layout: LocalTileLayout,
15}
16
17#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
18pub enum InnerLayout {
19 Contiguous,
30 SplitRows,
41}
42
43#[cube]
44impl<E: Numeric> LocalTile<E> {
45 pub fn new(#[comptime] layout: LocalTileLayout) -> LocalTile<E> {
46 let array = Array::<E>::new(comptime!(layout.unit_size.0 * layout.unit_size.1) as usize);
47
48 LocalTile::<E> { array, layout }
49 }
50
51 pub fn zero(&mut self) {
52 for i in 0..self.layout.unit_size.0 * self.layout.unit_size.1 {
53 self.array[i as usize] = E::from_int(0);
54 }
55 }
56
57 pub fn load_from_slice(&mut self, smem_slice: &Slice<E>) {
58 for r in 0..self.layout.unit_size.0 {
59 for c in 0..self.layout.unit_size.1 {
60 let (row, col) = local_layout_absolute_pos(self.layout, (r, c));
61 let index = row * self.layout.total_size.1 + col;
62
63 self.array[(r * self.layout.unit_size.1 + c) as usize] = smem_slice[index as usize];
64 }
65 }
66 }
67
68 pub fn load_from_strided_tile<E2: Numeric, N: Size>(
69 &mut self,
70 strided_tile: &StridedTile<E2, N>,
71 ) {
72 for r in 0..self.layout.unit_size.0 {
74 for c in 0..self.layout.unit_size.1 {
75 let (row, col) = local_layout_absolute_pos(self.layout, (r, c));
76 self.array[(r * self.layout.unit_size.1 + c) as usize] =
77 E::cast_from(strided_tile.get_vector(row, col))
78 }
79 }
80 }
81
82 pub fn store_to<F: Float>(&self, smem_slice: &mut SliceMut<F>) {
83 for r in 0..self.layout.unit_size.0 {
84 for c in 0..self.layout.unit_size.1 {
85 let (row, col) = local_layout_absolute_pos(self.layout, (r, c));
86 let index = row * self.layout.total_size.1 + col;
87
88 smem_slice[index as usize] =
89 F::cast_from(self.array[(r * self.layout.unit_size.1 + c) as usize]);
90 }
91 }
92 }
93
94 pub fn rowwise_scale(&mut self, scale: &RowWise<E>) {
95 for r in 0..self.layout.unit_size.0 as usize {
96 let row_offset = r as u32 * self.layout.unit_size.1;
97 for c in 0..self.layout.unit_size.1 {
98 let index = row_offset + c;
99 self.array[index as usize] = self.array[index as usize] * scale.vals[r];
100 }
101 }
102 }
103
104 pub fn rowwise_max(&self) -> RowWise<E> {
105 let num_rows = comptime!(self.layout.unit_size.0) as usize;
106 let num_cols = comptime!(self.layout.unit_size.1) as usize;
107 let mut vals = Array::new(num_rows);
108
109 for r in 0..num_rows {
110 let row_offset = r * num_cols;
111 let mut val = E::min_value();
112
113 for c in 0..num_cols {
114 let index = row_offset + c;
115 val = max(val, self.array[index]);
116 }
117
118 vals[r] = val;
119 }
120
121 RowWise::<E> { num_rows, vals }
122 }
123
124 pub fn rowwise_sum(&self) -> RowWise<E> {
125 let num_rows = comptime!(self.layout.unit_size.0) as usize;
126 let num_cols = comptime!(self.layout.unit_size.1) as usize;
127 let mut vals = Array::new(num_rows);
128
129 for r in 0..num_rows {
130 let row_offset = r * num_cols;
131 let mut val = E::from_int(0);
132
133 for c in 0..num_cols {
134 let index = row_offset + c;
135 val += self.array[index];
136 }
137
138 vals[r] = val;
139 }
140
141 RowWise::<E> { num_rows, vals }
142 }
143
144 pub fn num_units_per_row(&self) -> comptime_type!(u32) {
145 comptime!(self.layout.total_size.1 / self.layout.unit_size.1)
146 }
147
148 pub fn scale_and_mask<M: Mask>(&mut self, scale: E, mask: &M) {
149 for r in 0..self.layout.unit_size.0 {
150 let row_offset = r * self.layout.unit_size.1;
151 for c in 0..self.layout.unit_size.1 {
152 let index = row_offset + c;
153 self.array[index as usize] = self.array[index as usize] * scale
154 + E::cast_from(mask.should_mask((r, c))) * E::min_value();
155 }
156 }
157 }
158}
159
160#[cube]
161impl<E: Float> LocalTile<E> {
162 pub fn exp_diff(&mut self, rowwise: &RowWise<E>) {
163 let num_rows = comptime!(self.layout.unit_size.0) as usize;
164 let num_cols = comptime!(self.layout.unit_size.1) as usize;
165 let threshold = E::new(LOGIT_MASKED);
166
167 for r in 0..num_rows {
168 let row_offset = r * num_cols;
169
170 let val = rowwise.vals[r];
171 let safe_val = clamp_min(val, threshold);
172 let not_masked = E::cast_from(val >= threshold);
173
174 for c in 0..num_cols {
175 let index = row_offset + c;
176
177 self.array[index] = not_masked * (self.array[index] - safe_val).exp();
178 }
179 }
180 }
181}
182
183#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
184pub struct LocalTileLayout {
185 pub total_size: Coords2d,
186 pub unit_size: Coords2d,
187 pub num_units_per_row: u32,
188 pub plane_dim: u32,
189}
190
191impl LocalTileLayout {
192 pub const fn new(
193 total_size: Coords2d,
194 plane_dim: u32,
195 inner_layout: InnerLayout,
196 ) -> LocalTileLayout {
197 let total_elements = total_size.0 * total_size.1;
198 let elements_per_unit = total_elements.div_ceil(plane_dim);
199
200 let (num_rows_per_unit, num_cols_per_unit) = match inner_layout {
201 InnerLayout::Contiguous => (1u32, elements_per_unit),
202 InnerLayout::SplitRows => (2u32, elements_per_unit / 2u32),
203 };
204 let unit_size = (num_rows_per_unit, num_cols_per_unit);
205 let num_units_per_row = total_size.1 / unit_size.1;
206
207 LocalTileLayout {
208 total_size,
209 unit_size,
210 num_units_per_row,
211 plane_dim,
212 }
213 }
214
215 pub const fn num_units_per_row(&self) -> u32 {
216 self.total_size.1 / self.unit_size.1
217 }
218}
219
220#[cube]
221pub fn allocate_local_tile<E: Numeric, Sc: Scope>(
224 #[comptime] layout: LocalTileLayout,
225) -> Tile<E, Sc, ReadWrite> {
226 comptime!(assert_plane_scope(Sc::KIND));
227 Tile::new_Local(LocalTile::<E>::new(layout))
228}
229
230#[cube]
233pub fn local_layout_absolute_pos(
234 #[comptime] layout: LocalTileLayout,
235 local_pos: Coords2d,
236) -> Coords2d {
237 let abs_row_index = {
238 let row_0 = UNIT_POS_X / layout.num_units_per_row;
239 let row_jump = comptime!(layout.plane_dim / layout.num_units_per_row);
240 local_pos.0 * row_jump + row_0
241 };
242 let abs_col_index = layout.unit_size.1 * (UNIT_POS_X % layout.num_units_per_row) + local_pos.1;
243 (abs_row_index, abs_col_index)
244}
245
246#[cube]
248pub fn local_layout_zero_slice<E: Numeric>(
249 #[comptime] layout: LocalTileLayout,
250 slice: &mut SliceMut<E>,
251) {
252 for r in 0..layout.unit_size.0 {
253 for c in 0..layout.unit_size.1 {
254 let (row, col) = local_layout_absolute_pos(layout, (r, c));
255 let index = row * layout.total_size.1 + col;
256
257 slice[index as usize] = E::from_int(0);
258 }
259 }
260}