1use alloc::vec::Vec;
2use burn_tensor::ops::IntElem;
3
4use crate::tensor::{Bool, ElementConversion, Int, Shape, Tensor, TensorData, backend::Backend};
5
6pub fn generate_autoregressive_mask<B: Backend>(
10 batch_size: usize,
11 seq_length: usize,
12 device: &B::Device,
13) -> Tensor<B, 3, Bool> {
14 let mut mask = Tensor::<B, 3, Int>::zeros([1, seq_length, seq_length], device);
16
17 for i in 0..(seq_length - 1) {
18 let values = Tensor::<B, 3, Int>::ones([1, 1, seq_length - (i + 1)], device);
19 mask = mask.slice_assign([0..1, i..i + 1, i + 1..seq_length], values);
20 }
21
22 mask = mask.repeat_dim(0, batch_size);
23
24 mask.equal_elem(1_i64.elem::<i64>())
25}
26
27pub struct GeneratePaddingMask<B: Backend> {
29 pub tensor: Tensor<B, 2, Int>,
31
32 pub mask: Tensor<B, 2, Bool>,
34}
35
36pub fn generate_padding_mask<B: Backend>(
38 pad_token: usize,
39 tokens_list: Vec<Vec<usize>>,
40 max_seq_length: Option<usize>,
41 device: &B::Device,
42) -> GeneratePaddingMask<B> {
43 let mut max_size = 0;
44 let batch_size = tokens_list.len();
45
46 for tokens in tokens_list.iter() {
47 if tokens.len() > max_size {
48 max_size = tokens.len();
49 }
50
51 if let Some(max_seq_length) = max_seq_length {
52 if tokens.len() >= max_seq_length {
53 max_size = max_seq_length;
54 break;
55 }
56 }
57 }
58
59 let mut tensor = Tensor::zeros([batch_size, max_size], device);
60 tensor = tensor.add_scalar(pad_token as i64);
61
62 for (index, tokens) in tokens_list.into_iter().enumerate() {
63 let mut seq_length = tokens.len();
64 let mut tokens = tokens;
65
66 if let Some(max_seq_length) = max_seq_length {
67 if seq_length > max_seq_length {
68 seq_length = max_seq_length;
69 let _ = tokens.split_off(seq_length);
70 }
71 }
72
73 tensor = tensor.slice_assign(
74 [index..index + 1, 0..tokens.len()],
75 Tensor::from_data(
76 TensorData::new(
77 tokens
78 .into_iter()
79 .map(|e| (e as i64).elem::<IntElem<B>>())
80 .collect(),
81 Shape::new([1, seq_length]),
82 ),
83 device,
84 ),
85 );
86 }
87
88 let mask = tensor.clone().equal_elem(pad_token as i64);
89
90 GeneratePaddingMask { tensor, mask }
91}
92
93#[cfg(test)]
94mod tests {
95 use super::*;
96 use crate::TestBackend;
97 use crate::tensor::TensorData;
98 use alloc::vec;
99
100 #[test]
101 fn test_generate_autoregressive_mask() {
102 let device = <TestBackend as Backend>::Device::default();
103
104 let mask = generate_autoregressive_mask::<TestBackend>(2, 3, &device);
105
106 mask.into_data().assert_eq(
107 &TensorData::from([
108 [
109 [false, true, true],
110 [false, false, true],
111 [false, false, false],
112 ],
113 [
114 [false, true, true],
115 [false, false, true],
116 [false, false, false],
117 ],
118 ]),
119 false,
120 );
121 }
122
123 #[test]
124 fn test_generate_padding_mask() {
125 let device = <TestBackend as Backend>::Device::default();
126 let tokens = vec![
127 vec![3, 3, 3],
128 vec![3, 3, 3],
129 vec![3, 3, 3, 4],
130 vec![3, 3, 3, 4, 10, 15],
131 ];
132
133 let mask = generate_padding_mask::<TestBackend>(0, tokens, None, &device);
134
135 mask.mask.into_data().assert_eq(
136 &TensorData::from([
137 [false, false, false, true, true, true],
138 [false, false, false, true, true, true],
139 [false, false, false, false, true, true],
140 [false, false, false, false, false, false],
141 ]),
142 false,
143 );
144 }
145}