burn_core/nn/attention/
mask.rs

1use alloc::vec::Vec;
2use burn_tensor::ops::IntElem;
3
4use crate::tensor::{Bool, ElementConversion, Int, Shape, Tensor, TensorData, backend::Backend};
5
6/// Generate an autoregressive attention mask.
7///
8/// The mask can be used in Transformer modules to train models to generate tensors sequentially.
9pub fn generate_autoregressive_mask<B: Backend>(
10    batch_size: usize,
11    seq_length: usize,
12    device: &B::Device,
13) -> Tensor<B, 3, Bool> {
14    // TODO replace with more efficient op of `triu_mask` and `expand`
15    let mut mask = Tensor::<B, 3, Int>::zeros([1, seq_length, seq_length], device);
16
17    for i in 0..(seq_length - 1) {
18        let values = Tensor::<B, 3, Int>::ones([1, 1, seq_length - (i + 1)], device);
19        mask = mask.slice_assign([0..1, i..i + 1, i + 1..seq_length], values);
20    }
21
22    mask = mask.repeat_dim(0, batch_size);
23
24    mask.equal_elem(1_i64.elem::<i64>())
25}
26
27/// Generate a padding attention mask.
28pub struct GeneratePaddingMask<B: Backend> {
29    /// The generated tensor.
30    pub tensor: Tensor<B, 2, Int>,
31
32    /// The generated mask.
33    pub mask: Tensor<B, 2, Bool>,
34}
35
36/// Generation padding attention mask.
37pub fn generate_padding_mask<B: Backend>(
38    pad_token: usize,
39    tokens_list: Vec<Vec<usize>>,
40    max_seq_length: Option<usize>,
41    device: &B::Device,
42) -> GeneratePaddingMask<B> {
43    let mut max_size = 0;
44    let batch_size = tokens_list.len();
45
46    for tokens in tokens_list.iter() {
47        if tokens.len() > max_size {
48            max_size = tokens.len();
49        }
50
51        if let Some(max_seq_length) = max_seq_length {
52            if tokens.len() >= max_seq_length {
53                max_size = max_seq_length;
54                break;
55            }
56        }
57    }
58
59    let mut tensor = Tensor::zeros([batch_size, max_size], device);
60    tensor = tensor.add_scalar(pad_token as i64);
61
62    for (index, tokens) in tokens_list.into_iter().enumerate() {
63        let mut seq_length = tokens.len();
64        let mut tokens = tokens;
65
66        if let Some(max_seq_length) = max_seq_length {
67            if seq_length > max_seq_length {
68                seq_length = max_seq_length;
69                let _ = tokens.split_off(seq_length);
70            }
71        }
72
73        tensor = tensor.slice_assign(
74            [index..index + 1, 0..tokens.len()],
75            Tensor::from_data(
76                TensorData::new(
77                    tokens
78                        .into_iter()
79                        .map(|e| (e as i64).elem::<IntElem<B>>())
80                        .collect(),
81                    Shape::new([1, seq_length]),
82                ),
83                device,
84            ),
85        );
86    }
87
88    let mask = tensor.clone().equal_elem(pad_token as i64);
89
90    GeneratePaddingMask { tensor, mask }
91}
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96    use crate::TestBackend;
97    use crate::tensor::TensorData;
98    use alloc::vec;
99
100    #[test]
101    fn test_generate_autoregressive_mask() {
102        let device = <TestBackend as Backend>::Device::default();
103
104        let mask = generate_autoregressive_mask::<TestBackend>(2, 3, &device);
105
106        mask.into_data().assert_eq(
107            &TensorData::from([
108                [
109                    [false, true, true],
110                    [false, false, true],
111                    [false, false, false],
112                ],
113                [
114                    [false, true, true],
115                    [false, false, true],
116                    [false, false, false],
117                ],
118            ]),
119            false,
120        );
121    }
122
123    #[test]
124    fn test_generate_padding_mask() {
125        let device = <TestBackend as Backend>::Device::default();
126        let tokens = vec![
127            vec![3, 3, 3],
128            vec![3, 3, 3],
129            vec![3, 3, 3, 4],
130            vec![3, 3, 3, 4, 10, 15],
131        ];
132
133        let mask = generate_padding_mask::<TestBackend>(0, tokens, None, &device);
134
135        mask.mask.into_data().assert_eq(
136            &TensorData::from([
137                [false, false, false, true, true, true],
138                [false, false, false, true, true, true],
139                [false, false, false, false, true, true],
140                [false, false, false, false, false, false],
141            ]),
142            false,
143        );
144    }
145}