burn_cubecl/kernel/attention/
base.rs1use crate::{
2 CubeBackend, CubeRuntime, kernel::attention::attention_autotune,
3 ops::numeric::empty_device_dtype, tensor::CubeTensor,
4};
5use burn_backend::{
6 DType, Shape,
7 ops::{AttentionModuleOptions, attention::attention_fallback},
8};
9use cubek::attention::launch;
10use cubek::attention::{
11 definition::{
12 AccumulatorPrecision, AttentionGlobalTypes, AttentionOptions, AttentionSetupError,
13 },
14 routines::blackbox_accelerated::BlackboxAcceleratedStrategy,
15};
16
17#[derive(Debug)]
18pub enum AttentionStrategy {
20 FlashBlackboxAccelerated(BlackboxAcceleratedStrategy),
22
23 FlashUnit,
25
26 Fallback,
28
29 #[cfg(feature = "autotune")]
31 Autotune,
32}
33
34impl Default for AttentionStrategy {
35 fn default() -> Self {
36 #[cfg(feature = "autotune")]
38 return AttentionStrategy::Autotune;
39
40 #[cfg(not(feature = "autotune"))]
42 AttentionStrategy::Fallback
43 }
44}
45
46#[allow(clippy::too_many_arguments)]
47pub fn attention<R: CubeRuntime>(
49 query: CubeTensor<R>,
50 key: CubeTensor<R>,
51 value: CubeTensor<R>,
52 mask: Option<CubeTensor<R>>,
53 attn_bias: Option<CubeTensor<R>>,
54 options: AttentionModuleOptions,
55 strategy: AttentionStrategy,
56) -> Result<CubeTensor<R>, AttentionSetupError> {
57 match strategy {
58 AttentionStrategy::FlashBlackboxAccelerated(strategy) => flash_attention(
59 query,
60 key,
61 value,
62 mask,
63 attn_bias,
64 options,
65 launch::Strategy::BlackboxAccelerated(
66 cubek::attention::launch::BlueprintStrategy::Inferred(strategy),
67 ),
68 ),
69 AttentionStrategy::FlashUnit => flash_attention(
70 query,
71 key,
72 value,
73 mask,
74 attn_bias,
75 options,
76 launch::Strategy::Unit(cubek::attention::launch::BlueprintStrategy::Inferred(())),
77 ),
78 AttentionStrategy::Fallback => Ok(attention_fallback::<CubeBackend<R, f32, i32, u8>>(
79 query, key, value, mask, attn_bias, options,
80 )),
81 #[cfg(feature = "autotune")]
82 AttentionStrategy::Autotune => Ok(attention_autotune(
83 query, key, value, mask, attn_bias, options,
84 )),
85 }
86}
87
88#[allow(clippy::too_many_arguments)]
89pub fn flash_attention<R: CubeRuntime>(
91 query: CubeTensor<R>,
92 key: CubeTensor<R>,
93 value: CubeTensor<R>,
94 mask: Option<CubeTensor<R>>,
95 _attn_bias: Option<CubeTensor<R>>,
96 options: AttentionModuleOptions,
97 strategy: launch::Strategy,
98) -> Result<CubeTensor<R>, AttentionSetupError> {
99 let client = query.client.clone();
100 let out = init_attention_output(&query, &value);
101
102 let dtypes = AttentionGlobalTypes {
103 query: query.dtype.into(),
104 key: key.dtype.into(),
105 value: value.dtype.into(),
106 mask: mask.as_ref().map(|m| m.dtype).unwrap_or(DType::U8).into(),
107 out: out.dtype.into(),
108 };
109
110 cubek::attention::launch::launch_ref::<R>(
111 strategy,
112 &client,
113 query.binding(),
114 key.binding(),
115 value.binding(),
116 mask.map(|mask| mask.binding()),
117 out.clone().binding(),
118 &dtypes,
119 AttentionOptions {
120 causal: options.is_causal,
121 accumulator_precision: AccumulatorPrecision::Strict(cubecl::ir::StorageType::Scalar(
122 cubecl::ir::ElemType::Float(cubecl::ir::FloatKind::F32),
123 )),
124 },
125 )?;
126
127 Ok(out)
128}
129
130pub(crate) fn init_attention_output<R: CubeRuntime>(
131 query: &CubeTensor<R>,
132 value: &CubeTensor<R>,
133) -> CubeTensor<R> {
134 let num_batches = query.meta.shape[0];
135 let num_heads = query.meta.shape[1];
136 let seq_q = query.meta.shape[2];
137 let val_dim = value.meta.shape[3];
138 let out_shape = Shape::new([num_batches, num_heads, seq_q, val_dim]);
139
140 empty_device_dtype::<R>(
141 query.client.clone(),
142 query.device.clone(),
143 out_shape,
144 query.dtype,
145 )
146}