use burn::prelude::*;
use burn::nn::{Linear, LinearConfig};
use burn::tensor::activation::softmax;
#[derive(Module, Debug)]
pub struct Attention<B: Backend> {
pub qkv: Linear<B>,
pub proj: Linear<B>,
pub n_heads: usize,
pub head_dim: usize,
}
impl<B: Backend> Attention<B> {
pub fn new(dim: usize, n_heads: usize, qkv_bias: bool, device: &B::Device) -> Self {
let head_dim = dim / n_heads;
Self {
qkv: LinearConfig::new(dim, dim * 3).with_bias(qkv_bias).init(device),
proj: LinearConfig::new(dim, dim).with_bias(true).init(device),
n_heads,
head_dim,
}
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let [b, s, _] = x.dims();
let (h, dh) = (self.n_heads, self.head_dim);
let dim = h * dh;
let qkv = self.qkv.forward(x);
let q = qkv.clone().narrow(2, 0, dim).reshape([b, s, h, dh]).swap_dims(1, 2);
let k = qkv.clone().narrow(2, dim, dim).reshape([b, s, h, dh]).swap_dims(1, 2);
let v = qkv.narrow(2, dim * 2, dim).reshape([b, s, h, dh]).swap_dims(1, 2);
let scale = (dh as f64).powf(-0.5) as f32;
let attn = softmax(q.matmul(k.transpose()).mul_scalar(scale), 3);
let out = attn.matmul(v);
let out = out.swap_dims(1, 2).reshape([b, s, dim]);
self.proj.forward(out)
}
}