Skip to main content

attnres/
block_state.rs

1/// Tracks the state of block accumulation across layers.
2///
3/// This is passed between layers during the forward pass.
4/// It maintains:
5/// - Completed block representations (b_0, b_1, ..., b_{n-1})
6/// - The current partial sum within the active block (b_n^i)
7///
8/// Paper reference: Section 3, Block Attention Residuals.
9use burn::prelude::*;
10
11#[derive(Clone, Debug)]
12pub struct BlockState<B: Backend> {
13    /// Completed block representations.
14    /// blocks[0] = token embedding (b_0 = h_1).
15    /// blocks[n] = sum of layer outputs in block n.
16    pub blocks: Vec<Tensor<B, 3>>, // Each: [batch, seq_len, d_model]
17
18    /// Partial sum of layer outputs within the current (incomplete) block.
19    /// None at the start of a new block.
20    pub partial_block: Option<Tensor<B, 3>>, // [batch, seq_len, d_model]
21}
22
23impl<B: Backend> BlockState<B> {
24    /// Initialize with token embeddings as the first block (b_0).
25    pub fn new(token_embeddings: Tensor<B, 3>) -> Self {
26        Self {
27            blocks: vec![token_embeddings],
28            partial_block: None,
29        }
30    }
31
32    /// Number of completed blocks.
33    pub fn num_blocks(&self) -> usize {
34        self.blocks.len()
35    }
36}
37
38#[cfg(test)]
39mod tests {
40    use super::*;
41    use burn::backend::NdArray;
42
43    type TestBackend = NdArray;
44
45    #[test]
46    fn test_new_state() {
47        let device = Default::default();
48        let emb = Tensor::<TestBackend, 3>::zeros([2, 16, 64], &device);
49        let state = BlockState::new(emb);
50        assert_eq!(state.num_blocks(), 1);
51        assert!(state.partial_block.is_none());
52    }
53}