cubek_std/tile/ops/
tile_ops.rs1use cubecl;
2use cubecl::prelude::*;
3
4use crate::tile::ops::broadcast_reducer::{local_row_max, local_row_sum};
5use crate::tile::ops::{LOGIT_MASKED, Mask, MaskExpand, RowWise};
6use crate::tile::variants::BounceTile;
7use crate::tile::{Plane, Tile, TileExpand};
8
9#[cube]
21impl<E: Float> Tile<E, Plane, ReadWrite> {
22 pub fn row_max(&self, acc: &mut RowWise<E>, base: &RowWise<E>) {
23 match self {
24 Tile::Unit(t) => {
25 acc.copy_from(base);
26 let m = comptime!(t.layout.num_rows);
27 let n = comptime!(t.layout.num_cols);
28 for r in 0..m as usize {
29 let row_offset = r as u32 * n;
30 let mut val = E::min_value();
31 for c in 0..n {
32 val = max(val, t.data[(row_offset + c) as usize]);
33 }
34 acc.vals[r] = max(acc.vals[r], val);
35 }
36 }
37 Tile::Local(t) => {
38 local_row_max::<E>(acc, base, t);
39 }
40 Tile::Bounce(b) => {
41 local_row_max::<E>(acc, base, &b.local);
42 }
43 Tile::Register(t) => {
44 acc.copy_from(base);
45 let m = comptime!(t.config.tile_size.m());
46 let n = comptime!(t.config.tile_size.n());
47 for r in 0..m as usize {
48 let row_offset = r as u32 * n;
49 let mut val = E::min_value();
50 for c in 0..n {
51 val = max(val, t.data[(row_offset + c) as usize]);
52 }
53 acc.vals[r] = max(acc.vals[r], val);
54 }
55 }
56 _ => panic!("row_max: unsupported tile variant"),
57 }
58 }
59
60 pub fn row_sum(&self, acc: &mut RowWise<E>) {
61 match self {
62 Tile::Unit(t) => {
63 acc.fill(E::from_int(0));
64 let m = comptime!(t.layout.num_rows);
65 let n = comptime!(t.layout.num_cols);
66 for r in 0..m as usize {
67 let row_offset = r as u32 * n;
68 let mut val = E::from_int(0);
69 for c in 0..n {
70 val += t.data[(row_offset + c) as usize];
71 }
72 acc.vals[r] += val;
73 }
74 }
75 Tile::Local(t) => {
76 local_row_sum::<E>(acc, t);
77 }
78 Tile::Bounce(b) => {
79 local_row_sum::<E>(acc, &b.local);
80 }
81 Tile::Register(t) => {
82 acc.fill(E::from_int(0));
83 let m = comptime!(t.config.tile_size.m());
84 let n = comptime!(t.config.tile_size.n());
85 for r in 0..m as usize {
86 let row_offset = r as u32 * n;
87 let mut val = E::from_int(0);
88 for c in 0..n {
89 val += t.data[(row_offset + c) as usize];
90 }
91 acc.vals[r] += val;
92 }
93 }
94 _ => panic!("row_sum: unsupported tile variant"),
95 }
96 }
97
98 pub fn exp_diff(&mut self, rowwise: &RowWise<E>) {
99 match self {
100 Tile::Unit(t) => t.exp_diff(rowwise),
101 Tile::Local(t) => t.exp_diff(rowwise),
102 Tile::Bounce(b) => b.local.exp_diff(rowwise),
103 Tile::Register(t) => {
104 let m = comptime!(t.config.tile_size.m());
105 let n = comptime!(t.config.tile_size.n());
106 let threshold = E::new(LOGIT_MASKED);
107 for r in 0..m as usize {
108 let row_offset = r as u32 * n;
109 let val = rowwise.vals[r];
110 let safe_val = clamp_min(val, threshold);
111 let not_masked = E::cast_from(val >= threshold);
112 for c in 0..n {
113 let idx = (row_offset + c) as usize;
114 t.data[idx] = not_masked * (t.data[idx] - safe_val).exp();
115 }
116 }
117 }
118 _ => panic!("exp_diff: unsupported tile variant"),
119 }
120 }
121
122 pub fn rowwise_scale(&mut self, scale: &RowWise<E>) {
123 match self {
124 Tile::Unit(t) => t.rowwise_scale(scale),
125 Tile::Local(t) => t.rowwise_scale(scale),
126 Tile::Bounce(b) => b.local.rowwise_scale(scale),
127 Tile::Register(t) => {
128 let m = comptime!(t.config.tile_size.m());
129 let n = comptime!(t.config.tile_size.n());
130 for r in 0..m as usize {
131 let row_offset = r as u32 * n;
132 for c in 0..n {
133 let idx = (row_offset + c) as usize;
134 t.data[idx] = t.data[idx] * scale.vals[r];
135 }
136 }
137 }
138 _ => panic!("rowwise_scale: unsupported tile variant"),
139 }
140 }
141
142 pub fn scale_and_mask<M: Mask>(&mut self, scale: E, mask: &M) {
143 match self {
144 Tile::Unit(t) => t.scale_and_mask::<M>(scale, mask),
145 Tile::Local(t) => t.scale_and_mask::<M>(scale, mask),
146 Tile::Bounce(b) => b.local.scale_and_mask::<M>(scale, mask),
147 Tile::Register(t) => {
148 let m = comptime!(t.config.tile_size.m());
149 let n = comptime!(t.config.tile_size.n());
150 for r in 0..m {
151 let row_offset = r * n;
152 for c in 0..n {
153 let idx = (row_offset + c) as usize;
154 t.data[idx] = t.data[idx] * scale
155 + E::cast_from(mask.should_mask((r, c))) * E::min_value();
156 }
157 }
158 }
159 _ => panic!("scale_and_mask: unsupported tile variant"),
160 }
161 }
162
163 pub fn fill_zero(&mut self) {
164 match self {
165 Tile::Register(t) => {
166 let m = comptime!(t.config.tile_size.m());
167 let n = comptime!(t.config.tile_size.n());
168 for i in 0..m * n {
169 t.data[i as usize] = E::from_int(0);
170 }
171 }
172 Tile::Unit(t) => t.zero(),
173 Tile::Local(t) => t.zero(),
174 Tile::Bounce(b) => {
175 cubecl::cmma::fill(&b.cmma.matrix, E::from_int(0));
176 }
177 Tile::Cmma(t) => {
178 cubecl::cmma::fill(&t.matrix, E::from_int(0));
179 }
180 _ => panic!("fill_zero: unsupported tile variant"),
181 }
182 }
183}
184
185#[cube]
189pub(crate) fn cmma_to_local<E: Float>(b: &mut BounceTile<E>) {
190 let stride = comptime!(b.cmma.tile_size.n());
191 cubecl::cmma::store(
192 &mut b.smem,
193 &b.cmma.matrix,
194 stride,
195 cubecl::cmma::MatrixLayout::RowMajor,
196 );
197 sync_cube();
198 b.local.load_from_slice(&b.smem.to_slice());
199 sync_cube();
200}
201
202#[cube]
205pub(crate) fn local_to_cmma<E: Float>(b: &mut BounceTile<E>) {
206 let stride = comptime!(b.cmma.tile_size.n());
207 b.local.store_to(&mut b.smem);
208 sync_cube();
209 cubecl::cmma::load_with_layout(
210 &b.cmma.matrix,
211 &b.smem.to_slice(),
212 stride,
213 cubecl::cmma::MatrixLayout::RowMajor,
214 );
215}