Skip to main content

cubek_std/tile/compute/rowwise/
dispatch.rs

1use cubecl::prelude::*;
2
3use crate::tile::{
4    LOGIT_MASKED, Plane, Tile, TileExpand,
5    compute::rowwise::reducer::{fragment_row_max, fragment_row_sum},
6    data::RowWise,
7};
8
9/// Row-wise primitives on a `Tile<E, Plane, ReadWrite>` used for attention's
10/// online softmax and output scaling. Dispatch happens per-variant:
11/// - `Tile::Unit` — each unit holds its own copy of the tile, ops run in
12///   registers.
13/// - `Tile::WhiteboxFragment` — the tile is fragmented across plane units with
14///   an exposed layout, row-reductions use `plane_shuffle`.
15/// - `Tile::Bounce` — same as `WhiteboxFragment` but the underlying compute
16///   fragment (cmma) is opaque. The row-wise ops here read/write the inner
17///   fragment view; the smem ↔ cmma synchronization is driven by the higher-
18///   level `softmax` / `scale_mul` / `scale_div` methods (see `softmax.rs`).
19/// - `Tile::Register` — kept for the legacy direct-register attention path.
20#[cube]
21impl<E: Float> Tile<E, Plane, ReadWrite> {
22    pub fn row_max(&self, acc: &mut RowWise<E>, base: &RowWise<E>) {
23        match self {
24            Tile::Unit(t) => {
25                acc.copy_from(base);
26                let m = comptime!(t.layout.num_rows);
27                let n = comptime!(t.layout.num_cols);
28                for r in 0..m as usize {
29                    let row_offset = r as u32 * n;
30                    let mut val = E::min_value();
31                    for c in 0..n {
32                        val = max(val, t.data[(row_offset + c) as usize]);
33                    }
34                    acc.vals[r] = max(acc.vals[r], val);
35                }
36            }
37            Tile::WhiteboxFragment(t) => {
38                fragment_row_max::<E>(acc, base, t);
39            }
40            Tile::Bounce(b) => {
41                fragment_row_max::<E>(acc, base, &b.fragment);
42            }
43            Tile::Register(t) => {
44                acc.copy_from(base);
45                let m = comptime!(t.config.tile_size.m());
46                let n = comptime!(t.config.tile_size.n());
47                for r in 0..m as usize {
48                    let row_offset = r as u32 * n;
49                    let mut val = E::min_value();
50                    for c in 0..n {
51                        val = max(val, t.data[(row_offset + c) as usize]);
52                    }
53                    acc.vals[r] = max(acc.vals[r], val);
54                }
55            }
56            _ => panic!("row_max: unsupported tile variant"),
57        }
58    }
59
60    pub fn row_sum(&self, acc: &mut RowWise<E>) {
61        match self {
62            Tile::Unit(t) => {
63                acc.fill(E::from_int(0));
64                let m = comptime!(t.layout.num_rows);
65                let n = comptime!(t.layout.num_cols);
66                for r in 0..m as usize {
67                    let row_offset = r as u32 * n;
68                    let mut val = E::from_int(0);
69                    for c in 0..n {
70                        val += t.data[(row_offset + c) as usize];
71                    }
72                    acc.vals[r] += val;
73                }
74            }
75            Tile::WhiteboxFragment(t) => {
76                fragment_row_sum::<E>(acc, t);
77            }
78            Tile::Bounce(b) => {
79                fragment_row_sum::<E>(acc, &b.fragment);
80            }
81            Tile::Register(t) => {
82                acc.fill(E::from_int(0));
83                let m = comptime!(t.config.tile_size.m());
84                let n = comptime!(t.config.tile_size.n());
85                for r in 0..m as usize {
86                    let row_offset = r as u32 * n;
87                    let mut val = E::from_int(0);
88                    for c in 0..n {
89                        val += t.data[(row_offset + c) as usize];
90                    }
91                    acc.vals[r] += val;
92                }
93            }
94            _ => panic!("row_sum: unsupported tile variant"),
95        }
96    }
97
98    pub fn exp_diff(&mut self, rowwise: &RowWise<E>) {
99        match self {
100            Tile::Unit(t) => t.exp_diff(rowwise),
101            Tile::WhiteboxFragment(t) => t.exp_diff(rowwise),
102            Tile::Bounce(b) => b.fragment.exp_diff(rowwise),
103            Tile::Register(t) => {
104                let m = comptime!(t.config.tile_size.m());
105                let n = comptime!(t.config.tile_size.n());
106                let threshold = E::new(LOGIT_MASKED);
107                for r in 0..m as usize {
108                    let row_offset = r as u32 * n;
109                    let val = rowwise.vals[r];
110                    let safe_val = clamp_min(val, threshold);
111                    let not_masked = E::cast_from(val >= threshold);
112                    for c in 0..n {
113                        let idx = (row_offset + c) as usize;
114                        t.data[idx] = not_masked * (t.data[idx] - safe_val).exp();
115                    }
116                }
117            }
118            _ => panic!("exp_diff: unsupported tile variant"),
119        }
120    }
121
122    pub fn rowwise_scale(&mut self, scale: &RowWise<E>) {
123        match self {
124            Tile::Unit(t) => t.rowwise_scale(scale),
125            Tile::WhiteboxFragment(t) => t.rowwise_scale(scale),
126            Tile::Bounce(b) => b.fragment.rowwise_scale(scale),
127            Tile::Register(t) => {
128                let m = comptime!(t.config.tile_size.m());
129                let n = comptime!(t.config.tile_size.n());
130                for r in 0..m as usize {
131                    let row_offset = r as u32 * n;
132                    for c in 0..n {
133                        let idx = (row_offset + c) as usize;
134                        t.data[idx] = t.data[idx] * scale.vals[r];
135                    }
136                }
137            }
138            _ => panic!("rowwise_scale: unsupported tile variant"),
139        }
140    }
141}