Skip to main content

burn_cubecl/kernel/attention/
base.rs

1use crate::{
2    CubeBackend, CubeRuntime, kernel::attention::attention_autotune,
3    ops::numeric::empty_device_dtype, tensor::CubeTensor,
4};
5use burn_backend::{
6    DType, Shape,
7    ops::{AttentionModuleOptions, attention::attention_fallback},
8};
9use cubek::attention::launch;
10use cubek::attention::{
11    definition::{
12        AccumulatorPrecision, AttentionGlobalTypes, AttentionOptions, AttentionSetupError,
13    },
14    routines::blackbox_accelerated::BlackboxAcceleratedStrategy,
15};
16
17#[derive(Debug)]
18/// Strategy used to select which attention implementation to run.
19pub enum AttentionStrategy {
20    /// Flash Attention using accelerated inner matmuls.
21    FlashBlackboxAccelerated(BlackboxAcceleratedStrategy),
22
23    /// Flash Attention using unit inner matmuls.
24    FlashUnit,
25
26    /// Fallback implementation using multiple separate kernels.
27    Fallback,
28
29    /// Automatically benchmark and select the best strategy at runtime.
30    #[cfg(feature = "autotune")]
31    Autotune,
32}
33
34impl Default for AttentionStrategy {
35    fn default() -> Self {
36        // if autotune is enabled, default to autotune
37        #[cfg(feature = "autotune")]
38        return AttentionStrategy::Autotune;
39
40        // if autotune is disabled, default to fallback to make sure it runs
41        #[cfg(not(feature = "autotune"))]
42        AttentionStrategy::Fallback
43    }
44}
45
46#[allow(clippy::too_many_arguments)]
47/// Launch an attention kernel with given strategy
48pub fn attention<R: CubeRuntime>(
49    query: CubeTensor<R>,
50    key: CubeTensor<R>,
51    value: CubeTensor<R>,
52    mask: Option<CubeTensor<R>>,
53    attn_bias: Option<CubeTensor<R>>,
54    options: AttentionModuleOptions,
55    strategy: AttentionStrategy,
56) -> Result<CubeTensor<R>, AttentionSetupError> {
57    match strategy {
58        AttentionStrategy::FlashBlackboxAccelerated(strategy) => flash_attention(
59            query,
60            key,
61            value,
62            mask,
63            attn_bias,
64            options,
65            launch::Strategy::BlackboxAccelerated(
66                cubek::attention::launch::BlueprintStrategy::Inferred(strategy),
67            ),
68        ),
69        AttentionStrategy::FlashUnit => flash_attention(
70            query,
71            key,
72            value,
73            mask,
74            attn_bias,
75            options,
76            launch::Strategy::Unit(cubek::attention::launch::BlueprintStrategy::Inferred(())),
77        ),
78        AttentionStrategy::Fallback => Ok(attention_fallback::<CubeBackend<R, f32, i32, u8>>(
79            query, key, value, mask, attn_bias, options,
80        )),
81        #[cfg(feature = "autotune")]
82        AttentionStrategy::Autotune => Ok(attention_autotune(
83            query, key, value, mask, attn_bias, options,
84        )),
85    }
86}
87
88#[allow(clippy::too_many_arguments)]
89/// Launch a flash attention kernel
90pub fn flash_attention<R: CubeRuntime>(
91    query: CubeTensor<R>,
92    key: CubeTensor<R>,
93    value: CubeTensor<R>,
94    mask: Option<CubeTensor<R>>,
95    _attn_bias: Option<CubeTensor<R>>,
96    options: AttentionModuleOptions,
97    strategy: launch::Strategy,
98) -> Result<CubeTensor<R>, AttentionSetupError> {
99    let client = query.client.clone();
100    let out = init_attention_output(&query, &value);
101
102    let dtypes = AttentionGlobalTypes {
103        query: query.dtype.into(),
104        key: key.dtype.into(),
105        value: value.dtype.into(),
106        mask: mask.as_ref().map(|m| m.dtype).unwrap_or(DType::U8).into(),
107        out: out.dtype.into(),
108    };
109
110    cubek::attention::launch::launch_ref::<R>(
111        strategy,
112        &client,
113        query.binding(),
114        key.binding(),
115        value.binding(),
116        mask.map(|mask| mask.binding()),
117        out.clone().binding(),
118        &dtypes,
119        AttentionOptions {
120            causal: options.is_causal,
121            accumulator_precision: AccumulatorPrecision::Strict(cubecl::ir::StorageType::Scalar(
122                cubecl::ir::ElemType::Float(cubecl::ir::FloatKind::F32),
123            )),
124        },
125    )?;
126
127    Ok(out)
128}
129
130pub(crate) fn init_attention_output<R: CubeRuntime>(
131    query: &CubeTensor<R>,
132    value: &CubeTensor<R>,
133) -> CubeTensor<R> {
134    let num_batches = query.meta.shape[0];
135    let num_heads = query.meta.shape[1];
136    let seq_q = query.meta.shape[2];
137    let val_dim = value.meta.shape[3];
138    let out_shape = Shape::new([num_batches, num_heads, seq_q, val_dim]);
139
140    empty_device_dtype::<R>(
141        query.client.clone(),
142        query.device.clone(),
143        out_shape,
144        query.dtype,
145    )
146}