use alloc::vec;
use alloc::vec::Vec;
use burn_backend::DType;
use burn_backend::ops::AttentionModuleOptions;
use burn_std::Bytes;
use bytemuck::Pod;
use num_traits::Float;
use crate::{FlexTensor, Layout};
pub fn attention(
query: FlexTensor,
key: FlexTensor,
value: FlexTensor,
mask: Option<FlexTensor>,
attn_bias: Option<FlexTensor>,
options: AttentionModuleOptions,
) -> FlexTensor {
let dtype = query.dtype();
debug_assert_eq!(key.dtype(), dtype, "attention: key dtype mismatch");
debug_assert_eq!(value.dtype(), dtype, "attention: value dtype mismatch");
if let Some(ref b) = attn_bias {
debug_assert_eq!(b.dtype(), dtype, "attention: attn_bias dtype mismatch");
}
match dtype {
DType::F32 => attention_impl::<f32>(query, key, value, mask, attn_bias, options),
DType::F64 => attention_impl::<f64>(query, key, value, mask, attn_bias, options),
DType::F16 => {
use burn_std::f16;
let result = attention_impl::<f32>(
cast_to_f32(query, f16::to_f32),
cast_to_f32(key, f16::to_f32),
cast_to_f32(value, f16::to_f32),
mask,
attn_bias.map(|b| cast_to_f32(b, f16::to_f32)),
options,
);
cast_from_f32(result, f16::from_f32)
}
DType::BF16 => {
use burn_std::bf16;
let result = attention_impl::<f32>(
cast_to_f32(query, bf16::to_f32),
cast_to_f32(key, bf16::to_f32),
cast_to_f32(value, bf16::to_f32),
mask,
attn_bias.map(|b| cast_to_f32(b, bf16::to_f32)),
options,
);
cast_from_f32(result, bf16::from_f32)
}
dtype => panic!("attention: unsupported dtype {:?}", dtype),
}
}
fn cast_to_f32<E: burn_backend::Element + Pod + Copy>(
tensor: FlexTensor,
to_f32: fn(E) -> f32,
) -> FlexTensor {
let tensor = tensor.to_contiguous();
let shape = tensor.layout().shape().clone();
let data: &[E] = tensor.storage();
let f32_data: Vec<f32> = data.iter().map(|&v| to_f32(v)).collect();
FlexTensor::new(
Bytes::from_elems(f32_data),
Layout::contiguous(shape),
DType::F32,
)
}
fn cast_from_f32<E: burn_backend::Element + Pod + Copy>(
tensor: FlexTensor,
from_f32: fn(f32) -> E,
) -> FlexTensor {
let tensor = tensor.to_contiguous();
let shape = tensor.layout().shape().clone();
let data: &[f32] = tensor.storage();
let half_data: Vec<E> = data.iter().map(|&v| from_f32(v)).collect();
FlexTensor::new(
Bytes::from_elems(half_data),
Layout::contiguous(shape),
E::dtype(),
)
}
fn attention_impl<T>(
query: FlexTensor,
key: FlexTensor,
value: FlexTensor,
mask: Option<FlexTensor>,
attn_bias: Option<FlexTensor>,
options: AttentionModuleOptions,
) -> FlexTensor
where
T: Float + Pod + Copy + burn_backend::Element,
T: core::ops::AddAssign,
{
if let Some(softcap) = options.softcap {
assert!(softcap > 0.0, "softcap must be positive, got {softcap}");
}
let q_shape = query.layout().shape();
let ndims = q_shape.num_dims();
let head_dim = q_shape[ndims - 1];
let transposed_key = key.transpose(ndims - 2, ndims - 1);
let scores = crate::ops::matmul::matmul(query, transposed_key);
let weights = fused_softmax::<T>(
scores,
mask.as_ref(),
attn_bias.as_ref(),
&options,
head_dim,
);
crate::ops::matmul::matmul(weights, value)
}
fn fused_softmax<T>(
scores: FlexTensor,
mask: Option<&FlexTensor>,
attn_bias: Option<&FlexTensor>,
options: &AttentionModuleOptions,
head_dim: usize,
) -> FlexTensor
where
T: Float + Pod + Copy + burn_backend::Element,
T: core::ops::AddAssign,
{
let scores = scores.to_contiguous();
let shape = scores.layout().shape().clone();
let ndims = shape.num_dims();
assert!(ndims >= 2, "scores must be at least 2D");
let seq_q = shape[ndims - 2];
let seq_k = shape[ndims - 1];
let num_rows_total: usize = shape[..ndims - 2].iter().product::<usize>() * seq_q;
let scores_data: &[T] = scores.storage();
let scores_numel = scores_data.len();
let mask_tensor = mask.map(|m| {
let m = m.to_contiguous();
debug_assert_eq!(
m.layout().num_elements(),
scores_numel,
"attention: mask shape must match scores shape"
);
m
});
let mask_data: Option<&[u8]> = mask_tensor.as_ref().map(|m| m.bytes());
let bias_tensor = attn_bias.map(|b| {
let b = b.to_contiguous();
debug_assert_eq!(
b.layout().num_elements(),
scores_numel,
"attention: attn_bias shape must match scores shape"
);
b
});
let bias_data: Option<&[T]> = bias_tensor.as_ref().map(|b| b.storage());
let scale = T::from(
options
.scale
.unwrap_or_else(|| 1.0 / (head_dim as f64).sqrt()),
)
.unwrap();
let softcap: Option<T> = options.softcap.map(|s| T::from(s).unwrap());
let neg_inf = T::neg_infinity();
let causal_offset = if options.is_causal {
Some(seq_k as isize - seq_q as isize)
} else {
None
};
let params = SoftmaxParams {
scale,
softcap,
neg_inf,
causal_offset,
};
let mut output = vec![T::zero(); scores_data.len()];
for row_idx in 0..num_rows_total {
let q_pos = row_idx % seq_q;
let row_start = row_idx * seq_k;
let scores_row = &scores_data[row_start..row_start + seq_k];
let out_row = &mut output[row_start..row_start + seq_k];
let mask_row = mask_data.map(|m| &m[row_start..row_start + seq_k]);
let bias_row = bias_data.map(|b| &b[row_start..row_start + seq_k]);
fused_softmax_row(scores_row, out_row, mask_row, bias_row, ¶ms, q_pos);
}
FlexTensor::new(
Bytes::from_elems(output),
Layout::contiguous(shape),
T::dtype(),
)
}
struct SoftmaxParams<T> {
scale: T,
softcap: Option<T>,
neg_inf: T,
causal_offset: Option<isize>,
}
#[inline]
fn fused_softmax_row<T>(
scores: &[T],
output: &mut [T],
mask: Option<&[u8]>,
bias: Option<&[T]>,
params: &SoftmaxParams<T>,
q_pos: usize,
) where
T: Float + Copy + core::ops::AddAssign,
{
let mut row_max = params.neg_inf;
for (k, (out, &score)) in output.iter_mut().zip(scores.iter()).enumerate() {
let mut val = score * params.scale;
if let Some(cap) = params.softcap {
val = cap * (val / cap).tanh();
}
if let Some(m) = mask
&& m[k] != 0
{
val = params.neg_inf;
}
if let Some(offset) = params.causal_offset
&& (k as isize) > (q_pos as isize) + offset
{
val = params.neg_inf;
}
if let Some(b) = bias {
val += b[k];
}
*out = val;
if val > row_max {
row_max = val;
}
}
if row_max == params.neg_inf {
for out in output.iter_mut() {
*out = T::zero();
}
return;
}
let mut sum = T::zero();
for out in output.iter_mut() {
let e = (*out - row_max).exp();
*out = e;
sum += e;
}
let inv_sum = T::one() / sum;
for out in output.iter_mut() {
*out = *out * inv_sum;
}
}
#[cfg(test)]
mod tests {
use burn_backend::ops::AttentionModuleOptions;
use burn_tensor::{Tensor, TensorData};
use crate::Flex;
fn make_qkv(
q: &[&[f32]],
k: &[&[f32]],
v: &[&[f32]],
) -> (Tensor<Flex, 4>, Tensor<Flex, 4>, Tensor<Flex, 4>) {
let seq_q = q.len();
let seq_k = k.len();
let head_dim = q[0].len();
let val_dim = v[0].len();
let q_flat: Vec<f32> = q.iter().flat_map(|r| r.iter().copied()).collect();
let k_flat: Vec<f32> = k.iter().flat_map(|r| r.iter().copied()).collect();
let v_flat: Vec<f32> = v.iter().flat_map(|r| r.iter().copied()).collect();
let dev = Default::default();
let qt = Tensor::from_data(TensorData::new(q_flat, [1, 1, seq_q, head_dim]), &dev);
let kt = Tensor::from_data(TensorData::new(k_flat, [1, 1, seq_k, head_dim]), &dev);
let vt = Tensor::from_data(TensorData::new(v_flat, [1, 1, seq_k, val_dim]), &dev);
(qt, kt, vt)
}
#[test]
fn test_basic() {
let (q, k, v) = make_qkv(
&[&[1.0, 0.0], &[0.0, 1.0]],
&[&[1.0, 0.0], &[0.0, 1.0]],
&[&[10.0], &[20.0]],
);
let result = burn_tensor::module::attention(q, k, v, None, None, Default::default());
let data: Vec<f32> = result.into_data().to_vec().unwrap();
assert_eq!(data.len(), 2);
assert!((data[0] - 13.30).abs() < 0.1, "got {}", data[0]);
assert!((data[1] - 16.70).abs() < 0.1, "got {}", data[1]);
}
#[test]
fn test_causal_mask() {
let (q, k, v) = make_qkv(
&[&[1.0, 0.0], &[0.0, 1.0]],
&[&[1.0, 0.0], &[0.0, 1.0]],
&[&[10.0], &[20.0]],
);
let opts = AttentionModuleOptions {
is_causal: true,
..Default::default()
};
let result = burn_tensor::module::attention(q, k, v, None, None, opts);
let data: Vec<f32> = result.into_data().to_vec().unwrap();
assert!((data[0] - 10.0).abs() < 1e-5, "got {}", data[0]);
assert!((data[1] - 16.70).abs() < 0.1, "got {}", data[1]);
}
#[test]
fn test_bool_mask() {
let (q, k, v) = make_qkv(
&[&[1.0, 0.0], &[0.0, 1.0]],
&[&[1.0, 0.0], &[0.0, 1.0]],
&[&[10.0], &[20.0]],
);
let dev = Default::default();
use burn_tensor::Bool;
let mask: Tensor<Flex, 4, Bool> =
Tensor::from_data(TensorData::from([[[[true, false], [true, false]]]]), &dev);
let result = burn_tensor::module::attention(q, k, v, Some(mask), None, Default::default());
let data: Vec<f32> = result.into_data().to_vec().unwrap();
assert!((data[0] - 20.0).abs() < 1e-4, "got {}", data[0]);
assert!((data[1] - 20.0).abs() < 1e-4, "got {}", data[1]);
}
#[test]
fn test_additive_bias() {
let (q, k, v) = make_qkv(
&[&[1.0, 0.0], &[0.0, 1.0]],
&[&[1.0, 0.0], &[0.0, 1.0]],
&[&[10.0], &[20.0]],
);
let dev = Default::default();
let bias: Tensor<Flex, 4> = Tensor::from_data(
TensorData::new(vec![0.0f32, 100.0, 0.0, 100.0], [1, 1, 2, 2]),
&dev,
);
let result = burn_tensor::module::attention(q, k, v, None, Some(bias), Default::default());
let data: Vec<f32> = result.into_data().to_vec().unwrap();
assert!((data[0] - 20.0).abs() < 0.1, "got {}", data[0]);
assert!((data[1] - 20.0).abs() < 0.1, "got {}", data[1]);
}
#[test]
fn test_custom_scale() {
let (q, k, v) = make_qkv(
&[&[1.0, 0.0], &[0.0, 1.0]],
&[&[1.0, 0.0], &[0.0, 1.0]],
&[&[10.0], &[20.0]],
);
let opts = AttentionModuleOptions {
scale: Some(100.0),
..Default::default()
};
let result = burn_tensor::module::attention(q, k, v, None, None, opts);
let data: Vec<f32> = result.into_data().to_vec().unwrap();
assert!((data[0] - 10.0).abs() < 0.1, "got {}", data[0]);
assert!((data[1] - 20.0).abs() < 0.1, "got {}", data[1]);
}
#[test]
fn test_softcap() {
let (q, k, v) = make_qkv(
&[&[1.0, 0.0], &[0.0, 1.0]],
&[&[1.0, 0.0], &[0.0, 1.0]],
&[&[10.0], &[20.0]],
);
let opts = AttentionModuleOptions {
softcap: Some(0.1),
..Default::default()
};
let result = burn_tensor::module::attention(q, k, v, None, None, opts);
let data: Vec<f32> = result.into_data().to_vec().unwrap();
assert!((data[0] - 15.0).abs() < 0.5, "got {}", data[0]);
assert!((data[1] - 15.0).abs() < 0.5, "got {}", data[1]);
}
#[test]
fn test_cross_attention() {
let (q, k, v) = make_qkv(
&[&[1.0, 0.0], &[0.0, 1.0]],
&[&[1.0, 0.0], &[0.0, 1.0], &[0.5, 0.5]],
&[&[10.0], &[20.0], &[30.0]],
);
let result = burn_tensor::module::attention(q, k, v, None, None, Default::default());
let data: Vec<f32> = result.into_data().to_vec().unwrap();
assert_eq!(data.len(), 2);
for &val in &data {
assert!(val >= 9.0 && val <= 31.0, "unexpected value {val}");
}
}
#[test]
fn test_causal_cross_attention() {
let dev = Default::default();
let q: Tensor<Flex, 4> = Tensor::from_data(
TensorData::new(vec![1.0f32, 0.0, 0.0, 1.0], [1, 1, 2, 2]),
&dev,
);
let k: Tensor<Flex, 4> = Tensor::from_data(
TensorData::new(
vec![1.0f32, 0.0, 0.0, 1.0, 0.5, 0.5, 0.5, 0.5],
[1, 1, 4, 2],
),
&dev,
);
let v: Tensor<Flex, 4> = Tensor::from_data(
TensorData::new(vec![10.0f32, 20.0, 30.0, 40.0], [1, 1, 4, 1]),
&dev,
);
let opts = AttentionModuleOptions {
is_causal: true,
..Default::default()
};
let result_causal =
burn_tensor::module::attention(q.clone(), k.clone(), v.clone(), None, None, opts);
let data_causal: Vec<f32> = result_causal.into_data().to_vec().unwrap();
let result_full = burn_tensor::module::attention(q, k, v, None, None, Default::default());
let data_full: Vec<f32> = result_full.into_data().to_vec().unwrap();
assert_eq!(data_causal.len(), 2);
assert!(
data_causal[0] < data_full[0],
"expected causal[0] < full[0], got {} vs {}",
data_causal[0],
data_full[0]
);
assert!(
(data_causal[1] - data_full[1]).abs() < 1e-5,
"expected causal[1] ~= full[1], got {} vs {}",
data_causal[1],
data_full[1]
);
}
#[test]
fn test_all_masked_produces_zeros() {
let (q, k, v) = make_qkv(
&[&[1.0, 0.0], &[0.0, 1.0]],
&[&[1.0, 0.0], &[0.0, 1.0]],
&[&[10.0], &[20.0]],
);
let dev = Default::default();
use burn_tensor::Bool;
let mask: Tensor<Flex, 4, Bool> =
Tensor::from_data(TensorData::from([[[[true, true], [true, true]]]]), &dev);
let result = burn_tensor::module::attention(q, k, v, Some(mask), None, Default::default());
let data: Vec<f32> = result.into_data().to_vec().unwrap();
for (i, &val) in data.iter().enumerate() {
assert!(!val.is_nan(), "output[{i}] is NaN");
assert!((val - 0.0).abs() < 1e-6, "expected 0.0, got {val}");
}
}
#[test]
fn test_multi_batch_multi_head() {
let dev = Default::default();
let q: Tensor<Flex, 4> = Tensor::from_data(
TensorData::new(
vec![
1.0f32, 0.0, 0.0, 1.0, 0.5, 0.5, 0.5, 0.5, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0,
],
[2, 2, 2, 2],
),
&dev,
);
let k = q.clone();
let v: Tensor<Flex, 4> =
Tensor::from_data(TensorData::new(vec![10.0f32; 16], [2, 2, 2, 2]), &dev);
let result = burn_tensor::module::attention(q, k, v, None, None, Default::default());
let data: Vec<f32> = result.into_data().to_vec().unwrap();
assert_eq!(data.len(), 16);
for (i, &val) in data.iter().enumerate() {
assert!(
(val - 10.0).abs() < 1e-4,
"output[{i}] = {val}, expected 10.0"
);
}
}
#[test]
fn test_single_element() {
let (q, k, v) = make_qkv(&[&[1.0, 0.0]], &[&[1.0, 0.0]], &[&[42.0]]);
let result = burn_tensor::module::attention(q, k, v, None, None, Default::default());
let data: Vec<f32> = result.into_data().to_vec().unwrap();
assert_eq!(data.len(), 1);
assert!((data[0] - 42.0).abs() < 1e-5, "got {}", data[0]);
}
}