use burn::{
module::Module,
nn::{Linear, LinearConfig},
tensor::{Tensor, activation::softmax, backend::Backend},
};
#[derive(Module, Debug)]
pub struct MultiHeadAttention<B: Backend> {
n_heads: usize,
d_head: usize,
q_proj: Linear<B>,
k_proj: Linear<B>,
v_proj: Linear<B>,
out_proj: Linear<B>,
}
impl<B: Backend> MultiHeadAttention<B> {
pub fn new(d_model: usize, n_heads: usize, device: &B::Device) -> Self {
assert!(
d_model % n_heads == 0,
"d_model ({d_model}) must be divisible by n_heads ({n_heads})"
);
let d_head = d_model / n_heads;
let query = LinearConfig::new(d_model, d_model).init(device);
let key = LinearConfig::new(d_model, d_model)
.with_bias(false)
.init(device);
let value = LinearConfig::new(d_model, d_model).init(device);
let out = LinearConfig::new(d_model, d_model).init(device);
Self {
n_heads,
d_head,
q_proj: query,
k_proj: key,
v_proj: value,
out_proj: out,
}
}
pub fn forward(
&self,
x: Tensor<B, 3>,
context: Option<Tensor<B, 3>>,
mask: Option<Tensor<B, 2>>,
) -> Tensor<B, 3> {
let dims = x.shape().dims;
let batch = dims[0];
let seq_len = dims[1];
let q = self.q_proj.forward(x.clone());
let kv_src = context.unwrap_or(x);
let k = self.k_proj.forward(kv_src.clone());
let v = self.v_proj.forward(kv_src);
let ctx_len = k.shape().dims[1];
let q = q
.reshape([batch, seq_len, self.n_heads, self.d_head])
.swap_dims(1, 2);
let k = k
.reshape([batch, ctx_len, self.n_heads, self.d_head])
.swap_dims(1, 2);
let v = v
.reshape([batch, ctx_len, self.n_heads, self.d_head])
.swap_dims(1, 2);
let scale = 1.0 / (self.d_head as f64).sqrt();
let scores = q.matmul(k.transpose()) * scale;
let scores = if let Some(m) = mask {
scores + m.unsqueeze::<4>()
} else {
scores
};
let weights = softmax(scores, 3);
let attended = weights.matmul(v);
let output = attended
.swap_dims(1, 2)
.reshape([batch, seq_len, self.n_heads * self.d_head]);
self.out_proj.forward(output)
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
type TestBackend = NdArray;
#[test]
fn test_self_attention_shape() {
let device = Default::default();
let attn = MultiHeadAttention::<TestBackend>::new(384, 6, &device);
let x = Tensor::<TestBackend, 3>::random(
[2, 10, 384],
burn::tensor::Distribution::Normal(0.0, 1.0),
&device,
);
let output = attn.forward(x, None, None);
assert_eq!(output.shape().dims, [2, 10, 384]);
}
#[test]
fn test_cross_attention_shape() {
let device = Default::default();
let attn = MultiHeadAttention::<TestBackend>::new(384, 6, &device);
let x = Tensor::<TestBackend, 3>::random(
[2, 5, 384],
burn::tensor::Distribution::Normal(0.0, 1.0),
&device,
);
let context = Tensor::<TestBackend, 3>::random(
[2, 20, 384],
burn::tensor::Distribution::Normal(0.0, 1.0),
&device,
);
let output = attn.forward(x, Some(context), None);
assert_eq!(output.shape().dims, [2, 5, 384]);
}
}