use crate::{
CubeBackend, CubeRuntime, kernel::attention::attention_autotune,
ops::numeric::empty_device_dtype, tensor::CubeTensor,
};
use burn_backend::{
DType, Shape,
ops::{AttentionModuleOptions, attention::attention_fallback},
};
use cubek::attention::launch;
use cubek::attention::{
definition::{
AccumulatorPrecision, AttentionGlobalTypes, AttentionOptions, AttentionSetupError,
},
routines::blackbox_accelerated::BlackboxAcceleratedStrategy,
};
#[derive(Debug)]
pub enum AttentionStrategy {
FlashBlackboxAccelerated(BlackboxAcceleratedStrategy),
FlashUnit,
Fallback,
#[cfg(feature = "autotune")]
Autotune,
}
impl Default for AttentionStrategy {
fn default() -> Self {
#[cfg(feature = "autotune")]
return AttentionStrategy::Autotune;
#[cfg(not(feature = "autotune"))]
AttentionStrategy::Fallback
}
}
#[allow(clippy::too_many_arguments)]
pub fn attention<R: CubeRuntime>(
query: CubeTensor<R>,
key: CubeTensor<R>,
value: CubeTensor<R>,
mask: Option<CubeTensor<R>>,
attn_bias: Option<CubeTensor<R>>,
options: AttentionModuleOptions,
strategy: AttentionStrategy,
) -> Result<CubeTensor<R>, AttentionSetupError> {
match strategy {
AttentionStrategy::FlashBlackboxAccelerated(strategy) => flash_attention(
query,
key,
value,
mask,
attn_bias,
options,
launch::Strategy::BlackboxAccelerated(
cubek::attention::launch::BlueprintStrategy::Inferred(strategy),
),
),
AttentionStrategy::FlashUnit => flash_attention(
query,
key,
value,
mask,
attn_bias,
options,
launch::Strategy::Unit(cubek::attention::launch::BlueprintStrategy::Inferred(())),
),
AttentionStrategy::Fallback => Ok(attention_fallback::<CubeBackend<R, f32, i32, u8>>(
query, key, value, mask, attn_bias, options,
)),
#[cfg(feature = "autotune")]
AttentionStrategy::Autotune => Ok(attention_autotune(
query, key, value, mask, attn_bias, options,
)),
}
}
#[allow(clippy::too_many_arguments)]
pub fn flash_attention<R: CubeRuntime>(
query: CubeTensor<R>,
key: CubeTensor<R>,
value: CubeTensor<R>,
mask: Option<CubeTensor<R>>,
_attn_bias: Option<CubeTensor<R>>,
options: AttentionModuleOptions,
strategy: launch::Strategy,
) -> Result<CubeTensor<R>, AttentionSetupError> {
let client = query.client.clone();
let out = init_attention_output(&query, &value);
let dtypes = AttentionGlobalTypes {
query: query.dtype.into(),
key: key.dtype.into(),
value: value.dtype.into(),
mask: mask.as_ref().map(|m| m.dtype).unwrap_or(DType::U8).into(),
out: out.dtype.into(),
};
cubek::attention::launch::launch_ref::<R>(
strategy,
&client,
query.binding(),
key.binding(),
value.binding(),
mask.map(|mask| mask.binding()),
out.clone().binding(),
&dtypes,
AttentionOptions {
causal: options.is_causal,
accumulator_precision: AccumulatorPrecision::Strict(cubecl::ir::StorageType::Scalar(
cubecl::ir::ElemType::Float(cubecl::ir::FloatKind::F32),
)),
},
)?;
Ok(out)
}
pub(crate) fn init_attention_output<R: CubeRuntime>(
query: &CubeTensor<R>,
value: &CubeTensor<R>,
) -> CubeTensor<R> {
let num_batches = query.meta.shape[0];
let num_heads = query.meta.shape[1];
let seq_q = query.meta.shape[2];
let val_dim = value.meta.shape[3];
let out_shape = Shape::new([num_batches, num_heads, seq_q, val_dim]);
empty_device_dtype::<R>(
query.client.clone(),
query.device.clone(),
out_shape,
query.dtype,
)
}