#![allow(dead_code)]
use burn::nn::{Linear, LinearConfig, RmsNorm, RmsNormConfig, RotaryEncoding, RotaryEncodingConfig};
use burn::prelude::*;
#[derive(Config, Debug)]
pub struct FullAttentionConfig {
pub d_model: usize,
pub n_heads: usize,
pub n_kv_heads: usize,
pub head_dim: usize,
pub max_seq_len: usize,
#[config(default = false)]
pub qk_norm: bool,
#[config(default = 1.0)]
pub partial_rotary_factor: f64,
#[config(default = 10_000.0)]
pub rope_theta: f32,
#[config(default = 1e-6)]
pub rms_norm_eps: f64,
}
impl FullAttentionConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> FullAttention<B> {
let no_bias = |d_in, d_out| {
LinearConfig::new(d_in, d_out)
.with_bias(false)
.init(device)
};
let q_proj = no_bias(self.d_model, self.n_heads * self.head_dim);
let k_proj = no_bias(self.d_model, self.n_kv_heads * self.head_dim);
let v_proj = no_bias(self.d_model, self.n_kv_heads * self.head_dim);
let o_proj = no_bias(self.n_heads * self.head_dim, self.d_model);
let q_norm = if self.qk_norm {
Some(
RmsNormConfig::new(self.head_dim)
.with_epsilon(self.rms_norm_eps)
.init(device),
)
} else {
None
};
let k_norm = if self.qk_norm {
Some(
RmsNormConfig::new(self.head_dim)
.with_epsilon(self.rms_norm_eps)
.init(device),
)
} else {
None
};
let rotary_dim =
(self.head_dim as f64 * self.partial_rotary_factor).floor() as usize;
let rope = if rotary_dim > 0 {
Some(
RotaryEncodingConfig::new(self.max_seq_len, rotary_dim)
.with_theta(self.rope_theta)
.init(device),
)
} else {
None
};
FullAttention {
q_proj,
k_proj,
v_proj,
o_proj,
q_norm,
k_norm,
rope,
n_heads: self.n_heads,
n_kv_heads: self.n_kv_heads,
head_dim: self.head_dim,
rotary_dim,
}
}
}
#[derive(Module, Debug)]
pub struct FullAttention<B: Backend> {
pub(crate) q_proj: Linear<B>,
pub(crate) k_proj: Linear<B>,
pub(crate) v_proj: Linear<B>,
pub(crate) o_proj: Linear<B>,
pub(crate) q_norm: Option<RmsNorm<B>>,
pub(crate) k_norm: Option<RmsNorm<B>>,
pub(crate) rope: Option<RotaryEncoding<B>>,
pub(crate) n_heads: usize,
pub(crate) n_kv_heads: usize,
pub(crate) head_dim: usize,
pub(crate) rotary_dim: usize,
}
impl<B: Backend> FullAttention<B> {
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let [batch, seq_len, _] = x.dims();
let device = x.device();
let q = self.q_proj.forward(x.clone());
let k = self.k_proj.forward(x.clone());
let v = self.v_proj.forward(x);
let q = q
.reshape([batch, seq_len, self.n_heads, self.head_dim])
.swap_dims(1, 2);
let k = k
.reshape([batch, seq_len, self.n_kv_heads, self.head_dim])
.swap_dims(1, 2);
let v = v
.reshape([batch, seq_len, self.n_kv_heads, self.head_dim])
.swap_dims(1, 2);
let q = self.apply_optional_norm(&self.q_norm, q);
let k = self.apply_optional_norm(&self.k_norm, k);
let q = self.apply_partial_rope(q);
let k = self.apply_partial_rope(k);
let kv_repeat = self.n_heads / self.n_kv_heads;
let k = if kv_repeat > 1 {
Self::repeat_kv(k, kv_repeat)
} else {
k
};
let v = if kv_repeat > 1 {
Self::repeat_kv(v, kv_repeat)
} else {
v
};
let scale = (self.head_dim as f64).sqrt();
let scores = q.matmul(k.transpose()) / scale;
let mask = Self::causal_mask(seq_len, &device);
let scores = scores + mask;
let attn = burn::tensor::activation::softmax(scores, 3);
let out = attn.matmul(v);
let out = out
.swap_dims(1, 2)
.reshape([batch, seq_len, self.n_heads * self.head_dim]);
self.o_proj.forward(out)
}
fn apply_optional_norm(
&self,
norm: &Option<RmsNorm<B>>,
x: Tensor<B, 4>,
) -> Tensor<B, 4> {
match norm {
Some(n) => n.forward(x),
None => x,
}
}
fn apply_partial_rope(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let rope = match &self.rope {
Some(r) => r,
None => return x, };
if self.rotary_dim == self.head_dim {
return rope.forward(x);
}
let rotary_part = x.clone().narrow(3, 0, self.rotary_dim);
let pass_through = x.narrow(3, self.rotary_dim, self.head_dim - self.rotary_dim);
let rotary_part = rope.forward(rotary_part);
Tensor::cat(vec![rotary_part, pass_through], 3)
}
fn repeat_kv(x: Tensor<B, 4>, n_rep: usize) -> Tensor<B, 4> {
let [batch, n_kv_heads, seq_len, head_dim] = x.dims();
let x = x.unsqueeze_dim::<5>(2);
let x = x.repeat_dim(2, n_rep);
x.reshape([batch, n_kv_heads * n_rep, seq_len, head_dim])
}
fn causal_mask(seq_len: usize, device: &B::Device) -> Tensor<B, 4> {
let rows = Tensor::<B, 1, Int>::arange(0..(seq_len as i64), device)
.reshape([seq_len, 1])
.float();
let cols = Tensor::<B, 1, Int>::arange(0..(seq_len as i64), device)
.reshape([1, seq_len])
.float();
let future = cols.greater(rows);
let zeros = Tensor::<B, 2>::zeros([seq_len, seq_len], device);
let mask = zeros.mask_fill(future, -1e9);
mask.unsqueeze_dim::<3>(0).unsqueeze_dim::<4>(0)
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
type B = NdArray<f32>;
fn test_config() -> FullAttentionConfig {
FullAttentionConfig {
d_model: 64,
n_heads: 4,
n_kv_heads: 2,
head_dim: 16,
max_seq_len: 128,
qk_norm: true,
partial_rotary_factor: 0.25,
rope_theta: 10_000.0,
rms_norm_eps: 1e-6,
}
}
#[test]
fn forward_preserves_shape() {
let device = Default::default();
let attn = test_config().init::<B>(&device);
let x = Tensor::<B, 3>::zeros([2, 8, 64], &device);
let out = attn.forward(x);
assert_eq!(out.dims(), [2, 8, 64]);
}
#[test]
fn forward_single_token() {
let device = Default::default();
let attn = test_config().init::<B>(&device);
let x = Tensor::<B, 3>::zeros([1, 1, 64], &device);
let out = attn.forward(x);
assert_eq!(out.dims(), [1, 1, 64]);
}
#[test]
fn causal_mask_blocks_future() {
let device: <B as Backend>::Device = Default::default();
let mask = FullAttention::<B>::causal_mask(4, &device);
assert_eq!(mask.dims(), [1, 1, 4, 4]);
let data = mask.squeeze_dim::<3>(0).squeeze_dim::<2>(0);
let vals: Vec<f32> = data.to_data().to_vec().unwrap();
assert!((vals[0]).abs() < 1e-6, "mask[0,0] should be 0"); assert!(vals[1] < -1e8, "mask[0,1] should be large negative"); assert!((vals[12]).abs() < 1e-6, "mask[3,0] should be 0"); assert!((vals[15]).abs() < 1e-6, "mask[3,3] should be 0"); }
#[test]
fn partial_rope_leaves_passthrough_unchanged() {
let device = Default::default();
let attn = test_config().init::<B>(&device);
assert_eq!(attn.rotary_dim, 4);
let x = Tensor::<B, 4>::zeros([1, 4, 2, 16], &device);
let ones_part = Tensor::<B, 4>::ones([1, 4, 2, 12], &device);
let x = Tensor::cat(vec![x.narrow(3, 0, 4), ones_part], 3);
let out = attn.apply_partial_rope(x);
let pass_through = out.narrow(3, 4, 12);
let vals: Vec<f32> = pass_through.to_data().to_vec().unwrap();
for (i, v) in vals.iter().enumerate() {
assert!(
(v - 1.0).abs() < 1e-6,
"pass-through dim {i} should be 1.0, got {v}"
);
}
}
#[test]
fn repeat_kv_expands_correctly() {
let device: <B as Backend>::Device = Default::default();
let x = Tensor::<B, 4>::ones([1, 2, 4, 8], &device);
let expanded = FullAttention::<B>::repeat_kv(x, 3);
assert_eq!(expanded.dims(), [1, 6, 4, 8]);
}
#[test]
fn no_qk_norm_works() {
let device = Default::default();
let mut cfg = test_config();
cfg.qk_norm = false;
let attn = cfg.init::<B>(&device);
let x = Tensor::<B, 3>::random(
[1, 4, 64],
burn::tensor::Distribution::Normal(0.0, 1.0),
&device,
);
let out = attn.forward(x);
assert_eq!(out.dims(), [1, 4, 64]);
}
#[test]
fn full_rope_when_factor_is_one() {
let device = Default::default();
let mut cfg = test_config();
cfg.partial_rotary_factor = 1.0;
let attn = cfg.init::<B>(&device);
assert_eq!(attn.rotary_dim, cfg.head_dim);
}
}