use crate::cache::TensorCache;
use crate::modules::{Linear, LinearConfig};
use crate::{Dropout, DropoutConfig};
use burn_core as burn;
use burn::{
config::Config,
module::{Initializer, Module},
tensor::{
Bool, Tensor,
activation::{quiet_softmax, softmax},
backend::Backend,
},
};
#[cfg(not(feature = "std"))]
#[allow(unused_imports)]
use num_traits::Float as _;
#[derive(Config, Debug)]
pub struct CrossAttentionConfig {
pub d_model: usize,
pub d_context: usize,
pub n_heads: usize,
pub n_heads_kv: usize,
pub d_head: usize,
#[config(default = 0.1)]
pub dropout: f64,
#[config(default = -1.0e4)]
pub min_float: f64,
#[config(default = false)]
pub quiet_softmax: bool,
}
#[derive(Module, Debug)]
pub struct CrossAttention<B: Backend> {
query: Linear<B>,
key: Linear<B>,
value: Linear<B>,
output: Linear<B>,
dropout: Dropout,
n_heads: usize,
n_heads_kv: usize,
d_head: usize,
scale: f64,
min_float: f64,
quiet_softmax: bool,
}
pub struct CrossAttentionCache<B: Backend> {
pub k: TensorCache<B, 4>,
pub v: TensorCache<B, 4>,
}
impl<B: Backend> CrossAttentionCache<B> {
pub fn new() -> Self {
Self {
k: TensorCache::empty(),
v: TensorCache::empty(),
}
}
}
impl<B: Backend> Default for CrossAttentionCache<B> {
fn default() -> Self {
Self::new()
}
}
impl CrossAttentionConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> CrossAttention<B> {
assert_eq!(
self.n_heads % self.n_heads_kv,
0,
"Query heads must be divisible by KV heads"
);
let init_linear = |in_dim, out_dim| {
LinearConfig::new(in_dim, out_dim)
.with_initializer(Initializer::KaimingUniform {
gain: 1.0 / (self.d_head as f64).sqrt(),
fan_out_only: false,
})
.init(device)
};
CrossAttention {
query: init_linear(self.d_model, self.n_heads * self.d_head),
key: init_linear(self.d_context, self.n_heads_kv * self.d_head),
value: init_linear(self.d_context, self.n_heads_kv * self.d_head),
output: init_linear(self.n_heads * self.d_head, self.d_model),
dropout: DropoutConfig::new(self.dropout).init(),
n_heads: self.n_heads,
n_heads_kv: self.n_heads_kv,
d_head: self.d_head,
scale: (self.d_head as f64).sqrt().recip(),
min_float: self.min_float,
quiet_softmax: self.quiet_softmax,
}
}
}
impl<B: Backend> CrossAttention<B> {
pub fn forward(
&self,
query: Tensor<B, 3>,
context: Tensor<B, 3>,
mask: Option<Tensor<B, 2, Bool>>,
) -> Tensor<B, 3> {
let [batch, l_q, _] = query.dims();
let [_, l_k, _] = context.dims();
let q = self.query.forward(query);
let k = self.key.forward(context.clone());
let v = self.value.forward(context);
let q = q
.reshape([batch, l_q, self.n_heads, self.d_head])
.swap_dims(1, 2);
let k = k
.reshape([batch, l_k, self.n_heads_kv, self.d_head])
.swap_dims(1, 2);
let v = v
.reshape([batch, l_k, self.n_heads_kv, self.d_head])
.swap_dims(1, 2);
let (k, v) = if self.n_heads != self.n_heads_kv {
let n_rep = self.n_heads / self.n_heads_kv;
(self.repeat_kv(k, n_rep), self.repeat_kv(v, n_rep))
} else {
(k, v)
};
let scores = q.matmul(k.transpose()) * self.scale;
let scores = if let Some(mask) = mask {
let mask = mask.reshape([batch, 1, 1, l_k]);
scores.mask_fill(mask, self.min_float)
} else {
scores
};
let weights = if self.quiet_softmax {
quiet_softmax(scores, 3)
} else {
softmax(scores, 3)
};
let weights = self.dropout.forward(weights);
let output = weights.matmul(v);
let output = output
.swap_dims(1, 2)
.reshape([batch, l_q, self.n_heads * self.d_head]);
self.output.forward(output)
}
pub fn forward_cache(
&self,
query: Tensor<B, 3>,
context: Tensor<B, 3>,
mask: Option<Tensor<B, 2, Bool>>,
cache: &mut CrossAttentionCache<B>,
) -> Tensor<B, 3> {
let [batch, l_q, _] = query.dims();
let q = self.query.forward(query);
let k_compute = |context: Tensor<B, 3>| {
let [batch, l_k, _] = context.dims();
self.key
.forward(context)
.reshape([batch, l_k, self.n_heads_kv, self.d_head])
.swap_dims(1, 2)
};
let v_compute = |context: Tensor<B, 3>| {
let [batch, l_k, _] = context.dims();
self.value
.forward(context)
.reshape([batch, l_k, self.n_heads_kv, self.d_head])
.swap_dims(1, 2)
};
let k = cache.k.forward_full(context.clone(), k_compute);
let v = cache.v.forward_full(context, v_compute);
let [_, _, l_k, _] = k.dims();
let q = q
.reshape([batch, l_q, self.n_heads, self.d_head])
.swap_dims(1, 2);
let (k, v) = if self.n_heads != self.n_heads_kv {
let n_rep = self.n_heads / self.n_heads_kv;
(self.repeat_kv(k, n_rep), self.repeat_kv(v, n_rep))
} else {
(k, v)
};
let scores = q.matmul(k.transpose()) * self.scale;
let scores = if let Some(mask) = mask {
let mask = mask.reshape([batch, 1, 1, l_k]);
scores.mask_fill(mask, self.min_float)
} else {
scores
};
let weights = if self.quiet_softmax {
quiet_softmax(scores, 3)
} else {
softmax(scores, 3)
};
let weights = self.dropout.forward(weights);
let output = weights.matmul(v);
let output = output
.swap_dims(1, 2)
.reshape([batch, l_q, self.n_heads * self.d_head]);
self.output.forward(output)
}
fn repeat_kv(&self, x: Tensor<B, 4>, n_rep: usize) -> Tensor<B, 4> {
let [b, h, l, d] = x.dims();
x.reshape([b, h, 1, l, d])
.expand([b, h, n_rep, l, d])
.reshape([b, h * n_rep, l, d])
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use burn::tensor::{Distribution, Int, Shape, Tensor, Tolerance};
#[test]
fn test_cross_attention_mha_shapes() {
let [
batch_size,
seq_len_query,
seq_len_context,
d_model,
d_context,
n_heads,
d_head,
] = [7, 13, 15, 32, 40, 4, 8];
let device = Default::default();
let config = CrossAttentionConfig {
d_model,
d_context,
n_heads,
n_heads_kv: n_heads, d_head,
dropout: 0.1,
min_float: -1.0e4,
quiet_softmax: false,
};
let cross_attn = config.init::<TestBackend>(&device);
let query = Tensor::random(
[batch_size, seq_len_query, d_model],
Distribution::Default,
&device,
);
let context = Tensor::random(
[batch_size, seq_len_context, d_context],
Distribution::Default,
&device,
);
let output = cross_attn.forward(query, context, None);
assert_eq!(
output.shape(),
Shape::new([batch_size, seq_len_query, d_model]),
"Output should have the correct shape",
);
}
#[test]
fn test_cross_attention_gqa_shapes() {
let [
batch_size,
seq_len_query,
seq_len_context,
d_model,
d_context,
n_heads,
n_heads_kv,
d_head,
] = [7, 13, 15, 32, 40, 4, 2, 8];
let device = Default::default();
let config = CrossAttentionConfig {
d_model,
d_context,
n_heads,
n_heads_kv, d_head,
dropout: 0.1,
min_float: -1.0e4,
quiet_softmax: false,
};
let cross_attn = config.init::<TestBackend>(&device);
let query = Tensor::random(
[batch_size, seq_len_query, d_model],
Distribution::Default,
&device,
);
let context = Tensor::random(
[batch_size, seq_len_context, d_context],
Distribution::Default,
&device,
);
let output = cross_attn.forward(query, context, None);
assert_eq!(
output.shape(),
Shape::new([batch_size, seq_len_query, d_model]),
"Output should have the correct shape",
);
}
#[test]
fn test_cross_attention_mqa_shapes() {
let [
batch_size,
seq_len_query,
seq_len_context,
d_model,
d_context,
n_heads,
d_head,
] = [7, 13, 15, 32, 40, 4, 8];
let device = Default::default();
let config = CrossAttentionConfig {
d_model,
d_context,
n_heads,
n_heads_kv: 1, d_head,
dropout: 0.1,
min_float: -1.0e4,
quiet_softmax: false,
};
let cross_attn = config.init::<TestBackend>(&device);
let query = Tensor::random(
[batch_size, seq_len_query, d_model],
Distribution::Default,
&device,
);
let context = Tensor::random(
[batch_size, seq_len_context, d_context],
Distribution::Default,
&device,
);
let output = cross_attn.forward(query, context, None);
assert_eq!(
output.shape(),
Shape::new([batch_size, seq_len_query, d_model]),
"Output should have the correct shape",
);
}
#[test]
fn test_cross_attention_mask() {
let [
batch_size,
seq_len_query,
seq_len_context,
d_model,
d_context,
n_heads,
d_head,
] = [3, 6, 8, 12, 16, 4, 3];
let num_padded = 2;
let device = Default::default();
let config = CrossAttentionConfig {
d_model,
d_context,
n_heads,
n_heads_kv: n_heads,
d_head,
dropout: 0.0, min_float: -1.0e4,
quiet_softmax: false,
};
let cross_attn = config.init::<TestBackend>(&device);
let mut mask: Tensor<TestBackend, 2, Int> =
Tensor::zeros([batch_size, seq_len_context], &device);
mask = mask.slice_assign(
[0..batch_size, seq_len_context - num_padded..seq_len_context],
Tensor::ones([batch_size, num_padded], &device),
);
let mask_bool = mask.equal_elem(1);
let query = Tensor::<TestBackend, 3>::random(
[batch_size, seq_len_query, d_model],
Distribution::Default,
&device,
);
let context_1 = Tensor::<TestBackend, 3>::random(
[batch_size, seq_len_context, d_context],
Distribution::Default,
&device,
);
let context_2 = context_1.clone().slice_assign(
[
0..batch_size,
seq_len_context - num_padded..seq_len_context,
0..d_context,
],
Tensor::random(
[batch_size, num_padded, d_context],
Distribution::Default,
&device,
),
);
let output_1 = cross_attn.forward(query.clone(), context_1, Some(mask_bool.clone()));
let output_2 = cross_attn.forward(query, context_2, Some(mask_bool));
output_1
.into_data()
.assert_approx_eq(&output_2.into_data(), Tolerance::<f32>::default());
}
#[test]
#[should_panic]
fn test_gqa_panic_if_n_heads_not_divisible_by_n_heads_kv() {
let device = Default::default();
let config = CrossAttentionConfig {
d_model: 32,
d_context: 32,
n_heads: 5,
n_heads_kv: 2,
d_head: 8,
dropout: 0.1,
min_float: -1.0e4,
quiet_softmax: false,
};
config.init::<TestBackend>(&device);
}
#[test]
fn test_cross_attention_cache() {
let [
batch_size,
seq_len_query,
seq_len_context,
d_model,
d_context,
n_heads,
d_head,
] = [3, 6, 8, 12, 16, 4, 3];
let device = Default::default();
let config = CrossAttentionConfig {
d_model,
d_context,
n_heads,
n_heads_kv: n_heads,
d_head,
dropout: 0.0, min_float: -1.0e4,
quiet_softmax: false,
};
let cross_attn = config.init::<TestBackend>(&device);
let query1 = Tensor::<TestBackend, 3>::random(
[batch_size, seq_len_query, d_model],
Distribution::Default,
&device,
);
let context = Tensor::<TestBackend, 3>::random(
[batch_size, seq_len_context, d_context],
Distribution::Default,
&device,
);
let output1 = cross_attn.forward(query1.clone(), context.clone(), None);
let mut cache = CrossAttentionCache::new();
let output2 = cross_attn.forward_cache(query1.clone(), context.clone(), None, &mut cache);
output1
.into_data()
.assert_approx_eq(&output2.into_data(), Tolerance::<f32>::default());
let query2 = Tensor::<TestBackend, 3>::random(
[batch_size, seq_len_query, d_model],
Distribution::Default,
&device,
);
let output3 = cross_attn.forward_cache(query2.clone(), context.clone(), None, &mut cache);
let output4 = cross_attn.forward(query2.clone(), context.clone(), None);
output3
.into_data()
.assert_approx_eq(&output4.into_data(), Tolerance::<f32>::default());
}
}