use alloc::vec;
use alloc::vec::Vec;
use burn_backend::DType;
use burn_backend::ops::AttentionModuleOptions;
use burn_std::Bytes;
use bytemuck::Pod;
use num_traits::Float;
use crate::{FlexTensor, Layout};
#[cfg(target_family = "wasm")]
const TILE_KV: usize = 32;
#[cfg(not(target_family = "wasm"))]
const TILE_KV: usize = 64;
const NAIVE_SCORE_BUDGET: usize = 256 * 1024;
pub fn attention(
query: FlexTensor,
key: FlexTensor,
value: FlexTensor,
mask: Option<FlexTensor>,
attn_bias: Option<FlexTensor>,
options: AttentionModuleOptions,
) -> FlexTensor {
debug_assert!(
query.layout().shape().num_dims() == 4,
"attention: query must be 4D, got {}D",
query.layout().shape().num_dims()
);
debug_assert!(
key.layout().shape().num_dims() == 4,
"attention: key must be 4D, got {}D",
key.layout().shape().num_dims()
);
let seq_q = query.layout().shape()[2];
let seq_kv = key.layout().shape()[2];
if seq_q * seq_kv <= NAIVE_SCORE_BUDGET {
return attention_naive(query, key, value, mask, attn_bias, options);
}
attention_flash(query, key, value, mask, attn_bias, options)
}
macro_rules! dispatch_attention_dtype {
($query:expr, $key:expr, $value:expr, $mask:expr, $attn_bias:expr, $options:expr, $impl_fn:ident) => {{
let query = $query;
let key = $key;
let value = $value;
let mask = $mask;
let attn_bias = $attn_bias;
let options = $options;
let dtype = query.dtype();
debug_assert_eq!(key.dtype(), dtype, "attention: key dtype mismatch");
debug_assert_eq!(value.dtype(), dtype, "attention: value dtype mismatch");
if let Some(ref b) = attn_bias {
debug_assert_eq!(b.dtype(), dtype, "attention: attn_bias dtype mismatch");
}
match dtype {
DType::F32 => $impl_fn::<f32>(query, key, value, mask, attn_bias, options),
DType::F64 => $impl_fn::<f64>(query, key, value, mask, attn_bias, options),
DType::F16 => {
use burn_std::f16;
let r = $impl_fn::<f32>(
cast_to_f32(query, f16::to_f32),
cast_to_f32(key, f16::to_f32),
cast_to_f32(value, f16::to_f32),
mask,
attn_bias.map(|b| cast_to_f32(b, f16::to_f32)),
options,
);
cast_from_f32(r, f16::from_f32)
}
DType::BF16 => {
use burn_std::bf16;
let r = $impl_fn::<f32>(
cast_to_f32(query, bf16::to_f32),
cast_to_f32(key, bf16::to_f32),
cast_to_f32(value, bf16::to_f32),
mask,
attn_bias.map(|b| cast_to_f32(b, bf16::to_f32)),
options,
);
cast_from_f32(r, bf16::from_f32)
}
dtype => panic!("attention: unsupported dtype {:?}", dtype),
}
}};
}
pub fn attention_flash(
query: FlexTensor,
key: FlexTensor,
value: FlexTensor,
mask: Option<FlexTensor>,
attn_bias: Option<FlexTensor>,
options: AttentionModuleOptions,
) -> FlexTensor {
dispatch_attention_dtype!(query, key, value, mask, attn_bias, options, attention_impl)
}
fn cast_to_f32<E: burn_backend::Element + Pod + Copy>(
tensor: FlexTensor,
to_f32: fn(E) -> f32,
) -> FlexTensor {
let tensor = tensor.to_contiguous();
let shape = tensor.layout().shape().clone();
let data: &[E] = tensor.storage();
let f32_data: Vec<f32> = data.iter().map(|&v| to_f32(v)).collect();
FlexTensor::new(
Bytes::from_elems(f32_data),
Layout::contiguous(shape),
DType::F32,
)
}
fn cast_from_f32<E: burn_backend::Element + Pod + Copy>(
tensor: FlexTensor,
from_f32: fn(f32) -> E,
) -> FlexTensor {
let tensor = tensor.to_contiguous();
let shape = tensor.layout().shape().clone();
let data: &[f32] = tensor.storage();
let half_data: Vec<E> = data.iter().map(|&v| from_f32(v)).collect();
FlexTensor::new(
Bytes::from_elems(half_data),
Layout::contiguous(shape),
E::dtype(),
)
}
fn attention_impl<T>(
query: FlexTensor,
key: FlexTensor,
value: FlexTensor,
mask: Option<FlexTensor>,
attn_bias: Option<FlexTensor>,
options: AttentionModuleOptions,
) -> FlexTensor
where
T: FlashGemm + burn_backend::Element,
{
if let Some(softcap) = options.softcap {
assert!(softcap > 0.0, "softcap must be positive, got {softcap}");
}
let query = query.to_contiguous();
let key = key.to_contiguous();
let value = value.to_contiguous();
let mask_tensor = mask.map(|m| m.to_contiguous());
let bias_tensor = attn_bias.map(|b| b.to_contiguous());
let q_shape = query.layout().shape();
let k_shape = key.layout().shape();
let v_shape = value.layout().shape();
assert!(q_shape.num_dims() == 4, "attention: query must be 4D");
assert!(k_shape.num_dims() == 4, "attention: key must be 4D");
assert!(v_shape.num_dims() == 4, "attention: value must be 4D");
let batch = q_shape[0];
let heads = q_shape[1];
let seq_q = q_shape[2];
let head_dim = q_shape[3];
assert!(head_dim > 0, "attention: head_dim must be non-zero");
let seq_kv = k_shape[2];
let val_dim = v_shape[3];
assert_eq!(k_shape[0], batch, "attention: key batch mismatch");
assert_eq!(k_shape[1], heads, "attention: key heads mismatch");
assert_eq!(k_shape[3], head_dim, "attention: key head_dim mismatch");
assert_eq!(v_shape[0], batch, "attention: value batch mismatch");
assert_eq!(v_shape[1], heads, "attention: value heads mismatch");
assert_eq!(v_shape[2], seq_kv, "attention: value seq_kv mismatch");
if let Some(ref m) = mask_tensor {
let ms = m.layout().shape();
assert_eq!(
ms[..],
[batch, heads, seq_q, seq_kv],
"attention: mask shape mismatch"
);
}
if let Some(ref b) = bias_tensor {
let bs = b.layout().shape();
assert_eq!(
bs[..],
[batch, heads, seq_q, seq_kv],
"attention: bias shape mismatch"
);
}
let scale = T::from(
options
.scale
.unwrap_or_else(|| 1.0 / (head_dim as f64).sqrt()),
)
.unwrap();
let softcap: Option<T> = options.softcap.map(|s| T::from(s).unwrap());
let causal_offset = if options.is_causal {
Some(seq_kv as isize - seq_q as isize)
} else {
None
};
let q_data: &[T] = query.storage();
let k_data: &[T] = key.storage();
let v_data: &[T] = value.storage();
let mask_data: Option<&[u8]> = mask_tensor.as_ref().map(|m| m.bytes());
let bias_data: Option<&[T]> = bias_tensor.as_ref().map(|b| b.storage());
let mut output = vec![T::zero(); batch * heads * seq_q * val_dim];
let q_head_stride = seq_q * head_dim;
let q_batch_stride = heads * q_head_stride;
let k_head_stride = seq_kv * head_dim;
let k_batch_stride = heads * k_head_stride;
let v_head_stride = seq_kv * val_dim;
let v_batch_stride = heads * v_head_stride;
let o_head_stride = seq_q * val_dim;
let o_batch_stride = heads * o_head_stride;
let mask_head_stride = seq_q * seq_kv;
let mask_batch_stride = heads * mask_head_stride;
let params = AttentionParams {
scale,
softcap,
causal_offset,
seq_q,
seq_kv,
head_dim,
val_dim,
};
let mut scratch = ScratchBuffers {
row_max: vec![T::neg_infinity(); seq_q],
row_sum: vec![T::zero(); seq_q],
scores: vec![T::zero(); seq_q * TILE_KV],
};
for b in 0..batch {
for h in 0..heads {
let q_off = b * q_batch_stride + h * q_head_stride;
let k_off = b * k_batch_stride + h * k_head_stride;
let v_off = b * v_batch_stride + h * v_head_stride;
let o_off = b * o_batch_stride + h * o_head_stride;
let m_off = b * mask_batch_stride + h * mask_head_stride;
flash_attention_head(
&q_data[q_off..q_off + q_head_stride],
&k_data[k_off..k_off + k_head_stride],
&v_data[v_off..v_off + v_head_stride],
&mut output[o_off..o_off + o_head_stride],
mask_data.map(|m| &m[m_off..m_off + mask_head_stride]),
bias_data.map(|b| &b[m_off..m_off + mask_head_stride]),
¶ms,
&mut scratch,
);
}
}
let shape = burn_std::Shape::from(vec![batch, heads, seq_q, val_dim]);
FlexTensor::new(
Bytes::from_elems(output),
Layout::contiguous(shape),
T::dtype(),
)
}
trait FlashGemm: Float + Pod + Copy + core::ops::AddAssign {
unsafe fn block_gemm(args: BlockGemmArgs<Self>);
}
struct BlockGemmArgs<T> {
m: usize,
n: usize,
k: usize,
dst: *mut T,
dst_cs: isize,
dst_rs: isize,
read_dst: bool,
lhs: *const T,
lhs_cs: isize,
lhs_rs: isize,
rhs: *const T,
rhs_cs: isize,
rhs_rs: isize,
alpha: T,
beta: T,
}
macro_rules! impl_flash_gemm {
($ty:ty) => {
impl FlashGemm for $ty {
unsafe fn block_gemm(a: BlockGemmArgs<Self>) {
unsafe {
gemm::gemm(
a.m,
a.n,
a.k,
a.dst,
a.dst_cs,
a.dst_rs,
a.read_dst,
a.lhs,
a.lhs_cs,
a.lhs_rs,
a.rhs,
a.rhs_cs,
a.rhs_rs,
a.alpha,
a.beta,
false,
false,
false,
gemm::Parallelism::None,
);
}
}
}
};
}
impl_flash_gemm!(f32);
impl_flash_gemm!(f64);
struct ScratchBuffers<T> {
row_max: Vec<T>,
row_sum: Vec<T>,
scores: Vec<T>,
}
struct AttentionParams<T> {
scale: T,
softcap: Option<T>,
causal_offset: Option<isize>,
seq_q: usize,
seq_kv: usize,
head_dim: usize,
val_dim: usize,
}
#[allow(clippy::too_many_arguments)]
fn flash_attention_head<T: FlashGemm>(
q: &[T],
k: &[T],
v: &[T],
output: &mut [T],
mask: Option<&[u8]>,
bias: Option<&[T]>,
p: &AttentionParams<T>,
scratch: &mut ScratchBuffers<T>,
) {
debug_assert_eq!(q.len(), p.seq_q * p.head_dim);
debug_assert_eq!(k.len(), p.seq_kv * p.head_dim);
debug_assert_eq!(v.len(), p.seq_kv * p.val_dim);
debug_assert_eq!(output.len(), p.seq_q * p.val_dim);
let neg_inf = T::neg_infinity();
let AttentionParams {
scale,
softcap,
causal_offset,
seq_q,
seq_kv,
head_dim,
val_dim,
} = *p;
let row_max = &mut scratch.row_max;
row_max.fill(neg_inf);
let row_sum = &mut scratch.row_sum;
row_sum.fill(T::zero());
let scores = &mut scratch.scores;
let num_kv_tiles = seq_kv.div_ceil(TILE_KV);
for tile_idx in 0..num_kv_tiles {
let kv_start = tile_idx * TILE_KV;
let kv_end = (kv_start + TILE_KV).min(seq_kv);
let tile_kv = kv_end - kv_start;
unsafe {
T::block_gemm(BlockGemmArgs {
m: seq_q,
n: tile_kv,
k: head_dim,
dst: scores.as_mut_ptr(),
dst_cs: 1,
dst_rs: tile_kv as isize,
read_dst: false,
lhs: q.as_ptr(),
lhs_cs: 1,
lhs_rs: head_dim as isize,
rhs: k.as_ptr().add(kv_start * head_dim),
rhs_cs: head_dim as isize,
rhs_rs: 1,
alpha: T::zero(),
beta: T::one(),
});
}
for qi in 0..seq_q {
let score_row = &mut scores[qi * tile_kv..(qi + 1) * tile_kv];
let mut tile_max = neg_inf;
for (ki, score) in score_row.iter_mut().enumerate() {
let kv_idx = kv_start + ki;
let mut val = *score * scale;
if let Some(cap) = softcap {
val = cap * (val / cap).tanh();
}
if let Some(m) = mask
&& m[qi * seq_kv + kv_idx] != 0
{
val = neg_inf;
}
if let Some(offset) = causal_offset
&& (kv_idx as isize) > (qi as isize) + offset
{
val = neg_inf;
}
if let Some(b) = bias {
val += b[qi * seq_kv + kv_idx];
}
*score = val;
if val > tile_max {
tile_max = val;
}
}
if tile_max == neg_inf {
for score in score_row.iter_mut() {
*score = T::zero();
}
continue;
}
let new_max = if row_max[qi] > tile_max {
row_max[qi]
} else {
tile_max
};
let mut tile_sum = T::zero();
for score in score_row.iter_mut() {
let e = (*score - new_max).exp();
*score = e;
tile_sum += e;
}
let correction = if row_max[qi] == neg_inf {
T::zero()
} else {
(row_max[qi] - new_max).exp()
};
let out_row = &mut output[qi * val_dim..(qi + 1) * val_dim];
for o in out_row.iter_mut() {
*o = *o * correction;
}
row_sum[qi] = row_sum[qi] * correction + tile_sum;
row_max[qi] = new_max;
}
unsafe {
T::block_gemm(BlockGemmArgs {
m: seq_q,
n: val_dim,
k: tile_kv,
dst: output.as_mut_ptr(),
dst_cs: 1,
dst_rs: val_dim as isize,
read_dst: true,
lhs: scores.as_ptr(),
lhs_cs: 1,
lhs_rs: tile_kv as isize,
rhs: v.as_ptr().add(kv_start * val_dim),
rhs_cs: 1,
rhs_rs: val_dim as isize,
alpha: T::one(),
beta: T::one(),
});
}
}
for qi in 0..seq_q {
let sum = row_sum[qi];
if sum > T::zero() {
let inv_sum = T::one() / sum;
let out_row = &mut output[qi * val_dim..(qi + 1) * val_dim];
for o in out_row.iter_mut() {
*o = *o * inv_sum;
}
}
}
}
pub fn attention_naive(
query: FlexTensor,
key: FlexTensor,
value: FlexTensor,
mask: Option<FlexTensor>,
attn_bias: Option<FlexTensor>,
options: AttentionModuleOptions,
) -> FlexTensor {
dispatch_attention_dtype!(
query,
key,
value,
mask,
attn_bias,
options,
attention_naive_impl
)
}
fn attention_naive_impl<T>(
query: FlexTensor,
key: FlexTensor,
value: FlexTensor,
mask: Option<FlexTensor>,
attn_bias: Option<FlexTensor>,
options: AttentionModuleOptions,
) -> FlexTensor
where
T: FlashGemm + burn_backend::Element,
{
if let Some(softcap) = options.softcap {
assert!(softcap > 0.0, "softcap must be positive, got {softcap}");
}
let query = query.to_contiguous();
let key = key.to_contiguous();
let value = value.to_contiguous();
let mask_tensor = mask.map(|m| m.to_contiguous());
let bias_tensor = attn_bias.map(|b| b.to_contiguous());
let q_shape = query.layout().shape();
let k_shape = key.layout().shape();
let v_shape = value.layout().shape();
assert!(q_shape.num_dims() == 4, "attention_naive: query must be 4D");
assert!(k_shape.num_dims() == 4, "attention_naive: key must be 4D");
assert!(v_shape.num_dims() == 4, "attention_naive: value must be 4D");
let batch = q_shape[0];
let heads = q_shape[1];
let seq_q = q_shape[2];
let head_dim = q_shape[3];
assert!(head_dim > 0, "attention_naive: head_dim must be non-zero");
let seq_kv = k_shape[2];
let val_dim = v_shape[3];
assert_eq!(k_shape[0], batch, "attention_naive: key batch mismatch");
assert_eq!(k_shape[1], heads, "attention_naive: key heads mismatch");
assert_eq!(
k_shape[3], head_dim,
"attention_naive: key head_dim mismatch"
);
assert_eq!(v_shape[0], batch, "attention_naive: value batch mismatch");
assert_eq!(v_shape[1], heads, "attention_naive: value heads mismatch");
assert_eq!(v_shape[2], seq_kv, "attention_naive: value seq_kv mismatch");
if let Some(ref m) = mask_tensor {
let ms = m.layout().shape();
assert_eq!(
ms[..],
[batch, heads, seq_q, seq_kv],
"attention_naive: mask shape mismatch"
);
}
if let Some(ref b) = bias_tensor {
let bs = b.layout().shape();
assert_eq!(
bs[..],
[batch, heads, seq_q, seq_kv],
"attention_naive: bias shape mismatch"
);
}
let scale = T::from(
options
.scale
.unwrap_or_else(|| 1.0 / (head_dim as f64).sqrt()),
)
.unwrap();
let softcap: Option<T> = options.softcap.map(|s| T::from(s).unwrap());
let causal_offset = if options.is_causal {
Some(seq_kv as isize - seq_q as isize)
} else {
None
};
let q_data: &[T] = query.storage();
let k_data: &[T] = key.storage();
let v_data: &[T] = value.storage();
let mask_data: Option<&[u8]> = mask_tensor.as_ref().map(|m| m.bytes());
let bias_data: Option<&[T]> = bias_tensor.as_ref().map(|b| b.storage());
let mut output = vec![T::zero(); batch * heads * seq_q * val_dim];
let mut scores = vec![T::zero(); seq_q * seq_kv];
let q_head_stride = seq_q * head_dim;
let q_batch_stride = heads * q_head_stride;
let k_head_stride = seq_kv * head_dim;
let k_batch_stride = heads * k_head_stride;
let v_head_stride = seq_kv * val_dim;
let v_batch_stride = heads * v_head_stride;
let o_head_stride = seq_q * val_dim;
let o_batch_stride = heads * o_head_stride;
let mask_head_stride = seq_q * seq_kv;
let mask_batch_stride = heads * mask_head_stride;
let params = AttentionParams {
scale,
softcap,
causal_offset,
seq_q,
seq_kv,
head_dim,
val_dim,
};
for b in 0..batch {
for h in 0..heads {
let q_off = b * q_batch_stride + h * q_head_stride;
let k_off = b * k_batch_stride + h * k_head_stride;
let v_off = b * v_batch_stride + h * v_head_stride;
let o_off = b * o_batch_stride + h * o_head_stride;
let m_off = b * mask_batch_stride + h * mask_head_stride;
naive_attention_head(
&q_data[q_off..q_off + q_head_stride],
&k_data[k_off..k_off + k_head_stride],
&v_data[v_off..v_off + v_head_stride],
&mut output[o_off..o_off + o_head_stride],
&mut scores,
¶ms,
(
mask_data.map(|m| &m[m_off..m_off + mask_head_stride]),
bias_data.map(|b| &b[m_off..m_off + mask_head_stride]),
),
);
}
}
let shape = burn_std::Shape::from(vec![batch, heads, seq_q, val_dim]);
FlexTensor::new(
Bytes::from_elems(output),
Layout::contiguous(shape),
T::dtype(),
)
}
fn naive_attention_head<T: FlashGemm>(
q: &[T],
k: &[T],
v: &[T],
output: &mut [T],
scores: &mut [T],
p: &AttentionParams<T>,
mask_bias: (Option<&[u8]>, Option<&[T]>),
) {
let (mask, bias) = mask_bias;
let neg_inf = T::neg_infinity();
let AttentionParams {
scale,
softcap,
causal_offset,
seq_q,
seq_kv,
head_dim,
val_dim,
} = *p;
unsafe {
T::block_gemm(BlockGemmArgs {
m: seq_q,
n: seq_kv,
k: head_dim,
dst: scores.as_mut_ptr(),
dst_cs: 1,
dst_rs: seq_kv as isize,
read_dst: false,
lhs: q.as_ptr(),
lhs_cs: 1,
lhs_rs: head_dim as isize,
rhs: k.as_ptr(),
rhs_cs: head_dim as isize,
rhs_rs: 1,
alpha: T::zero(),
beta: T::one(),
});
}
for qi in 0..seq_q {
let row = &mut scores[qi * seq_kv..(qi + 1) * seq_kv];
let mut row_max = neg_inf;
for (ki, s) in row.iter_mut().enumerate() {
let mut val = *s * scale;
if let Some(cap) = softcap {
val = cap * (val / cap).tanh();
}
if let Some(m) = mask
&& m[qi * seq_kv + ki] != 0
{
val = neg_inf;
}
if let Some(offset) = causal_offset
&& (ki as isize) > (qi as isize) + offset
{
val = neg_inf;
}
if let Some(b) = bias {
val += b[qi * seq_kv + ki];
}
*s = val;
if val > row_max {
row_max = val;
}
}
if row_max == neg_inf {
row.fill(T::zero());
continue;
}
let mut sum = T::zero();
for s in row.iter_mut() {
let e = (*s - row_max).exp();
*s = e;
sum += e;
}
let inv_sum = T::one() / sum;
for s in row.iter_mut() {
*s = *s * inv_sum;
}
}
unsafe {
T::block_gemm(BlockGemmArgs {
m: seq_q,
n: val_dim,
k: seq_kv,
dst: output.as_mut_ptr(),
dst_cs: 1,
dst_rs: val_dim as isize,
read_dst: false,
lhs: scores.as_ptr(),
lhs_cs: 1,
lhs_rs: seq_kv as isize,
rhs: v.as_ptr(),
rhs_cs: 1,
rhs_rs: val_dim as isize,
alpha: T::zero(),
beta: T::one(),
});
}
}
#[cfg(test)]
mod tests {
use burn_backend::ops::AttentionModuleOptions;
use burn_tensor::{Tensor, TensorData};
use crate::Flex;
fn make_qkv(
q: &[&[f32]],
k: &[&[f32]],
v: &[&[f32]],
) -> (Tensor<Flex, 4>, Tensor<Flex, 4>, Tensor<Flex, 4>) {
let seq_q = q.len();
let seq_k = k.len();
let head_dim = q[0].len();
let val_dim = v[0].len();
let q_flat: Vec<f32> = q.iter().flat_map(|r| r.iter().copied()).collect();
let k_flat: Vec<f32> = k.iter().flat_map(|r| r.iter().copied()).collect();
let v_flat: Vec<f32> = v.iter().flat_map(|r| r.iter().copied()).collect();
let dev = Default::default();
let qt = Tensor::from_data(TensorData::new(q_flat, [1, 1, seq_q, head_dim]), &dev);
let kt = Tensor::from_data(TensorData::new(k_flat, [1, 1, seq_k, head_dim]), &dev);
let vt = Tensor::from_data(TensorData::new(v_flat, [1, 1, seq_k, val_dim]), &dev);
(qt, kt, vt)
}
#[test]
fn test_basic() {
let (q, k, v) = make_qkv(
&[&[1.0, 0.0], &[0.0, 1.0]],
&[&[1.0, 0.0], &[0.0, 1.0]],
&[&[10.0], &[20.0]],
);
let result = burn_tensor::module::attention(q, k, v, None, None, Default::default());
let data: Vec<f32> = result.into_data().to_vec().unwrap();
assert_eq!(data.len(), 2);
assert!((data[0] - 13.30).abs() < 0.1, "got {}", data[0]);
assert!((data[1] - 16.70).abs() < 0.1, "got {}", data[1]);
}
#[test]
fn test_causal_mask() {
let (q, k, v) = make_qkv(
&[&[1.0, 0.0], &[0.0, 1.0]],
&[&[1.0, 0.0], &[0.0, 1.0]],
&[&[10.0], &[20.0]],
);
let opts = AttentionModuleOptions {
is_causal: true,
..Default::default()
};
let result = burn_tensor::module::attention(q, k, v, None, None, opts);
let data: Vec<f32> = result.into_data().to_vec().unwrap();
assert!((data[0] - 10.0).abs() < 1e-5, "got {}", data[0]);
assert!((data[1] - 16.70).abs() < 0.1, "got {}", data[1]);
}
#[test]
fn test_bool_mask() {
let (q, k, v) = make_qkv(
&[&[1.0, 0.0], &[0.0, 1.0]],
&[&[1.0, 0.0], &[0.0, 1.0]],
&[&[10.0], &[20.0]],
);
let dev = Default::default();
use burn_tensor::Bool;
let mask: Tensor<Flex, 4, Bool> =
Tensor::from_data(TensorData::from([[[[true, false], [true, false]]]]), &dev);
let result = burn_tensor::module::attention(q, k, v, Some(mask), None, Default::default());
let data: Vec<f32> = result.into_data().to_vec().unwrap();
assert!((data[0] - 20.0).abs() < 1e-4, "got {}", data[0]);
assert!((data[1] - 20.0).abs() < 1e-4, "got {}", data[1]);
}
#[test]
fn test_additive_bias() {
let (q, k, v) = make_qkv(
&[&[1.0, 0.0], &[0.0, 1.0]],
&[&[1.0, 0.0], &[0.0, 1.0]],
&[&[10.0], &[20.0]],
);
let dev = Default::default();
let bias: Tensor<Flex, 4> = Tensor::from_data(
TensorData::new(vec![0.0f32, 100.0, 0.0, 100.0], [1, 1, 2, 2]),
&dev,
);
let result = burn_tensor::module::attention(q, k, v, None, Some(bias), Default::default());
let data: Vec<f32> = result.into_data().to_vec().unwrap();
assert!((data[0] - 20.0).abs() < 0.1, "got {}", data[0]);
assert!((data[1] - 20.0).abs() < 0.1, "got {}", data[1]);
}
#[test]
fn test_custom_scale() {
let (q, k, v) = make_qkv(
&[&[1.0, 0.0], &[0.0, 1.0]],
&[&[1.0, 0.0], &[0.0, 1.0]],
&[&[10.0], &[20.0]],
);
let opts = AttentionModuleOptions {
scale: Some(100.0),
..Default::default()
};
let result = burn_tensor::module::attention(q, k, v, None, None, opts);
let data: Vec<f32> = result.into_data().to_vec().unwrap();
assert!((data[0] - 10.0).abs() < 0.1, "got {}", data[0]);
assert!((data[1] - 20.0).abs() < 0.1, "got {}", data[1]);
}
#[test]
fn test_softcap() {
let (q, k, v) = make_qkv(
&[&[1.0, 0.0], &[0.0, 1.0]],
&[&[1.0, 0.0], &[0.0, 1.0]],
&[&[10.0], &[20.0]],
);
let opts = AttentionModuleOptions {
softcap: Some(0.1),
..Default::default()
};
let result = burn_tensor::module::attention(q, k, v, None, None, opts);
let data: Vec<f32> = result.into_data().to_vec().unwrap();
assert!((data[0] - 15.0).abs() < 0.5, "got {}", data[0]);
assert!((data[1] - 15.0).abs() < 0.5, "got {}", data[1]);
}
#[test]
fn test_cross_attention() {
let (q, k, v) = make_qkv(
&[&[1.0, 0.0], &[0.0, 1.0]],
&[&[1.0, 0.0], &[0.0, 1.0], &[0.5, 0.5]],
&[&[10.0], &[20.0], &[30.0]],
);
let result = burn_tensor::module::attention(q, k, v, None, None, Default::default());
let data: Vec<f32> = result.into_data().to_vec().unwrap();
assert_eq!(data.len(), 2);
for &val in &data {
assert!(val >= 9.0 && val <= 31.0, "unexpected value {val}");
}
}
#[test]
fn test_causal_cross_attention() {
let dev = Default::default();
let q: Tensor<Flex, 4> = Tensor::from_data(
TensorData::new(vec![1.0f32, 0.0, 0.0, 1.0], [1, 1, 2, 2]),
&dev,
);
let k: Tensor<Flex, 4> = Tensor::from_data(
TensorData::new(
vec![1.0f32, 0.0, 0.0, 1.0, 0.5, 0.5, 0.5, 0.5],
[1, 1, 4, 2],
),
&dev,
);
let v: Tensor<Flex, 4> = Tensor::from_data(
TensorData::new(vec![10.0f32, 20.0, 30.0, 40.0], [1, 1, 4, 1]),
&dev,
);
let opts = AttentionModuleOptions {
is_causal: true,
..Default::default()
};
let result_causal =
burn_tensor::module::attention(q.clone(), k.clone(), v.clone(), None, None, opts);
let data_causal: Vec<f32> = result_causal.into_data().to_vec().unwrap();
let result_full = burn_tensor::module::attention(q, k, v, None, None, Default::default());
let data_full: Vec<f32> = result_full.into_data().to_vec().unwrap();
assert_eq!(data_causal.len(), 2);
assert!(
data_causal[0] < data_full[0],
"expected causal[0] < full[0], got {} vs {}",
data_causal[0],
data_full[0]
);
assert!(
(data_causal[1] - data_full[1]).abs() < 1e-5,
"expected causal[1] ~= full[1], got {} vs {}",
data_causal[1],
data_full[1]
);
}
#[test]
fn test_all_masked_produces_zeros() {
let (q, k, v) = make_qkv(
&[&[1.0, 0.0], &[0.0, 1.0]],
&[&[1.0, 0.0], &[0.0, 1.0]],
&[&[10.0], &[20.0]],
);
let dev = Default::default();
use burn_tensor::Bool;
let mask: Tensor<Flex, 4, Bool> =
Tensor::from_data(TensorData::from([[[[true, true], [true, true]]]]), &dev);
let result = burn_tensor::module::attention(q, k, v, Some(mask), None, Default::default());
let data: Vec<f32> = result.into_data().to_vec().unwrap();
for (i, &val) in data.iter().enumerate() {
assert!(!val.is_nan(), "output[{i}] is NaN");
assert!((val - 0.0).abs() < 1e-6, "expected 0.0, got {val}");
}
}
#[test]
fn test_multi_batch_multi_head() {
let dev = Default::default();
let q: Tensor<Flex, 4> = Tensor::from_data(
TensorData::new(
vec![
1.0f32, 0.0, 0.0, 1.0, 0.5, 0.5, 0.5, 0.5, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0,
],
[2, 2, 2, 2],
),
&dev,
);
let k = q.clone();
let v: Tensor<Flex, 4> =
Tensor::from_data(TensorData::new(vec![10.0f32; 16], [2, 2, 2, 2]), &dev);
let result = burn_tensor::module::attention(q, k, v, None, None, Default::default());
let data: Vec<f32> = result.into_data().to_vec().unwrap();
assert_eq!(data.len(), 16);
for (i, &val) in data.iter().enumerate() {
assert!(
(val - 10.0).abs() < 1e-4,
"output[{i}] = {val}, expected 10.0"
);
}
}
#[test]
fn test_single_element() {
let (q, k, v) = make_qkv(&[&[1.0, 0.0]], &[&[1.0, 0.0]], &[&[42.0]]);
let result = burn_tensor::module::attention(q, k, v, None, None, Default::default());
let data: Vec<f32> = result.into_data().to_vec().unwrap();
assert_eq!(data.len(), 1);
assert!((data[0] - 42.0).abs() < 1e-5, "got {}", data[0]);
}
#[test]
fn test_multi_tile_seq_kv() {
let dev = <crate::FlexDevice as Default>::default();
let seq_q = 2;
let seq_kv = 128;
let head_dim = 4;
let val_dim = 2;
let mut q_data = vec![0.0f32; seq_q * head_dim];
q_data[0] = 1.0; q_data[head_dim + 1] = 1.0;
let k_data = vec![0.1f32; seq_kv * head_dim];
let mut v_data = vec![0.0f32; seq_kv * val_dim];
for i in 0..seq_kv {
v_data[i * val_dim] = i as f32;
v_data[i * val_dim + 1] = (seq_kv - 1 - i) as f32;
}
let q: Tensor<Flex, 4> =
Tensor::from_data(TensorData::new(q_data, [1, 1, seq_q, head_dim]), &dev);
let k: Tensor<Flex, 4> =
Tensor::from_data(TensorData::new(k_data, [1, 1, seq_kv, head_dim]), &dev);
let v: Tensor<Flex, 4> =
Tensor::from_data(TensorData::new(v_data, [1, 1, seq_kv, val_dim]), &dev);
let result = burn_tensor::module::attention(q, k, v, None, None, Default::default());
let data: Vec<f32> = result.into_data().to_vec().unwrap();
assert_eq!(data.len(), seq_q * val_dim);
assert!(
(data[0] - 63.5).abs() < 0.1,
"expected ~63.5, got {}",
data[0]
);
assert!(
(data[1] - 63.5).abs() < 0.1,
"expected ~63.5, got {}",
data[1]
);
}
#[test]
fn test_multi_tile_causal() {
let dev = <crate::FlexDevice as Default>::default();
let seq_q = 4;
let seq_kv = 128;
let head_dim = 2;
let val_dim = 1;
let mut q_data = vec![0.0f32; seq_q * head_dim];
for i in 0..seq_q {
q_data[i * head_dim] = 1.0;
}
let mut k_data = vec![0.0f32; seq_kv * head_dim];
for i in 0..seq_kv {
k_data[i * head_dim] = 1.0;
}
let v_data: Vec<f32> = (0..seq_kv).map(|i| i as f32).collect();
let q: Tensor<Flex, 4> =
Tensor::from_data(TensorData::new(q_data, [1, 1, seq_q, head_dim]), &dev);
let k: Tensor<Flex, 4> =
Tensor::from_data(TensorData::new(k_data, [1, 1, seq_kv, head_dim]), &dev);
let v: Tensor<Flex, 4> =
Tensor::from_data(TensorData::new(v_data, [1, 1, seq_kv, val_dim]), &dev);
let opts = AttentionModuleOptions {
is_causal: true,
..Default::default()
};
let result = burn_tensor::module::attention(q, k, v, None, None, opts);
let data: Vec<f32> = result.into_data().to_vec().unwrap();
assert_eq!(data.len(), seq_q);
assert!((data[0] - 62.0).abs() < 0.1, "q0: got {}", data[0]);
assert!((data[1] - 62.5).abs() < 0.1, "q1: got {}", data[1]);
assert!((data[2] - 63.0).abs() < 0.1, "q2: got {}", data[2]);
assert!((data[3] - 63.5).abs() < 0.1, "q3: got {}", data[3]);
}
#[test]
fn test_tile_boundary_mask() {
let dev = <crate::FlexDevice as Default>::default();
let seq_q = 1;
let seq_kv = 128;
let head_dim = 2;
let val_dim = 1;
let q_data = vec![1.0f32, 0.0];
let k_data = vec![1.0f32, 0.0].repeat(seq_kv);
let v_data: Vec<f32> = (0..seq_kv).map(|i| i as f32).collect();
let mask_data: Vec<bool> = (0..seq_kv).map(|i| i < 64).collect();
let q: Tensor<Flex, 4> =
Tensor::from_data(TensorData::new(q_data, [1, 1, seq_q, head_dim]), &dev);
let k: Tensor<Flex, 4> =
Tensor::from_data(TensorData::new(k_data, [1, 1, seq_kv, head_dim]), &dev);
let v: Tensor<Flex, 4> =
Tensor::from_data(TensorData::new(v_data, [1, 1, seq_kv, val_dim]), &dev);
use burn_tensor::Bool;
let mask: Tensor<Flex, 4, Bool> =
Tensor::from_data(TensorData::new(mask_data, [1, 1, seq_q, seq_kv]), &dev);
let result = burn_tensor::module::attention(q, k, v, Some(mask), None, Default::default());
let data: Vec<f32> = result.into_data().to_vec().unwrap();
assert!(
(data[0] - 95.5).abs() < 0.1,
"expected ~95.5, got {}",
data[0]
);
}
#[test]
fn test_non_uniform_scores_across_tiles() {
let dev = <crate::FlexDevice as Default>::default();
let seq_q = 1;
let seq_kv = 128;
let head_dim = 1;
let val_dim = 1;
let q_data = vec![1.0f32];
let mut k_data = vec![0.0f32; seq_kv];
for i in 0..64 {
k_data[i] = 0.1;
}
for i in 64..128 {
k_data[i] = 5.0;
}
let mut v_data = vec![0.0f32; seq_kv];
for i in 64..128 {
v_data[i] = 1.0;
}
let q: Tensor<Flex, 4> =
Tensor::from_data(TensorData::new(q_data, [1, 1, seq_q, head_dim]), &dev);
let k: Tensor<Flex, 4> =
Tensor::from_data(TensorData::new(k_data, [1, 1, seq_kv, head_dim]), &dev);
let v: Tensor<Flex, 4> =
Tensor::from_data(TensorData::new(v_data, [1, 1, seq_kv, val_dim]), &dev);
let result = burn_tensor::module::attention(q, k, v, None, None, Default::default());
let data: Vec<f32> = result.into_data().to_vec().unwrap();
assert!(data[0] > 0.99, "expected ~1.0, got {}", data[0]);
}
#[test]
fn test_partial_last_tile() {
let dev = <crate::FlexDevice as Default>::default();
let seq_q = 2;
let seq_kv = 100;
let head_dim = 2;
let val_dim = 1;
let q_data = vec![0.1f32, 0.1].repeat(seq_q);
let k_data = vec![0.1f32, 0.1].repeat(seq_kv);
let v_data: Vec<f32> = (0..seq_kv).map(|i| i as f32).collect();
let q: Tensor<Flex, 4> =
Tensor::from_data(TensorData::new(q_data, [1, 1, seq_q, head_dim]), &dev);
let k: Tensor<Flex, 4> =
Tensor::from_data(TensorData::new(k_data, [1, 1, seq_kv, head_dim]), &dev);
let v: Tensor<Flex, 4> =
Tensor::from_data(TensorData::new(v_data, [1, 1, seq_kv, val_dim]), &dev);
let result = burn_tensor::module::attention(q, k, v, None, None, Default::default());
let data: Vec<f32> = result.into_data().to_vec().unwrap();
assert_eq!(data.len(), seq_q);
assert!(
(data[0] - 49.5).abs() < 0.1,
"expected ~49.5, got {}",
data[0]
);
assert!(
(data[1] - 49.5).abs() < 0.1,
"expected ~49.5, got {}",
data[1]
);
}
#[test]
fn test_naive_matches_flash() {
use crate::Layout;
use burn_backend::ops::AttentionModuleOptions;
use burn_std::{Bytes, Shape};
fn make_tensor(shape: &[usize], dtype: burn_backend::DType) -> crate::FlexTensor {
let len: usize = shape.iter().product();
let layout = Layout::contiguous(Shape::from(shape.to_vec()));
match dtype {
burn_backend::DType::F32 => {
let data: Vec<f32> =
(0..len).map(|i| ((i % 997) as f32 / 997.0) - 0.5).collect();
crate::FlexTensor::new(Bytes::from_elems(data), layout, dtype)
}
burn_backend::DType::Bool(_) => {
let data: Vec<u8> = (0..len)
.map(|i| (i.wrapping_mul(997) % 100 < 30) as u8)
.collect();
crate::FlexTensor::new(Bytes::from_elems(data), layout, dtype)
}
_ => unreachable!(),
}
}
fn run_both(
batch: usize,
heads: usize,
seq_q: usize,
seq_kv: usize,
head_dim: usize,
val_dim: usize,
with_mask: bool,
with_bias: bool,
options: AttentionModuleOptions,
label: &str,
) {
let f32_dt = burn_backend::DType::F32;
let bool_dt = burn_backend::DType::Bool(burn_std::BoolStore::Native);
let q = make_tensor(&[batch, heads, seq_q, head_dim], f32_dt);
let k = make_tensor(&[batch, heads, seq_kv, head_dim], f32_dt);
let v = make_tensor(&[batch, heads, seq_kv, val_dim], f32_dt);
let score_shape = [batch, heads, seq_q, seq_kv];
let mask = with_mask.then(|| make_tensor(&score_shape, bool_dt));
let bias = with_bias.then(|| make_tensor(&score_shape, f32_dt));
let flash = super::attention_flash(
q.clone(),
k.clone(),
v.clone(),
mask.clone(),
bias.clone(),
options,
);
let naive = super::attention_naive(q, k, v, mask, bias, options);
let flash_data: &[f32] = flash.storage();
let naive_data: &[f32] = naive.storage();
assert_eq!(
flash_data.len(),
naive_data.len(),
"{label}: length mismatch"
);
for (i, (&f, &n)) in flash_data.iter().zip(naive_data.iter()).enumerate() {
let diff = (f - n).abs();
let tol = 1e-4 * f.abs().max(n.abs()).max(1.0);
assert!(
diff < tol,
"{label}: position {i}: flash={f} vs naive={n}, diff={diff}"
);
}
}
let default = AttentionModuleOptions::default();
let causal = AttentionModuleOptions {
is_causal: true,
..Default::default()
};
let all_opts = AttentionModuleOptions {
scale: Some(0.05),
softcap: Some(30.0),
is_causal: true,
};
run_both(1, 1, 4, 4, 8, 8, false, false, default, "basic_4x4");
run_both(
2,
4,
8,
8,
16,
16,
false,
false,
default,
"multi_head_batch",
);
run_both(1, 2, 4, 32, 16, 16, false, false, default, "cross_attn");
run_both(1, 1, 4, 128, 16, 16, false, false, default, "multi_tile");
run_both(1, 2, 16, 16, 32, 32, false, false, causal, "causal");
run_both(2, 2, 16, 16, 32, 32, false, false, all_opts, "all_options");
run_both(1, 1, 32, 256, 64, 64, false, false, causal, "large_causal");
run_both(1, 2, 8, 8, 16, 16, true, false, default, "with_mask");
run_both(1, 2, 8, 8, 16, 16, false, true, default, "with_bias");
run_both(
2,
2,
16,
128,
32,
32,
true,
true,
causal,
"mask_bias_causal",
);
run_both(1, 1, 4, 100, 16, 16, false, false, default, "partial_tile");
run_both(
1,
2,
8,
100,
16,
16,
false,
false,
causal,
"partial_tile_causal",
);
}
#[test]
fn test_f64_flash_attention() {
use crate::Layout;
use burn_std::{Bytes, Shape};
let q = crate::FlexTensor::new(
Bytes::from_elems(vec![1.0f64, 0.0, 0.0, 1.0]),
Layout::contiguous(Shape::from(vec![1, 1, 2, 2])),
burn_backend::DType::F64,
);
let k = q.clone();
let v = crate::FlexTensor::new(
Bytes::from_elems(vec![10.0f64, 20.0]),
Layout::contiguous(Shape::from(vec![1, 1, 2, 1])),
burn_backend::DType::F64,
);
let result = super::attention(q, k, v, None, None, Default::default());
let data: &[f64] = result.storage();
assert!((data[0] - 13.30).abs() < 0.1, "got {}", data[0]);
assert!((data[1] - 16.70).abs() < 0.1, "got {}", data[1]);
}
#[test]
fn test_f16_attention() {
use crate::Layout;
use burn_std::{Bytes, Shape, f16};
let q_f16: Vec<f16> = [1.0f32, 0.0, 0.0, 1.0]
.iter()
.map(|&v| f16::from_f32(v))
.collect();
let v_f16: Vec<f16> = [10.0f32, 20.0].iter().map(|&v| f16::from_f32(v)).collect();
let q = crate::FlexTensor::new(
Bytes::from_elems(q_f16.clone()),
Layout::contiguous(Shape::from(vec![1, 1, 2, 2])),
burn_backend::DType::F16,
);
let k = q.clone();
let v = crate::FlexTensor::new(
Bytes::from_elems(v_f16),
Layout::contiguous(Shape::from(vec![1, 1, 2, 1])),
burn_backend::DType::F16,
);
let flash = super::attention_flash(
q.clone(),
k.clone(),
v.clone(),
None,
None,
Default::default(),
);
let naive = super::attention_naive(q, k, v, None, None, Default::default());
let flash_data: &[f16] = flash.storage();
let naive_data: &[f16] = naive.storage();
for (label, data) in [("flash", flash_data), ("naive", naive_data)] {
let r0 = data[0].to_f32();
let r1 = data[1].to_f32();
assert!((r0 - 13.30).abs() < 0.2, "{label} row0: got {r0}");
assert!((r1 - 16.70).abs() < 0.2, "{label} row1: got {r1}");
}
}
}