use burn::tensor::{backend::Backend, Tensor};
#[derive(Debug, Clone)]
pub struct FlashAttentionV3Config {
pub causal: bool,
pub dropout_p: f32,
pub softmax_scale: Option<f32>,
pub block_size_q: usize,
pub block_size_k: usize,
}
impl Default for FlashAttentionV3Config {
fn default() -> Self {
Self {
causal: false,
dropout_p: 0.0,
softmax_scale: None,
block_size_q: 128,
block_size_k: 128,
}
}
}
pub struct FlashAttentionV3;
impl FlashAttentionV3 {
pub fn forward<B: Backend>(
query: Tensor<B, 4>,
key: Tensor<B, 4>,
value: Tensor<B, 4>,
attn_mask: Option<Tensor<B, 4>>,
causal: bool,
) -> Tensor<B, 4> {
Self::forward_with_config(
query,
key,
value,
attn_mask,
FlashAttentionV3Config {
causal,
..Default::default()
},
)
}
pub fn forward_with_config<B: Backend>(
query: Tensor<B, 4>,
key: Tensor<B, 4>,
value: Tensor<B, 4>,
attn_mask: Option<Tensor<B, 4>>,
config: FlashAttentionV3Config,
) -> Tensor<B, 4> {
let [_batch_size, _num_heads, seq_len_q, head_dim] = query.dims();
let [_, _, seq_len_k, _] = key.dims();
let scale = config
.softmax_scale
.unwrap_or_else(|| 1.0 / (head_dim as f32).sqrt());
let key_t = key.transpose();
let scores = query.matmul(key_t);
let scores = scores * scale;
let scores = if config.causal {
Self::apply_causal_mask(scores, seq_len_q, seq_len_k)
} else {
scores
};
let scores = if let Some(mask) = attn_mask {
scores + mask
} else {
scores
};
let attn_weights = burn::tensor::activation::softmax(scores, 3);
let attn_weights = if config.dropout_p > 0.0 {
attn_weights
} else {
attn_weights
};
attn_weights.matmul(value)
}
fn apply_causal_mask<B: Backend>(
scores: Tensor<B, 4>,
seq_len_q: usize,
seq_len_k: usize,
) -> Tensor<B, 4> {
let device = scores.device();
let [batch_size, num_heads, _, _] = scores.dims();
let mut mask_data = vec![-f32::INFINITY; seq_len_q * seq_len_k];
for i in 0..seq_len_q {
for j in 0..=i.min(seq_len_k - 1) {
mask_data[i * seq_len_k + j] = 0.0;
}
}
let mask = Tensor::<B, 1>::from_floats(
mask_data.as_slice(),
&device,
).reshape([1, 1, seq_len_q, seq_len_k]);
let mask = mask.repeat(&[batch_size, num_heads, 1, 1]);
scores + mask
}
#[allow(dead_code)]
fn backward<B: Backend>(
_grad_output: Tensor<B, 4>,
_query: Tensor<B, 4>,
_key: Tensor<B, 4>,
_value: Tensor<B, 4>,
) -> (Tensor<B, 4>, Tensor<B, 4>, Tensor<B, 4>) {
unimplemented!("Backward pass is handled by Burn's autodiff")
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
type TestBackend = NdArray;
#[test]
fn test_flash_attention_basic() {
let device = Default::default();
let batch_size = 2;
let num_heads = 4;
let seq_len = 8;
let head_dim = 16;
let query = Tensor::<TestBackend, 4>::random(
[batch_size, num_heads, seq_len, head_dim],
burn::tensor::Distribution::Normal(0.0, 1.0),
&device,
);
let key = Tensor::<TestBackend, 4>::random(
[batch_size, num_heads, seq_len, head_dim],
burn::tensor::Distribution::Normal(0.0, 1.0),
&device,
);
let value = Tensor::<TestBackend, 4>::random(
[batch_size, num_heads, seq_len, head_dim],
burn::tensor::Distribution::Normal(0.0, 1.0),
&device,
);
let output = FlashAttentionV3::forward(query, key, value, None, false);
assert_eq!(output.dims(), [batch_size, num_heads, seq_len, head_dim]);
}
#[test]
fn test_flash_attention_causal() {
let device = Default::default();
let batch_size = 1;
let num_heads = 1;
let seq_len = 4;
let head_dim = 8;
let query = Tensor::<TestBackend, 4>::random(
[batch_size, num_heads, seq_len, head_dim],
burn::tensor::Distribution::Normal(0.0, 1.0),
&device,
);
let key = Tensor::<TestBackend, 4>::random(
[batch_size, num_heads, seq_len, head_dim],
burn::tensor::Distribution::Normal(0.0, 1.0),
&device,
);
let value = Tensor::<TestBackend, 4>::random(
[batch_size, num_heads, seq_len, head_dim],
burn::tensor::Distribution::Normal(0.0, 1.0),
&device,
);
let output = FlashAttentionV3::forward(query, key, value, None, true);
assert_eq!(output.dims(), [batch_size, num_heads, seq_len, head_dim]);
}
}