use burn::prelude::*;
use burn::module::Param;
use burn::nn::Linear;
use super::linear_zeros;
#[derive(Module, Debug)]
pub struct Attention<B: Backend> {
pub qkv: Linear<B>,
pub q_bias: Param<Tensor<B, 1>>,
pub v_bias: Param<Tensor<B, 1>>,
pub proj: Linear<B>,
pub num_heads: usize,
pub head_dim: usize,
pub scale: f32,
}
impl<B: Backend> Attention<B> {
pub fn new(dim: usize, num_heads: usize, device: &B::Device) -> Self {
let head_dim = dim / num_heads;
let all_head_dim = head_dim * num_heads;
Self {
qkv: linear_zeros::<B>(dim, all_head_dim * 3, false, device),
q_bias: Param::initialized(
burn::module::ParamId::new(),
Tensor::zeros([all_head_dim], device),
),
v_bias: Param::initialized(
burn::module::ParamId::new(),
Tensor::zeros([all_head_dim], device),
),
proj: linear_zeros::<B>(all_head_dim, dim, true, device),
num_heads,
head_dim,
scale: (head_dim as f32).powf(-0.5),
}
}
pub fn fuse_qkv_bias(&mut self) {
let dim = self.num_heads * self.head_dim;
let device = self.q_bias.val().device();
let k_bias = Tensor::<B, 1>::zeros([dim], &device);
let fused = Tensor::cat(vec![self.q_bias.val(), k_bias, self.v_bias.val()], 0);
self.qkv.bias = Some(Param::initialized(
burn::module::ParamId::new(),
fused,
));
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let [b, n, _] = x.dims();
let h = self.num_heads;
let d = self.head_dim;
let qkv = if self.qkv.bias.is_some() {
self.qkv.forward(x)
} else {
let device = x.device();
let k_bias = Tensor::<B, 1>::zeros([h * d], &device);
let bias = Tensor::cat(vec![self.q_bias.val(), k_bias, self.v_bias.val()], 0);
self.qkv.forward(x) + bias.unsqueeze::<2>().unsqueeze::<3>()
};
let qkv = qkv.reshape([b, n, 3, h, d]).permute([2, 0, 3, 1, 4]);
let q = qkv.clone().narrow(0, 0, 1).reshape([b, h, n, d]);
let k = qkv.clone().narrow(0, 1, 1).reshape([b, h, n, d]);
let v = qkv.narrow(0, 2, 1).reshape([b, h, n, d]);
let attn = burn::tensor::activation::softmax(
(q * self.scale).matmul(k.transpose()),
3,
);
self.proj.forward(attn.matmul(v).permute([0, 2, 1, 3]).reshape([b, n, h * d]))
}
}