use burn::prelude::*;
use burn::nn::Linear;
use burn::tensor::activation::softmax;
use crate::model::rope::apply_rope;
use crate::model::linear_zeros;
#[derive(Module, Debug)]
pub struct Attention<B: Backend> {
pub wq: Linear<B>,
pub wk: Linear<B>,
pub wv: Linear<B>,
pub wo: Linear<B>,
pub n_heads: usize,
pub n_kv_heads: usize,
pub head_dim: usize,
}
impl<B: Backend> Attention<B> {
pub fn new(
dim: usize,
head_dim: usize,
n_heads: usize,
n_kv_heads: usize,
device: &B::Device,
) -> Self {
let z = |i, o| linear_zeros(i, o, false, device);
Self {
wq: z(dim, n_heads * head_dim),
wk: z(dim, n_kv_heads * head_dim),
wv: z(dim, n_kv_heads * head_dim),
wo: z(n_heads * head_dim, dim),
n_heads, n_kv_heads, head_dim,
}
}
pub fn forward(
&self,
x: Tensor<B, 3>,
freqs_4d: Tensor<B, 4>,
) -> Tensor<B, 3> {
let [b, s, _] = x.dims();
let (h, dh) = (self.n_heads, self.head_dim);
let xq = self.wq.forward(x.clone()).reshape([b, s, h, dh]);
let xk = self.wk.forward(x.clone()).reshape([b, s, h, dh]);
let xv = self.wv.forward(x).reshape([b, s, h, dh]);
let (xq, xk) = apply_rope(xq, xk, freqs_4d);
let xq = xq.swap_dims(1, 2); let xk = xk.swap_dims(1, 2);
let xv = xv.swap_dims(1, 2);
let scale = (dh as f64).powf(-0.5) as f32;
let attn = softmax(xq.matmul(xk.transpose()).mul_scalar(scale), 3);
let out = attn.matmul(xv);
self.wo.forward(out.swap_dims(1, 2).reshape([b, s, h * dh]))
}
}