use crate as burn;
use crate::nn::cache::TensorCache;
use crate::{
config::Config,
module::Module,
nn,
tensor::{activation, backend::Backend, Bool, Tensor},
};
use libm::sqrtf;
#[derive(Config)]
pub struct MultiHeadAttentionConfig {
d_model: usize,
n_heads: usize,
#[config(default = 0.1)]
dropout: f64,
#[config(default = -1.0e4)]
min_float: f64,
}
#[derive(Module, Debug)]
pub struct MultiHeadAttention<B: Backend> {
query: nn::Linear<B>,
key: nn::Linear<B>,
value: nn::Linear<B>,
output: nn::Linear<B>,
dropout: nn::Dropout,
activation: nn::GELU,
n_heads: usize,
d_k: usize,
min_float: f64,
}
#[derive(Debug, Clone)]
pub struct MhaInput<B: Backend> {
query: Tensor<B, 3>,
key: Tensor<B, 3>,
value: Tensor<B, 3>,
mask_pad: Option<Tensor<B, 2, Bool>>,
mask_attn: Option<Tensor<B, 3, Bool>>,
}
impl MultiHeadAttentionConfig {
pub fn init<B: Backend>(&self) -> MultiHeadAttention<B> {
let linear = |config: &Self| nn::LinearConfig::new(config.d_model, config.d_model).init();
MultiHeadAttention {
query: linear(self),
key: linear(self),
value: linear(self),
output: linear(self),
dropout: nn::DropoutConfig::new(self.dropout).init(),
activation: nn::GELU::new(),
n_heads: self.n_heads,
d_k: self.d_model / self.n_heads,
min_float: self.min_float,
}
}
pub fn init_with<B: Backend>(
&self,
record: MultiHeadAttentionRecord<B>,
) -> MultiHeadAttention<B> {
let linear = |config: &Self, record| {
nn::LinearConfig::new(config.d_model, config.d_model).init_with(record)
};
MultiHeadAttention {
query: linear(self, record.query),
key: linear(self, record.key),
value: linear(self, record.value),
output: linear(self, record.output),
dropout: nn::DropoutConfig::new(self.dropout).init(),
activation: nn::GELU::new(),
n_heads: self.n_heads,
d_k: self.d_model / self.n_heads,
min_float: self.min_float,
}
}
}
impl<B: Backend> MhaInput<B> {
pub fn self_attn(tensor: Tensor<B, 3>) -> Self {
Self {
query: tensor.clone(),
key: tensor.clone(),
value: tensor,
mask_pad: None,
mask_attn: None,
}
}
pub fn new(query: Tensor<B, 3>, key: Tensor<B, 3>, value: Tensor<B, 3>) -> Self {
Self {
query,
key,
value,
mask_pad: None,
mask_attn: None,
}
}
pub fn mask_pad(mut self, mask_pad: Tensor<B, 2, Bool>) -> Self {
self.mask_pad = Some(mask_pad);
self
}
pub fn mask_attn(mut self, mask_attn: Tensor<B, 3, Bool>) -> Self {
self.mask_attn = Some(mask_attn);
self
}
}
#[derive(Debug, Clone)]
pub struct MhaOutput<B: Backend> {
pub weights: Tensor<B, 4>,
pub context: Tensor<B, 3>,
}
impl<B: Backend> MultiHeadAttention<B> {
pub fn forward(&self, input: MhaInput<B>) -> MhaOutput<B> {
let [batch_size, seq_length_1, d_model] = input.query.dims();
let query = self.attention_linear(input.query, &self.query);
let key = self.attention_linear(input.key, &self.key);
let value = self.attention_linear(input.value, &self.value);
let attn_scores = self.attn_scores(query, key);
let weights = self.attn_weights(attn_scores, input.mask_pad, input.mask_attn);
let context = weights.clone().matmul(value);
let context = context
.swap_dims(1, 2)
.reshape([batch_size, seq_length_1, d_model]);
let context = self.output.forward(context);
MhaOutput { weights, context }
}
pub fn forward_cache(&self, input: MhaInput<B>, cache: &mut MhaCache<B>) -> MhaOutput<B> {
let [batch_size, seq_length_1, d_model] = input.query.dims();
let query = cache
.query
.forward(input.query, |t| self.attention_linear(t, &self.query));
let key = cache
.key
.forward(input.key, |t| self.attention_linear(t, &self.key));
let value = cache
.value
.forward(input.value, |t| self.attention_linear(t, &self.value));
let attn_scores = self.attn_scores(query, key);
let weights = self.attn_weights(attn_scores, input.mask_pad, input.mask_attn);
let context = weights.clone().matmul(value);
let context = context
.swap_dims(1, 2)
.reshape([batch_size, seq_length_1, d_model]);
let context = cache.output.forward(context, |t| self.output.forward(t));
MhaOutput { weights, context }
}
fn attn_scores(&self, query: Tensor<B, 4>, key: Tensor<B, 4>) -> Tensor<B, 4> {
let attn_scores = query
.matmul(key.transpose())
.div_scalar(sqrtf(self.d_k as f32));
self.dropout.forward(attn_scores)
}
fn attn_weights(
&self,
mut attn_scores: Tensor<B, 4>,
mask_pad: Option<Tensor<B, 2, Bool>>,
mask_attn: Option<Tensor<B, 3, Bool>>,
) -> Tensor<B, 4> {
if let Some(mask_pad) = mask_pad {
let [batch_size, seq_length] = mask_pad.dims();
attn_scores = attn_scores.mask_fill(
mask_pad.reshape([batch_size, 1, 1, seq_length]),
self.min_float,
);
}
if let Some(mask_attn) = mask_attn {
let [batch_size, seq_length_1, seq_length_2] = mask_attn.dims();
attn_scores = attn_scores.mask_fill(
mask_attn.reshape([batch_size, 1, seq_length_1, seq_length_2]),
self.min_float,
);
}
activation::softmax(attn_scores, 3)
}
fn attention_linear(&self, x: Tensor<B, 3>, linear: &nn::Linear<B>) -> Tensor<B, 4> {
let [batch_size, seq_length, _d_model] = x.dims();
linear
.forward(x)
.reshape([batch_size, seq_length, self.n_heads, self.d_k])
.swap_dims(1, 2)
}
}
pub struct MhaCache<B: Backend> {
query: MhaLinearCache<B, 4>,
key: MhaLinearCache<B, 4>,
value: MhaLinearCache<B, 4>,
output: MhaLinearCache<B, 3>,
}
enum MhaLinearCache<B: Backend, const D: usize> {
Autoregressive(TensorCache<B, D>, usize),
Full(TensorCache<B, D>),
}
impl<B: Backend> MhaCache<B> {
pub fn autoregressive() -> Self {
Self {
query: MhaLinearCache::Autoregressive(TensorCache::empty(), 2),
key: MhaLinearCache::Autoregressive(TensorCache::empty(), 2),
value: MhaLinearCache::Autoregressive(TensorCache::empty(), 2),
output: MhaLinearCache::Autoregressive(TensorCache::empty(), 1),
}
}
pub fn autoregressive_cross_attention() -> Self {
Self {
query: MhaLinearCache::Autoregressive(TensorCache::empty(), 2),
key: MhaLinearCache::Full(TensorCache::empty()),
value: MhaLinearCache::Full(TensorCache::empty()),
output: MhaLinearCache::Autoregressive(TensorCache::empty(), 1),
}
}
}
impl<B: Backend, const D: usize> MhaLinearCache<B, D> {
pub fn forward<F: Fn(Tensor<B, 3>) -> Tensor<B, D>>(
&mut self,
tensor: Tensor<B, 3>,
func: F,
) -> Tensor<B, D> {
match self {
MhaLinearCache::Autoregressive(cache, dim) => {
cache.forward_autoregressive(tensor, *dim, func)
}
MhaLinearCache::Full(cache) => cache.forward_full(tensor, func),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{nn::attention::generate_autoregressive_mask, TestBackend};
use alloc::vec::Vec;
use burn::tensor::{Distribution, Shape};
use burn_tensor::Int;
#[test]
fn test_self_attention_shapes() {
let [batch_size, seq_length, d_model, n_heads] = [7, 13, 32, 4];
let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::<TestBackend>();
let input = MhaInput::self_attn(Tensor::random(
[batch_size, seq_length, d_model],
Distribution::Standard,
));
let output = mha.forward(input);
assert_eq!(
output.context.shape(),
Shape::new([batch_size, seq_length, d_model]),
"Context should have the correct shape",
);
assert_eq!(
output.weights.shape(),
Shape::new([batch_size, n_heads, seq_length, seq_length]),
"Weights should have the correct shape",
);
}
#[test]
fn test_generic_mha_shapes() {
let [batch_size, seq_length_1, seq_length_2, d_model, n_heads] = [7, 13, 15, 32, 4];
let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::<TestBackend>();
let input = MhaInput::new(
Tensor::random([batch_size, seq_length_1, d_model], Distribution::Standard),
Tensor::random([batch_size, seq_length_2, d_model], Distribution::Standard),
Tensor::random([batch_size, seq_length_2, d_model], Distribution::Standard),
);
let output = mha.forward(input);
assert_eq!(
output.context.shape(),
Shape::new([batch_size, seq_length_1, d_model]),
"Context should have the correct shape",
);
assert_eq!(
output.weights.shape(),
Shape::new([batch_size, n_heads, seq_length_1, seq_length_2]),
"Weights should have the correct shape",
);
}
#[test]
fn test_self_attention_mask_pad() {
let [batch_size, seq_length, d_model, n_heads, num_padded] = [3, 6, 32, 2, 2];
let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::<TestBackend>();
let mask_pad: Tensor<TestBackend, 2, Int> = Tensor::zeros([batch_size, seq_length]);
let mask_pad = mask_pad.index_assign(
[0..batch_size, seq_length - num_padded..seq_length],
Tensor::ones([batch_size, num_padded]),
);
let mask_pad = mask_pad.equal_elem(1);
let tensor_1 = Tensor::<TestBackend, 3>::random(
[batch_size, seq_length, d_model],
Distribution::Standard,
);
let tensor_2 = tensor_1.clone().index_assign(
[
0..batch_size,
seq_length - num_padded..seq_length,
0..d_model,
],
Tensor::random([batch_size, num_padded, d_model], Distribution::Standard),
);
let input_1 = MhaInput::self_attn(tensor_1).mask_pad(mask_pad.clone());
let input_2 = MhaInput::self_attn(tensor_2).mask_pad(mask_pad);
let output_1 = mha.forward(input_1);
let output_2 = mha.forward(input_2);
output_1
.context
.index([0..batch_size, 0..seq_length - num_padded, 0..d_model])
.into_data()
.assert_approx_eq(
&output_2
.context
.index([0..batch_size, 0..seq_length - num_padded, 0..d_model])
.into_data(),
3,
);
}
#[test]
fn test_autoregressive_mask_should_have_same_output_as_autoregressive_decoding() {
let [batch_size, seq_length, d_model, n_heads] = [3, 4, 12, 2];
let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::<TestBackend>();
let tensor = Tensor::<TestBackend, 3>::random(
[batch_size, seq_length, d_model],
Distribution::Standard,
);
let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &tensor.device());
let input = MhaInput::self_attn(tensor.clone()).mask_attn(mask_attn);
let output_1 = mha.forward(input);
let mut output_2 = Vec::new();
let mut cache = MhaCache::autoregressive();
for i in 1..seq_length + 1 {
let tensor = tensor.clone().index([0..batch_size, 0..i, 0..d_model]);
let input = MhaInput::self_attn(tensor);
let next_tok = mha.forward_cache(input, &mut cache).context.index([
0..batch_size,
i - 1..i,
0..d_model,
]);
output_2.push(next_tok);
}
let output_2 = Tensor::cat(output_2, 1);
output_1
.context
.into_data()
.assert_approx_eq(&output_2.into_data(), 3);
}
}