#[cfg(any(feature = "wgpu", feature = "wgpu-f16"))]
pub mod gpu {
use burn::prelude::*;
use burn::tensor::TensorPrimitive;
use burn_cubecl::tensor::CubeTensor;
use burn_cubecl::{CubeRuntime, ops::numeric::empty_device_dtype};
use burn_backend::{DType, Shape};
use cubek::attention::{
definition::{AccumulatorPrecision, AttentionGlobalTypes, AttentionOptions},
launch::Strategy,
};
pub fn flash_attention_noncausal<R: CubeRuntime>(
query: CubeTensor<R>,
key: CubeTensor<R>,
value: CubeTensor<R>,
) -> CubeTensor<R> {
let client = &query.client;
let device = &query.device;
let out_dtype = query.dtype;
let num_batches = query.shape.dims[0];
let num_heads = query.shape.dims[1];
let seq_q = query.shape.dims[2];
let val_dim = value.shape.dims[3];
let out_shape = Shape::new([num_batches, num_heads, seq_q, val_dim]);
let out = empty_device_dtype::<R>(client.clone(), device.clone(), out_shape, out_dtype);
let dtypes = AttentionGlobalTypes {
query: query.dtype.into(),
key: key.dtype.into(),
value: value.dtype.into(),
mask: cubecl::ir::StorageType::Scalar(cubecl::ir::ElemType::UInt(
cubecl::ir::UIntKind::U8,
)),
out: out.dtype.into(),
};
cubek::attention::launch::launch_ref::<R>(
Strategy::Unit(cubek::attention::launch::BlueprintStrategy::Inferred(())),
client,
&query.as_handle_ref(),
&key.as_handle_ref(),
&value.as_handle_ref(),
&None,
&out.as_handle_ref(),
&dtypes,
AttentionOptions {
causal: false,
accumulator_precision: AccumulatorPrecision::Strict(
cubecl::ir::StorageType::Scalar(cubecl::ir::ElemType::Float(
cubecl::ir::FloatKind::F32,
)),
),
},
)
.expect("non-causal flash attention kernel launch failed");
out
}
pub fn flash_attention_tensor<B, R>(
q: Tensor<B, 4>,
k: Tensor<B, 4>,
v: Tensor<B, 4>,
) -> Tensor<B, 4>
where
B: Backend<FloatTensorPrimitive = CubeTensor<R>>,
R: CubeRuntime,
{
let q_prim = q.into_primitive().tensor();
let k_prim = k.into_primitive().tensor();
let v_prim = v.into_primitive().tensor();
let out_prim = flash_attention_noncausal(q_prim, k_prim, v_prim);
Tensor::from_primitive(TensorPrimitive::Float(out_prim))
}
}