burn_core/nn/attention/
mask.rs

1use alloc::vec::Vec;
2use burn_tensor::ops::IntElem;
3
4use crate::tensor::{Bool, ElementConversion, Int, Shape, Tensor, TensorData, backend::Backend};
5
6/// Generate an autoregressive attention mask.
7///
8/// The mask can be used in Transformer modules to train models to generate tensors sequentially.
9pub fn generate_autoregressive_mask<B: Backend>(
10    batch_size: usize,
11    seq_length: usize,
12    device: &B::Device,
13) -> Tensor<B, 3, Bool> {
14    let mask = Tensor::<B, 2, Bool>::tril_mask([seq_length, seq_length], 0, device);
15    mask.expand([batch_size, seq_length, seq_length])
16}
17
18/// Generate a padding attention mask.
19pub struct GeneratePaddingMask<B: Backend> {
20    /// The generated tensor.
21    pub tensor: Tensor<B, 2, Int>,
22
23    /// The generated mask.
24    pub mask: Tensor<B, 2, Bool>,
25}
26
27/// Generation padding attention mask.
28pub fn generate_padding_mask<B: Backend>(
29    pad_token: usize,
30    tokens_list: Vec<Vec<usize>>,
31    max_seq_length: Option<usize>,
32    device: &B::Device,
33) -> GeneratePaddingMask<B> {
34    let mut max_size = 0;
35    let batch_size = tokens_list.len();
36
37    for tokens in tokens_list.iter() {
38        if tokens.len() > max_size {
39            max_size = tokens.len();
40        }
41
42        if let Some(max_seq_length) = max_seq_length {
43            if tokens.len() >= max_seq_length {
44                max_size = max_seq_length;
45                break;
46            }
47        }
48    }
49
50    let mut tensor = Tensor::zeros([batch_size, max_size], device);
51    tensor = tensor.add_scalar(pad_token as i64);
52
53    for (index, tokens) in tokens_list.into_iter().enumerate() {
54        let seq_length = tokens.len().min(max_size);
55        tensor = tensor.slice_assign(
56            [index..index + 1, 0..seq_length],
57            Tensor::from_data(
58                TensorData::new(
59                    tokens
60                        .into_iter()
61                        .take(max_size)
62                        .map(|e| (e as i64).elem::<IntElem<B>>())
63                        .collect(),
64                    Shape::new([1, seq_length]),
65                ),
66                device,
67            ),
68        );
69    }
70
71    let mask = tensor.clone().equal_elem(pad_token as i64);
72
73    GeneratePaddingMask { tensor, mask }
74}
75
76#[cfg(test)]
77mod tests {
78    use super::*;
79    use crate::TestBackend;
80    use crate::tensor::TensorData;
81    use alloc::vec;
82
83    #[test]
84    fn test_generate_autoregressive_mask() {
85        let device = <TestBackend as Backend>::Device::default();
86
87        let mask = generate_autoregressive_mask::<TestBackend>(2, 3, &device);
88
89        mask.into_data().assert_eq(
90            &TensorData::from([
91                [
92                    [false, true, true],
93                    [false, false, true],
94                    [false, false, false],
95                ],
96                [
97                    [false, true, true],
98                    [false, false, true],
99                    [false, false, false],
100                ],
101            ]),
102            false,
103        );
104    }
105
106    #[test]
107    fn test_generate_padding_mask() {
108        let device = <TestBackend as Backend>::Device::default();
109        let tokens = vec![
110            vec![3, 3, 3],
111            vec![3, 3, 3],
112            vec![3, 3, 3, 4],
113            vec![3, 3, 3, 4, 10, 15],
114        ];
115
116        let mask = generate_padding_mask::<TestBackend>(0, tokens, None, &device);
117
118        mask.mask.into_data().assert_eq(
119            &TensorData::from([
120                [false, false, false, true, true, true],
121                [false, false, false, true, true, true],
122                [false, false, false, false, true, true],
123                [false, false, false, false, false, false],
124            ]),
125            false,
126        );
127    }
128}