use alloc::vec;
use alloc::vec::Vec;
use burn_backend::DType;
use burn_backend::ops::AttentionModuleOptions;
use burn_std::Bytes;
use bytemuck::Pod;
use num_traits::Float;
use crate::{FlexTensor, Layout};
#[cfg(target_family = "wasm")]
const TILE_KV: usize = 32;
#[cfg(not(target_family = "wasm"))]
const TILE_KV: usize = 64;
const NAIVE_SCORE_BUDGET: usize = 256 * 1024;
pub fn attention(
query: FlexTensor,
key: FlexTensor,
value: FlexTensor,
mask: Option<FlexTensor>,
attn_bias: Option<FlexTensor>,
options: AttentionModuleOptions,
) -> FlexTensor {
debug_assert!(
query.layout().shape().num_dims() == 4,
"attention: query must be 4D, got {}D",
query.layout().shape().num_dims()
);
debug_assert!(
key.layout().shape().num_dims() == 4,
"attention: key must be 4D, got {}D",
key.layout().shape().num_dims()
);
let seq_q = query.layout().shape()[2];
let seq_kv = key.layout().shape()[2];
if seq_q * seq_kv <= NAIVE_SCORE_BUDGET {
return attention_naive(query, key, value, mask, attn_bias, options);
}
attention_flash(query, key, value, mask, attn_bias, options)
}
macro_rules! dispatch_attention_dtype {
($query:expr, $key:expr, $value:expr, $mask:expr, $attn_bias:expr, $options:expr, $impl_fn:ident) => {{
let query = $query;
let key = $key;
let value = $value;
let mask = $mask;
let attn_bias = $attn_bias;
let options = $options;
let dtype = query.dtype();
debug_assert_eq!(key.dtype(), dtype, "attention: key dtype mismatch");
debug_assert_eq!(value.dtype(), dtype, "attention: value dtype mismatch");
if let Some(ref b) = attn_bias {
debug_assert_eq!(b.dtype(), dtype, "attention: attn_bias dtype mismatch");
}
match dtype {
DType::F32 => $impl_fn::<f32>(query, key, value, mask, attn_bias, options),
DType::F64 => $impl_fn::<f64>(query, key, value, mask, attn_bias, options),
DType::F16 => {
use burn_std::f16;
let r = $impl_fn::<f32>(
cast_to_f32(query, f16::to_f32),
cast_to_f32(key, f16::to_f32),
cast_to_f32(value, f16::to_f32),
mask,
attn_bias.map(|b| cast_to_f32(b, f16::to_f32)),
options,
);
cast_from_f32(r, f16::from_f32)
}
DType::BF16 => {
use burn_std::bf16;
let r = $impl_fn::<f32>(
cast_to_f32(query, bf16::to_f32),
cast_to_f32(key, bf16::to_f32),
cast_to_f32(value, bf16::to_f32),
mask,
attn_bias.map(|b| cast_to_f32(b, bf16::to_f32)),
options,
);
cast_from_f32(r, bf16::from_f32)
}
dtype => panic!("attention: unsupported dtype {:?}", dtype),
}
}};
}
struct BroadcastMaskBias {
tensor: FlexTensor,
batch_step: usize,
head_step: usize,
}
fn broadcast_attn_mask_bias(
tensor: FlexTensor,
target: [usize; 4],
name: &'static str,
) -> BroadcastMaskBias {
let ndim = tensor.layout().shape().num_dims();
assert!(ndim == 4, "attention: {name} must be 4D, got {ndim}D");
let shape = tensor.layout().shape();
let src = [shape[0], shape[1], shape[2], shape[3]];
for i in 0..4 {
assert!(
src[i] == target[i] || src[i] == 1,
"attention: {name} dim {i} must be {} or 1, got {}",
target[i],
src[i]
);
}
let tile_len = target[2] * target[3];
if src[2] != target[2] || src[3] != target[3] {
let expanded = crate::ops::expand::expand(tensor, burn_std::Shape::new(target));
return BroadcastMaskBias {
tensor: expanded.to_contiguous(),
batch_step: target[1] * tile_len,
head_step: tile_len,
};
}
BroadcastMaskBias {
tensor: tensor.to_contiguous(),
batch_step: if src[0] == 1 { 0 } else { src[1] * tile_len },
head_step: if src[1] == 1 { 0 } else { tile_len },
}
}
pub fn attention_flash(
query: FlexTensor,
key: FlexTensor,
value: FlexTensor,
mask: Option<FlexTensor>,
attn_bias: Option<FlexTensor>,
options: AttentionModuleOptions,
) -> FlexTensor {
dispatch_attention_dtype!(query, key, value, mask, attn_bias, options, attention_impl)
}
fn cast_to_f32<E: burn_backend::Element + Pod + Copy>(
tensor: FlexTensor,
to_f32: fn(E) -> f32,
) -> FlexTensor {
let tensor = tensor.to_contiguous();
let shape = tensor.layout().shape().clone();
let data: &[E] = tensor.storage();
let f32_data: Vec<f32> = data.iter().map(|&v| to_f32(v)).collect();
FlexTensor::new(
Bytes::from_elems(f32_data),
Layout::contiguous(shape),
DType::F32,
)
}
fn cast_from_f32<E: burn_backend::Element + Pod + Copy>(
tensor: FlexTensor,
from_f32: fn(f32) -> E,
) -> FlexTensor {
let tensor = tensor.to_contiguous();
let shape = tensor.layout().shape().clone();
let data: &[f32] = tensor.storage();
let half_data: Vec<E> = data.iter().map(|&v| from_f32(v)).collect();
FlexTensor::new(
Bytes::from_elems(half_data),
Layout::contiguous(shape),
E::dtype(),
)
}
fn attention_impl<T>(
query: FlexTensor,
key: FlexTensor,
value: FlexTensor,
mask: Option<FlexTensor>,
attn_bias: Option<FlexTensor>,
options: AttentionModuleOptions,
) -> FlexTensor
where
T: FlashGemm + burn_backend::Element,
{
if let Some(softcap) = options.softcap {
assert!(softcap > 0.0, "softcap must be positive, got {softcap}");
}
let query = query.to_contiguous();
let key = key.to_contiguous();
let value = value.to_contiguous();
let q_shape = query.layout().shape();
let k_shape = key.layout().shape();
let v_shape = value.layout().shape();
assert!(q_shape.num_dims() == 4, "attention: query must be 4D");
assert!(k_shape.num_dims() == 4, "attention: key must be 4D");
assert!(v_shape.num_dims() == 4, "attention: value must be 4D");
let batch = q_shape[0];
let heads = q_shape[1];
let seq_q = q_shape[2];
let head_dim = q_shape[3];
assert!(head_dim > 0, "attention: head_dim must be non-zero");
let seq_kv = k_shape[2];
let val_dim = v_shape[3];
assert_eq!(k_shape[0], batch, "attention: key batch mismatch");
assert_eq!(k_shape[1], heads, "attention: key heads mismatch");
assert_eq!(k_shape[3], head_dim, "attention: key head_dim mismatch");
assert_eq!(v_shape[0], batch, "attention: value batch mismatch");
assert_eq!(v_shape[1], heads, "attention: value heads mismatch");
assert_eq!(v_shape[2], seq_kv, "attention: value seq_kv mismatch");
let target = [batch, heads, seq_q, seq_kv];
let mask_bcast = mask.map(|m| broadcast_attn_mask_bias(m, target, "mask"));
let bias_bcast = attn_bias.map(|b| broadcast_attn_mask_bias(b, target, "bias"));
let scale = T::from(
options
.scale
.unwrap_or_else(|| 1.0 / (head_dim as f64).sqrt()),
)
.unwrap();
let softcap: Option<T> = options.softcap.map(|s| T::from(s).unwrap());
let causal_offset = if options.is_causal {
Some(seq_kv as isize - seq_q as isize)
} else {
None
};
let q_data: &[T] = query.storage();
let k_data: &[T] = key.storage();
let v_data: &[T] = value.storage();
let mask_data: Option<&[u8]> = mask_bcast.as_ref().map(|b| b.tensor.bytes());
let bias_data: Option<&[T]> = bias_bcast.as_ref().map(|b| b.tensor.storage());
let (mask_batch_step, mask_head_step) = mask_bcast
.as_ref()
.map(|b| (b.batch_step, b.head_step))
.unwrap_or((0, 0));
let (bias_batch_step, bias_head_step) = bias_bcast
.as_ref()
.map(|b| (b.batch_step, b.head_step))
.unwrap_or((0, 0));
let mut output = vec![T::zero(); batch * heads * seq_q * val_dim];
let q_head_stride = seq_q * head_dim;
let q_batch_stride = heads * q_head_stride;
let k_head_stride = seq_kv * head_dim;
let k_batch_stride = heads * k_head_stride;
let v_head_stride = seq_kv * val_dim;
let v_batch_stride = heads * v_head_stride;
let o_head_stride = seq_q * val_dim;
let o_batch_stride = heads * o_head_stride;
let mask_tile_len = seq_q * seq_kv;
let params = AttentionParams {
scale,
softcap,
causal_offset,
seq_q,
seq_kv,
head_dim,
val_dim,
};
let mut scratch = ScratchBuffers {
row_max: vec![T::neg_infinity(); seq_q],
row_sum: vec![T::zero(); seq_q],
scores: vec![T::zero(); seq_q * TILE_KV],
};
for b in 0..batch {
for h in 0..heads {
let q_off = b * q_batch_stride + h * q_head_stride;
let k_off = b * k_batch_stride + h * k_head_stride;
let v_off = b * v_batch_stride + h * v_head_stride;
let o_off = b * o_batch_stride + h * o_head_stride;
let mask_off = b * mask_batch_step + h * mask_head_step;
let bias_off = b * bias_batch_step + h * bias_head_step;
flash_attention_head(
&q_data[q_off..q_off + q_head_stride],
&k_data[k_off..k_off + k_head_stride],
&v_data[v_off..v_off + v_head_stride],
&mut output[o_off..o_off + o_head_stride],
mask_data.map(|m| &m[mask_off..mask_off + mask_tile_len]),
bias_data.map(|b| &b[bias_off..bias_off + mask_tile_len]),
¶ms,
&mut scratch,
);
}
}
let shape = burn_std::Shape::from(vec![batch, heads, seq_q, val_dim]);
FlexTensor::new(
Bytes::from_elems(output),
Layout::contiguous(shape),
T::dtype(),
)
}
trait FlashGemm: Float + Pod + Copy + core::ops::AddAssign {
unsafe fn block_gemm(args: BlockGemmArgs<Self>);
}
struct BlockGemmArgs<T> {
m: usize,
n: usize,
k: usize,
dst: *mut T,
dst_cs: isize,
dst_rs: isize,
read_dst: bool,
lhs: *const T,
lhs_cs: isize,
lhs_rs: isize,
rhs: *const T,
rhs_cs: isize,
rhs_rs: isize,
alpha: T,
beta: T,
}
macro_rules! impl_flash_gemm {
($ty:ty) => {
impl FlashGemm for $ty {
unsafe fn block_gemm(a: BlockGemmArgs<Self>) {
unsafe {
gemm::gemm(
a.m,
a.n,
a.k,
a.dst,
a.dst_cs,
a.dst_rs,
a.read_dst,
a.lhs,
a.lhs_cs,
a.lhs_rs,
a.rhs,
a.rhs_cs,
a.rhs_rs,
a.alpha,
a.beta,
false,
false,
false,
gemm::Parallelism::None,
);
}
}
}
};
}
impl_flash_gemm!(f32);
impl_flash_gemm!(f64);
struct ScratchBuffers<T> {
row_max: Vec<T>,
row_sum: Vec<T>,
scores: Vec<T>,
}
struct AttentionParams<T> {
scale: T,
softcap: Option<T>,
causal_offset: Option<isize>,
seq_q: usize,
seq_kv: usize,
head_dim: usize,
val_dim: usize,
}
#[allow(clippy::too_many_arguments)]
fn flash_attention_head<T: FlashGemm>(
q: &[T],
k: &[T],
v: &[T],
output: &mut [T],
mask: Option<&[u8]>,
bias: Option<&[T]>,
p: &AttentionParams<T>,
scratch: &mut ScratchBuffers<T>,
) {
debug_assert_eq!(q.len(), p.seq_q * p.head_dim);
debug_assert_eq!(k.len(), p.seq_kv * p.head_dim);
debug_assert_eq!(v.len(), p.seq_kv * p.val_dim);
debug_assert_eq!(output.len(), p.seq_q * p.val_dim);
let neg_inf = T::neg_infinity();
let AttentionParams {
scale,
softcap,
causal_offset,
seq_q,
seq_kv,
head_dim,
val_dim,
} = *p;
let row_max = &mut scratch.row_max;
row_max.fill(neg_inf);
let row_sum = &mut scratch.row_sum;
row_sum.fill(T::zero());
let scores = &mut scratch.scores;
let num_kv_tiles = seq_kv.div_ceil(TILE_KV);
for tile_idx in 0..num_kv_tiles {
let kv_start = tile_idx * TILE_KV;
let kv_end = (kv_start + TILE_KV).min(seq_kv);
let tile_kv = kv_end - kv_start;
unsafe {
T::block_gemm(BlockGemmArgs {
m: seq_q,
n: tile_kv,
k: head_dim,
dst: scores.as_mut_ptr(),
dst_cs: 1,
dst_rs: tile_kv as isize,
read_dst: false,
lhs: q.as_ptr(),
lhs_cs: 1,
lhs_rs: head_dim as isize,
rhs: k.as_ptr().add(kv_start * head_dim),
rhs_cs: head_dim as isize,
rhs_rs: 1,
alpha: T::zero(),
beta: T::one(),
});
}
for qi in 0..seq_q {
let score_row = &mut scores[qi * tile_kv..(qi + 1) * tile_kv];
let mut tile_max = neg_inf;
for (ki, score) in score_row.iter_mut().enumerate() {
let kv_idx = kv_start + ki;
let mut val = *score * scale;
if let Some(cap) = softcap {
val = cap * (val / cap).tanh();
}
if let Some(m) = mask
&& m[qi * seq_kv + kv_idx] != 0
{
val = neg_inf;
}
if let Some(offset) = causal_offset
&& (kv_idx as isize) > (qi as isize) + offset
{
val = neg_inf;
}
if let Some(b) = bias {
val += b[qi * seq_kv + kv_idx];
}
*score = val;
if val > tile_max {
tile_max = val;
}
}
if tile_max == neg_inf {
for score in score_row.iter_mut() {
*score = T::zero();
}
continue;
}
let new_max = if row_max[qi] > tile_max {
row_max[qi]
} else {
tile_max
};
let mut tile_sum = T::zero();
for score in score_row.iter_mut() {
let e = (*score - new_max).exp();
*score = e;
tile_sum += e;
}
let correction = if row_max[qi] == neg_inf {
T::zero()
} else {
(row_max[qi] - new_max).exp()
};
let out_row = &mut output[qi * val_dim..(qi + 1) * val_dim];
for o in out_row.iter_mut() {
*o = *o * correction;
}
row_sum[qi] = row_sum[qi] * correction + tile_sum;
row_max[qi] = new_max;
}
unsafe {
T::block_gemm(BlockGemmArgs {
m: seq_q,
n: val_dim,
k: tile_kv,
dst: output.as_mut_ptr(),
dst_cs: 1,
dst_rs: val_dim as isize,
read_dst: true,
lhs: scores.as_ptr(),
lhs_cs: 1,
lhs_rs: tile_kv as isize,
rhs: v.as_ptr().add(kv_start * val_dim),
rhs_cs: 1,
rhs_rs: val_dim as isize,
alpha: T::one(),
beta: T::one(),
});
}
}
for qi in 0..seq_q {
let sum = row_sum[qi];
if sum > T::zero() {
let inv_sum = T::one() / sum;
let out_row = &mut output[qi * val_dim..(qi + 1) * val_dim];
for o in out_row.iter_mut() {
*o = *o * inv_sum;
}
}
}
}
pub fn attention_naive(
query: FlexTensor,
key: FlexTensor,
value: FlexTensor,
mask: Option<FlexTensor>,
attn_bias: Option<FlexTensor>,
options: AttentionModuleOptions,
) -> FlexTensor {
dispatch_attention_dtype!(
query,
key,
value,
mask,
attn_bias,
options,
attention_naive_impl
)
}
fn attention_naive_impl<T>(
query: FlexTensor,
key: FlexTensor,
value: FlexTensor,
mask: Option<FlexTensor>,
attn_bias: Option<FlexTensor>,
options: AttentionModuleOptions,
) -> FlexTensor
where
T: FlashGemm + burn_backend::Element,
{
if let Some(softcap) = options.softcap {
assert!(softcap > 0.0, "softcap must be positive, got {softcap}");
}
let query = query.to_contiguous();
let key = key.to_contiguous();
let value = value.to_contiguous();
let q_shape = query.layout().shape();
let k_shape = key.layout().shape();
let v_shape = value.layout().shape();
assert!(q_shape.num_dims() == 4, "attention_naive: query must be 4D");
assert!(k_shape.num_dims() == 4, "attention_naive: key must be 4D");
assert!(v_shape.num_dims() == 4, "attention_naive: value must be 4D");
let batch = q_shape[0];
let heads = q_shape[1];
let seq_q = q_shape[2];
let head_dim = q_shape[3];
assert!(head_dim > 0, "attention_naive: head_dim must be non-zero");
let seq_kv = k_shape[2];
let val_dim = v_shape[3];
assert_eq!(k_shape[0], batch, "attention_naive: key batch mismatch");
assert_eq!(k_shape[1], heads, "attention_naive: key heads mismatch");
assert_eq!(
k_shape[3], head_dim,
"attention_naive: key head_dim mismatch"
);
assert_eq!(v_shape[0], batch, "attention_naive: value batch mismatch");
assert_eq!(v_shape[1], heads, "attention_naive: value heads mismatch");
assert_eq!(v_shape[2], seq_kv, "attention_naive: value seq_kv mismatch");
let target = [batch, heads, seq_q, seq_kv];
let mask_bcast = mask.map(|m| broadcast_attn_mask_bias(m, target, "mask"));
let bias_bcast = attn_bias.map(|b| broadcast_attn_mask_bias(b, target, "bias"));
let scale = T::from(
options
.scale
.unwrap_or_else(|| 1.0 / (head_dim as f64).sqrt()),
)
.unwrap();
let softcap: Option<T> = options.softcap.map(|s| T::from(s).unwrap());
let causal_offset = if options.is_causal {
Some(seq_kv as isize - seq_q as isize)
} else {
None
};
let q_data: &[T] = query.storage();
let k_data: &[T] = key.storage();
let v_data: &[T] = value.storage();
let mask_data: Option<&[u8]> = mask_bcast.as_ref().map(|b| b.tensor.bytes());
let bias_data: Option<&[T]> = bias_bcast.as_ref().map(|b| b.tensor.storage());
let (mask_batch_step, mask_head_step) = mask_bcast
.as_ref()
.map(|b| (b.batch_step, b.head_step))
.unwrap_or((0, 0));
let (bias_batch_step, bias_head_step) = bias_bcast
.as_ref()
.map(|b| (b.batch_step, b.head_step))
.unwrap_or((0, 0));
let mut output = vec![T::zero(); batch * heads * seq_q * val_dim];
let mut scores = vec![T::zero(); seq_q * seq_kv];
let q_head_stride = seq_q * head_dim;
let q_batch_stride = heads * q_head_stride;
let k_head_stride = seq_kv * head_dim;
let k_batch_stride = heads * k_head_stride;
let v_head_stride = seq_kv * val_dim;
let v_batch_stride = heads * v_head_stride;
let o_head_stride = seq_q * val_dim;
let o_batch_stride = heads * o_head_stride;
let mask_tile_len = seq_q * seq_kv;
let params = AttentionParams {
scale,
softcap,
causal_offset,
seq_q,
seq_kv,
head_dim,
val_dim,
};
for b in 0..batch {
for h in 0..heads {
let q_off = b * q_batch_stride + h * q_head_stride;
let k_off = b * k_batch_stride + h * k_head_stride;
let v_off = b * v_batch_stride + h * v_head_stride;
let o_off = b * o_batch_stride + h * o_head_stride;
let mask_off = b * mask_batch_step + h * mask_head_step;
let bias_off = b * bias_batch_step + h * bias_head_step;
naive_attention_head(
&q_data[q_off..q_off + q_head_stride],
&k_data[k_off..k_off + k_head_stride],
&v_data[v_off..v_off + v_head_stride],
&mut output[o_off..o_off + o_head_stride],
&mut scores,
¶ms,
(
mask_data.map(|m| &m[mask_off..mask_off + mask_tile_len]),
bias_data.map(|b| &b[bias_off..bias_off + mask_tile_len]),
),
);
}
}
let shape = burn_std::Shape::from(vec![batch, heads, seq_q, val_dim]);
FlexTensor::new(
Bytes::from_elems(output),
Layout::contiguous(shape),
T::dtype(),
)
}
fn naive_attention_head<T: FlashGemm>(
q: &[T],
k: &[T],
v: &[T],
output: &mut [T],
scores: &mut [T],
p: &AttentionParams<T>,
mask_bias: (Option<&[u8]>, Option<&[T]>),
) {
let (mask, bias) = mask_bias;
let neg_inf = T::neg_infinity();
let AttentionParams {
scale,
softcap,
causal_offset,
seq_q,
seq_kv,
head_dim,
val_dim,
} = *p;
unsafe {
T::block_gemm(BlockGemmArgs {
m: seq_q,
n: seq_kv,
k: head_dim,
dst: scores.as_mut_ptr(),
dst_cs: 1,
dst_rs: seq_kv as isize,
read_dst: false,
lhs: q.as_ptr(),
lhs_cs: 1,
lhs_rs: head_dim as isize,
rhs: k.as_ptr(),
rhs_cs: head_dim as isize,
rhs_rs: 1,
alpha: T::zero(),
beta: T::one(),
});
}
for qi in 0..seq_q {
let row = &mut scores[qi * seq_kv..(qi + 1) * seq_kv];
let mut row_max = neg_inf;
for (ki, s) in row.iter_mut().enumerate() {
let mut val = *s * scale;
if let Some(cap) = softcap {
val = cap * (val / cap).tanh();
}
if let Some(m) = mask
&& m[qi * seq_kv + ki] != 0
{
val = neg_inf;
}
if let Some(offset) = causal_offset
&& (ki as isize) > (qi as isize) + offset
{
val = neg_inf;
}
if let Some(b) = bias {
val += b[qi * seq_kv + ki];
}
*s = val;
if val > row_max {
row_max = val;
}
}
if row_max == neg_inf {
row.fill(T::zero());
continue;
}
let mut sum = T::zero();
for s in row.iter_mut() {
let e = (*s - row_max).exp();
*s = e;
sum += e;
}
let inv_sum = T::one() / sum;
for s in row.iter_mut() {
*s = *s * inv_sum;
}
}
unsafe {
T::block_gemm(BlockGemmArgs {
m: seq_q,
n: val_dim,
k: seq_kv,
dst: output.as_mut_ptr(),
dst_cs: 1,
dst_rs: val_dim as isize,
read_dst: false,
lhs: scores.as_ptr(),
lhs_cs: 1,
lhs_rs: seq_kv as isize,
rhs: v.as_ptr(),
rhs_cs: 1,
rhs_rs: val_dim as isize,
alpha: T::zero(),
beta: T::one(),
});
}
}
#[cfg(test)]
mod tests {
use alloc::{vec, vec::Vec};
use burn_backend::DType;
use burn_backend::ops::AttentionModuleOptions;
use burn_std::{BoolStore, Bytes, Shape, f16};
use crate::{FlexTensor, Layout};
fn flex_f32(data: Vec<f32>, shape: &[usize]) -> FlexTensor {
FlexTensor::new(
Bytes::from_elems(data),
Layout::contiguous(Shape::from(shape.to_vec())),
DType::F32,
)
}
fn flex_f64(data: Vec<f64>, shape: &[usize]) -> FlexTensor {
FlexTensor::new(
Bytes::from_elems(data),
Layout::contiguous(Shape::from(shape.to_vec())),
DType::F64,
)
}
fn flex_f16(data: Vec<f16>, shape: &[usize]) -> FlexTensor {
FlexTensor::new(
Bytes::from_elems(data),
Layout::contiguous(Shape::from(shape.to_vec())),
DType::F16,
)
}
fn flex_bool(data: Vec<u8>, shape: &[usize]) -> FlexTensor {
FlexTensor::new(
Bytes::from_elems(data),
Layout::contiguous(Shape::from(shape.to_vec())),
DType::Bool(BoolStore::Native),
)
}
fn assert_attention_outputs_close(bcast: &[f32], full: &[f32], label: &str) {
assert_eq!(bcast.len(), full.len(), "{label}: length mismatch");
for (i, (&a, &b)) in bcast.iter().zip(full).enumerate() {
assert!((a - b).abs() < 1e-5, "{label} mismatch at {i}: {a} vs {b}");
}
}
#[test]
fn test_flash_bias_broadcast_across_batch_and_heads() {
let batch = 2;
let heads = 2;
let seq_q = 3;
let seq_kv = 5;
let head_dim = 4;
let mk = |shape: &[usize], g: &dyn Fn(usize) -> f32| -> FlexTensor {
let len: usize = shape.iter().product();
flex_f32((0..len).map(g).collect(), shape)
};
let q = mk(&[batch, heads, seq_q, head_dim], &|i| {
(i as f32 * 0.1).sin()
});
let k = mk(&[batch, heads, seq_kv, head_dim], &|i| {
(i as f32 * 0.1 + 1.0).sin()
});
let v = mk(&[batch, heads, seq_kv, head_dim], &|i| {
(i as f32 * 0.1 + 2.0).sin()
});
let bias_tile: Vec<f32> = (0..seq_q * seq_kv)
.map(|i| (i as f32 * 0.4).sin())
.collect();
let bias_bcast = flex_f32(bias_tile.clone(), &[1, 1, seq_q, seq_kv]);
let bias_full_vec: Vec<f32> = bias_tile
.iter()
.cloned()
.cycle()
.take(batch * heads * seq_q * seq_kv)
.collect();
let bias_full = flex_f32(bias_full_vec, &[batch, heads, seq_q, seq_kv]);
let out_bcast = super::attention_flash(
q.clone(),
k.clone(),
v.clone(),
None,
Some(bias_bcast),
Default::default(),
);
let out_full = super::attention_flash(q, k, v, None, Some(bias_full), Default::default());
let bcast: &[f32] = out_bcast.storage();
let full: &[f32] = out_full.storage();
assert_attention_outputs_close(bcast, full, "flash bias[1,1,sq,skv]");
}
#[test]
fn test_flash_bool_mask_broadcast_across_batch_and_heads() {
let batch = 2;
let heads = 2;
let seq_q = 3;
let seq_kv = 5;
let head_dim = 4;
let mk = |shape: &[usize], g: &dyn Fn(usize) -> f32| -> FlexTensor {
let len: usize = shape.iter().product();
flex_f32((0..len).map(g).collect(), shape)
};
let q = mk(&[batch, heads, seq_q, head_dim], &|i| {
(i as f32 * 0.1).sin()
});
let k = mk(&[batch, heads, seq_kv, head_dim], &|i| {
(i as f32 * 0.1 + 1.0).sin()
});
let v = mk(&[batch, heads, seq_kv, head_dim], &|i| {
(i as f32 * 0.1 + 2.0).sin()
});
let mask_tile: Vec<u8> = (0..seq_q * seq_kv)
.map(|i| if (i % seq_kv) >= 3 { 1u8 } else { 0u8 })
.collect();
let mask_bcast = flex_bool(mask_tile.clone(), &[1, 1, seq_q, seq_kv]);
let mask_full_vec: Vec<u8> = mask_tile
.iter()
.copied()
.cycle()
.take(batch * heads * seq_q * seq_kv)
.collect();
let mask_full = flex_bool(mask_full_vec, &[batch, heads, seq_q, seq_kv]);
let out_bcast = super::attention_flash(
q.clone(),
k.clone(),
v.clone(),
Some(mask_bcast),
None,
Default::default(),
);
let out_full = super::attention_flash(q, k, v, Some(mask_full), None, Default::default());
let bcast: &[f32] = out_bcast.storage();
let full: &[f32] = out_full.storage();
assert_attention_outputs_close(bcast, full, "flash bool mask[1,1,sq,skv]");
}
#[test]
#[should_panic(expected = "must be 4D")]
fn test_mask_wrong_rank_panics() {
let mask = flex_bool(vec![0u8; 6], &[2, 3]);
super::broadcast_attn_mask_bias(mask, [1, 1, 2, 3], "mask");
}
#[test]
#[should_panic(expected = "bias dim 1 must be 3 or 1, got 2")]
fn test_bias_incompatible_dim_panics() {
let bias = flex_f32(vec![0.0f32; 2 * 2 * 4 * 5], &[2, 2, 4, 5]);
super::broadcast_attn_mask_bias(bias, [2, 3, 4, 5], "bias");
}
#[test]
fn test_multi_tile_seq_kv() {
let seq_q = 2;
let seq_kv = 128;
let head_dim = 4;
let val_dim = 2;
let mut q_data = vec![0.0f32; seq_q * head_dim];
q_data[0] = 1.0;
q_data[head_dim + 1] = 1.0;
let k_data = vec![0.1f32; seq_kv * head_dim];
let mut v_data = vec![0.0f32; seq_kv * val_dim];
for i in 0..seq_kv {
v_data[i * val_dim] = i as f32;
v_data[i * val_dim + 1] = (seq_kv - 1 - i) as f32;
}
let q = flex_f32(q_data, &[1, 1, seq_q, head_dim]);
let k = flex_f32(k_data, &[1, 1, seq_kv, head_dim]);
let v = flex_f32(v_data, &[1, 1, seq_kv, val_dim]);
let result = super::attention(q, k, v, None, None, Default::default());
let data: &[f32] = result.storage();
assert_eq!(data.len(), seq_q * val_dim);
assert!(
(data[0] - 63.5).abs() < 0.1,
"expected ~63.5, got {}",
data[0]
);
assert!(
(data[1] - 63.5).abs() < 0.1,
"expected ~63.5, got {}",
data[1]
);
}
#[test]
fn test_multi_tile_causal() {
let seq_q = 4;
let seq_kv = 128;
let head_dim = 2;
let val_dim = 1;
let mut q_data = vec![0.0f32; seq_q * head_dim];
for i in 0..seq_q {
q_data[i * head_dim] = 1.0;
}
let mut k_data = vec![0.0f32; seq_kv * head_dim];
for i in 0..seq_kv {
k_data[i * head_dim] = 1.0;
}
let v_data: Vec<f32> = (0..seq_kv).map(|i| i as f32).collect();
let q = flex_f32(q_data, &[1, 1, seq_q, head_dim]);
let k = flex_f32(k_data, &[1, 1, seq_kv, head_dim]);
let v = flex_f32(v_data, &[1, 1, seq_kv, val_dim]);
let opts = AttentionModuleOptions {
is_causal: true,
..Default::default()
};
let result = super::attention(q, k, v, None, None, opts);
let data: &[f32] = result.storage();
assert_eq!(data.len(), seq_q);
assert!((data[0] - 62.0).abs() < 0.1, "q0: got {}", data[0]);
assert!((data[1] - 62.5).abs() < 0.1, "q1: got {}", data[1]);
assert!((data[2] - 63.0).abs() < 0.1, "q2: got {}", data[2]);
assert!((data[3] - 63.5).abs() < 0.1, "q3: got {}", data[3]);
}
#[test]
fn test_tile_boundary_mask() {
let seq_q = 1;
let seq_kv = 128;
let head_dim = 2;
let val_dim = 1;
let q_data = vec![1.0f32, 0.0];
let k_data = vec![1.0f32, 0.0].repeat(seq_kv);
let v_data: Vec<f32> = (0..seq_kv).map(|i| i as f32).collect();
let mask_data: Vec<u8> = (0..seq_kv).map(|i| (i < 64) as u8).collect();
let q = flex_f32(q_data, &[1, 1, seq_q, head_dim]);
let k = flex_f32(k_data, &[1, 1, seq_kv, head_dim]);
let v = flex_f32(v_data, &[1, 1, seq_kv, val_dim]);
let mask = flex_bool(mask_data, &[1, 1, seq_q, seq_kv]);
let result = super::attention(q, k, v, Some(mask), None, Default::default());
let data: &[f32] = result.storage();
assert!(
(data[0] - 95.5).abs() < 0.1,
"expected ~95.5, got {}",
data[0]
);
}
#[test]
fn test_non_uniform_scores_across_tiles() {
let seq_q = 1;
let seq_kv = 128;
let head_dim = 1;
let val_dim = 1;
let q_data = vec![1.0f32];
let mut k_data = vec![0.0f32; seq_kv];
for k in k_data.iter_mut().take(64) {
*k = 0.1;
}
for k in k_data.iter_mut().skip(64) {
*k = 5.0;
}
let mut v_data = vec![0.0f32; seq_kv];
for v in v_data.iter_mut().skip(64) {
*v = 1.0;
}
let q = flex_f32(q_data, &[1, 1, seq_q, head_dim]);
let k = flex_f32(k_data, &[1, 1, seq_kv, head_dim]);
let v = flex_f32(v_data, &[1, 1, seq_kv, val_dim]);
let result = super::attention(q, k, v, None, None, Default::default());
let data: &[f32] = result.storage();
assert!(data[0] > 0.99, "expected ~1.0, got {}", data[0]);
}
#[test]
fn test_partial_last_tile() {
let seq_q = 2;
let seq_kv = 100;
let head_dim = 2;
let val_dim = 1;
let q_data = vec![0.1f32, 0.1].repeat(seq_q);
let k_data = vec![0.1f32, 0.1].repeat(seq_kv);
let v_data: Vec<f32> = (0..seq_kv).map(|i| i as f32).collect();
let q = flex_f32(q_data, &[1, 1, seq_q, head_dim]);
let k = flex_f32(k_data, &[1, 1, seq_kv, head_dim]);
let v = flex_f32(v_data, &[1, 1, seq_kv, val_dim]);
let result = super::attention(q, k, v, None, None, Default::default());
let data: &[f32] = result.storage();
assert_eq!(data.len(), seq_q);
assert!(
(data[0] - 49.5).abs() < 0.1,
"expected ~49.5, got {}",
data[0]
);
assert!(
(data[1] - 49.5).abs() < 0.1,
"expected ~49.5, got {}",
data[1]
);
}
#[test]
fn test_naive_matches_flash() {
fn make_tensor(shape: &[usize], dtype: DType) -> FlexTensor {
let len: usize = shape.iter().product();
match dtype {
DType::F32 => {
let data: Vec<f32> =
(0..len).map(|i| ((i % 997) as f32 / 997.0) - 0.5).collect();
flex_f32(data, shape)
}
DType::Bool(_) => {
let data: Vec<u8> = (0..len)
.map(|i| (i.wrapping_mul(997) % 100 < 30) as u8)
.collect();
flex_bool(data, shape)
}
_ => unreachable!(),
}
}
fn run_both(
batch: usize,
heads: usize,
seq_q: usize,
seq_kv: usize,
head_dim: usize,
val_dim: usize,
with_mask: bool,
with_bias: bool,
options: AttentionModuleOptions,
label: &str,
) {
let f32_dt = DType::F32;
let bool_dt = DType::Bool(BoolStore::Native);
let q = make_tensor(&[batch, heads, seq_q, head_dim], f32_dt);
let k = make_tensor(&[batch, heads, seq_kv, head_dim], f32_dt);
let v = make_tensor(&[batch, heads, seq_kv, val_dim], f32_dt);
let score_shape = [batch, heads, seq_q, seq_kv];
let mask = with_mask.then(|| make_tensor(&score_shape, bool_dt));
let bias = with_bias.then(|| make_tensor(&score_shape, f32_dt));
let flash = super::attention_flash(
q.clone(),
k.clone(),
v.clone(),
mask.clone(),
bias.clone(),
options,
);
let naive = super::attention_naive(q, k, v, mask, bias, options);
let flash_data: &[f32] = flash.storage();
let naive_data: &[f32] = naive.storage();
assert_eq!(
flash_data.len(),
naive_data.len(),
"{label}: length mismatch"
);
for (i, (&f, &n)) in flash_data.iter().zip(naive_data.iter()).enumerate() {
let diff = (f - n).abs();
let tol = 1e-4 * f.abs().max(n.abs()).max(1.0);
assert!(
diff < tol,
"{label}: position {i}: flash={f} vs naive={n}, diff={diff}"
);
}
}
let default = AttentionModuleOptions::default();
let causal = AttentionModuleOptions {
is_causal: true,
..Default::default()
};
let all_opts = AttentionModuleOptions {
scale: Some(0.05),
softcap: Some(30.0),
is_causal: true,
};
run_both(1, 1, 4, 4, 8, 8, false, false, default, "basic_4x4");
run_both(
2,
4,
8,
8,
16,
16,
false,
false,
default,
"multi_head_batch",
);
run_both(1, 2, 4, 32, 16, 16, false, false, default, "cross_attn");
run_both(1, 1, 4, 128, 16, 16, false, false, default, "multi_tile");
run_both(1, 2, 16, 16, 32, 32, false, false, causal, "causal");
run_both(2, 2, 16, 16, 32, 32, false, false, all_opts, "all_options");
run_both(1, 1, 32, 256, 64, 64, false, false, causal, "large_causal");
run_both(1, 2, 8, 8, 16, 16, true, false, default, "with_mask");
run_both(1, 2, 8, 8, 16, 16, false, true, default, "with_bias");
run_both(
2,
2,
16,
128,
32,
32,
true,
true,
causal,
"mask_bias_causal",
);
run_both(1, 1, 4, 100, 16, 16, false, false, default, "partial_tile");
run_both(
1,
2,
8,
100,
16,
16,
false,
false,
causal,
"partial_tile_causal",
);
}
#[test]
fn test_f64_flash_attention() {
let q = flex_f64(vec![1.0f64, 0.0, 0.0, 1.0], &[1, 1, 2, 2]);
let k = q.clone();
let v = flex_f64(vec![10.0f64, 20.0], &[1, 1, 2, 1]);
let result = super::attention(q, k, v, None, None, Default::default());
let data: &[f64] = result.storage();
assert!((data[0] - 13.30).abs() < 0.1, "got {}", data[0]);
assert!((data[1] - 16.70).abs() < 0.1, "got {}", data[1]);
}
#[test]
fn test_f16_attention() {
let q_vec: Vec<f16> = [1.0f32, 0.0, 0.0, 1.0]
.iter()
.map(|&v| f16::from_f32(v))
.collect();
let v_vec: Vec<f16> = [10.0f32, 20.0].iter().map(|&v| f16::from_f32(v)).collect();
let q = flex_f16(q_vec, &[1, 1, 2, 2]);
let k = q.clone();
let v = flex_f16(v_vec, &[1, 1, 2, 1]);
let flash = super::attention_flash(
q.clone(),
k.clone(),
v.clone(),
None,
None,
Default::default(),
);
let naive = super::attention_naive(q, k, v, None, None, Default::default());
let flash_data: &[f16] = flash.storage();
let naive_data: &[f16] = naive.storage();
for (label, data) in [("flash", flash_data), ("naive", naive_data)] {
let r0 = data[0].to_f32();
let r1 = data[1].to_f32();
assert!((r0 - 13.30).abs() < 0.2, "{label} row0: got {r0}");
assert!((r1 - 16.70).abs() < 0.2, "{label} row1: got {r1}");
}
}
}