cubek_attention/definition/
spec.rs

1use cubecl::prelude::*;
2use half::{bf16, f16};
3
4use crate::{
5    definition::{AccumulatorPrecision, AttentionGlobalTypes},
6    launch::{AttentionArgs, TensorArgs},
7};
8
9/// Attention spec defining each element types used in the computation as well as
10/// how the arguments are passed to the kernel.
11pub trait AttentionSpec: Send + Sync + Clone + 'static {
12    type Precision: AttentionPrecision;
13    /// How the input and output tensors are passed as arguments.
14    type Args: AttentionArgs;
15}
16
17impl<AP: AttentionPrecision, Args: AttentionArgs> AttentionSpec for (AP, Args) {
18    type Precision = AP;
19    type Args = Args;
20}
21
22// A simple default for TensorArgs
23impl<AP: AttentionPrecision> AttentionSpec for AP {
24    type Precision = AP;
25    type Args = TensorArgs;
26}
27
28pub trait QueryPrecision: Send + Sync + Copy + 'static {
29    type Global: Float;
30    type Tile: Float;
31}
32
33pub trait StagedMatrixPrecision: Send + Sync + Copy + 'static {
34    type Global: Float;
35    type Stage: Float;
36}
37
38pub trait AttentionPrecision: Send + Sync + Copy + 'static {
39    type Query: QueryPrecision;
40    type Key: StagedMatrixPrecision;
41    type Value: StagedMatrixPrecision;
42    type KVTile: Float;
43    type Softmax: Float;
44    type Accumulator: Float;
45    type Mask: Numeric;
46    type Out: StagedMatrixPrecision;
47}
48
49impl QueryPrecision for f16 {
50    type Global = f16;
51    type Tile = f16;
52}
53
54impl QueryPrecision for bf16 {
55    type Global = bf16;
56    type Tile = bf16;
57}
58
59impl QueryPrecision for flex32 {
60    type Global = f32;
61    type Tile = f16;
62}
63
64impl QueryPrecision for f32 {
65    type Global = f32;
66    type Tile = f32;
67}
68
69impl QueryPrecision for f64 {
70    type Global = f64;
71    type Tile = f32;
72}
73
74impl<G: Float, T: Float> QueryPrecision for (G, T) {
75    type Global = G;
76    type Tile = T;
77}
78
79impl StagedMatrixPrecision for f16 {
80    type Global = f16;
81    type Stage = f16;
82}
83
84impl StagedMatrixPrecision for bf16 {
85    type Global = bf16;
86    type Stage = bf16;
87}
88
89impl StagedMatrixPrecision for flex32 {
90    type Global = f32;
91    type Stage = f16;
92}
93
94impl StagedMatrixPrecision for f32 {
95    type Global = f32;
96    type Stage = f32;
97}
98
99impl StagedMatrixPrecision for f64 {
100    type Global = f64;
101    type Stage = f32;
102}
103
104impl<G: Float, S: Float> StagedMatrixPrecision for (G, S) {
105    type Global = G;
106    type Stage = S;
107}
108
109impl AttentionPrecision for f16 {
110    type Query = f16;
111    type Key = f16;
112    type Value = f16;
113    type KVTile = f16;
114    #[cfg(target_os = "macos")]
115    type Softmax = f16;
116    #[cfg(target_os = "macos")]
117    type Accumulator = f16;
118    #[cfg(not(target_os = "macos"))]
119    type Softmax = f32;
120    #[cfg(not(target_os = "macos"))]
121    type Accumulator = f32;
122    type Mask = u8;
123    type Out = f16;
124}
125
126impl AttentionPrecision for flex32 {
127    type Query = flex32;
128    type Key = flex32;
129    type Value = flex32;
130    type KVTile = f16;
131    #[cfg(target_os = "macos")]
132    type Softmax = f16;
133    #[cfg(target_os = "macos")]
134    type Accumulator = f16;
135    #[cfg(not(target_os = "macos"))]
136    type Softmax = f32;
137    #[cfg(not(target_os = "macos"))]
138    type Accumulator = f32;
139    type Mask = u8;
140    type Out = f32;
141}
142
143impl AttentionPrecision for bf16 {
144    type Query = bf16;
145    type Key = bf16;
146    type Value = bf16;
147    type KVTile = bf16;
148    #[cfg(target_os = "macos")]
149    type Softmax = bf16;
150    #[cfg(target_os = "macos")]
151    type Accumulator = bf16;
152    #[cfg(not(target_os = "macos"))]
153    type Softmax = f32;
154    #[cfg(not(target_os = "macos"))]
155    type Accumulator = f32;
156    type Mask = u8;
157    type Out = bf16;
158}
159
160impl AttentionPrecision for f32 {
161    type Query = f32;
162    type Key = f32;
163    type Value = f32;
164    type KVTile = f32;
165    type Softmax = f32;
166    type Accumulator = f32;
167    type Mask = u8;
168    type Out = f32;
169}
170
171impl AttentionPrecision for f64 {
172    type Query = f64;
173    type Key = f64;
174    type Value = f64;
175    type KVTile = f32;
176    type Softmax = f32;
177    type Accumulator = f32;
178    type Mask = u8;
179    type Out = f64;
180}
181
182impl<
183    QG: Float,
184    QT: Float,
185    KG: Float,
186    KS: Float,
187    VG: Float,
188    VS: Float,
189    KVT: Float,
190    SM: Float,
191    ACC: Float,
192    MSK: Numeric,
193    OG: Float,
194    OS: Float,
195> AttentionPrecision for (QG, QT, KG, KS, VG, VS, KVT, SM, ACC, MSK, OG, OS)
196{
197    type Query = (QG, QT);
198    type Key = (KG, KS);
199    type Value = (VG, VS);
200    type KVTile = KVT;
201    type Softmax = SM;
202    type Accumulator = ACC;
203    type Mask = MSK;
204    type Out = (OG, OS);
205}
206
207// TODO make sure the numbers are the right ones
208
209/// Input argument
210pub type InputArg<AA> = <AA as AttentionArgs>::Input<
211    NumericExpand<0>,
212    NumericExpand<2>,
213    NumericExpand<4>,
214    NumericExpand<9>,
215>;
216
217/// Output argument
218pub type OutputArg<AA> = <AA as AttentionArgs>::Output<NumericExpand<10>>;
219
220/// Input runtime argument
221pub type InputRuntimeArg<'a, AA, R> = <InputArg<AA> as LaunchArg>::RuntimeArg<'a, R>;
222
223/// Output runtime argument
224pub type OutputRuntimeArg<'a, AA, R> = <OutputArg<AA> as LaunchArg>::RuntimeArg<'a, R>;
225
226pub mod attention_types {
227    use crate::definition::{
228        AttentionPrecision, AttentionSpec, QueryPrecision, StagedMatrixPrecision,
229    };
230
231    pub type QG<AS> =
232        <<<AS as AttentionSpec>::Precision as AttentionPrecision>::Query as QueryPrecision>::Global;
233    pub type QT<AS> =
234        <<<AS as AttentionSpec>::Precision as AttentionPrecision>::Query as QueryPrecision>::Tile;
235    pub type KG<AS> =
236    <<<AS as AttentionSpec>::Precision as AttentionPrecision>::Key as StagedMatrixPrecision>::Global;
237    pub type KS<AS> =
238        <<<AS as AttentionSpec>::Precision as AttentionPrecision>::Key as StagedMatrixPrecision>::Stage;
239    pub type VG<AS> =
240    <<<AS as AttentionSpec>::Precision as AttentionPrecision>::Value as StagedMatrixPrecision>::Global;
241    pub type VS<AS> =
242    <<<AS as AttentionSpec>::Precision as AttentionPrecision>::Value as StagedMatrixPrecision>::Stage;
243
244    pub type KVT<AS> = <<AS as AttentionSpec>::Precision as AttentionPrecision>::KVTile;
245    pub type SM<AS> = <<AS as AttentionSpec>::Precision as AttentionPrecision>::Softmax;
246    pub type ACC<AS> = <<AS as AttentionSpec>::Precision as AttentionPrecision>::Accumulator;
247    pub type MSK<AS> = <<AS as AttentionSpec>::Precision as AttentionPrecision>::Mask;
248
249    pub type OG<AS> = <<<AS as AttentionSpec>::Precision as AttentionPrecision>::Out as StagedMatrixPrecision>::Global;
250    pub type OS<AS> = <<<AS as AttentionSpec>::Precision as AttentionPrecision>::Out as StagedMatrixPrecision>::Stage;
251}
252
253pub type Args<MS> = <MS as AttentionSpec>::Args;
254
255#[derive(Debug, Clone, Eq, PartialEq, Hash)]
256pub struct AttentionElems {
257    pub query_global: StorageType,
258    pub query_tile: StorageType,
259    pub key_global: StorageType,
260    pub key_stage: StorageType,
261    pub value_global: StorageType,
262    pub value_stage: StorageType,
263    pub key_value_tile: StorageType,
264    pub softmax: StorageType,
265    pub accumulator: StorageType,
266    pub mask: StorageType,
267    pub out_global: StorageType,
268    pub out_stage: StorageType,
269}
270
271impl AttentionElems {
272    pub fn from_global_types(
273        global_dtypes: &AttentionGlobalTypes,
274        accumulator_precision: &AccumulatorPrecision,
275    ) -> AttentionElems {
276        let accumulator = match accumulator_precision {
277            AccumulatorPrecision::Strict(storage_type) => *storage_type,
278            AccumulatorPrecision::Loose => AccumulatorPrecision::default_accumulator_type(),
279        };
280
281        Self {
282            query_global: global_dtypes.query,
283            query_tile: global_dtypes.query,
284            key_global: global_dtypes.key,
285            key_stage: global_dtypes.key,
286            value_global: global_dtypes.value,
287            value_stage: global_dtypes.value,
288            key_value_tile: global_dtypes.value,
289            softmax: accumulator,
290            accumulator,
291            mask: global_dtypes.mask,
292            out_global: global_dtypes.out,
293            out_stage: global_dtypes.out,
294        }
295    }
296
297    pub fn from_define_array(elem_types: [StorageType; 12]) -> AttentionElems {
298        AttentionElems {
299            query_global: elem_types[0],
300            query_tile: elem_types[1],
301            key_global: elem_types[2],
302            key_stage: elem_types[3],
303            value_global: elem_types[4],
304            value_stage: elem_types[5],
305            key_value_tile: elem_types[6],
306            softmax: elem_types[7],
307            accumulator: elem_types[8],
308            mask: elem_types[9],
309            out_global: elem_types[10],
310            out_stage: elem_types[11],
311        }
312    }
313}
314
315impl From<&AttentionElems> for [StorageType; 12] {
316    fn from(elems: &AttentionElems) -> Self {
317        [
318            elems.query_global,
319            elems.query_tile,
320            elems.key_global,
321            elems.key_stage,
322            elems.value_global,
323            elems.value_stage,
324            elems.key_value_tile,
325            elems.softmax,
326            elems.accumulator,
327            elems.mask,
328            elems.out_global,
329            elems.out_stage,
330        ]
331    }
332}