Skip to main content

cubek_std/tile/ops/
softmax.rs

1use cubecl;
2use cubecl::prelude::*;
3
4use crate::StageIdent;
5use crate::tile::ops::tile_ops::{cmma_to_local, local_to_cmma};
6use crate::tile::ops::{Mask, RowWise};
7use crate::tile::variants::InnerLayout;
8use crate::tile::{Plane, Tile, TileExpand};
9
10/// Comptime descriptor for the row-shape used by online softmax. Determines
11/// how many rows per unit each running-state vector holds.
12///
13/// - `Direct { num_rows_per_unit }` — used with `Tile::Unit` or `Tile::Register`
14///   when each unit owns its own copy of the tile.
15/// - `Plane { inner_layout }` — used with `Tile::Local` or `Tile::Bounce`,
16///   where the inner layout determines how many rows each unit covers.
17#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
18pub enum SoftmaxKind {
19    Direct { num_rows_per_unit: u32 },
20    Plane { inner_layout: InnerLayout },
21}
22
23impl SoftmaxKind {
24    pub const fn num_rows_per_unit(&self) -> u32 {
25        match self {
26            SoftmaxKind::Direct { num_rows_per_unit } => *num_rows_per_unit,
27            SoftmaxKind::Plane { inner_layout } => match inner_layout {
28                InnerLayout::Contiguous => 1,
29                InnerLayout::SplitRows => 2,
30            },
31        }
32    }
33}
34
35/// Initial running state `(m, l)` for the online softmax over a single tile row.
36#[cube]
37pub fn softmax_init_state<E: Float>(
38    #[comptime] num_rows_per_unit: u32,
39) -> (RowWise<E>, RowWise<E>) {
40    (
41        RowWise::<E>::new_min_value(num_rows_per_unit as usize),
42        RowWise::<E>::new_zero(num_rows_per_unit as usize),
43    )
44}
45
46#[cube]
47impl<Acc: Float> Tile<Acc, Plane, ReadWrite> {
48    /// Online softmax update over a single attention tile, fused with the
49    /// precision-cast write into a value-matmul lhs tile.
50    ///
51    /// For `Tile::Bounce`, the smem ↔ cmma sync is internal: `cmma_to_local`
52    /// once at the start so all subsequent ops read/write the local view, and
53    /// the softmaxed values are streamed straight into the destination's cmma
54    /// fragment via its own smem (no `local_to_cmma` for `self`, which is
55    /// cleared next iteration).
56    ///
57    /// Returns the per-row scaling factor `α_i = e^(m_old - m_new)` used by the
58    /// caller to rescale running output accumulators.
59    pub fn softmax<Lhs: Float, M: Mask>(
60        &mut self,
61        mask: &M,
62        softmaxed_tile: &mut Tile<Lhs, Plane, ReadWrite>,
63        state: &mut (RowWise<Acc>, RowWise<Acc>),
64        head_dim_factor: Acc,
65    ) -> RowWise<Acc> {
66        let num_rows = comptime!(state.0.num_rows);
67        let mut max_buf = RowWise::<Acc>::new_min_value(num_rows);
68        let mut sum_buf = RowWise::<Acc>::new_zero(num_rows);
69
70        bounce_in(self);
71
72        self.scale_and_mask::<M>(head_dim_factor, mask);
73        self.row_max(&mut max_buf, &state.0);
74        self.exp_diff(&max_buf);
75        self.row_sum(&mut sum_buf);
76
77        let exp_m_diff = state.0.exp_diff(&max_buf);
78        let new_l = exp_m_diff.mul(&state.1).add(&sum_buf);
79
80        write_softmaxed(self, softmaxed_tile);
81
82        RowWise::copy_from(&mut state.0, &max_buf);
83        RowWise::copy_from(&mut state.1, &new_l);
84
85        exp_m_diff
86    }
87
88    /// Multiplies each row of `self` by the corresponding `scale[r]`. For
89    /// `Tile::Bounce`, this round-trips through smem so the cmma fragment is
90    /// up to date for the next mma.
91    pub fn scale_mul<SM: Float>(&mut self, scale: &RowWise<SM>) {
92        let scale_acc = RowWise::<SM>::cast_from::<Acc>(scale);
93        bounce_in(self);
94        self.rowwise_scale(&scale_acc);
95        bounce_out(self);
96    }
97
98    /// Divides each row of `self` by the corresponding `running_state_l[r]`,
99    /// guarding against zero (a fully-masked row stays zero).
100    pub fn scale_div<SM: Float>(&mut self, running_state_l: &RowWise<SM>) {
101        let mut scale = RowWise::<SM>::cast_from::<Acc>(running_state_l);
102        scale.recip_inplace();
103        bounce_in(self);
104        self.rowwise_scale(&scale);
105        bounce_out(self);
106    }
107
108    /// Copies `self` into `dest` (a stage-side strided/shared tile in the
109    /// caller's downstream write path).
110    pub fn write_results<DE: Float, DS: Size>(&self, dest: &mut Tile<DE, Plane, ReadWrite>) {
111        dest.copy_from::<Acc, DS, Acc, Acc, Acc, ReadWrite>(self, StageIdent::Out);
112    }
113}
114
115#[cube]
116fn bounce_in<E: Float>(tile: &mut Tile<E, Plane, ReadWrite>) {
117    match tile {
118        Tile::Bounce(b) => {
119            cmma_to_local::<E>(b);
120        }
121        Tile::Unit(_) => {}
122        Tile::Local(_) => {}
123        Tile::Register(_) => {}
124        _ => panic!("bounce_in: unsupported tile variant"),
125    }
126}
127
128#[cube]
129fn bounce_out<E: Float>(tile: &mut Tile<E, Plane, ReadWrite>) {
130    match tile {
131        Tile::Bounce(b) => {
132            local_to_cmma::<E>(b);
133        }
134        Tile::Unit(_) => {}
135        Tile::Local(_) => {}
136        Tile::Register(_) => {}
137        _ => panic!("bounce_out: unsupported tile variant"),
138    }
139}
140
141#[cube]
142fn write_softmaxed<Acc: Float, Lhs: Float>(
143    score_tile: &Tile<Acc, Plane, ReadWrite>,
144    softmaxed_tile: &mut Tile<Lhs, Plane, ReadWrite>,
145) {
146    match (score_tile, softmaxed_tile) {
147        (Tile::Register(s), Tile::Register(d)) => {
148            let m = comptime!(s.config.tile_size.m());
149            let n = comptime!(s.config.tile_size.n());
150            for i in 0..m * n {
151                d.data[i as usize] = Lhs::cast_from(s.data[i as usize]);
152            }
153        }
154        (Tile::Unit(s), Tile::Unit(d)) => {
155            let m = comptime!(s.layout.num_rows);
156            let n = comptime!(s.layout.num_cols);
157            for i in 0..m * n {
158                d.data[i as usize] = Lhs::cast_from(s.data[i as usize]);
159            }
160        }
161        (Tile::Bounce(s), Tile::Bounce(d)) => {
162            // score's LocalTile already holds the post-exp_diff values; route
163            // through `softmaxed`'s smem to avoid clobbering score's smem and
164            // load directly into softmaxed's cmma fragment.
165            let stride = comptime!(d.cmma.tile_size.n());
166            s.local.store_to(&mut d.smem);
167            sync_cube();
168            cubecl::cmma::load(&d.cmma.matrix, &d.smem.to_slice(), stride);
169        }
170        _ => panic!("write_softmaxed: incompatible tile pair"),
171    }
172}