tokenizers/processors/
roberta.rs

1use crate::processors::byte_level::process_offsets;
2use crate::tokenizer::{Encoding, PostProcessor, Result};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::iter::FromIterator;
6
7#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
8#[serde(tag = "type")]
9pub struct RobertaProcessing {
10    sep: (String, u32),
11    cls: (String, u32),
12    trim_offsets: bool,
13    add_prefix_space: bool,
14}
15
16impl Default for RobertaProcessing {
17    fn default() -> Self {
18        Self {
19            sep: ("</s>".into(), 2),
20            cls: ("<s>".into(), 0),
21            trim_offsets: true,
22            add_prefix_space: true,
23        }
24    }
25}
26
27impl RobertaProcessing {
28    pub fn new(sep: (String, u32), cls: (String, u32)) -> Self {
29        Self {
30            sep,
31            cls,
32            ..Default::default()
33        }
34    }
35
36    #[must_use]
37    pub fn trim_offsets(mut self, v: bool) -> Self {
38        self.trim_offsets = v;
39        self
40    }
41
42    #[must_use]
43    pub fn add_prefix_space(mut self, v: bool) -> Self {
44        self.add_prefix_space = v;
45        self
46    }
47}
48
49impl PostProcessor for RobertaProcessing {
50    fn added_tokens(&self, is_pair: bool) -> usize {
51        if is_pair {
52            4
53        } else {
54            2
55        }
56    }
57
58    fn process_encodings(
59        &self,
60        mut encodings: Vec<Encoding>,
61        add_special_tokens: bool,
62    ) -> Result<Vec<Encoding>> {
63        if self.trim_offsets {
64            for encoding in encodings.iter_mut() {
65                process_offsets(encoding, self.add_prefix_space);
66                encoding
67                    .get_overflowing_mut()
68                    .iter_mut()
69                    .for_each(|encoding| process_offsets(encoding, self.add_prefix_space));
70            }
71        }
72
73        // Roberta is weird, and every encoding is type_id=0.
74        encodings
75            .iter_mut()
76            .for_each(|encoding| encoding.set_type_ids(vec![0; encoding.len()]));
77
78        if !add_special_tokens {
79            return Ok(encodings);
80        }
81
82        let encodings: Vec<Encoding> = encodings
83            .iter_mut()
84            .enumerate()
85            .map(|(i, encoding)| {
86                if i == 0 {
87                    let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
88                    let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
89                    let tokens = [
90                        &[self.cls.0.clone()],
91                        encoding.get_tokens(),
92                        &[self.sep.0.clone()],
93                    ]
94                    .concat();
95                    let words = [&[None], encoding.get_word_ids(), &[None]].concat();
96                    let offsets = [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
97                    let special_tokens =
98                        [&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat();
99                    let attention_mask = vec![1; ids.len()];
100
101                    // For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain
102                    // the special tokens.
103                    let sequence_ranges = HashMap::from_iter(vec![(0, 1..ids.len() - 1)]);
104                    Encoding::new(
105                        ids,
106                        type_ids,
107                        tokens,
108                        words,
109                        offsets,
110                        special_tokens,
111                        attention_mask,
112                        encoding
113                            .take_overflowing()
114                            .into_iter()
115                            .map(|encoding| {
116                                let ids =
117                                    [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
118                                let type_ids = vec![0; encoding.get_ids().len() + 2];
119                                let tokens = [
120                                    &[self.cls.0.clone()],
121                                    encoding.get_tokens(),
122                                    &[self.sep.0.clone()],
123                                ]
124                                .concat();
125                                let words = [&[None], encoding.get_word_ids(), &[None]].concat();
126                                let offsets =
127                                    [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
128                                let special_tokens =
129                                    [&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]]
130                                        .concat();
131                                let attention_mask = vec![1; ids.len()];
132
133                                // For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't
134                                // contain the special tokens.
135                                let sequence_ranges =
136                                    HashMap::from_iter(vec![(0, 1..ids.len() - 1)]);
137                                Encoding::new(
138                                    ids,
139                                    type_ids,
140                                    tokens,
141                                    words,
142                                    offsets,
143                                    special_tokens,
144                                    attention_mask,
145                                    vec![],
146                                    sequence_ranges,
147                                )
148                            })
149                            .collect(),
150                        sequence_ranges,
151                    )
152                } else {
153                    let pair_ids = [&[self.sep.1], encoding.get_ids(), &[self.sep.1]].concat();
154                    let pair_type_ids = vec![0; encoding.get_ids().len() + 2];
155                    let pair_tokens = [
156                        &[self.sep.0.clone()],
157                        encoding.get_tokens(),
158                        &[self.sep.0.clone()],
159                    ]
160                    .concat();
161                    let pair_words = [&[None], encoding.get_word_ids(), &[None]].concat();
162                    let pair_offsets = [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
163                    let pair_special_tokens =
164                        [&[1], &vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();
165                    let pair_attention_mask = vec![1; pair_ids.len()];
166
167                    // For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain
168                    // the special tokens.
169                    let pair_sequence_ranges = HashMap::from_iter(vec![(1, 1..pair_ids.len() - 1)]);
170                    Encoding::new(
171                        pair_ids,
172                        pair_type_ids,
173                        pair_tokens,
174                        pair_words,
175                        pair_offsets,
176                        pair_special_tokens,
177                        pair_attention_mask,
178                        encoding
179                            .take_overflowing()
180                            .into_iter()
181                            .map(|encoding| {
182                                let pair_ids =
183                                    [&[self.sep.1], encoding.get_ids(), &[self.sep.1]].concat();
184                                let pair_type_ids = vec![0; encoding.get_ids().len() + 2];
185                                let pair_tokens = [
186                                    &[self.sep.0.clone()],
187                                    encoding.get_tokens(),
188                                    &[self.sep.0.clone()],
189                                ]
190                                .concat();
191                                let pair_words =
192                                    [&[None], encoding.get_word_ids(), &[None]].concat();
193                                let pair_offsets =
194                                    [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
195                                let pair_special_tokens =
196                                    [&[1], &vec![0u32; encoding.get_type_ids().len()][..], &[1]]
197                                        .concat();
198                                let pair_attention_mask = vec![1; pair_ids.len()];
199
200                                // For compatibility with `TemplateProcessing`, the sequence_ranges
201                                // shouldn't contain the special tokens.
202                                let pair_sequence_ranges =
203                                    HashMap::from_iter(vec![(1, 1..pair_ids.len() - 1)]);
204                                Encoding::new(
205                                    pair_ids,
206                                    pair_type_ids,
207                                    pair_tokens,
208                                    pair_words,
209                                    pair_offsets,
210                                    pair_special_tokens,
211                                    pair_attention_mask,
212                                    vec![],
213                                    pair_sequence_ranges,
214                                )
215                            })
216                            .collect(),
217                        pair_sequence_ranges,
218                    )
219                }
220            })
221            .collect();
222
223        Ok(encodings)
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230
231    #[test]
232    fn serde() {
233        let roberta = RobertaProcessing::default();
234        let roberta_r = r#"{
235            "type":"RobertaProcessing",
236            "sep":["</s>",2],
237            "cls":["<s>",0],
238            "trim_offsets":true,
239            "add_prefix_space":true
240        }"#
241        .replace(char::is_whitespace, "");
242        assert_eq!(serde_json::to_string(&roberta).unwrap(), roberta_r);
243        assert_eq!(
244            serde_json::from_str::<RobertaProcessing>(&roberta_r).unwrap(),
245            roberta
246        );
247    }
248
249    #[test]
250    fn roberta_processing() {
251        let processor = RobertaProcessing::default();
252        assert_eq!(processor.added_tokens(false), 2);
253        assert_eq!(processor.added_tokens(true), 4);
254
255        use crate::Token;
256        let encoding = Encoding::from_tokens(
257            vec![
258                Token::new(12, "Hello".into(), (0, 5)),
259                Token::new(14, "there".into(), (6, 11)),
260            ],
261            0,
262        );
263        let pair = Encoding::from_tokens(vec![Token::new(15, "pair".into(), (0, 4))], 0);
264        let single_encoding = processor.process(encoding.clone(), None, true).unwrap();
265        assert_eq!(
266            single_encoding,
267            Encoding::new(
268                vec![0, 12, 14, 2],
269                vec![0, 0, 0, 0],
270                vec!["<s>".into(), "Hello".into(), "there".into(), "</s>".into()],
271                vec![None, None, None, None],
272                vec![(0, 0), (0, 5), (6, 11), (0, 0)],
273                vec![1, 0, 0, 1],
274                vec![1, 1, 1, 1],
275                vec![],
276                HashMap::from_iter(vec![(0, 1..3)]),
277            )
278        );
279        assert_eq!(single_encoding.token_to_sequence(2), Some(0));
280        assert_eq!(single_encoding.token_to_sequence(3), None);
281        let pair_encoding = processor
282            .process(encoding.clone(), Some(pair.clone()), true)
283            .unwrap();
284        assert_eq!(
285            pair_encoding,
286            Encoding::new(
287                vec![0, 12, 14, 2, 2, 15, 2],
288                vec![0, 0, 0, 0, 0, 0, 0],
289                vec![
290                    "<s>".into(),
291                    "Hello".into(),
292                    "there".into(),
293                    "</s>".into(),
294                    "</s>".into(),
295                    "pair".into(),
296                    "</s>".into()
297                ],
298                vec![None, None, None, None, None, None, None],
299                vec![(0, 0), (0, 5), (6, 11), (0, 0), (0, 0), (0, 4), (0, 0)],
300                vec![1, 0, 0, 1, 1, 0, 1],
301                vec![1, 1, 1, 1, 1, 1, 1],
302                vec![],
303                HashMap::from_iter(vec![(0, 1..3), (1, 5..6)]),
304            )
305        );
306        assert_eq!(pair_encoding.token_to_sequence(2), Some(0));
307        assert_eq!(pair_encoding.token_to_sequence(3), None);
308        assert_eq!(pair_encoding.token_to_sequence(4), None);
309        assert_eq!(pair_encoding.token_to_sequence(5), Some(1));
310        assert_eq!(pair_encoding.token_to_sequence(6), None);
311
312        // No special tokens
313        let pair_encoding = processor.process(encoding, Some(pair), false).unwrap();
314        assert_eq!(
315            pair_encoding,
316            Encoding::new(
317                vec![12, 14, 15],
318                vec![0, 0, 0],
319                vec!["Hello".into(), "there".into(), "pair".into(),],
320                vec![None, None, None],
321                vec![(0, 5), (6, 11), (0, 4)],
322                vec![0, 0, 0],
323                vec![1, 1, 1],
324                vec![],
325                HashMap::from_iter(vec![(0, 0..2), (1, 2..3)]),
326            )
327        );
328        assert_eq!(pair_encoding.token_to_sequence(0), Some(0));
329        assert_eq!(pair_encoding.token_to_sequence(1), Some(0));
330        assert_eq!(pair_encoding.token_to_sequence(2), Some(1));
331    }
332}