use burn::prelude::*;
use burn::nn::{Linear, LinearConfig};
use burn::tensor::activation::softmax;
use crate::model::rms_norm::RmsNorm;
#[derive(Module, Debug)]
pub struct Attention<B: Backend> {
pub norm: RmsNorm<B>,
pub to_qkv: Linear<B>,
pub to_out: Linear<B>,
pub n_heads: usize,
pub head_dim: usize,
}
impl<B: Backend> Attention<B> {
pub fn new(dim: usize, heads: usize, head_dim: usize, device: &B::Device) -> Self {
let inner_dim = head_dim * heads;
Self {
norm: RmsNorm::new(dim, 1e-6, device),
to_qkv: LinearConfig::new(dim, inner_dim * 3).with_bias(false).init(device),
to_out: LinearConfig::new(inner_dim, dim).with_bias(false).init(device),
n_heads: heads,
head_dim,
}
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let [b, s, _] = x.dims();
let (h, dh) = (self.n_heads, self.head_dim);
let normed = self.norm.forward(x);
let qkv = self.to_qkv.forward(normed);
let inner = h * dh;
let q = qkv.clone().narrow(2, 0, inner);
let k = qkv.clone().narrow(2, inner, inner);
let v = qkv.narrow(2, inner * 2, inner);
let q = q.reshape([b, s, h, dh]).swap_dims(1, 2);
let k = k.reshape([b, s, h, dh]).swap_dims(1, 2);
let v = v.reshape([b, s, h, dh]).swap_dims(1, 2);
let scale = (dh as f64).powf(-0.5) as f32;
let attn = softmax(q.matmul(k.transpose()).mul_scalar(scale), 3);
let out = attn.matmul(v);
let out = out.swap_dims(1, 2).reshape([b, s, h * dh]);
self.to_out.forward(out)
}
}