1use crate::processors::byte_level::process_offsets;
2use crate::tokenizer::{Encoding, PostProcessor, Result};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::iter::FromIterator;
6
7#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
8#[serde(tag = "type")]
9pub struct RobertaProcessing {
10 sep: (String, u32),
11 cls: (String, u32),
12 trim_offsets: bool,
13 add_prefix_space: bool,
14}
15
16impl Default for RobertaProcessing {
17 fn default() -> Self {
18 Self {
19 sep: ("</s>".into(), 2),
20 cls: ("<s>".into(), 0),
21 trim_offsets: true,
22 add_prefix_space: true,
23 }
24 }
25}
26
27impl RobertaProcessing {
28 pub fn new(sep: (String, u32), cls: (String, u32)) -> Self {
29 Self {
30 sep,
31 cls,
32 ..Default::default()
33 }
34 }
35
36 #[must_use]
37 pub fn trim_offsets(mut self, v: bool) -> Self {
38 self.trim_offsets = v;
39 self
40 }
41
42 #[must_use]
43 pub fn add_prefix_space(mut self, v: bool) -> Self {
44 self.add_prefix_space = v;
45 self
46 }
47}
48
49impl PostProcessor for RobertaProcessing {
50 fn added_tokens(&self, is_pair: bool) -> usize {
51 if is_pair {
52 4
53 } else {
54 2
55 }
56 }
57
58 fn process_encodings(
59 &self,
60 mut encodings: Vec<Encoding>,
61 add_special_tokens: bool,
62 ) -> Result<Vec<Encoding>> {
63 if self.trim_offsets {
64 for encoding in encodings.iter_mut() {
65 process_offsets(encoding, self.add_prefix_space);
66 encoding
67 .get_overflowing_mut()
68 .iter_mut()
69 .for_each(|encoding| process_offsets(encoding, self.add_prefix_space));
70 }
71 }
72
73 encodings
75 .iter_mut()
76 .for_each(|encoding| encoding.set_type_ids(vec![0; encoding.len()]));
77
78 if !add_special_tokens {
79 return Ok(encodings);
80 }
81
82 let encodings: Vec<Encoding> = encodings
83 .iter_mut()
84 .enumerate()
85 .map(|(i, encoding)| {
86 if i == 0 {
87 let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
88 let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
89 let tokens = [
90 &[self.cls.0.clone()],
91 encoding.get_tokens(),
92 &[self.sep.0.clone()],
93 ]
94 .concat();
95 let words = [&[None], encoding.get_word_ids(), &[None]].concat();
96 let offsets = [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
97 let special_tokens =
98 [&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat();
99 let attention_mask = vec![1; ids.len()];
100
101 let sequence_ranges = HashMap::from_iter(vec![(0, 1..ids.len() - 1)]);
104 Encoding::new(
105 ids,
106 type_ids,
107 tokens,
108 words,
109 offsets,
110 special_tokens,
111 attention_mask,
112 encoding
113 .take_overflowing()
114 .into_iter()
115 .map(|encoding| {
116 let ids =
117 [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
118 let type_ids = vec![0; encoding.get_ids().len() + 2];
119 let tokens = [
120 &[self.cls.0.clone()],
121 encoding.get_tokens(),
122 &[self.sep.0.clone()],
123 ]
124 .concat();
125 let words = [&[None], encoding.get_word_ids(), &[None]].concat();
126 let offsets =
127 [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
128 let special_tokens =
129 [&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]]
130 .concat();
131 let attention_mask = vec![1; ids.len()];
132
133 let sequence_ranges =
136 HashMap::from_iter(vec![(0, 1..ids.len() - 1)]);
137 Encoding::new(
138 ids,
139 type_ids,
140 tokens,
141 words,
142 offsets,
143 special_tokens,
144 attention_mask,
145 vec![],
146 sequence_ranges,
147 )
148 })
149 .collect(),
150 sequence_ranges,
151 )
152 } else {
153 let pair_ids = [&[self.sep.1], encoding.get_ids(), &[self.sep.1]].concat();
154 let pair_type_ids = vec![0; encoding.get_ids().len() + 2];
155 let pair_tokens = [
156 &[self.sep.0.clone()],
157 encoding.get_tokens(),
158 &[self.sep.0.clone()],
159 ]
160 .concat();
161 let pair_words = [&[None], encoding.get_word_ids(), &[None]].concat();
162 let pair_offsets = [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
163 let pair_special_tokens =
164 [&[1], &vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();
165 let pair_attention_mask = vec![1; pair_ids.len()];
166
167 let pair_sequence_ranges = HashMap::from_iter(vec![(1, 1..pair_ids.len() - 1)]);
170 Encoding::new(
171 pair_ids,
172 pair_type_ids,
173 pair_tokens,
174 pair_words,
175 pair_offsets,
176 pair_special_tokens,
177 pair_attention_mask,
178 encoding
179 .take_overflowing()
180 .into_iter()
181 .map(|encoding| {
182 let pair_ids =
183 [&[self.sep.1], encoding.get_ids(), &[self.sep.1]].concat();
184 let pair_type_ids = vec![0; encoding.get_ids().len() + 2];
185 let pair_tokens = [
186 &[self.sep.0.clone()],
187 encoding.get_tokens(),
188 &[self.sep.0.clone()],
189 ]
190 .concat();
191 let pair_words =
192 [&[None], encoding.get_word_ids(), &[None]].concat();
193 let pair_offsets =
194 [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
195 let pair_special_tokens =
196 [&[1], &vec![0u32; encoding.get_type_ids().len()][..], &[1]]
197 .concat();
198 let pair_attention_mask = vec![1; pair_ids.len()];
199
200 let pair_sequence_ranges =
203 HashMap::from_iter(vec![(1, 1..pair_ids.len() - 1)]);
204 Encoding::new(
205 pair_ids,
206 pair_type_ids,
207 pair_tokens,
208 pair_words,
209 pair_offsets,
210 pair_special_tokens,
211 pair_attention_mask,
212 vec![],
213 pair_sequence_ranges,
214 )
215 })
216 .collect(),
217 pair_sequence_ranges,
218 )
219 }
220 })
221 .collect();
222
223 Ok(encodings)
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230
231 #[test]
232 fn serde() {
233 let roberta = RobertaProcessing::default();
234 let roberta_r = r#"{
235 "type":"RobertaProcessing",
236 "sep":["</s>",2],
237 "cls":["<s>",0],
238 "trim_offsets":true,
239 "add_prefix_space":true
240 }"#
241 .replace(char::is_whitespace, "");
242 assert_eq!(serde_json::to_string(&roberta).unwrap(), roberta_r);
243 assert_eq!(
244 serde_json::from_str::<RobertaProcessing>(&roberta_r).unwrap(),
245 roberta
246 );
247 }
248
249 #[test]
250 fn roberta_processing() {
251 let processor = RobertaProcessing::default();
252 assert_eq!(processor.added_tokens(false), 2);
253 assert_eq!(processor.added_tokens(true), 4);
254
255 use crate::Token;
256 let encoding = Encoding::from_tokens(
257 vec![
258 Token::new(12, "Hello".into(), (0, 5)),
259 Token::new(14, "there".into(), (6, 11)),
260 ],
261 0,
262 );
263 let pair = Encoding::from_tokens(vec![Token::new(15, "pair".into(), (0, 4))], 0);
264 let single_encoding = processor.process(encoding.clone(), None, true).unwrap();
265 assert_eq!(
266 single_encoding,
267 Encoding::new(
268 vec![0, 12, 14, 2],
269 vec![0, 0, 0, 0],
270 vec!["<s>".into(), "Hello".into(), "there".into(), "</s>".into()],
271 vec![None, None, None, None],
272 vec![(0, 0), (0, 5), (6, 11), (0, 0)],
273 vec![1, 0, 0, 1],
274 vec![1, 1, 1, 1],
275 vec![],
276 HashMap::from_iter(vec![(0, 1..3)]),
277 )
278 );
279 assert_eq!(single_encoding.token_to_sequence(2), Some(0));
280 assert_eq!(single_encoding.token_to_sequence(3), None);
281 let pair_encoding = processor
282 .process(encoding.clone(), Some(pair.clone()), true)
283 .unwrap();
284 assert_eq!(
285 pair_encoding,
286 Encoding::new(
287 vec![0, 12, 14, 2, 2, 15, 2],
288 vec![0, 0, 0, 0, 0, 0, 0],
289 vec![
290 "<s>".into(),
291 "Hello".into(),
292 "there".into(),
293 "</s>".into(),
294 "</s>".into(),
295 "pair".into(),
296 "</s>".into()
297 ],
298 vec![None, None, None, None, None, None, None],
299 vec![(0, 0), (0, 5), (6, 11), (0, 0), (0, 0), (0, 4), (0, 0)],
300 vec![1, 0, 0, 1, 1, 0, 1],
301 vec![1, 1, 1, 1, 1, 1, 1],
302 vec![],
303 HashMap::from_iter(vec![(0, 1..3), (1, 5..6)]),
304 )
305 );
306 assert_eq!(pair_encoding.token_to_sequence(2), Some(0));
307 assert_eq!(pair_encoding.token_to_sequence(3), None);
308 assert_eq!(pair_encoding.token_to_sequence(4), None);
309 assert_eq!(pair_encoding.token_to_sequence(5), Some(1));
310 assert_eq!(pair_encoding.token_to_sequence(6), None);
311
312 let pair_encoding = processor.process(encoding, Some(pair), false).unwrap();
314 assert_eq!(
315 pair_encoding,
316 Encoding::new(
317 vec![12, 14, 15],
318 vec![0, 0, 0],
319 vec!["Hello".into(), "there".into(), "pair".into(),],
320 vec![None, None, None],
321 vec![(0, 5), (6, 11), (0, 4)],
322 vec![0, 0, 0],
323 vec![1, 1, 1],
324 vec![],
325 HashMap::from_iter(vec![(0, 0..2), (1, 2..3)]),
326 )
327 );
328 assert_eq!(pair_encoding.token_to_sequence(0), Some(0));
329 assert_eq!(pair_encoding.token_to_sequence(1), Some(0));
330 assert_eq!(pair_encoding.token_to_sequence(2), Some(1));
331 }
332}