use burn::prelude::*;
use burn::nn::Linear;
use burn::tensor::activation::softmax;
use crate::model::rope::apply_rope;
use crate::model::linear_zeros;
#[derive(Module, Debug)]
pub struct CrossAttention<B: Backend> {
pub wq: Linear<B>,
pub wk: Linear<B>,
pub wv: Linear<B>,
pub wo: Linear<B>,
pub n_heads: usize,
pub n_kv_heads: usize,
pub head_dim: usize,
}
impl<B: Backend> CrossAttention<B> {
pub fn new(
dim: usize,
head_dim: usize,
n_heads: usize,
n_kv_heads: usize,
device: &B::Device,
) -> Self {
let z = |i, o| linear_zeros(i, o, false, device);
Self {
wq: z(dim, n_heads * head_dim),
wk: z(dim, n_kv_heads * head_dim),
wv: z(dim, n_kv_heads * head_dim),
wo: z(n_heads * head_dim, dim),
n_heads, n_kv_heads, head_dim,
}
}
pub fn forward(
&self,
xq: Tensor<B, 3>,
xkv: Tensor<B, 3>,
freqs_q: Tensor<B, 4>,
freqs_kv: Tensor<B, 4>,
) -> Tensor<B, 3> {
let [b, s_q, _] = xq.dims();
let [_, s_kv, _] = xkv.dims();
let (h, dh) = (self.n_heads, self.head_dim);
let device = xq.device();
let q = self.wq.forward(xq).reshape([b, s_q, h, dh]);
let k = self.wk.forward(xkv.clone()).reshape([b, s_kv, h, dh]);
let v = self.wv.forward(xkv).reshape([b, s_kv, h, dh]);
let (q_rot, _) = apply_rope(
q,
Tensor::zeros([b, s_q, h, dh], &device),
freqs_q,
);
let (_, k_rot) = apply_rope(
Tensor::zeros([b, s_kv, h, dh], &device),
k,
freqs_kv,
);
let q_t = q_rot.swap_dims(1, 2); let k_t = k_rot.swap_dims(1, 2); let v_t = v.swap_dims(1, 2);
let scale = (dh as f64).powf(-0.5) as f32;
let attn = softmax(q_t.matmul(k_t.transpose()).mul_scalar(scale), 3);
let out = attn.matmul(v_t);
self.wo.forward(out.swap_dims(1, 2).reshape([b, s_q, h * dh]))
}
}