Skip to main content

burn_nn/modules/attention/
mask.rs

1use burn_core as burn;
2use burn_core::config::Config;
3
4use alloc::vec::Vec;
5use burn::tensor::ops::IntElem;
6
7use burn::tensor::{Bool, ElementConversion, Int, Shape, Tensor, TensorData, backend::Backend};
8
9/// Generate an autoregressive attention mask.
10///
11/// The mask can be used in Transformer modules to train models to generate tensors sequentially.
12pub fn generate_autoregressive_mask<B: Backend>(
13    batch_size: usize,
14    seq_length: usize,
15    device: &B::Device,
16) -> Tensor<B, 3, Bool> {
17    let mask = Tensor::<B, 2, Bool>::tril_mask([seq_length, seq_length], 0, device);
18    mask.expand([batch_size, seq_length, seq_length])
19}
20
21/// Generate a padding attention mask.
22pub struct GeneratePaddingMask<B: Backend> {
23    /// The generated tensor.
24    pub tensor: Tensor<B, 2, Int>,
25
26    /// The generated mask.
27    pub mask: Tensor<B, 2, Bool>,
28}
29
30/// Defines an enumeration to specify sequence length options for padding
31#[derive(Config, Debug, Copy)]
32pub enum SeqLengthOption {
33    /// No maximum length; use the longest sequence
34    NoMax,
35    /// Maximum length specified, truncate if necessary
36    Max(usize),
37    /// Fixed length, pad or truncate to this exact length
38    Fixed(usize),
39}
40
41impl From<Option<usize>> for SeqLengthOption {
42    fn from(val: Option<usize>) -> Self {
43        match val {
44            Some(max) => SeqLengthOption::Max(max),
45            None => SeqLengthOption::NoMax,
46        }
47    }
48}
49
50/// Generates a padding attention mask for a batch of token sequences.
51///
52/// # Arguments
53///
54/// * `pad_token` - The token ID used for padding
55/// * `tokens_list` - Vector of token sequences (each sequence is a vector of token IDs)
56/// * `seq_length` - Sequence length option (NoMax, Max, or Fixed)
57/// * `device` - The device for tensor operations
58///
59/// # Returns
60///
61/// A `GeneratePaddingMask` containing the padded tensor and corresponding mask
62pub fn generate_padding_mask<B: Backend>(
63    pad_token: usize,
64    tokens_list: Vec<Vec<usize>>,
65    seq_length: impl Into<SeqLengthOption>,
66    device: &B::Device,
67) -> GeneratePaddingMask<B> {
68    let tokens_max = || {
69        tokens_list
70            .iter()
71            .map(|tokens| tokens.len())
72            .max()
73            .unwrap_or(1)
74    };
75
76    let size = match seq_length.into() {
77        SeqLengthOption::NoMax => tokens_max(),
78        SeqLengthOption::Max(max) => usize::min(tokens_max(), max),
79        SeqLengthOption::Fixed(limit) => limit,
80    };
81    let batch_size = tokens_list.len();
82
83    let mut tensor = Tensor::zeros([batch_size, size], device);
84    tensor = tensor.add_scalar(pad_token as i64);
85
86    for (index, tokens) in tokens_list.into_iter().enumerate() {
87        let seq_length = tokens.len().min(size);
88        tensor = tensor.slice_assign(
89            [index..index + 1, 0..seq_length],
90            Tensor::from_data(
91                TensorData::new(
92                    tokens
93                        .into_iter()
94                        .take(size)
95                        .map(|e| (e as i64).elem::<IntElem<B>>())
96                        .collect(),
97                    Shape::new([1, seq_length]),
98                ),
99                device,
100            ),
101        );
102    }
103
104    let mask = tensor.clone().equal_elem(pad_token as i64);
105
106    GeneratePaddingMask { tensor, mask }
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112    use crate::TestBackend;
113    use alloc::vec;
114    use burn::tensor::TensorData;
115
116    #[test]
117    fn test_generate_autoregressive_mask() {
118        let device = Default::default();
119
120        let mask = generate_autoregressive_mask::<TestBackend>(2, 3, &device);
121
122        mask.into_data().assert_eq(
123            &TensorData::from([
124                [
125                    [false, true, true],
126                    [false, false, true],
127                    [false, false, false],
128                ],
129                [
130                    [false, true, true],
131                    [false, false, true],
132                    [false, false, false],
133                ],
134            ]),
135            false,
136        );
137    }
138
139    #[test]
140    fn test_generate_padding_mask() {
141        let device = Default::default();
142        let tokens = vec![
143            vec![3, 3, 3],
144            vec![3, 3, 3],
145            vec![3, 3, 3, 4],
146            vec![3, 3, 3, 4, 10, 15],
147        ];
148
149        let mask = generate_padding_mask::<TestBackend>(0, tokens, None, &device);
150
151        mask.mask.into_data().assert_eq(
152            &TensorData::from([
153                [false, false, false, true, true, true],
154                [false, false, false, true, true, true],
155                [false, false, false, false, true, true],
156                [false, false, false, false, false, false],
157            ]),
158            false,
159        );
160    }
161}