cubek_attention/definition/
spec.rs1use cubecl::prelude::*;
2use half::{bf16, f16};
3
4use crate::{
5 definition::{AccumulatorPrecision, AttentionGlobalTypes},
6 launch::{AttentionArgs, TensorArgs},
7};
8
9pub trait AttentionSpec: Send + Sync + Clone + 'static {
12 type Precision: AttentionPrecision;
13 type Args: AttentionArgs;
15}
16
17impl<AP: AttentionPrecision, Args: AttentionArgs> AttentionSpec for (AP, Args) {
18 type Precision = AP;
19 type Args = Args;
20}
21
22impl<AP: AttentionPrecision> AttentionSpec for AP {
24 type Precision = AP;
25 type Args = TensorArgs;
26}
27
28pub trait QueryPrecision: Send + Sync + Copy + 'static {
29 type Global: Float;
30 type Tile: Float;
31}
32
33pub trait StagedMatrixPrecision: Send + Sync + Copy + 'static {
34 type Global: Float;
35 type Stage: Float;
36}
37
38pub trait AttentionPrecision: Send + Sync + Copy + 'static {
39 type Query: QueryPrecision;
40 type Key: StagedMatrixPrecision;
41 type Value: StagedMatrixPrecision;
42 type KVTile: Float;
43 type Softmax: Float;
44 type Accumulator: Float;
45 type Mask: Numeric;
46 type Out: StagedMatrixPrecision;
47}
48
49impl QueryPrecision for f16 {
50 type Global = f16;
51 type Tile = f16;
52}
53
54impl QueryPrecision for bf16 {
55 type Global = bf16;
56 type Tile = bf16;
57}
58
59impl QueryPrecision for flex32 {
60 type Global = f32;
61 type Tile = f16;
62}
63
64impl QueryPrecision for f32 {
65 type Global = f32;
66 type Tile = f32;
67}
68
69impl QueryPrecision for f64 {
70 type Global = f64;
71 type Tile = f32;
72}
73
74impl<G: Float, T: Float> QueryPrecision for (G, T) {
75 type Global = G;
76 type Tile = T;
77}
78
79impl StagedMatrixPrecision for f16 {
80 type Global = f16;
81 type Stage = f16;
82}
83
84impl StagedMatrixPrecision for bf16 {
85 type Global = bf16;
86 type Stage = bf16;
87}
88
89impl StagedMatrixPrecision for flex32 {
90 type Global = f32;
91 type Stage = f16;
92}
93
94impl StagedMatrixPrecision for f32 {
95 type Global = f32;
96 type Stage = f32;
97}
98
99impl StagedMatrixPrecision for f64 {
100 type Global = f64;
101 type Stage = f32;
102}
103
104impl<G: Float, S: Float> StagedMatrixPrecision for (G, S) {
105 type Global = G;
106 type Stage = S;
107}
108
109impl AttentionPrecision for f16 {
110 type Query = f16;
111 type Key = f16;
112 type Value = f16;
113 type KVTile = f16;
114 #[cfg(target_os = "macos")]
115 type Softmax = f16;
116 #[cfg(target_os = "macos")]
117 type Accumulator = f16;
118 #[cfg(not(target_os = "macos"))]
119 type Softmax = f32;
120 #[cfg(not(target_os = "macos"))]
121 type Accumulator = f32;
122 type Mask = u8;
123 type Out = f16;
124}
125
126impl AttentionPrecision for flex32 {
127 type Query = flex32;
128 type Key = flex32;
129 type Value = flex32;
130 type KVTile = f16;
131 #[cfg(target_os = "macos")]
132 type Softmax = f16;
133 #[cfg(target_os = "macos")]
134 type Accumulator = f16;
135 #[cfg(not(target_os = "macos"))]
136 type Softmax = f32;
137 #[cfg(not(target_os = "macos"))]
138 type Accumulator = f32;
139 type Mask = u8;
140 type Out = f32;
141}
142
143impl AttentionPrecision for bf16 {
144 type Query = bf16;
145 type Key = bf16;
146 type Value = bf16;
147 type KVTile = bf16;
148 #[cfg(target_os = "macos")]
149 type Softmax = bf16;
150 #[cfg(target_os = "macos")]
151 type Accumulator = bf16;
152 #[cfg(not(target_os = "macos"))]
153 type Softmax = f32;
154 #[cfg(not(target_os = "macos"))]
155 type Accumulator = f32;
156 type Mask = u8;
157 type Out = bf16;
158}
159
160impl AttentionPrecision for f32 {
161 type Query = f32;
162 type Key = f32;
163 type Value = f32;
164 type KVTile = f32;
165 type Softmax = f32;
166 type Accumulator = f32;
167 type Mask = u8;
168 type Out = f32;
169}
170
171impl AttentionPrecision for f64 {
172 type Query = f64;
173 type Key = f64;
174 type Value = f64;
175 type KVTile = f32;
176 type Softmax = f32;
177 type Accumulator = f32;
178 type Mask = u8;
179 type Out = f64;
180}
181
182impl<
183 QG: Float,
184 QT: Float,
185 KG: Float,
186 KS: Float,
187 VG: Float,
188 VS: Float,
189 KVT: Float,
190 SM: Float,
191 ACC: Float,
192 MSK: Numeric,
193 OG: Float,
194 OS: Float,
195> AttentionPrecision for (QG, QT, KG, KS, VG, VS, KVT, SM, ACC, MSK, OG, OS)
196{
197 type Query = (QG, QT);
198 type Key = (KG, KS);
199 type Value = (VG, VS);
200 type KVTile = KVT;
201 type Softmax = SM;
202 type Accumulator = ACC;
203 type Mask = MSK;
204 type Out = (OG, OS);
205}
206
207pub type InputArg<AA> = <AA as AttentionArgs>::Input<
211 NumericExpand<0>,
212 NumericExpand<2>,
213 NumericExpand<4>,
214 NumericExpand<9>,
215>;
216
217pub type OutputArg<AA> = <AA as AttentionArgs>::Output<NumericExpand<10>>;
219
220pub type InputRuntimeArg<'a, AA, R> = <InputArg<AA> as LaunchArg>::RuntimeArg<'a, R>;
222
223pub type OutputRuntimeArg<'a, AA, R> = <OutputArg<AA> as LaunchArg>::RuntimeArg<'a, R>;
225
226pub mod attention_types {
227 use crate::definition::{
228 AttentionPrecision, AttentionSpec, QueryPrecision, StagedMatrixPrecision,
229 };
230
231 pub type QG<AS> =
232 <<<AS as AttentionSpec>::Precision as AttentionPrecision>::Query as QueryPrecision>::Global;
233 pub type QT<AS> =
234 <<<AS as AttentionSpec>::Precision as AttentionPrecision>::Query as QueryPrecision>::Tile;
235 pub type KG<AS> =
236 <<<AS as AttentionSpec>::Precision as AttentionPrecision>::Key as StagedMatrixPrecision>::Global;
237 pub type KS<AS> =
238 <<<AS as AttentionSpec>::Precision as AttentionPrecision>::Key as StagedMatrixPrecision>::Stage;
239 pub type VG<AS> =
240 <<<AS as AttentionSpec>::Precision as AttentionPrecision>::Value as StagedMatrixPrecision>::Global;
241 pub type VS<AS> =
242 <<<AS as AttentionSpec>::Precision as AttentionPrecision>::Value as StagedMatrixPrecision>::Stage;
243
244 pub type KVT<AS> = <<AS as AttentionSpec>::Precision as AttentionPrecision>::KVTile;
245 pub type SM<AS> = <<AS as AttentionSpec>::Precision as AttentionPrecision>::Softmax;
246 pub type ACC<AS> = <<AS as AttentionSpec>::Precision as AttentionPrecision>::Accumulator;
247 pub type MSK<AS> = <<AS as AttentionSpec>::Precision as AttentionPrecision>::Mask;
248
249 pub type OG<AS> = <<<AS as AttentionSpec>::Precision as AttentionPrecision>::Out as StagedMatrixPrecision>::Global;
250 pub type OS<AS> = <<<AS as AttentionSpec>::Precision as AttentionPrecision>::Out as StagedMatrixPrecision>::Stage;
251}
252
253pub type Args<MS> = <MS as AttentionSpec>::Args;
254
255#[derive(Debug, Clone, Eq, PartialEq, Hash)]
256pub struct AttentionElems {
257 pub query_global: StorageType,
258 pub query_tile: StorageType,
259 pub key_global: StorageType,
260 pub key_stage: StorageType,
261 pub value_global: StorageType,
262 pub value_stage: StorageType,
263 pub key_value_tile: StorageType,
264 pub softmax: StorageType,
265 pub accumulator: StorageType,
266 pub mask: StorageType,
267 pub out_global: StorageType,
268 pub out_stage: StorageType,
269}
270
271impl AttentionElems {
272 pub fn from_global_types(
273 global_dtypes: &AttentionGlobalTypes,
274 accumulator_precision: &AccumulatorPrecision,
275 ) -> AttentionElems {
276 let accumulator = match accumulator_precision {
277 AccumulatorPrecision::Strict(storage_type) => *storage_type,
278 AccumulatorPrecision::Loose => AccumulatorPrecision::default_accumulator_type(),
279 };
280
281 Self {
282 query_global: global_dtypes.query,
283 query_tile: global_dtypes.query,
284 key_global: global_dtypes.key,
285 key_stage: global_dtypes.key,
286 value_global: global_dtypes.value,
287 value_stage: global_dtypes.value,
288 key_value_tile: global_dtypes.value,
289 softmax: accumulator,
290 accumulator,
291 mask: global_dtypes.mask,
292 out_global: global_dtypes.out,
293 out_stage: global_dtypes.out,
294 }
295 }
296
297 pub fn from_define_array(elem_types: [StorageType; 12]) -> AttentionElems {
298 AttentionElems {
299 query_global: elem_types[0],
300 query_tile: elem_types[1],
301 key_global: elem_types[2],
302 key_stage: elem_types[3],
303 value_global: elem_types[4],
304 value_stage: elem_types[5],
305 key_value_tile: elem_types[6],
306 softmax: elem_types[7],
307 accumulator: elem_types[8],
308 mask: elem_types[9],
309 out_global: elem_types[10],
310 out_stage: elem_types[11],
311 }
312 }
313}
314
315impl From<&AttentionElems> for [StorageType; 12] {
316 fn from(elems: &AttentionElems) -> Self {
317 [
318 elems.query_global,
319 elems.query_tile,
320 elems.key_global,
321 elems.key_stage,
322 elems.value_global,
323 elems.value_stage,
324 elems.key_value_tile,
325 elems.softmax,
326 elems.accumulator,
327 elems.mask,
328 elems.out_global,
329 elems.out_stage,
330 ]
331 }
332}