Skip to main content

cubek_std/tile/compute/
softmax.rs

1use cubecl;
2use cubecl::prelude::*;
3
4use crate::StageIdent;
5use crate::tile::compute::copy::{cmma_to_whitebox_fragment, whitebox_fragment_to_cmma};
6use crate::tile::compute::mask::{Mask, MaskExpand};
7use crate::tile::compute::rowwise::reducer::{fragment_row_max, fragment_row_sum};
8use crate::tile::data::{
9    BounceTile, InnerLayout, RegisterTile, RowWise, RowWiseExpand, UnitTile, WhiteboxFragment,
10};
11use crate::tile::{Plane, Tile, TileExpand};
12
13/// Logits below this are considered masked (effectively -inf).
14/// Value chosen to fit within f16 range (~-65,504 max).
15pub const LOGIT_MASKED: f32 = -6e4;
16
17/// Any value smaller than this is considered numerically zero (used for
18/// fully-masked rows or tiny contributions). Value chosen to be above f16
19/// smallest normal (~6.1e-5).
20pub const FULLY_MASKED_ROW_THRESHOLD: f32 = 1e-4;
21
22#[cube]
23impl<E: Float> RowWise<E> {
24    /// Replaces each value `v` (v >= 0) in a row with `1/v`.
25    ///
26    /// If `v = 0`, the result is set to `0` instead of `1/0`.
27    /// This occurs when the entire row is masked, meaning it should
28    /// contribute no information, and ensures numerical stability.
29    pub fn recip_inplace(&mut self) {
30        for i in 0..self.num_rows {
31            let row_val = self.vals[i];
32
33            let epsilon = E::new(FULLY_MASKED_ROW_THRESHOLD);
34            let not_masked = E::cast_from(row_val >= epsilon);
35            let safe_val = clamp_min(row_val, epsilon);
36            let recip = safe_val.recip();
37            self.vals[i] = not_masked * recip;
38        }
39    }
40}
41
42/// Comptime descriptor for the row-shape used by online softmax. Determines
43/// how many rows per unit each running-state vector holds.
44///
45/// - `Direct { num_rows_per_unit }` — used with `Tile::Unit` or `Tile::Register`
46///   when each unit owns its own copy of the tile.
47/// - `Plane { inner_layout }` — used with `Tile::WhiteboxFragment` or `Tile::Bounce`,
48///   where the inner layout determines how many rows each unit covers.
49#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
50pub enum SoftmaxKind {
51    Direct { num_rows_per_unit: u32 },
52    Plane { inner_layout: InnerLayout },
53}
54
55impl SoftmaxKind {
56    pub const fn num_rows_per_unit(&self) -> u32 {
57        match self {
58            SoftmaxKind::Direct { num_rows_per_unit } => *num_rows_per_unit,
59            SoftmaxKind::Plane { inner_layout } => match inner_layout {
60                InnerLayout::Contiguous => 1,
61                InnerLayout::SplitRows => 2,
62            },
63        }
64    }
65}
66
67/// Initial running state `(m, l)` for the online softmax over a single tile row.
68#[cube]
69pub fn softmax_init_state<E: Float>(
70    #[comptime] num_rows_per_unit: u32,
71) -> (RowWise<E>, RowWise<E>) {
72    (
73        RowWise::<E>::new_min_value(num_rows_per_unit as usize),
74        RowWise::<E>::new_zero(num_rows_per_unit as usize),
75    )
76}
77
78#[cube]
79impl<Acc: Float> Tile<Acc, Plane, ReadWrite> {
80    /// Online softmax update over a single attention tile, fused with the
81    /// precision-cast write into a value-matmul lhs tile. Dispatches on the
82    /// score variant — each variant owns the algorithm best suited to its
83    /// storage and is polymorphic in the destination: a `Bounce` score can
84    /// be written into any compatible softmaxed tile (Bounce, fragment, …),
85    /// not just another `Bounce`.
86    ///
87    /// Returns the per-row scaling factor `α_i = e^(m_old - m_new)` used by the
88    /// caller to rescale running output accumulators.
89    pub fn softmax<Lhs: Float, M: Mask>(
90        &mut self,
91        mask: &M,
92        softmaxed_tile: &mut Tile<Lhs, Plane, ReadWrite>,
93        state: &mut (RowWise<Acc>, RowWise<Acc>),
94        head_dim_factor: Acc,
95    ) -> RowWise<Acc> {
96        match self {
97            Tile::Bounce(s) => {
98                bounce_softmax::<Acc, Lhs, M>(s, softmaxed_tile, mask, state, head_dim_factor)
99            }
100            Tile::WhiteboxFragment(s) => {
101                fragment_softmax::<Acc, Lhs, M>(s, softmaxed_tile, mask, state, head_dim_factor)
102            }
103            Tile::Unit(s) => {
104                unit_softmax::<Acc, Lhs, M>(s, softmaxed_tile, mask, state, head_dim_factor)
105            }
106            Tile::Register(s) => {
107                register_softmax::<Acc, Lhs, M>(s, softmaxed_tile, mask, state, head_dim_factor)
108            }
109            _ => panic!("softmax: unsupported score variant"),
110        }
111    }
112
113    /// Multiplies each row of `self` by the corresponding `scale[r]`. The
114    /// `Bounce` arm round-trips through smem so the cmma fragment is current
115    /// for the next mma; the others operate in place on their native storage.
116    pub fn scale_mul<SM: Float>(&mut self, scale: &RowWise<SM>) {
117        let scale_acc = RowWise::<SM>::cast_from::<Acc>(scale);
118        match self {
119            Tile::Bounce(b) => {
120                cmma_to_whitebox_fragment::<Acc>(b);
121                b.fragment.rowwise_scale(&scale_acc);
122                whitebox_fragment_to_cmma::<Acc>(b);
123            }
124            Tile::WhiteboxFragment(t) => t.rowwise_scale(&scale_acc),
125            Tile::Unit(t) => t.rowwise_scale(&scale_acc),
126            Tile::Register(t) => register_rowwise_scale::<Acc>(t, &scale_acc),
127            _ => panic!("scale_mul: unsupported tile variant"),
128        }
129    }
130
131    /// Divides each row of `self` by the corresponding `running_state_l[r]`,
132    /// guarding against zero (a fully-masked row stays zero).
133    pub fn scale_div<SM: Float>(&mut self, running_state_l: &RowWise<SM>) {
134        let mut scale = RowWise::<SM>::cast_from::<Acc>(running_state_l);
135        scale.recip_inplace();
136        match self {
137            Tile::Bounce(b) => {
138                cmma_to_whitebox_fragment::<Acc>(b);
139                b.fragment.rowwise_scale(&scale);
140                whitebox_fragment_to_cmma::<Acc>(b);
141            }
142            Tile::WhiteboxFragment(t) => t.rowwise_scale(&scale),
143            Tile::Unit(t) => t.rowwise_scale(&scale),
144            Tile::Register(t) => register_rowwise_scale::<Acc>(t, &scale),
145            _ => panic!("scale_div: unsupported tile variant"),
146        }
147    }
148
149    /// Copies `self` into `dest` (a stage-side strided/shared tile in the
150    /// caller's downstream write path).
151    pub fn write_results<DE: Float, DS: Size>(&self, dest: &mut Tile<DE, Plane, ReadWrite>) {
152        dest.copy_from::<Acc, DS, Acc, Acc, Acc, ReadWrite>(self, StageIdent::Out);
153    }
154}
155
156#[cube]
157fn bounce_softmax<Acc: Float, Lhs: Float, M: Mask>(
158    score: &mut BounceTile<Acc>,
159    softmaxed: &mut Tile<Lhs, Plane, ReadWrite>,
160    mask: &M,
161    state: &mut (RowWise<Acc>, RowWise<Acc>),
162    head_dim_factor: Acc,
163) -> RowWise<Acc> {
164    let num_rows = comptime!(state.0.num_rows);
165    let mut max_buf = RowWise::<Acc>::new_min_value(num_rows);
166    let mut sum_buf = RowWise::<Acc>::new_zero(num_rows);
167
168    // cmma → fragment once at entry so all subsequent ops read/write the
169    // fragment view.
170    cmma_to_whitebox_fragment::<Acc>(score);
171
172    score.fragment.scale_and_mask::<M>(head_dim_factor, mask);
173    fragment_row_max::<Acc>(&mut max_buf, &state.0, &score.fragment);
174    score.fragment.exp_diff(&max_buf);
175    fragment_row_sum::<Acc>(&mut sum_buf, &score.fragment);
176
177    let exp_m_diff = state.0.exp_diff(&max_buf);
178    let new_l = exp_m_diff.mul(&state.1).add(&sum_buf);
179
180    // The post-exp values are still in `score.fragment` — we skip
181    // `whitebox_fragment_to_cmma` on score (its cmma is cleared next
182    // iteration) and stream the values straight into `softmaxed`.
183    write_fragment_into::<Acc, Lhs>(&score.fragment, softmaxed);
184
185    RowWise::copy_from(&mut state.0, &max_buf);
186    RowWise::copy_from(&mut state.1, &new_l);
187
188    exp_m_diff
189}
190
191#[cube]
192fn fragment_softmax<Acc: Float, Lhs: Float, M: Mask>(
193    score: &mut WhiteboxFragment<Acc>,
194    softmaxed: &mut Tile<Lhs, Plane, ReadWrite>,
195    mask: &M,
196    state: &mut (RowWise<Acc>, RowWise<Acc>),
197    head_dim_factor: Acc,
198) -> RowWise<Acc> {
199    let num_rows = comptime!(state.0.num_rows);
200    let mut max_buf = RowWise::<Acc>::new_min_value(num_rows);
201    let mut sum_buf = RowWise::<Acc>::new_zero(num_rows);
202
203    score.scale_and_mask::<M>(head_dim_factor, mask);
204    fragment_row_max::<Acc>(&mut max_buf, &state.0, score);
205    score.exp_diff(&max_buf);
206    fragment_row_sum::<Acc>(&mut sum_buf, score);
207
208    let exp_m_diff = state.0.exp_diff(&max_buf);
209    let new_l = exp_m_diff.mul(&state.1).add(&sum_buf);
210
211    write_fragment_into::<Acc, Lhs>(score, softmaxed);
212
213    RowWise::copy_from(&mut state.0, &max_buf);
214    RowWise::copy_from(&mut state.1, &new_l);
215
216    exp_m_diff
217}
218
219#[cube]
220fn unit_softmax<Acc: Float, Lhs: Float, M: Mask>(
221    score: &mut UnitTile<Acc>,
222    softmaxed: &mut Tile<Lhs, Plane, ReadWrite>,
223    mask: &M,
224    state: &mut (RowWise<Acc>, RowWise<Acc>),
225    head_dim_factor: Acc,
226) -> RowWise<Acc> {
227    let num_rows = comptime!(state.0.num_rows);
228    let mut max_buf = RowWise::<Acc>::new_min_value(num_rows);
229    let mut sum_buf = RowWise::<Acc>::new_zero(num_rows);
230
231    score.scale_and_mask::<M>(head_dim_factor, mask);
232
233    max_buf.copy_from(&state.0);
234    max_buf.max_inplace(&score.rowwise_max());
235
236    score.exp_diff(&max_buf);
237
238    sum_buf.add_inplace(&score.rowwise_sum());
239
240    let exp_m_diff = state.0.exp_diff(&max_buf);
241    let new_l = exp_m_diff.mul(&state.1).add(&sum_buf);
242
243    match softmaxed {
244        Tile::Unit(d) => write_unit_into::<Acc, Lhs>(score, d),
245        Tile::Bounce(_) => panic!("unit_softmax: Bounce destination not supported"),
246        Tile::WhiteboxFragment(_) => {
247            panic!("unit_softmax: WhiteboxFragment destination not supported")
248        }
249        Tile::Register(_) => panic!("unit_softmax: Register destination not supported"),
250        _ => panic!("unit_softmax: unsupported softmaxed variant"),
251    }
252
253    RowWise::copy_from(&mut state.0, &max_buf);
254    RowWise::copy_from(&mut state.1, &new_l);
255
256    exp_m_diff
257}
258
259#[cube]
260fn register_softmax<Acc: Float, Lhs: Float, M: Mask>(
261    score: &mut RegisterTile<Acc>,
262    softmaxed: &mut Tile<Lhs, Plane, ReadWrite>,
263    mask: &M,
264    state: &mut (RowWise<Acc>, RowWise<Acc>),
265    head_dim_factor: Acc,
266) -> RowWise<Acc> {
267    let m = comptime!(score.config.tile_size.m());
268    let n = comptime!(score.config.tile_size.n());
269    let num_rows = comptime!(state.0.num_rows);
270    let threshold = Acc::new(LOGIT_MASKED);
271
272    let mut max_buf = RowWise::<Acc>::new_min_value(num_rows);
273    let mut sum_buf = RowWise::<Acc>::new_zero(num_rows);
274
275    for r in 0..m {
276        let row_offset = r * n;
277        for c in 0..n {
278            let idx = (row_offset + c) as usize;
279            score.data[idx] = score.data[idx] * head_dim_factor
280                + Acc::cast_from(mask.should_mask((r, c))) * Acc::min_value();
281        }
282    }
283
284    max_buf.copy_from(&state.0);
285    for r in 0..m as usize {
286        let row_offset = r as u32 * n;
287        let mut val = Acc::min_value();
288        for c in 0..n {
289            val = max(val, score.data[(row_offset + c) as usize]);
290        }
291        max_buf.vals[r] = max(max_buf.vals[r], val);
292    }
293
294    for r in 0..m as usize {
295        let row_offset = r as u32 * n;
296        let val = max_buf.vals[r];
297        let safe_val = clamp_min(val, threshold);
298        let not_masked = Acc::cast_from(val >= threshold);
299        for c in 0..n {
300            let idx = (row_offset + c) as usize;
301            score.data[idx] = not_masked * (score.data[idx] - safe_val).exp();
302        }
303    }
304
305    for r in 0..m as usize {
306        let row_offset = r as u32 * n;
307        let mut val = Acc::from_int(0);
308        for c in 0..n {
309            val += score.data[(row_offset + c) as usize];
310        }
311        sum_buf.vals[r] += val;
312    }
313
314    let exp_m_diff = state.0.exp_diff(&max_buf);
315    let new_l = exp_m_diff.mul(&state.1).add(&sum_buf);
316
317    match softmaxed {
318        Tile::Register(d) => write_register_into::<Acc, Lhs>(score, d),
319        Tile::Bounce(_) => panic!("register_softmax: Bounce destination not supported"),
320        Tile::WhiteboxFragment(_) => {
321            panic!("register_softmax: WhiteboxFragment destination not supported")
322        }
323        Tile::Unit(_) => panic!("register_softmax: Unit destination not supported"),
324        _ => panic!("register_softmax: unsupported softmaxed variant"),
325    }
326
327    RowWise::copy_from(&mut state.0, &max_buf);
328    RowWise::copy_from(&mut state.1, &new_l);
329
330    exp_m_diff
331}
332
333/// Writes a `WhiteboxFragment` of post-softmax values into `softmaxed`,
334/// dispatching on the destination variant. The source is plane-fragmented so
335/// each unit only writes its slice; for a `Bounce` destination this routes
336/// directly through its smem into its cmma fragment.
337#[cube]
338fn write_fragment_into<Acc: Float, Lhs: Float>(
339    src: &WhiteboxFragment<Acc>,
340    softmaxed: &mut Tile<Lhs, Plane, ReadWrite>,
341) {
342    match softmaxed {
343        Tile::Bounce(d) => {
344            let stride = comptime!(d.cmma.tile_size.n());
345            src.store_to(&mut d.smem);
346            sync_cube();
347            cubecl::cmma::load(&d.cmma.matrix, &d.smem.to_slice(), stride);
348        }
349        Tile::WhiteboxFragment(d) => {
350            let total = comptime!(src.layout.unit_size.0 * src.layout.unit_size.1);
351            for i in 0..total {
352                d.array[i as usize] = Lhs::cast_from(src.array[i as usize]);
353            }
354        }
355        _ => panic!("write_fragment_into: unsupported softmaxed variant"),
356    }
357}
358
359#[cube]
360fn write_unit_into<Acc: Float, Lhs: Float>(src: &UnitTile<Acc>, dest: &mut UnitTile<Lhs>) {
361    let total = comptime!(src.layout.num_rows * src.layout.num_cols);
362    for i in 0..total {
363        dest.data[i as usize] = Lhs::cast_from(src.data[i as usize]);
364    }
365}
366
367#[cube]
368fn write_register_into<Acc: Float, Lhs: Float>(
369    src: &RegisterTile<Acc>,
370    dest: &mut RegisterTile<Lhs>,
371) {
372    let m = comptime!(src.config.tile_size.m());
373    let n = comptime!(src.config.tile_size.n());
374    for i in 0..m * n {
375        dest.data[i as usize] = Lhs::cast_from(src.data[i as usize]);
376    }
377}
378
379#[cube]
380fn register_rowwise_scale<E: Float>(tile: &mut RegisterTile<E>, scale: &RowWise<E>) {
381    let m = comptime!(tile.config.tile_size.m());
382    let n = comptime!(tile.config.tile_size.n());
383    for r in 0..m as usize {
384        let row_offset = r as u32 * n;
385        for c in 0..n {
386            let idx = (row_offset + c) as usize;
387            tile.data[idx] = tile.data[idx] * scale.vals[r];
388        }
389    }
390}