cubek_attention/definition/
spec.rs1use cubecl::prelude::*;
2use half::{bf16, f16};
3
4use crate::{
5 definition::{AccumulatorPrecision, AttentionGlobalTypes},
6 launch::{AttentionArgs, TensorArgs},
7};
8
9pub trait AttentionSpec: Send + Sync + Clone + 'static {
12 type Precision: AttentionPrecision;
13 type Args: AttentionArgs;
15}
16
17impl<AP: AttentionPrecision, Args: AttentionArgs> AttentionSpec for (AP, Args) {
18 type Precision = AP;
19 type Args = Args;
20}
21
22impl<AP: AttentionPrecision> AttentionSpec for AP {
24 type Precision = AP;
25 type Args = TensorArgs;
26}
27
28pub trait QueryPrecision: Send + Sync + Copy + 'static {
29 type Global: Float;
30 type GlobalSize: Size;
31 type Tile: Float;
32}
33
34pub trait StagedMatrixPrecision: Send + Sync + Copy + 'static {
35 type Global: Float;
36 type GlobalSize: Size;
37 type Stage: Float;
38 type StageSize: Size;
39}
40
41pub trait AttentionPrecision: Send + Sync + Copy + 'static {
42 type Query: QueryPrecision;
43 type Key: StagedMatrixPrecision;
44 type Value: StagedMatrixPrecision;
45 type KVTile: Float;
46 type SoftmaxAcc: Float;
47 type SoftmaxLhs: Float;
48 type Accumulator: Float;
49 type Mask: Numeric;
50 type MaskSize: Size;
51 type Out: StagedMatrixPrecision;
52}
53
54impl QueryPrecision for f16 {
55 type Global = f16;
56 type GlobalSize = Const<0>;
57 type Tile = f16;
58}
59
60impl QueryPrecision for bf16 {
61 type Global = bf16;
62 type GlobalSize = Const<0>;
63 type Tile = bf16;
64}
65
66impl QueryPrecision for flex32 {
67 type Global = f32;
68 type GlobalSize = Const<0>;
69 type Tile = f16;
70}
71
72impl QueryPrecision for f32 {
73 type Global = f32;
74 type GlobalSize = Const<0>;
75 type Tile = f32;
76}
77
78impl QueryPrecision for f64 {
79 type Global = f64;
80 type GlobalSize = Const<0>;
81 type Tile = f32;
82}
83
84impl<G: Float, GS: Size, T: Float> QueryPrecision for (G, GS, T) {
85 type Global = G;
86 type GlobalSize = GS;
87 type Tile = T;
88}
89
90impl StagedMatrixPrecision for f16 {
91 type Global = f16;
92 type GlobalSize = Const<0>;
93 type Stage = f16;
94 type StageSize = Const<0>;
95}
96
97impl StagedMatrixPrecision for bf16 {
98 type Global = bf16;
99 type GlobalSize = Const<0>;
100 type Stage = bf16;
101 type StageSize = Const<0>;
102}
103
104impl StagedMatrixPrecision for flex32 {
105 type Global = f32;
106 type GlobalSize = Const<0>;
107 type Stage = f16;
108 type StageSize = Const<0>;
109}
110
111impl StagedMatrixPrecision for f32 {
112 type Global = f32;
113 type GlobalSize = Const<0>;
114 type Stage = f32;
115 type StageSize = Const<0>;
116}
117
118impl StagedMatrixPrecision for f64 {
119 type Global = f64;
120 type GlobalSize = Const<0>;
121 type Stage = f32;
122 type StageSize = Const<0>;
123}
124
125impl<G: Float, GS: Size, S: Float, SS: Size> StagedMatrixPrecision for (G, GS, S, SS) {
126 type Global = G;
127 type GlobalSize = GS;
128 type Stage = S;
129 type StageSize = SS;
130}
131
132impl AttentionPrecision for f16 {
133 type Query = f16;
134 type Key = f16;
135 type Value = f16;
136 type KVTile = f16;
137 type SoftmaxLhs = f16;
138 #[cfg(target_os = "macos")]
139 type SoftmaxAcc = f16;
140 #[cfg(target_os = "macos")]
141 type Accumulator = f16;
142 #[cfg(not(target_os = "macos"))]
143 type SoftmaxAcc = f32;
144 #[cfg(not(target_os = "macos"))]
145 type Accumulator = f32;
146 type Mask = u8;
147 type MaskSize = Const<0>;
148 type Out = f16;
149}
150
151impl AttentionPrecision for flex32 {
152 type Query = flex32;
153 type Key = flex32;
154 type Value = flex32;
155 type KVTile = f16;
156 type SoftmaxLhs = f16;
157 #[cfg(target_os = "macos")]
158 type SoftmaxAcc = f16;
159 #[cfg(target_os = "macos")]
160 type Accumulator = f16;
161 #[cfg(not(target_os = "macos"))]
162 type SoftmaxAcc = f32;
163 #[cfg(not(target_os = "macos"))]
164 type Accumulator = f32;
165 type Mask = u8;
166 type MaskSize = Const<0>;
167 type Out = f32;
168}
169
170impl AttentionPrecision for bf16 {
171 type Query = bf16;
172 type Key = bf16;
173 type Value = bf16;
174 type KVTile = bf16;
175 type SoftmaxLhs = bf16;
176 #[cfg(target_os = "macos")]
177 type SoftmaxAcc = bf16;
178 #[cfg(target_os = "macos")]
179 type Accumulator = bf16;
180 #[cfg(not(target_os = "macos"))]
181 type SoftmaxAcc = f32;
182 #[cfg(not(target_os = "macos"))]
183 type Accumulator = f32;
184 type Mask = u8;
185 type MaskSize = Const<0>;
186 type Out = bf16;
187}
188
189impl AttentionPrecision for f32 {
190 type Query = f32;
191 type Key = f32;
192 type Value = f32;
193 type KVTile = f32;
194 type SoftmaxAcc = f32;
195 type SoftmaxLhs = f32;
196 type Accumulator = f32;
197 type Mask = u8;
198 type MaskSize = Const<0>;
199 type Out = f32;
200}
201
202impl AttentionPrecision for f64 {
203 type Query = f64;
204 type Key = f64;
205 type Value = f64;
206 type KVTile = f32;
207 type SoftmaxAcc = f32;
208 type SoftmaxLhs = f32;
209 type Accumulator = f32;
210 type Mask = u8;
211 type MaskSize = Const<0>;
212 type Out = f64;
213}
214
215impl<
216 QG: Float,
217 QGS: Size,
218 QT: Float,
219 KG: Float,
220 KGS: Size,
221 KS: Float,
222 KSS: Size,
223 VG: Float,
224 VGS: Size,
225 VS: Float,
226 VSS: Size,
227 KVT: Float,
228 SM: Float,
229 SML: Float,
230 ACC: Float,
231 MSK: Numeric,
232 MSKS: Size,
233 OG: Float,
234 OGS: Size,
235 OS: Float,
236 OSS: Size,
237> AttentionPrecision
238 for (
239 (QG, QGS, QT),
240 (KG, KGS, KS, KSS),
241 (VG, VGS, VS, VSS),
242 KVT,
243 SM,
244 SML,
245 ACC,
246 MSK,
247 MSKS,
248 (OG, OGS, OS, OSS),
249 )
250{
251 type Query = (QG, QGS, QT);
252 type Key = (KG, KGS, KS, KSS);
253 type Value = (VG, VGS, VS, VSS);
254 type KVTile = KVT;
255 type SoftmaxAcc = SM;
256 type SoftmaxLhs = SML;
257 type Accumulator = ACC;
258 type Mask = MSK;
259 type MaskSize = MSKS;
260 type Out = (OG, OGS, OS, OSS);
261}
262
263pub mod launch_types {
264 use super::*;
265
266 define_scalar!(pub QG);
267 define_scalar!(pub KG);
268 define_scalar!(pub VG);
269 define_scalar!(pub MSK);
270 define_scalar!(pub OG);
271
272 define_size!(pub QGS);
273 define_size!(pub KGS);
274 define_size!(pub VGS);
275 define_size!(pub MSKS);
276 define_size!(pub OGS);
277
278 pub type InputArg<AA> =
280 <AA as AttentionArgs>::Input<(QG, QGS), (KG, KGS), (VG, VGS), (MSK, MSKS)>;
281
282 pub type OutputArg<AA> = <AA as AttentionArgs>::Output<(OG, OGS)>;
284}
285
286pub use launch_types::{InputArg, OutputArg};
287
288pub type InputRuntimeArg<AA, R> = <InputArg<AA> as LaunchArg>::RuntimeArg<R>;
290
291pub type OutputRuntimeArg<AA, R> = <OutputArg<AA> as LaunchArg>::RuntimeArg<R>;
293
294pub mod attention_types {
295 use cubecl::prelude::*;
296
297 use crate::definition::{
298 AttentionPrecision, AttentionSpec, QueryPrecision, StagedMatrixPrecision,
299 };
300
301 pub type Query<AS> = <<AS as AttentionSpec>::Precision as AttentionPrecision>::Query;
304 pub type Key<AS> = <<AS as AttentionSpec>::Precision as AttentionPrecision>::Key;
305 pub type Value<AS> = <<AS as AttentionSpec>::Precision as AttentionPrecision>::Value;
306 pub type Out<AS> = <<AS as AttentionSpec>::Precision as AttentionPrecision>::Out;
307
308 pub type QG<AS> = <Query<AS> as QueryPrecision>::Global;
312 pub type QGS<AS> = <Query<AS> as QueryPrecision>::GlobalSize;
313 pub type QT<AS> = <Query<AS> as QueryPrecision>::Tile;
314
315 pub type QGV<AS> = Vector<QG<AS>, QGS<AS>>;
317
318 pub type KG<AS> = <Key<AS> as StagedMatrixPrecision>::Global;
322 pub type KGS<AS> = <Key<AS> as StagedMatrixPrecision>::GlobalSize;
323 pub type KS<AS> = <Key<AS> as StagedMatrixPrecision>::Stage;
324 pub type KSS<AS> = <Key<AS> as StagedMatrixPrecision>::StageSize;
325
326 pub type KGV<AS> = Vector<KG<AS>, KGS<AS>>;
328 pub type KSV<AS> = Vector<KS<AS>, KSS<AS>>;
329
330 pub type VG<AS> = <Value<AS> as StagedMatrixPrecision>::Global;
334 pub type VGS<AS> = <Value<AS> as StagedMatrixPrecision>::GlobalSize;
335 pub type VS<AS> = <Value<AS> as StagedMatrixPrecision>::Stage;
336 pub type VSS<AS> = <Value<AS> as StagedMatrixPrecision>::StageSize;
337
338 pub type VGV<AS> = Vector<VG<AS>, VGS<AS>>;
340 pub type VSV<AS> = Vector<VS<AS>, VSS<AS>>;
341
342 pub type KVT<AS> = <<AS as AttentionSpec>::Precision as AttentionPrecision>::KVTile;
345 pub type SM<AS> = <<AS as AttentionSpec>::Precision as AttentionPrecision>::SoftmaxAcc;
346 pub type SML<AS> = <<AS as AttentionSpec>::Precision as AttentionPrecision>::SoftmaxLhs;
347 pub type ACC<AS> = <<AS as AttentionSpec>::Precision as AttentionPrecision>::Accumulator;
348
349 pub type MSK<AS> = <<AS as AttentionSpec>::Precision as AttentionPrecision>::Mask;
352 pub type MSKS<AS> = <<AS as AttentionSpec>::Precision as AttentionPrecision>::MaskSize;
353
354 pub type MSKV<AS> = Vector<MSK<AS>, MSKS<AS>>;
356
357 pub type OG<AS> = <Out<AS> as StagedMatrixPrecision>::Global;
361 pub type OGS<AS> = <Out<AS> as StagedMatrixPrecision>::GlobalSize;
362 pub type OS<AS> = <Out<AS> as StagedMatrixPrecision>::Stage;
363 pub type OSS<AS> = <Out<AS> as StagedMatrixPrecision>::StageSize;
364
365 pub type OGV<AS> = Vector<OG<AS>, OGS<AS>>;
367 pub type OSV<AS> = Vector<OS<AS>, OSS<AS>>;
368}
369
370pub type Args<MS> = <MS as AttentionSpec>::Args;
371
372#[derive(Debug, Clone, Eq, PartialEq, Hash)]
373pub struct AttentionElems {
374 pub query_global: StorageType,
375 pub query_tile: StorageType,
376 pub key_global: StorageType,
377 pub key_stage: StorageType,
378 pub value_global: StorageType,
379 pub value_stage: StorageType,
380 pub key_value_tile: StorageType,
381 pub softmax_acc: StorageType,
382 pub softmax_lhs: StorageType,
383 pub accumulator: StorageType,
384 pub mask: StorageType,
385 pub out_global: StorageType,
386 pub out_stage: StorageType,
387}
388
389impl AttentionElems {
390 pub fn from_global_types(
391 global_dtypes: &AttentionGlobalTypes,
392 tile_type: StorageType,
393 accumulator_precision: &AccumulatorPrecision,
394 ) -> AttentionElems {
395 let accumulator = match accumulator_precision {
396 AccumulatorPrecision::Strict(storage_type) => *storage_type,
397 AccumulatorPrecision::Loose => AccumulatorPrecision::default_accumulator_type(),
398 };
399
400 Self {
401 query_global: global_dtypes.query,
402 query_tile: tile_type,
403 key_global: global_dtypes.key,
404 key_stage: tile_type,
405 value_global: global_dtypes.value,
406 value_stage: tile_type,
407 key_value_tile: tile_type,
408 softmax_acc: accumulator,
409 softmax_lhs: tile_type,
410 accumulator,
411 mask: global_dtypes.mask,
412 out_global: global_dtypes.out,
413 out_stage: global_dtypes.out,
414 }
415 }
416}
417
418impl From<&AttentionElems> for [StorageType; 5] {
419 fn from(elems: &AttentionElems) -> Self {
420 [
421 elems.query_global,
422 elems.key_global,
423 elems.value_global,
424 elems.mask,
425 elems.out_global,
426 ]
427 }
428}