1use cubecl;
2use cubecl::prelude::*;
3
4use crate::StageIdent;
5use crate::tile::compute::copy::{cmma_to_whitebox_fragment, whitebox_fragment_to_cmma};
6use crate::tile::compute::mask::{Mask, MaskExpand};
7use crate::tile::compute::rowwise::reducer::{fragment_row_max, fragment_row_sum};
8use crate::tile::data::{
9 BounceTile, InnerLayout, RegisterTile, RowWise, RowWiseExpand, UnitTile, WhiteboxFragment,
10};
11use crate::tile::{Plane, Tile, TileExpand};
12
13pub const LOGIT_MASKED: f32 = -6e4;
16
17pub const FULLY_MASKED_ROW_THRESHOLD: f32 = 1e-4;
21
22#[cube]
23impl<E: Float> RowWise<E> {
24 pub fn recip_inplace(&mut self) {
30 for i in 0..self.num_rows {
31 let row_val = self.vals[i];
32
33 let epsilon = E::new(FULLY_MASKED_ROW_THRESHOLD);
34 let not_masked = E::cast_from(row_val >= epsilon);
35 let safe_val = clamp_min(row_val, epsilon);
36 let recip = safe_val.recip();
37 self.vals[i] = not_masked * recip;
38 }
39 }
40}
41
42#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
50pub enum SoftmaxKind {
51 Direct { num_rows_per_unit: u32 },
52 Plane { inner_layout: InnerLayout },
53}
54
55impl SoftmaxKind {
56 pub const fn num_rows_per_unit(&self) -> u32 {
57 match self {
58 SoftmaxKind::Direct { num_rows_per_unit } => *num_rows_per_unit,
59 SoftmaxKind::Plane { inner_layout } => match inner_layout {
60 InnerLayout::Contiguous => 1,
61 InnerLayout::SplitRows => 2,
62 },
63 }
64 }
65}
66
67#[cube]
69pub fn softmax_init_state<E: Float>(
70 #[comptime] num_rows_per_unit: u32,
71) -> (RowWise<E>, RowWise<E>) {
72 (
73 RowWise::<E>::new_min_value(num_rows_per_unit as usize),
74 RowWise::<E>::new_zero(num_rows_per_unit as usize),
75 )
76}
77
78#[cube]
79impl<Acc: Float> Tile<Acc, Plane, ReadWrite> {
80 pub fn softmax<Lhs: Float, M: Mask>(
90 &mut self,
91 mask: &M,
92 softmaxed_tile: &mut Tile<Lhs, Plane, ReadWrite>,
93 state: &mut (RowWise<Acc>, RowWise<Acc>),
94 head_dim_factor: Acc,
95 ) -> RowWise<Acc> {
96 match self {
97 Tile::Bounce(s) => {
98 bounce_softmax::<Acc, Lhs, M>(s, softmaxed_tile, mask, state, head_dim_factor)
99 }
100 Tile::WhiteboxFragment(s) => {
101 fragment_softmax::<Acc, Lhs, M>(s, softmaxed_tile, mask, state, head_dim_factor)
102 }
103 Tile::Unit(s) => {
104 unit_softmax::<Acc, Lhs, M>(s, softmaxed_tile, mask, state, head_dim_factor)
105 }
106 Tile::Register(s) => {
107 register_softmax::<Acc, Lhs, M>(s, softmaxed_tile, mask, state, head_dim_factor)
108 }
109 _ => panic!("softmax: unsupported score variant"),
110 }
111 }
112
113 pub fn scale_mul<SM: Float>(&mut self, scale: &RowWise<SM>) {
117 let scale_acc = RowWise::<SM>::cast_from::<Acc>(scale);
118 match self {
119 Tile::Bounce(b) => {
120 cmma_to_whitebox_fragment::<Acc>(b);
121 b.fragment.rowwise_scale(&scale_acc);
122 whitebox_fragment_to_cmma::<Acc>(b);
123 }
124 Tile::WhiteboxFragment(t) => t.rowwise_scale(&scale_acc),
125 Tile::Unit(t) => t.rowwise_scale(&scale_acc),
126 Tile::Register(t) => register_rowwise_scale::<Acc>(t, &scale_acc),
127 _ => panic!("scale_mul: unsupported tile variant"),
128 }
129 }
130
131 pub fn scale_div<SM: Float>(&mut self, running_state_l: &RowWise<SM>) {
134 let mut scale = RowWise::<SM>::cast_from::<Acc>(running_state_l);
135 scale.recip_inplace();
136 match self {
137 Tile::Bounce(b) => {
138 cmma_to_whitebox_fragment::<Acc>(b);
139 b.fragment.rowwise_scale(&scale);
140 whitebox_fragment_to_cmma::<Acc>(b);
141 }
142 Tile::WhiteboxFragment(t) => t.rowwise_scale(&scale),
143 Tile::Unit(t) => t.rowwise_scale(&scale),
144 Tile::Register(t) => register_rowwise_scale::<Acc>(t, &scale),
145 _ => panic!("scale_div: unsupported tile variant"),
146 }
147 }
148
149 pub fn write_results<DE: Float, DS: Size>(&self, dest: &mut Tile<DE, Plane, ReadWrite>) {
152 dest.copy_from::<Acc, DS, Acc, Acc, Acc, ReadWrite>(self, StageIdent::Out);
153 }
154}
155
156#[cube]
157fn bounce_softmax<Acc: Float, Lhs: Float, M: Mask>(
158 score: &mut BounceTile<Acc>,
159 softmaxed: &mut Tile<Lhs, Plane, ReadWrite>,
160 mask: &M,
161 state: &mut (RowWise<Acc>, RowWise<Acc>),
162 head_dim_factor: Acc,
163) -> RowWise<Acc> {
164 let num_rows = comptime!(state.0.num_rows);
165 let mut max_buf = RowWise::<Acc>::new_min_value(num_rows);
166 let mut sum_buf = RowWise::<Acc>::new_zero(num_rows);
167
168 cmma_to_whitebox_fragment::<Acc>(score);
171
172 score.fragment.scale_and_mask::<M>(head_dim_factor, mask);
173 fragment_row_max::<Acc>(&mut max_buf, &state.0, &score.fragment);
174 score.fragment.exp_diff(&max_buf);
175 fragment_row_sum::<Acc>(&mut sum_buf, &score.fragment);
176
177 let exp_m_diff = state.0.exp_diff(&max_buf);
178 let new_l = exp_m_diff.mul(&state.1).add(&sum_buf);
179
180 write_fragment_into::<Acc, Lhs>(&score.fragment, softmaxed);
184
185 RowWise::copy_from(&mut state.0, &max_buf);
186 RowWise::copy_from(&mut state.1, &new_l);
187
188 exp_m_diff
189}
190
191#[cube]
192fn fragment_softmax<Acc: Float, Lhs: Float, M: Mask>(
193 score: &mut WhiteboxFragment<Acc>,
194 softmaxed: &mut Tile<Lhs, Plane, ReadWrite>,
195 mask: &M,
196 state: &mut (RowWise<Acc>, RowWise<Acc>),
197 head_dim_factor: Acc,
198) -> RowWise<Acc> {
199 let num_rows = comptime!(state.0.num_rows);
200 let mut max_buf = RowWise::<Acc>::new_min_value(num_rows);
201 let mut sum_buf = RowWise::<Acc>::new_zero(num_rows);
202
203 score.scale_and_mask::<M>(head_dim_factor, mask);
204 fragment_row_max::<Acc>(&mut max_buf, &state.0, score);
205 score.exp_diff(&max_buf);
206 fragment_row_sum::<Acc>(&mut sum_buf, score);
207
208 let exp_m_diff = state.0.exp_diff(&max_buf);
209 let new_l = exp_m_diff.mul(&state.1).add(&sum_buf);
210
211 write_fragment_into::<Acc, Lhs>(score, softmaxed);
212
213 RowWise::copy_from(&mut state.0, &max_buf);
214 RowWise::copy_from(&mut state.1, &new_l);
215
216 exp_m_diff
217}
218
219#[cube]
220fn unit_softmax<Acc: Float, Lhs: Float, M: Mask>(
221 score: &mut UnitTile<Acc>,
222 softmaxed: &mut Tile<Lhs, Plane, ReadWrite>,
223 mask: &M,
224 state: &mut (RowWise<Acc>, RowWise<Acc>),
225 head_dim_factor: Acc,
226) -> RowWise<Acc> {
227 let num_rows = comptime!(state.0.num_rows);
228 let mut max_buf = RowWise::<Acc>::new_min_value(num_rows);
229 let mut sum_buf = RowWise::<Acc>::new_zero(num_rows);
230
231 score.scale_and_mask::<M>(head_dim_factor, mask);
232
233 max_buf.copy_from(&state.0);
234 max_buf.max_inplace(&score.rowwise_max());
235
236 score.exp_diff(&max_buf);
237
238 sum_buf.add_inplace(&score.rowwise_sum());
239
240 let exp_m_diff = state.0.exp_diff(&max_buf);
241 let new_l = exp_m_diff.mul(&state.1).add(&sum_buf);
242
243 match softmaxed {
244 Tile::Unit(d) => write_unit_into::<Acc, Lhs>(score, d),
245 Tile::Bounce(_) => panic!("unit_softmax: Bounce destination not supported"),
246 Tile::WhiteboxFragment(_) => {
247 panic!("unit_softmax: WhiteboxFragment destination not supported")
248 }
249 Tile::Register(_) => panic!("unit_softmax: Register destination not supported"),
250 _ => panic!("unit_softmax: unsupported softmaxed variant"),
251 }
252
253 RowWise::copy_from(&mut state.0, &max_buf);
254 RowWise::copy_from(&mut state.1, &new_l);
255
256 exp_m_diff
257}
258
259#[cube]
260fn register_softmax<Acc: Float, Lhs: Float, M: Mask>(
261 score: &mut RegisterTile<Acc>,
262 softmaxed: &mut Tile<Lhs, Plane, ReadWrite>,
263 mask: &M,
264 state: &mut (RowWise<Acc>, RowWise<Acc>),
265 head_dim_factor: Acc,
266) -> RowWise<Acc> {
267 let m = comptime!(score.config.tile_size.m());
268 let n = comptime!(score.config.tile_size.n());
269 let num_rows = comptime!(state.0.num_rows);
270 let threshold = Acc::new(LOGIT_MASKED);
271
272 let mut max_buf = RowWise::<Acc>::new_min_value(num_rows);
273 let mut sum_buf = RowWise::<Acc>::new_zero(num_rows);
274
275 for r in 0..m {
276 let row_offset = r * n;
277 for c in 0..n {
278 let idx = (row_offset + c) as usize;
279 score.data[idx] = score.data[idx] * head_dim_factor
280 + Acc::cast_from(mask.should_mask((r, c))) * Acc::min_value();
281 }
282 }
283
284 max_buf.copy_from(&state.0);
285 for r in 0..m as usize {
286 let row_offset = r as u32 * n;
287 let mut val = Acc::min_value();
288 for c in 0..n {
289 val = max(val, score.data[(row_offset + c) as usize]);
290 }
291 max_buf.vals[r] = max(max_buf.vals[r], val);
292 }
293
294 for r in 0..m as usize {
295 let row_offset = r as u32 * n;
296 let val = max_buf.vals[r];
297 let safe_val = clamp_min(val, threshold);
298 let not_masked = Acc::cast_from(val >= threshold);
299 for c in 0..n {
300 let idx = (row_offset + c) as usize;
301 score.data[idx] = not_masked * (score.data[idx] - safe_val).exp();
302 }
303 }
304
305 for r in 0..m as usize {
306 let row_offset = r as u32 * n;
307 let mut val = Acc::from_int(0);
308 for c in 0..n {
309 val += score.data[(row_offset + c) as usize];
310 }
311 sum_buf.vals[r] += val;
312 }
313
314 let exp_m_diff = state.0.exp_diff(&max_buf);
315 let new_l = exp_m_diff.mul(&state.1).add(&sum_buf);
316
317 match softmaxed {
318 Tile::Register(d) => write_register_into::<Acc, Lhs>(score, d),
319 Tile::Bounce(_) => panic!("register_softmax: Bounce destination not supported"),
320 Tile::WhiteboxFragment(_) => {
321 panic!("register_softmax: WhiteboxFragment destination not supported")
322 }
323 Tile::Unit(_) => panic!("register_softmax: Unit destination not supported"),
324 _ => panic!("register_softmax: unsupported softmaxed variant"),
325 }
326
327 RowWise::copy_from(&mut state.0, &max_buf);
328 RowWise::copy_from(&mut state.1, &new_l);
329
330 exp_m_diff
331}
332
333#[cube]
338fn write_fragment_into<Acc: Float, Lhs: Float>(
339 src: &WhiteboxFragment<Acc>,
340 softmaxed: &mut Tile<Lhs, Plane, ReadWrite>,
341) {
342 match softmaxed {
343 Tile::Bounce(d) => {
344 let stride = comptime!(d.cmma.tile_size.n());
345 src.store_to(&mut d.smem);
346 sync_cube();
347 cubecl::cmma::load(&d.cmma.matrix, &d.smem.to_slice(), stride);
348 }
349 Tile::WhiteboxFragment(d) => {
350 let total = comptime!(src.layout.unit_size.0 * src.layout.unit_size.1);
351 for i in 0..total {
352 d.array[i as usize] = Lhs::cast_from(src.array[i as usize]);
353 }
354 }
355 _ => panic!("write_fragment_into: unsupported softmaxed variant"),
356 }
357}
358
359#[cube]
360fn write_unit_into<Acc: Float, Lhs: Float>(src: &UnitTile<Acc>, dest: &mut UnitTile<Lhs>) {
361 let total = comptime!(src.layout.num_rows * src.layout.num_cols);
362 for i in 0..total {
363 dest.data[i as usize] = Lhs::cast_from(src.data[i as usize]);
364 }
365}
366
367#[cube]
368fn write_register_into<Acc: Float, Lhs: Float>(
369 src: &RegisterTile<Acc>,
370 dest: &mut RegisterTile<Lhs>,
371) {
372 let m = comptime!(src.config.tile_size.m());
373 let n = comptime!(src.config.tile_size.n());
374 for i in 0..m * n {
375 dest.data[i as usize] = Lhs::cast_from(src.data[i as usize]);
376 }
377}
378
379#[cube]
380fn register_rowwise_scale<E: Float>(tile: &mut RegisterTile<E>, scale: &RowWise<E>) {
381 let m = comptime!(tile.config.tile_size.m());
382 let n = comptime!(tile.config.tile_size.n());
383 for r in 0..m as usize {
384 let row_offset = r as u32 * n;
385 for c in 0..n {
386 let idx = (row_offset + c) as usize;
387 tile.data[idx] = tile.data[idx] * scale.vals[r];
388 }
389 }
390}