Skip to main content

cubek_attention/definition/
spec.rs

1use cubecl::prelude::*;
2use half::{bf16, f16};
3
4use crate::{
5    definition::{AccumulatorPrecision, AttentionGlobalTypes},
6    launch::{AttentionArgs, TensorArgs},
7};
8
9/// Attention spec defining each element types used in the computation as well as
10/// how the arguments are passed to the kernel.
11pub trait AttentionSpec: Send + Sync + Clone + 'static {
12    type Precision: AttentionPrecision;
13    /// How the input and output tensors are passed as arguments.
14    type Args: AttentionArgs;
15}
16
17impl<AP: AttentionPrecision, Args: AttentionArgs> AttentionSpec for (AP, Args) {
18    type Precision = AP;
19    type Args = Args;
20}
21
22// A simple default for TensorArgs
23impl<AP: AttentionPrecision> AttentionSpec for AP {
24    type Precision = AP;
25    type Args = TensorArgs;
26}
27
28pub trait QueryPrecision: Send + Sync + Copy + 'static {
29    type Global: Float;
30    type GlobalSize: Size;
31    type Tile: Float;
32}
33
34pub trait StagedMatrixPrecision: Send + Sync + Copy + 'static {
35    type Global: Float;
36    type GlobalSize: Size;
37    type Stage: Float;
38    type StageSize: Size;
39}
40
41pub trait AttentionPrecision: Send + Sync + Copy + 'static {
42    type Query: QueryPrecision;
43    type Key: StagedMatrixPrecision;
44    type Value: StagedMatrixPrecision;
45    type KVTile: Float;
46    type SoftmaxAcc: Float;
47    type SoftmaxLhs: Float;
48    type Accumulator: Float;
49    type Mask: Numeric;
50    type MaskSize: Size;
51    type Out: StagedMatrixPrecision;
52}
53
54impl QueryPrecision for f16 {
55    type Global = f16;
56    type GlobalSize = Const<0>;
57    type Tile = f16;
58}
59
60impl QueryPrecision for bf16 {
61    type Global = bf16;
62    type GlobalSize = Const<0>;
63    type Tile = bf16;
64}
65
66impl QueryPrecision for flex32 {
67    type Global = f32;
68    type GlobalSize = Const<0>;
69    type Tile = f16;
70}
71
72impl QueryPrecision for f32 {
73    type Global = f32;
74    type GlobalSize = Const<0>;
75    type Tile = f32;
76}
77
78impl QueryPrecision for f64 {
79    type Global = f64;
80    type GlobalSize = Const<0>;
81    type Tile = f32;
82}
83
84impl<G: Float, GS: Size, T: Float> QueryPrecision for (G, GS, T) {
85    type Global = G;
86    type GlobalSize = GS;
87    type Tile = T;
88}
89
90impl StagedMatrixPrecision for f16 {
91    type Global = f16;
92    type GlobalSize = Const<0>;
93    type Stage = f16;
94    type StageSize = Const<0>;
95}
96
97impl StagedMatrixPrecision for bf16 {
98    type Global = bf16;
99    type GlobalSize = Const<0>;
100    type Stage = bf16;
101    type StageSize = Const<0>;
102}
103
104impl StagedMatrixPrecision for flex32 {
105    type Global = f32;
106    type GlobalSize = Const<0>;
107    type Stage = f16;
108    type StageSize = Const<0>;
109}
110
111impl StagedMatrixPrecision for f32 {
112    type Global = f32;
113    type GlobalSize = Const<0>;
114    type Stage = f32;
115    type StageSize = Const<0>;
116}
117
118impl StagedMatrixPrecision for f64 {
119    type Global = f64;
120    type GlobalSize = Const<0>;
121    type Stage = f32;
122    type StageSize = Const<0>;
123}
124
125impl<G: Float, GS: Size, S: Float, SS: Size> StagedMatrixPrecision for (G, GS, S, SS) {
126    type Global = G;
127    type GlobalSize = GS;
128    type Stage = S;
129    type StageSize = SS;
130}
131
132impl AttentionPrecision for f16 {
133    type Query = f16;
134    type Key = f16;
135    type Value = f16;
136    type KVTile = f16;
137    type SoftmaxLhs = f16;
138    #[cfg(target_os = "macos")]
139    type SoftmaxAcc = f16;
140    #[cfg(target_os = "macos")]
141    type Accumulator = f16;
142    #[cfg(not(target_os = "macos"))]
143    type SoftmaxAcc = f32;
144    #[cfg(not(target_os = "macos"))]
145    type Accumulator = f32;
146    type Mask = u8;
147    type MaskSize = Const<0>;
148    type Out = f16;
149}
150
151impl AttentionPrecision for flex32 {
152    type Query = flex32;
153    type Key = flex32;
154    type Value = flex32;
155    type KVTile = f16;
156    type SoftmaxLhs = f16;
157    #[cfg(target_os = "macos")]
158    type SoftmaxAcc = f16;
159    #[cfg(target_os = "macos")]
160    type Accumulator = f16;
161    #[cfg(not(target_os = "macos"))]
162    type SoftmaxAcc = f32;
163    #[cfg(not(target_os = "macos"))]
164    type Accumulator = f32;
165    type Mask = u8;
166    type MaskSize = Const<0>;
167    type Out = f32;
168}
169
170impl AttentionPrecision for bf16 {
171    type Query = bf16;
172    type Key = bf16;
173    type Value = bf16;
174    type KVTile = bf16;
175    type SoftmaxLhs = bf16;
176    #[cfg(target_os = "macos")]
177    type SoftmaxAcc = bf16;
178    #[cfg(target_os = "macos")]
179    type Accumulator = bf16;
180    #[cfg(not(target_os = "macos"))]
181    type SoftmaxAcc = f32;
182    #[cfg(not(target_os = "macos"))]
183    type Accumulator = f32;
184    type Mask = u8;
185    type MaskSize = Const<0>;
186    type Out = bf16;
187}
188
189impl AttentionPrecision for f32 {
190    type Query = f32;
191    type Key = f32;
192    type Value = f32;
193    type KVTile = f32;
194    type SoftmaxAcc = f32;
195    type SoftmaxLhs = f32;
196    type Accumulator = f32;
197    type Mask = u8;
198    type MaskSize = Const<0>;
199    type Out = f32;
200}
201
202impl AttentionPrecision for f64 {
203    type Query = f64;
204    type Key = f64;
205    type Value = f64;
206    type KVTile = f32;
207    type SoftmaxAcc = f32;
208    type SoftmaxLhs = f32;
209    type Accumulator = f32;
210    type Mask = u8;
211    type MaskSize = Const<0>;
212    type Out = f64;
213}
214
215impl<
216    QG: Float,
217    QGS: Size,
218    QT: Float,
219    KG: Float,
220    KGS: Size,
221    KS: Float,
222    KSS: Size,
223    VG: Float,
224    VGS: Size,
225    VS: Float,
226    VSS: Size,
227    KVT: Float,
228    SM: Float,
229    SML: Float,
230    ACC: Float,
231    MSK: Numeric,
232    MSKS: Size,
233    OG: Float,
234    OGS: Size,
235    OS: Float,
236    OSS: Size,
237> AttentionPrecision
238    for (
239        (QG, QGS, QT),
240        (KG, KGS, KS, KSS),
241        (VG, VGS, VS, VSS),
242        KVT,
243        SM,
244        SML,
245        ACC,
246        MSK,
247        MSKS,
248        (OG, OGS, OS, OSS),
249    )
250{
251    type Query = (QG, QGS, QT);
252    type Key = (KG, KGS, KS, KSS);
253    type Value = (VG, VGS, VS, VSS);
254    type KVTile = KVT;
255    type SoftmaxAcc = SM;
256    type SoftmaxLhs = SML;
257    type Accumulator = ACC;
258    type Mask = MSK;
259    type MaskSize = MSKS;
260    type Out = (OG, OGS, OS, OSS);
261}
262
263pub mod launch_types {
264    use super::*;
265
266    define_scalar!(pub QG);
267    define_scalar!(pub KG);
268    define_scalar!(pub VG);
269    define_scalar!(pub MSK);
270    define_scalar!(pub OG);
271
272    define_size!(pub QGS);
273    define_size!(pub KGS);
274    define_size!(pub VGS);
275    define_size!(pub MSKS);
276    define_size!(pub OGS);
277
278    /// Input argument
279    pub type InputArg<AA> =
280        <AA as AttentionArgs>::Input<(QG, QGS), (KG, KGS), (VG, VGS), (MSK, MSKS)>;
281
282    /// Output argument
283    pub type OutputArg<AA> = <AA as AttentionArgs>::Output<(OG, OGS)>;
284}
285
286pub use launch_types::{InputArg, OutputArg};
287
288/// Input runtime argument
289pub type InputRuntimeArg<AA, R> = <InputArg<AA> as LaunchArg>::RuntimeArg<R>;
290
291/// Output runtime argument
292pub type OutputRuntimeArg<AA, R> = <OutputArg<AA> as LaunchArg>::RuntimeArg<R>;
293
294pub mod attention_types {
295    use cubecl::prelude::*;
296
297    use crate::definition::{
298        AttentionPrecision, AttentionSpec, QueryPrecision, StagedMatrixPrecision,
299    };
300
301    // ==================== Per-operand precision grouping ====================
302
303    pub type Query<AS> = <<AS as AttentionSpec>::Precision as AttentionPrecision>::Query;
304    pub type Key<AS> = <<AS as AttentionSpec>::Precision as AttentionPrecision>::Key;
305    pub type Value<AS> = <<AS as AttentionSpec>::Precision as AttentionPrecision>::Value;
306    pub type Out<AS> = <<AS as AttentionSpec>::Precision as AttentionPrecision>::Out;
307
308    // ==================== QUERY ====================
309
310    // Element / Size splits
311    pub type QG<AS> = <Query<AS> as QueryPrecision>::Global;
312    pub type QGS<AS> = <Query<AS> as QueryPrecision>::GlobalSize;
313    pub type QT<AS> = <Query<AS> as QueryPrecision>::Tile;
314
315    // Vector form
316    pub type QGV<AS> = Vector<QG<AS>, QGS<AS>>;
317
318    // ==================== KEY ====================
319
320    // Element / Size splits
321    pub type KG<AS> = <Key<AS> as StagedMatrixPrecision>::Global;
322    pub type KGS<AS> = <Key<AS> as StagedMatrixPrecision>::GlobalSize;
323    pub type KS<AS> = <Key<AS> as StagedMatrixPrecision>::Stage;
324    pub type KSS<AS> = <Key<AS> as StagedMatrixPrecision>::StageSize;
325
326    // Vector forms
327    pub type KGV<AS> = Vector<KG<AS>, KGS<AS>>;
328    pub type KSV<AS> = Vector<KS<AS>, KSS<AS>>;
329
330    // ==================== VALUE ====================
331
332    // Element / Size splits
333    pub type VG<AS> = <Value<AS> as StagedMatrixPrecision>::Global;
334    pub type VGS<AS> = <Value<AS> as StagedMatrixPrecision>::GlobalSize;
335    pub type VS<AS> = <Value<AS> as StagedMatrixPrecision>::Stage;
336    pub type VSS<AS> = <Value<AS> as StagedMatrixPrecision>::StageSize;
337
338    // Vector forms
339    pub type VGV<AS> = Vector<VG<AS>, VGS<AS>>;
340    pub type VSV<AS> = Vector<VS<AS>, VSS<AS>>;
341
342    // ==================== KV TILE / SOFTMAX / ACCUMULATOR ====================
343
344    pub type KVT<AS> = <<AS as AttentionSpec>::Precision as AttentionPrecision>::KVTile;
345    pub type SM<AS> = <<AS as AttentionSpec>::Precision as AttentionPrecision>::SoftmaxAcc;
346    pub type SML<AS> = <<AS as AttentionSpec>::Precision as AttentionPrecision>::SoftmaxLhs;
347    pub type ACC<AS> = <<AS as AttentionSpec>::Precision as AttentionPrecision>::Accumulator;
348
349    // ==================== MASK ====================
350
351    pub type MSK<AS> = <<AS as AttentionSpec>::Precision as AttentionPrecision>::Mask;
352    pub type MSKS<AS> = <<AS as AttentionSpec>::Precision as AttentionPrecision>::MaskSize;
353
354    // Vector form
355    pub type MSKV<AS> = Vector<MSK<AS>, MSKS<AS>>;
356
357    // ==================== OUT ====================
358
359    // Element / Size splits
360    pub type OG<AS> = <Out<AS> as StagedMatrixPrecision>::Global;
361    pub type OGS<AS> = <Out<AS> as StagedMatrixPrecision>::GlobalSize;
362    pub type OS<AS> = <Out<AS> as StagedMatrixPrecision>::Stage;
363    pub type OSS<AS> = <Out<AS> as StagedMatrixPrecision>::StageSize;
364
365    // Vector forms
366    pub type OGV<AS> = Vector<OG<AS>, OGS<AS>>;
367    pub type OSV<AS> = Vector<OS<AS>, OSS<AS>>;
368}
369
370pub type Args<MS> = <MS as AttentionSpec>::Args;
371
372#[derive(Debug, Clone, Eq, PartialEq, Hash)]
373pub struct AttentionElems {
374    pub query_global: StorageType,
375    pub query_tile: StorageType,
376    pub key_global: StorageType,
377    pub key_stage: StorageType,
378    pub value_global: StorageType,
379    pub value_stage: StorageType,
380    pub key_value_tile: StorageType,
381    pub softmax_acc: StorageType,
382    pub softmax_lhs: StorageType,
383    pub accumulator: StorageType,
384    pub mask: StorageType,
385    pub out_global: StorageType,
386    pub out_stage: StorageType,
387}
388
389impl AttentionElems {
390    pub fn from_global_types(
391        global_dtypes: &AttentionGlobalTypes,
392        tile_type: StorageType,
393        accumulator_precision: &AccumulatorPrecision,
394    ) -> AttentionElems {
395        let accumulator = match accumulator_precision {
396            AccumulatorPrecision::Strict(storage_type) => *storage_type,
397            AccumulatorPrecision::Loose => AccumulatorPrecision::default_accumulator_type(),
398        };
399
400        Self {
401            query_global: global_dtypes.query,
402            query_tile: tile_type,
403            key_global: global_dtypes.key,
404            key_stage: tile_type,
405            value_global: global_dtypes.value,
406            value_stage: tile_type,
407            key_value_tile: tile_type,
408            softmax_acc: accumulator,
409            softmax_lhs: tile_type,
410            accumulator,
411            mask: global_dtypes.mask,
412            out_global: global_dtypes.out,
413            out_stage: global_dtypes.out,
414        }
415    }
416}
417
418impl From<&AttentionElems> for [StorageType; 5] {
419    fn from(elems: &AttentionElems) -> Self {
420        [
421            elems.query_global,
422            elems.key_global,
423            elems.value_global,
424            elems.mask,
425            elems.out_global,
426        ]
427    }
428}