rust_transformers/preprocessing/tokenizer/
bert_tokenizer.rs

1// Copyright 2018 The Google AI Language Team Authors
2// Copyright 2018 The HuggingFace Inc. team.
3// Copyright 2019 Guillaume Becquin
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//     http://www.apache.org/licenses/LICENSE-2.0
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14use crate::preprocessing::tokenizer::base_tokenizer::{MultiThreadedTokenizer, BaseTokenizer, Tokenizer};
15use std::sync::Arc;
16use crate::preprocessing::tokenizer::tokenization_utils::{tokenize_wordpiece, split_on_special_tokens};
17use crate::preprocessing::vocab::base_vocab::Vocab;
18use crate::BertVocab;
19
20pub struct BertTokenizer {
21    vocab: Arc<BertVocab>,
22    base_tokenizer: BaseTokenizer<BertVocab>,
23}
24
25impl BertTokenizer {
26    pub fn from_file(path: &str) -> BertTokenizer {
27        let vocab = Arc::new(BertVocab::from_file(path));
28        let base_tokenizer = BaseTokenizer::from_existing_vocab(vocab.clone());
29        BertTokenizer { vocab, base_tokenizer }
30    }
31
32    pub fn from_existing_vocab(vocab: Arc<BertVocab>) -> BertTokenizer {
33        let base_tokenizer = BaseTokenizer::from_existing_vocab(vocab.clone());
34        BertTokenizer { vocab, base_tokenizer }
35    }
36}
37
38impl Tokenizer<BertVocab> for BertTokenizer {
39
40    fn vocab(&self) -> &BertVocab {
41        &self.vocab
42    }
43
44    fn tokenize(&self, text: &str) -> Vec<String> {
45        let mut tokenized_text: Vec<String> = Vec::with_capacity(text.len());
46        let temp_text = split_on_special_tokens(text, self.vocab.as_ref());
47        for text in temp_text {
48            tokenized_text.extend(self.base_tokenizer.tokenize(text));
49        }
50
51        let tokenized_text: Vec<String> = tokenized_text
52            .iter()
53            .map(|v| tokenize_wordpiece(v.to_owned(), self.vocab.as_ref(), 100))
54            .flatten()
55            .map(|s| s.to_string())
56            .collect();
57        tokenized_text
58    }
59
60    fn build_input_with_special_tokens(&self, tokens_1: Vec<i64>, tokens_2: Option<Vec<i64>>) -> (Vec<i64>, Vec<i8>, Vec<i8>) {
61        let mut output: Vec<i64> = vec!();
62        let mut token_segment_ids: Vec<i8> = vec!();
63        let mut special_tokens_mask: Vec<i8> = vec!();
64        special_tokens_mask.push(1);
65        special_tokens_mask.extend(vec![0; tokens_1.len()]);
66        special_tokens_mask.push(1);
67        token_segment_ids.extend(vec![0; tokens_1.len() + 2]);
68        output.push(self.vocab.token_to_id(BertVocab::cls_value()));
69        output.extend(tokens_1);
70        output.push(self.vocab.token_to_id(BertVocab::sep_value()));
71        if let Some(add_tokens) = tokens_2 {
72            special_tokens_mask.extend(vec![0; add_tokens.len()]);
73            special_tokens_mask.push(1);
74            token_segment_ids.extend(vec![1; add_tokens.len() + 1]);
75            output.extend(add_tokens);
76            output.push(self.vocab.token_to_id(BertVocab::sep_value()));
77        }
78        (output, token_segment_ids, special_tokens_mask)
79    }
80}
81
82impl MultiThreadedTokenizer<BertVocab> for BertTokenizer {}
83
84
85//==============================
86// Unit tests
87//==============================
88#[cfg(test)]
89mod tests {
90    use super::*;
91    use crate::BertVocab;
92    use std::collections::HashMap;
93    use crate::preprocessing::tokenizer::base_tokenizer::{TruncationStrategy, TokenizedInput};
94    use crate::preprocessing::vocab::base_vocab::swap_key_values;
95
96    fn generate_test_vocab() -> BertVocab {
97        let values: HashMap<String, i64> = [
98            ("hello".to_owned(), 0),
99            ("world".to_owned(), 1),
100            ("[UNK]".to_owned(), 2),
101            ("!".to_owned(), 3),
102            ("[CLS]".to_owned(), 4),
103            ("[SEP]".to_owned(), 5),
104            ("[MASK]".to_owned(), 6),
105            ("中".to_owned(), 7),
106            ("华".to_owned(), 8),
107            ("人".to_owned(), 9),
108            ("[PAD]".to_owned(), 10),
109            ("una".to_owned(), 11),
110            ("##ffa".to_owned(), 12),
111            ("##ble".to_owned(), 13)
112        ].iter().cloned().collect();
113
114        let special_values: HashMap<String, i64> = [
115            ("[UNK]".to_owned(), 2),
116            ("[CLS]".to_owned(), 4),
117            ("[SEP]".to_owned(), 5),
118            ("[MASK]".to_owned(), 6),
119            ("[PAD]".to_owned(), 10)
120        ].iter().cloned().collect();
121
122        let indices = swap_key_values(&values);
123        let special_indices = swap_key_values(&special_values);
124
125        BertVocab { values, indices, unknown_value: "[UNK]", special_values, special_indices }
126    }
127
128    #[test]
129    fn test_bert_tokenizer() {
130//        Given
131        let vocab = Arc::new(generate_test_vocab());
132        let bert_tokenizer: BertTokenizer = BertTokenizer::from_existing_vocab(vocab);
133        let test_tuples = [
134            (
135                "Hello [MASK] world!",
136                vec!("hello", "[MASK]", "world", "!")
137            ),
138            (
139                "Hello, unaffable world!",
140                vec!("hello", "[UNK]", "una", "##ffa", "##ble", "world", "!")
141            ),
142            (
143                "[UNK]中华人民共和国 [PAD] asdf",
144                vec!("[UNK]", "中", "华", "人", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[PAD]", "[UNK]")
145            )
146        ];
147        let source_texts: Vec<&str> = test_tuples.iter().map(|v| v.0).collect();
148        let expected_results: Vec<Vec<&str>> = test_tuples.iter().map(|v| v.1.clone()).collect();
149
150//        When & Then
151        for (source_text, expected_result) in test_tuples.iter() {
152            assert_eq!(bert_tokenizer.tokenize(*source_text), *expected_result);
153        }
154
155        assert_eq!(Tokenizer::tokenize_list(&bert_tokenizer, source_texts.clone()), expected_results);
156        assert_eq!(MultiThreadedTokenizer::tokenize_list(&bert_tokenizer, source_texts.clone()), expected_results);
157    }
158
159    #[test]
160    fn test_encode() {
161//        Given
162        let vocab = Arc::new(generate_test_vocab());
163        let bert_tokenizer: BertTokenizer = BertTokenizer::from_existing_vocab(vocab);
164        let truncation_strategy = TruncationStrategy::LongestFirst;
165        let test_tuples = [
166            (
167                "hello[MASK] world!",
168                TokenizedInput { token_ids: vec!(4, 0, 6, 1, 3, 5), segment_ids: vec!(0, 0, 0, 0, 0, 0), special_tokens_mask: vec!(1, 0, 0, 0, 0, 1), overflowing_tokens: vec!(), num_truncated_tokens: 0 }
169            ),
170            (
171                "hello, unaffable world!",
172                TokenizedInput { token_ids: vec!(4, 0, 2, 11, 12, 13, 1, 3, 5), segment_ids: vec!(0, 0, 0, 0, 0, 0, 0, 0, 0), special_tokens_mask: vec!(1, 0, 0, 0, 0, 0, 0, 0, 1), overflowing_tokens: vec!(), num_truncated_tokens: 0 }
173            ),
174            (
175                "[UNK]中华人民共和国 [PAD] asdf",
176                TokenizedInput { token_ids: vec!(4, 2, 7, 8, 9, 2, 2, 2, 2, 10, 2, 5), segment_ids: vec!(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), special_tokens_mask: vec!(1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1), overflowing_tokens: vec!(), num_truncated_tokens: 0 }
177            )
178        ];
179        let source_texts: Vec<&str> = test_tuples.iter().map(|v| v.0).collect();
180        let expected_results: Vec<TokenizedInput> = test_tuples.iter().map(|v| v.1.clone()).collect();
181
182//        When & Then
183        for (source_text, expected_result) in test_tuples.iter() {
184            assert_eq!(bert_tokenizer.encode(source_text, None, 128, &truncation_strategy, 0),
185                       *expected_result);
186        }
187        assert_eq!(Tokenizer::encode_list(&bert_tokenizer, source_texts.clone(), 128, &truncation_strategy, 0), expected_results);
188        assert_eq!(MultiThreadedTokenizer::encode_list(&bert_tokenizer, source_texts.clone(), 128, &truncation_strategy, 0), expected_results);
189    }
190
191    #[test]
192    fn test_encode_sentence_pair() {
193//        Given
194        let vocab = Arc::new(generate_test_vocab());
195        let bert_tokenizer: BertTokenizer = BertTokenizer::from_existing_vocab(vocab);
196        let truncation_strategy = TruncationStrategy::LongestFirst;
197        let test_tuples = [
198//            No truncation required
199            (
200                ("hello world", "This is the second sentence"),
201                TokenizedInput { token_ids: vec!(4, 0, 1, 5, 2, 2, 2, 2, 2, 5), segment_ids: vec!(0, 0, 0, 0, 1, 1, 1, 1, 1, 1), special_tokens_mask: vec!(1, 0, 0, 1, 0, 0, 0, 0, 0, 1), overflowing_tokens: vec!(), num_truncated_tokens: 0 }
202            ),
203//            Truncation of sentence 2 (longest)
204            (
205                ("hello world", "!This is the second sentence!!!"),
206                TokenizedInput { token_ids: vec!(4, 0, 1, 5, 3, 2, 2, 2, 2, 5), segment_ids: vec!(0, 0, 0, 0, 1, 1, 1, 1, 1, 1), special_tokens_mask: vec!(1, 0, 0, 1, 0, 0, 0, 0, 0, 1), overflowing_tokens: vec!(), num_truncated_tokens: 4 }
207            ),
208//            Truncation of sentence 1 (longest)
209            (
210                ("[UNK] hello  hello  hello  hello  hello  hello  hello  hello  hello  hello  hello", "!!!"),
211                TokenizedInput { token_ids: vec!(4, 2, 0, 0, 0, 5, 3, 3, 3, 5), segment_ids: vec!(0, 0, 0, 0, 0, 0, 1, 1, 1, 1), special_tokens_mask: vec!(1, 0, 0, 0, 0, 1, 0, 0, 0, 1), overflowing_tokens: vec!(0, 0, 0, 0, 0, 0, 0, 0), num_truncated_tokens: 8 }
212            ),
213//            Truncation of both sentences (longest)
214            (
215                ("[UNK] hello  hello  hello  hello  hello", "!!!!!!!!"),
216                TokenizedInput { token_ids: vec!(4, 2, 0, 0, 5, 3, 3, 3, 3, 5), segment_ids: vec!(0, 0, 0, 0, 0, 1, 1, 1, 1, 1), special_tokens_mask: vec!(1, 0, 0, 0, 1, 0, 0, 0, 0, 1), overflowing_tokens: vec!(0, 0, 0), num_truncated_tokens: 7 }
217            )
218        ];
219        let source_texts: Vec<(&str, &str)> = test_tuples.iter().map(|v| v.0).collect();
220        let expected_results: Vec<TokenizedInput> = test_tuples.iter().map(|v| v.1.clone()).collect();
221
222//        When & Then
223        for (source_text, expected_result) in test_tuples.iter() {
224            assert_eq!(bert_tokenizer.encode(source_text.0, Some(source_text.1), 10, &truncation_strategy, 0),
225                       *expected_result);
226        }
227        assert_eq!(Tokenizer::encode_pair_list(&bert_tokenizer, source_texts.clone(), 10, &truncation_strategy, 0), expected_results);
228        assert_eq!(MultiThreadedTokenizer::encode_pair_list(&bert_tokenizer, source_texts.clone(), 10, &truncation_strategy, 0), expected_results);
229    }
230}