use burn::module::{Module, Param};
use burn::prelude::*;
use burn::tensor::activation::softmax;
use crate::config::AttnResConfig;
use crate::rms_norm::{RmsNorm, RmsNormConfig};
#[derive(Module, Debug)]
pub struct AttnResOp<B: Backend> {
pub pseudo_query: Param<Tensor<B, 1>>,
pub norm: RmsNorm<B>,
}
impl AttnResConfig {
pub fn init_op<B: Backend>(&self, device: &B::Device) -> AttnResOp<B> {
AttnResOp {
pseudo_query: Param::from_tensor(Tensor::zeros([self.d_model], device)),
norm: RmsNormConfig::new(self.d_model)
.with_eps(self.rms_norm_eps)
.init(device),
}
}
}
impl<B: Backend> AttnResOp<B> {
pub fn forward_optional_partial(
&self,
blocks: &[Tensor<B, 3>],
partial_block: Option<&Tensor<B, 3>>,
) -> Tensor<B, 3> {
let mut sources: Vec<Tensor<B, 3>> = blocks.to_vec();
if let Some(partial_block) = partial_block {
sources.push(partial_block.clone());
}
assert!(
!sources.is_empty(),
"AttnResOp requires at least one source tensor"
);
let v = Tensor::stack(sources, 0);
let k = self.norm.forward_4d(v.clone());
let w = self
.pseudo_query
.val()
.unsqueeze_dim::<2>(0)
.unsqueeze_dim::<3>(0)
.unsqueeze_dim::<4>(0); let logits = (k * w).sum_dim(3).squeeze_dim::<3>(3);
let alpha = softmax(logits, 0);
let alpha_expanded = alpha.unsqueeze_dim::<4>(3);
let weighted = v * alpha_expanded;
weighted.sum_dim(0).squeeze_dim::<3>(0)
}
pub fn forward(&self, blocks: &[Tensor<B, 3>], partial_block: &Tensor<B, 3>) -> Tensor<B, 3> {
self.forward_optional_partial(blocks, Some(partial_block))
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
use burn::tensor::Distribution;
type TestBackend = NdArray;
#[test]
fn test_output_shape() {
let device = Default::default();
let config = AttnResConfig::new(64, 12, 4);
let op = config.init_op::<TestBackend>(&device);
let blocks = vec![
Tensor::random([2, 16, 64], Distribution::Normal(0.0, 1.0), &device),
Tensor::random([2, 16, 64], Distribution::Normal(0.0, 1.0), &device),
Tensor::random([2, 16, 64], Distribution::Normal(0.0, 1.0), &device),
];
let partial = Tensor::random([2, 16, 64], Distribution::Normal(0.0, 1.0), &device);
let output = op.forward(&blocks, &partial);
assert_eq!(output.dims(), [2, 16, 64]);
}
#[test]
fn test_zero_init_uniform_weights() {
let device = Default::default();
let config = AttnResConfig::new(64, 12, 4);
let op = config.init_op::<TestBackend>(&device);
let blocks = vec![
Tensor::random([2, 16, 64], Distribution::Normal(0.0, 1.0), &device),
Tensor::random([2, 16, 64], Distribution::Normal(0.0, 1.0), &device),
];
let partial = Tensor::random([2, 16, 64], Distribution::Normal(0.0, 1.0), &device);
let output = op.forward(&blocks, &partial);
let expected = (blocks[0].clone() + blocks[1].clone() + partial) / 3.0;
let diff: f32 = (output - expected).abs().max().into_scalar();
assert!(
diff < 1e-4,
"Zero-init should produce uniform weights (mean of sources), diff={diff}"
);
}
#[test]
fn test_single_block_is_mean() {
let device = Default::default();
let config = AttnResConfig::new(32, 4, 4);
let op = config.init_op::<TestBackend>(&device);
let blocks = vec![Tensor::random(
[1, 8, 32],
Distribution::Normal(0.0, 1.0),
&device,
)];
let partial = Tensor::random([1, 8, 32], Distribution::Normal(0.0, 1.0), &device);
let output = op.forward(&blocks, &partial);
let expected = (blocks[0].clone() + partial) / 2.0;
let diff: f32 = (output - expected).abs().max().into_scalar();
assert!(diff < 1e-4, "Single block should produce mean, diff={diff}");
}
#[test]
fn test_blocks_only_returns_only_source() {
let device = Default::default();
let config = AttnResConfig::new(32, 4, 2);
let op = config.init_op::<TestBackend>(&device);
let embedding = Tensor::random([1, 8, 32], Distribution::Normal(0.0, 1.0), &device);
let output = op.forward_optional_partial(&[embedding.clone()], None);
let diff: f32 = (output - embedding).abs().max().into_scalar();
assert!(
diff < 1e-5,
"A single completed block should be returned unchanged, diff={diff}"
);
}
}