cubek_std/tile/ops/
softmax.rs1use cubecl;
2use cubecl::prelude::*;
3
4use crate::StageIdent;
5use crate::tile::ops::tile_ops::{cmma_to_local, local_to_cmma};
6use crate::tile::ops::{Mask, RowWise};
7use crate::tile::variants::InnerLayout;
8use crate::tile::{Plane, Tile, TileExpand};
9
10#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
18pub enum SoftmaxKind {
19 Direct { num_rows_per_unit: u32 },
20 Plane { inner_layout: InnerLayout },
21}
22
23impl SoftmaxKind {
24 pub const fn num_rows_per_unit(&self) -> u32 {
25 match self {
26 SoftmaxKind::Direct { num_rows_per_unit } => *num_rows_per_unit,
27 SoftmaxKind::Plane { inner_layout } => match inner_layout {
28 InnerLayout::Contiguous => 1,
29 InnerLayout::SplitRows => 2,
30 },
31 }
32 }
33}
34
35#[cube]
37pub fn softmax_init_state<E: Float>(
38 #[comptime] num_rows_per_unit: u32,
39) -> (RowWise<E>, RowWise<E>) {
40 (
41 RowWise::<E>::new_min_value(num_rows_per_unit as usize),
42 RowWise::<E>::new_zero(num_rows_per_unit as usize),
43 )
44}
45
46#[cube]
47impl<Acc: Float> Tile<Acc, Plane, ReadWrite> {
48 pub fn softmax<Lhs: Float, M: Mask>(
60 &mut self,
61 mask: &M,
62 softmaxed_tile: &mut Tile<Lhs, Plane, ReadWrite>,
63 state: &mut (RowWise<Acc>, RowWise<Acc>),
64 head_dim_factor: Acc,
65 ) -> RowWise<Acc> {
66 let num_rows = comptime!(state.0.num_rows);
67 let mut max_buf = RowWise::<Acc>::new_min_value(num_rows);
68 let mut sum_buf = RowWise::<Acc>::new_zero(num_rows);
69
70 bounce_in(self);
71
72 self.scale_and_mask::<M>(head_dim_factor, mask);
73 self.row_max(&mut max_buf, &state.0);
74 self.exp_diff(&max_buf);
75 self.row_sum(&mut sum_buf);
76
77 let exp_m_diff = state.0.exp_diff(&max_buf);
78 let new_l = exp_m_diff.mul(&state.1).add(&sum_buf);
79
80 write_softmaxed(self, softmaxed_tile);
81
82 RowWise::copy_from(&mut state.0, &max_buf);
83 RowWise::copy_from(&mut state.1, &new_l);
84
85 exp_m_diff
86 }
87
88 pub fn scale_mul<SM: Float>(&mut self, scale: &RowWise<SM>) {
92 let scale_acc = RowWise::<SM>::cast_from::<Acc>(scale);
93 bounce_in(self);
94 self.rowwise_scale(&scale_acc);
95 bounce_out(self);
96 }
97
98 pub fn scale_div<SM: Float>(&mut self, running_state_l: &RowWise<SM>) {
101 let mut scale = RowWise::<SM>::cast_from::<Acc>(running_state_l);
102 scale.recip_inplace();
103 bounce_in(self);
104 self.rowwise_scale(&scale);
105 bounce_out(self);
106 }
107
108 pub fn write_results<DE: Float, DS: Size>(&self, dest: &mut Tile<DE, Plane, ReadWrite>) {
111 dest.copy_from::<Acc, DS, Acc, Acc, Acc, ReadWrite>(self, StageIdent::Out);
112 }
113}
114
115#[cube]
116fn bounce_in<E: Float>(tile: &mut Tile<E, Plane, ReadWrite>) {
117 match tile {
118 Tile::Bounce(b) => {
119 cmma_to_local::<E>(b);
120 }
121 Tile::Unit(_) => {}
122 Tile::Local(_) => {}
123 Tile::Register(_) => {}
124 _ => panic!("bounce_in: unsupported tile variant"),
125 }
126}
127
128#[cube]
129fn bounce_out<E: Float>(tile: &mut Tile<E, Plane, ReadWrite>) {
130 match tile {
131 Tile::Bounce(b) => {
132 local_to_cmma::<E>(b);
133 }
134 Tile::Unit(_) => {}
135 Tile::Local(_) => {}
136 Tile::Register(_) => {}
137 _ => panic!("bounce_out: unsupported tile variant"),
138 }
139}
140
141#[cube]
142fn write_softmaxed<Acc: Float, Lhs: Float>(
143 score_tile: &Tile<Acc, Plane, ReadWrite>,
144 softmaxed_tile: &mut Tile<Lhs, Plane, ReadWrite>,
145) {
146 match (score_tile, softmaxed_tile) {
147 (Tile::Register(s), Tile::Register(d)) => {
148 let m = comptime!(s.config.tile_size.m());
149 let n = comptime!(s.config.tile_size.n());
150 for i in 0..m * n {
151 d.data[i as usize] = Lhs::cast_from(s.data[i as usize]);
152 }
153 }
154 (Tile::Unit(s), Tile::Unit(d)) => {
155 let m = comptime!(s.layout.num_rows);
156 let n = comptime!(s.layout.num_cols);
157 for i in 0..m * n {
158 d.data[i as usize] = Lhs::cast_from(s.data[i as usize]);
159 }
160 }
161 (Tile::Bounce(s), Tile::Bounce(d)) => {
162 let stride = comptime!(d.cmma.tile_size.n());
166 s.local.store_to(&mut d.smem);
167 sync_cube();
168 cubecl::cmma::load(&d.cmma.matrix, &d.smem.to_slice(), stride);
169 }
170 _ => panic!("write_softmaxed: incompatible tile pair"),
171 }
172}