use burn::prelude::*;
use burn::module::{Param, ParamId};
use burn::nn;
use super::rope;
#[derive(Debug)]
pub struct Attention<B: Backend> {
pub q_proj: nn::Linear<B>,
pub k_proj: nn::Linear<B>,
pub v_proj: nn::Linear<B>,
pub o_proj: nn::Linear<B>,
pub sinks: Param<Tensor<B, 1>>,
pub num_heads: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
pub num_kv_groups: usize,
pub scaling: f32,
}
impl<B: Backend> Attention<B> {
pub fn new(
hidden_size: usize,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
bias: bool,
device: &B::Device,
) -> Self {
let q_out = num_heads * head_dim;
let kv_out = num_kv_heads * head_dim;
let q_proj = nn::LinearConfig::new(hidden_size, q_out)
.with_bias(bias)
.init(device);
let k_proj = nn::LinearConfig::new(hidden_size, kv_out)
.with_bias(bias)
.init(device);
let v_proj = nn::LinearConfig::new(hidden_size, kv_out)
.with_bias(bias)
.init(device);
let o_proj = nn::LinearConfig::new(q_out, hidden_size)
.with_bias(bias)
.init(device);
let sinks_tensor = Tensor::zeros([num_heads], device);
let scaling = (head_dim as f32).powf(-0.25);
Self {
q_proj,
k_proj,
v_proj,
o_proj,
sinks: Param::initialized(ParamId::new(), sinks_tensor),
num_heads,
num_kv_heads,
head_dim,
num_kv_groups: num_heads / num_kv_heads,
scaling,
}
}
pub fn forward(
&self,
hidden_states: Tensor<B, 3>,
cos: &Tensor<B, 3>,
sin: &Tensor<B, 3>,
attention_mask: &Tensor<B, 4>,
) -> Tensor<B, 3> {
let [batch, seq_len, _] = hidden_states.dims();
let q = self.q_proj.forward(hidden_states.clone());
let k = self.k_proj.forward(hidden_states.clone());
let v = self.v_proj.forward(hidden_states);
let q = q.reshape([batch, seq_len, self.num_heads, self.head_dim])
.swap_dims(1, 2);
let k = k.reshape([batch, seq_len, self.num_kv_heads, self.head_dim])
.swap_dims(1, 2);
let v = v.reshape([batch, seq_len, self.num_kv_heads, self.head_dim])
.swap_dims(1, 2);
let (q, k) = rope::apply_rotary_emb(q, k, cos, sin);
let q = q.mul_scalar(self.scaling);
let k = k.mul_scalar(self.scaling);
let k = repeat_kv(k, self.num_kv_groups);
let v = repeat_kv(v, self.num_kv_groups);
let attn_weights = q.matmul(k.swap_dims(2, 3));
let attn_weights = attn_weights + attention_mask.clone();
let sinks = self.sinks.val().clone()
.reshape([1, self.num_heads, 1, 1])
.expand([batch, self.num_heads, seq_len, 1]);
let combined = Tensor::cat(vec![attn_weights, sinks], 3);
let max_vals = combined.clone().max_dim(3);
let combined = combined - max_vals;
let probs = burn::tensor::activation::softmax(combined, 3);
let scores = probs.slice([0..batch, 0..self.num_heads, 0..seq_len, 0..seq_len]);
let attn_output = scores.matmul(v);
let attn_output = attn_output
.swap_dims(1, 2)
.reshape([batch, seq_len, self.num_heads * self.head_dim]);
self.o_proj.forward(attn_output)
}
}
fn repeat_kv<B: Backend>(x: Tensor<B, 4>, n_rep: usize) -> Tensor<B, 4> {
if n_rep == 1 {
return x;
}
let [batch, kv_heads, seq_len, head_dim] = x.dims();
let x = x.unsqueeze_dim::<5>(2);
let x = x.expand([batch, kv_heads, n_rep, seq_len, head_dim]);
x.reshape([batch, kv_heads * n_rep, seq_len, head_dim])
}
pub fn create_sliding_window_mask<B: Backend>(
seq_len: usize,
window_size: usize,
device: &B::Device,
) -> Tensor<B, 4> {
let n = seq_len;
let mut mask_data = vec![0f32; n * n];
let neg_inf: f32 = -1e9;
for i in 0..n {
for j in 0..n {
let dist = if i > j { i - j } else { j - i };
if dist > window_size {
mask_data[i * n + j] = neg_inf;
}
}
}
Tensor::<B, 4>::from_data(
TensorData::new(mask_data, [1, 1, n, n]),
device,
)
}