1use alloc::vec::Vec;
2use burn_tensor::ops::IntElem;
3
4use crate::tensor::{Bool, ElementConversion, Int, Shape, Tensor, TensorData, backend::Backend};
5
6pub fn generate_autoregressive_mask<B: Backend>(
10 batch_size: usize,
11 seq_length: usize,
12 device: &B::Device,
13) -> Tensor<B, 3, Bool> {
14 let mask = Tensor::<B, 2, Bool>::tril_mask([seq_length, seq_length], 0, device);
15 mask.expand([batch_size, seq_length, seq_length])
16}
17
18pub struct GeneratePaddingMask<B: Backend> {
20 pub tensor: Tensor<B, 2, Int>,
22
23 pub mask: Tensor<B, 2, Bool>,
25}
26
27pub fn generate_padding_mask<B: Backend>(
29 pad_token: usize,
30 tokens_list: Vec<Vec<usize>>,
31 max_seq_length: Option<usize>,
32 device: &B::Device,
33) -> GeneratePaddingMask<B> {
34 let mut max_size = 0;
35 let batch_size = tokens_list.len();
36
37 for tokens in tokens_list.iter() {
38 if tokens.len() > max_size {
39 max_size = tokens.len();
40 }
41
42 if let Some(max_seq_length) = max_seq_length {
43 if tokens.len() >= max_seq_length {
44 max_size = max_seq_length;
45 break;
46 }
47 }
48 }
49
50 let mut tensor = Tensor::zeros([batch_size, max_size], device);
51 tensor = tensor.add_scalar(pad_token as i64);
52
53 for (index, tokens) in tokens_list.into_iter().enumerate() {
54 let seq_length = tokens.len().min(max_size);
55 tensor = tensor.slice_assign(
56 [index..index + 1, 0..seq_length],
57 Tensor::from_data(
58 TensorData::new(
59 tokens
60 .into_iter()
61 .take(max_size)
62 .map(|e| (e as i64).elem::<IntElem<B>>())
63 .collect(),
64 Shape::new([1, seq_length]),
65 ),
66 device,
67 ),
68 );
69 }
70
71 let mask = tensor.clone().equal_elem(pad_token as i64);
72
73 GeneratePaddingMask { tensor, mask }
74}
75
76#[cfg(test)]
77mod tests {
78 use super::*;
79 use crate::TestBackend;
80 use crate::tensor::TensorData;
81 use alloc::vec;
82
83 #[test]
84 fn test_generate_autoregressive_mask() {
85 let device = <TestBackend as Backend>::Device::default();
86
87 let mask = generate_autoregressive_mask::<TestBackend>(2, 3, &device);
88
89 mask.into_data().assert_eq(
90 &TensorData::from([
91 [
92 [false, true, true],
93 [false, false, true],
94 [false, false, false],
95 ],
96 [
97 [false, true, true],
98 [false, false, true],
99 [false, false, false],
100 ],
101 ]),
102 false,
103 );
104 }
105
106 #[test]
107 fn test_generate_padding_mask() {
108 let device = <TestBackend as Backend>::Device::default();
109 let tokens = vec![
110 vec![3, 3, 3],
111 vec![3, 3, 3],
112 vec![3, 3, 3, 4],
113 vec![3, 3, 3, 4, 10, 15],
114 ];
115
116 let mask = generate_padding_mask::<TestBackend>(0, tokens, None, &device);
117
118 mask.mask.into_data().assert_eq(
119 &TensorData::from([
120 [false, false, false, true, true, true],
121 [false, false, false, true, true, true],
122 [false, false, false, false, true, true],
123 [false, false, false, false, false, false],
124 ]),
125 false,
126 );
127 }
128}