Skip to main content

cubek_std/tile/ops/
tile_ops.rs

1use cubecl;
2use cubecl::prelude::*;
3
4use crate::tile::ops::broadcast_reducer::{local_row_max, local_row_sum};
5use crate::tile::ops::{LOGIT_MASKED, Mask, MaskExpand, RowWise};
6use crate::tile::variants::BounceTile;
7use crate::tile::{Plane, Tile, TileExpand};
8
9/// Row-wise primitives on a `Tile<E, Plane, ReadWrite>` used for attention's
10/// online softmax and output scaling. Dispatch happens per-variant:
11/// - `Tile::Unit` — each unit holds its own copy of the tile, ops run in
12///   registers.
13/// - `Tile::Local` — the tile is fragmented across plane units, row-reductions
14///   use `plane_shuffle`.
15/// - `Tile::Bounce` — same as `Local` but the underlying compute fragment
16///   (cmma) is opaque. The row-wise ops here read/write the inner `LocalTile`;
17///   the smem ↔ cmma synchronization is driven by the higher-level
18///   `softmax` / `scale_mul` / `scale_div` methods (see `ops/softmax.rs`).
19/// - `Tile::Register` — kept for the legacy direct-register attention path.
20#[cube]
21impl<E: Float> Tile<E, Plane, ReadWrite> {
22    pub fn row_max(&self, acc: &mut RowWise<E>, base: &RowWise<E>) {
23        match self {
24            Tile::Unit(t) => {
25                acc.copy_from(base);
26                let m = comptime!(t.layout.num_rows);
27                let n = comptime!(t.layout.num_cols);
28                for r in 0..m as usize {
29                    let row_offset = r as u32 * n;
30                    let mut val = E::min_value();
31                    for c in 0..n {
32                        val = max(val, t.data[(row_offset + c) as usize]);
33                    }
34                    acc.vals[r] = max(acc.vals[r], val);
35                }
36            }
37            Tile::Local(t) => {
38                local_row_max::<E>(acc, base, t);
39            }
40            Tile::Bounce(b) => {
41                local_row_max::<E>(acc, base, &b.local);
42            }
43            Tile::Register(t) => {
44                acc.copy_from(base);
45                let m = comptime!(t.config.tile_size.m());
46                let n = comptime!(t.config.tile_size.n());
47                for r in 0..m as usize {
48                    let row_offset = r as u32 * n;
49                    let mut val = E::min_value();
50                    for c in 0..n {
51                        val = max(val, t.data[(row_offset + c) as usize]);
52                    }
53                    acc.vals[r] = max(acc.vals[r], val);
54                }
55            }
56            _ => panic!("row_max: unsupported tile variant"),
57        }
58    }
59
60    pub fn row_sum(&self, acc: &mut RowWise<E>) {
61        match self {
62            Tile::Unit(t) => {
63                acc.fill(E::from_int(0));
64                let m = comptime!(t.layout.num_rows);
65                let n = comptime!(t.layout.num_cols);
66                for r in 0..m as usize {
67                    let row_offset = r as u32 * n;
68                    let mut val = E::from_int(0);
69                    for c in 0..n {
70                        val += t.data[(row_offset + c) as usize];
71                    }
72                    acc.vals[r] += val;
73                }
74            }
75            Tile::Local(t) => {
76                local_row_sum::<E>(acc, t);
77            }
78            Tile::Bounce(b) => {
79                local_row_sum::<E>(acc, &b.local);
80            }
81            Tile::Register(t) => {
82                acc.fill(E::from_int(0));
83                let m = comptime!(t.config.tile_size.m());
84                let n = comptime!(t.config.tile_size.n());
85                for r in 0..m as usize {
86                    let row_offset = r as u32 * n;
87                    let mut val = E::from_int(0);
88                    for c in 0..n {
89                        val += t.data[(row_offset + c) as usize];
90                    }
91                    acc.vals[r] += val;
92                }
93            }
94            _ => panic!("row_sum: unsupported tile variant"),
95        }
96    }
97
98    pub fn exp_diff(&mut self, rowwise: &RowWise<E>) {
99        match self {
100            Tile::Unit(t) => t.exp_diff(rowwise),
101            Tile::Local(t) => t.exp_diff(rowwise),
102            Tile::Bounce(b) => b.local.exp_diff(rowwise),
103            Tile::Register(t) => {
104                let m = comptime!(t.config.tile_size.m());
105                let n = comptime!(t.config.tile_size.n());
106                let threshold = E::new(LOGIT_MASKED);
107                for r in 0..m as usize {
108                    let row_offset = r as u32 * n;
109                    let val = rowwise.vals[r];
110                    let safe_val = clamp_min(val, threshold);
111                    let not_masked = E::cast_from(val >= threshold);
112                    for c in 0..n {
113                        let idx = (row_offset + c) as usize;
114                        t.data[idx] = not_masked * (t.data[idx] - safe_val).exp();
115                    }
116                }
117            }
118            _ => panic!("exp_diff: unsupported tile variant"),
119        }
120    }
121
122    pub fn rowwise_scale(&mut self, scale: &RowWise<E>) {
123        match self {
124            Tile::Unit(t) => t.rowwise_scale(scale),
125            Tile::Local(t) => t.rowwise_scale(scale),
126            Tile::Bounce(b) => b.local.rowwise_scale(scale),
127            Tile::Register(t) => {
128                let m = comptime!(t.config.tile_size.m());
129                let n = comptime!(t.config.tile_size.n());
130                for r in 0..m as usize {
131                    let row_offset = r as u32 * n;
132                    for c in 0..n {
133                        let idx = (row_offset + c) as usize;
134                        t.data[idx] = t.data[idx] * scale.vals[r];
135                    }
136                }
137            }
138            _ => panic!("rowwise_scale: unsupported tile variant"),
139        }
140    }
141
142    pub fn scale_and_mask<M: Mask>(&mut self, scale: E, mask: &M) {
143        match self {
144            Tile::Unit(t) => t.scale_and_mask::<M>(scale, mask),
145            Tile::Local(t) => t.scale_and_mask::<M>(scale, mask),
146            Tile::Bounce(b) => b.local.scale_and_mask::<M>(scale, mask),
147            Tile::Register(t) => {
148                let m = comptime!(t.config.tile_size.m());
149                let n = comptime!(t.config.tile_size.n());
150                for r in 0..m {
151                    let row_offset = r * n;
152                    for c in 0..n {
153                        let idx = (row_offset + c) as usize;
154                        t.data[idx] = t.data[idx] * scale
155                            + E::cast_from(mask.should_mask((r, c))) * E::min_value();
156                    }
157                }
158            }
159            _ => panic!("scale_and_mask: unsupported tile variant"),
160        }
161    }
162
163    pub fn fill_zero(&mut self) {
164        match self {
165            Tile::Register(t) => {
166                let m = comptime!(t.config.tile_size.m());
167                let n = comptime!(t.config.tile_size.n());
168                for i in 0..m * n {
169                    t.data[i as usize] = E::from_int(0);
170                }
171            }
172            Tile::Unit(t) => t.zero(),
173            Tile::Local(t) => t.zero(),
174            Tile::Bounce(b) => {
175                cubecl::cmma::fill(&b.cmma.matrix, E::from_int(0));
176            }
177            Tile::Cmma(t) => {
178                cubecl::cmma::fill(&t.matrix, E::from_int(0));
179            }
180            _ => panic!("fill_zero: unsupported tile variant"),
181        }
182    }
183}
184
185/// Internal `copy_from` between the `cmma` and `local` parts of a [`BounceTile`]:
186/// cmma -> smem -> local. Used by the high-level `softmax` / `scale_mul` /
187/// `scale_div` methods to make the local view current.
188#[cube]
189pub(crate) fn cmma_to_local<E: Float>(b: &mut BounceTile<E>) {
190    let stride = comptime!(b.cmma.tile_size.n());
191    cubecl::cmma::store(
192        &mut b.smem,
193        &b.cmma.matrix,
194        stride,
195        cubecl::cmma::MatrixLayout::RowMajor,
196    );
197    sync_cube();
198    b.local.load_from_slice(&b.smem.to_slice());
199    sync_cube();
200}
201
202/// Internal `copy_from` between the `local` and `cmma` parts of a [`BounceTile`]:
203/// local -> smem -> cmma. Reverses [`cmma_to_local`].
204#[cube]
205pub(crate) fn local_to_cmma<E: Float>(b: &mut BounceTile<E>) {
206    let stride = comptime!(b.cmma.tile_size.n());
207    b.local.store_to(&mut b.smem);
208    sync_cube();
209    cubecl::cmma::load_with_layout(
210        &b.cmma.matrix,
211        &b.smem.to_slice(),
212        stride,
213        cubecl::cmma::MatrixLayout::RowMajor,
214    );
215}