attnres/block_state.rs
1/// Tracks the state of block accumulation across layers.
2///
3/// This is passed between layers during the forward pass.
4/// It maintains:
5/// - Completed block representations (b_0, b_1, ..., b_{n-1})
6/// - The current partial sum within the active block (b_n^i)
7///
8/// Paper reference: Section 3, Block Attention Residuals.
9use burn::prelude::*;
10
11#[derive(Clone, Debug)]
12pub struct BlockState<B: Backend> {
13 /// Completed block representations.
14 /// blocks[0] = token embedding (b_0 = h_1).
15 /// blocks[n] = sum of layer outputs in block n.
16 pub blocks: Vec<Tensor<B, 3>>, // Each: [batch, seq_len, d_model]
17
18 /// Partial sum of layer outputs within the current (incomplete) block.
19 /// None at the start of a new block.
20 pub partial_block: Option<Tensor<B, 3>>, // [batch, seq_len, d_model]
21}
22
23impl<B: Backend> BlockState<B> {
24 /// Initialize with token embeddings as the first block (b_0).
25 pub fn new(token_embeddings: Tensor<B, 3>) -> Self {
26 Self {
27 blocks: vec![token_embeddings],
28 partial_block: None,
29 }
30 }
31
32 /// Number of completed blocks.
33 pub fn num_blocks(&self) -> usize {
34 self.blocks.len()
35 }
36}
37
38#[cfg(test)]
39mod tests {
40 use super::*;
41 use burn::backend::NdArray;
42
43 type TestBackend = NdArray;
44
45 #[test]
46 fn test_new_state() {
47 let device = Default::default();
48 let emb = Tensor::<TestBackend, 3>::zeros([2, 16, 64], &device);
49 let state = BlockState::new(emb);
50 assert_eq!(state.num_blocks(), 1);
51 assert!(state.partial_block.is_none());
52 }
53}