burn_nn/modules/attention/
mask.rs1use burn_core as burn;
2use burn_core::config::Config;
3
4use alloc::vec::Vec;
5use burn::tensor::ops::IntElem;
6
7use burn::tensor::{Bool, ElementConversion, Int, Shape, Tensor, TensorData, backend::Backend};
8
9pub fn generate_autoregressive_mask<B: Backend>(
13 batch_size: usize,
14 seq_length: usize,
15 device: &B::Device,
16) -> Tensor<B, 3, Bool> {
17 let mask = Tensor::<B, 2, Bool>::tril_mask([seq_length, seq_length], 0, device);
18 mask.expand([batch_size, seq_length, seq_length])
19}
20
21pub struct GeneratePaddingMask<B: Backend> {
23 pub tensor: Tensor<B, 2, Int>,
25
26 pub mask: Tensor<B, 2, Bool>,
28}
29
30#[derive(Config, Debug, Copy)]
32pub enum SeqLengthOption {
33 NoMax,
35 Max(usize),
37 Fixed(usize),
39}
40
41impl From<Option<usize>> for SeqLengthOption {
42 fn from(val: Option<usize>) -> Self {
43 match val {
44 Some(max) => SeqLengthOption::Max(max),
45 None => SeqLengthOption::NoMax,
46 }
47 }
48}
49
50pub fn generate_padding_mask<B: Backend>(
63 pad_token: usize,
64 tokens_list: Vec<Vec<usize>>,
65 seq_length: impl Into<SeqLengthOption>,
66 device: &B::Device,
67) -> GeneratePaddingMask<B> {
68 let tokens_max = || {
69 tokens_list
70 .iter()
71 .map(|tokens| tokens.len())
72 .max()
73 .unwrap_or(1)
74 };
75
76 let size = match seq_length.into() {
77 SeqLengthOption::NoMax => tokens_max(),
78 SeqLengthOption::Max(max) => usize::min(tokens_max(), max),
79 SeqLengthOption::Fixed(limit) => limit,
80 };
81 let batch_size = tokens_list.len();
82
83 let mut tensor = Tensor::zeros([batch_size, size], device);
84 tensor = tensor.add_scalar(pad_token as i64);
85
86 for (index, tokens) in tokens_list.into_iter().enumerate() {
87 let seq_length = tokens.len().min(size);
88 tensor = tensor.slice_assign(
89 [index..index + 1, 0..seq_length],
90 Tensor::from_data(
91 TensorData::new(
92 tokens
93 .into_iter()
94 .take(size)
95 .map(|e| (e as i64).elem::<IntElem<B>>())
96 .collect(),
97 Shape::new([1, seq_length]),
98 ),
99 device,
100 ),
101 );
102 }
103
104 let mask = tensor.clone().equal_elem(pad_token as i64);
105
106 GeneratePaddingMask { tensor, mask }
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112 use crate::TestBackend;
113 use alloc::vec;
114 use burn::tensor::TensorData;
115
116 #[test]
117 fn test_generate_autoregressive_mask() {
118 let device = Default::default();
119
120 let mask = generate_autoregressive_mask::<TestBackend>(2, 3, &device);
121
122 mask.into_data().assert_eq(
123 &TensorData::from([
124 [
125 [false, true, true],
126 [false, false, true],
127 [false, false, false],
128 ],
129 [
130 [false, true, true],
131 [false, false, true],
132 [false, false, false],
133 ],
134 ]),
135 false,
136 );
137 }
138
139 #[test]
140 fn test_generate_padding_mask() {
141 let device = Default::default();
142 let tokens = vec![
143 vec![3, 3, 3],
144 vec![3, 3, 3],
145 vec![3, 3, 3, 4],
146 vec![3, 3, 3, 4, 10, 15],
147 ];
148
149 let mask = generate_padding_mask::<TestBackend>(0, tokens, None, &device);
150
151 mask.mask.into_data().assert_eq(
152 &TensorData::from([
153 [false, false, false, true, true, true],
154 [false, false, false, true, true, true],
155 [false, false, false, false, true, true],
156 [false, false, false, false, false, false],
157 ]),
158 false,
159 );
160 }
161}