use burn::nn::Linear;
use burn::prelude::*;
use burn::tensor::activation::softmax;
use crate::model::linear_zeros;
const ATTN_TILE: usize = 1024;
#[derive(Module, Debug)]
pub struct Attention<B: Backend> {
pub qkv: Linear<B>,
pub proj: Linear<B>,
pub num_heads: usize,
pub head_dim: usize,
pub scale: f32,
}
impl<B: Backend> Attention<B> {
pub fn new(dim: usize, num_heads: usize, qkv_bias: bool, device: &B::Device) -> Self {
let head_dim = dim / num_heads;
let scale = (head_dim as f64).powf(-0.5) as f32;
Self {
qkv: linear_zeros(dim, dim * 3, qkv_bias, device),
proj: linear_zeros(dim, dim, true, device),
num_heads,
head_dim,
scale,
}
}
pub fn forward(&self, x: Tensor<B, 3>, attn_mask: Option<&Tensor<B, 2>>) -> Tensor<B, 3> {
let [b, n, c] = x.dims();
let h = self.num_heads;
let dh = self.head_dim;
let qkv = self.qkv.forward(x);
let qkv = qkv.reshape([b, n, 3, h, dh]);
let q = qkv.clone().narrow(2, 0, 1).reshape([b, n, h, dh]);
let k = qkv.clone().narrow(2, 1, 1).reshape([b, n, h, dh]);
let v = qkv.narrow(2, 2, 1).reshape([b, n, h, dh]);
let q = q.swap_dims(1, 2).mul_scalar(self.scale);
let k = k.swap_dims(1, 2);
let v = v.swap_dims(1, 2);
let k_t = k.transpose();
let mask_4d = attn_mask.map(|mask| {
let m = mask.clone().unsqueeze_dim::<3>(1).unsqueeze_dim::<4>(2);
let neg_inf = Tensor::<B, 4>::full([b, 1, 1, n], -1e9, &m.device());
(m, neg_inf)
});
let mut tiles: Vec<Tensor<B, 4>> = Vec::with_capacity((n + ATTN_TILE - 1) / ATTN_TILE);
let mut offset = 0;
while offset < n {
let tile_len = (n - offset).min(ATTN_TILE);
let q_tile = q.clone().narrow(2, offset, tile_len);
let mut scores = q_tile.matmul(k_t.clone());
if let Some((ref m, ref neg_inf)) = mask_4d {
scores = scores * m.clone() + neg_inf.clone() * (m.clone().mul_scalar(-1.0) + 1.0);
}
tiles.push(softmax(scores, 3).matmul(v.clone()));
offset += tile_len;
}
let out = Tensor::cat(tiles, 2); let out = out.swap_dims(1, 2).reshape([b, n, c]);
self.proj.forward(out)
}
}