1use crate::preprocessing::tokenizer::base_tokenizer::{MultiThreadedTokenizer, BaseTokenizer, Tokenizer};
15use std::sync::Arc;
16use crate::preprocessing::tokenizer::tokenization_utils::{tokenize_wordpiece, split_on_special_tokens};
17use crate::preprocessing::vocab::base_vocab::Vocab;
18use crate::BertVocab;
19
20pub struct BertTokenizer {
21 vocab: Arc<BertVocab>,
22 base_tokenizer: BaseTokenizer<BertVocab>,
23}
24
25impl BertTokenizer {
26 pub fn from_file(path: &str) -> BertTokenizer {
27 let vocab = Arc::new(BertVocab::from_file(path));
28 let base_tokenizer = BaseTokenizer::from_existing_vocab(vocab.clone());
29 BertTokenizer { vocab, base_tokenizer }
30 }
31
32 pub fn from_existing_vocab(vocab: Arc<BertVocab>) -> BertTokenizer {
33 let base_tokenizer = BaseTokenizer::from_existing_vocab(vocab.clone());
34 BertTokenizer { vocab, base_tokenizer }
35 }
36}
37
38impl Tokenizer<BertVocab> for BertTokenizer {
39
40 fn vocab(&self) -> &BertVocab {
41 &self.vocab
42 }
43
44 fn tokenize(&self, text: &str) -> Vec<String> {
45 let mut tokenized_text: Vec<String> = Vec::with_capacity(text.len());
46 let temp_text = split_on_special_tokens(text, self.vocab.as_ref());
47 for text in temp_text {
48 tokenized_text.extend(self.base_tokenizer.tokenize(text));
49 }
50
51 let tokenized_text: Vec<String> = tokenized_text
52 .iter()
53 .map(|v| tokenize_wordpiece(v.to_owned(), self.vocab.as_ref(), 100))
54 .flatten()
55 .map(|s| s.to_string())
56 .collect();
57 tokenized_text
58 }
59
60 fn build_input_with_special_tokens(&self, tokens_1: Vec<i64>, tokens_2: Option<Vec<i64>>) -> (Vec<i64>, Vec<i8>, Vec<i8>) {
61 let mut output: Vec<i64> = vec!();
62 let mut token_segment_ids: Vec<i8> = vec!();
63 let mut special_tokens_mask: Vec<i8> = vec!();
64 special_tokens_mask.push(1);
65 special_tokens_mask.extend(vec![0; tokens_1.len()]);
66 special_tokens_mask.push(1);
67 token_segment_ids.extend(vec![0; tokens_1.len() + 2]);
68 output.push(self.vocab.token_to_id(BertVocab::cls_value()));
69 output.extend(tokens_1);
70 output.push(self.vocab.token_to_id(BertVocab::sep_value()));
71 if let Some(add_tokens) = tokens_2 {
72 special_tokens_mask.extend(vec![0; add_tokens.len()]);
73 special_tokens_mask.push(1);
74 token_segment_ids.extend(vec![1; add_tokens.len() + 1]);
75 output.extend(add_tokens);
76 output.push(self.vocab.token_to_id(BertVocab::sep_value()));
77 }
78 (output, token_segment_ids, special_tokens_mask)
79 }
80}
81
82impl MultiThreadedTokenizer<BertVocab> for BertTokenizer {}
83
84
85#[cfg(test)]
89mod tests {
90 use super::*;
91 use crate::BertVocab;
92 use std::collections::HashMap;
93 use crate::preprocessing::tokenizer::base_tokenizer::{TruncationStrategy, TokenizedInput};
94 use crate::preprocessing::vocab::base_vocab::swap_key_values;
95
96 fn generate_test_vocab() -> BertVocab {
97 let values: HashMap<String, i64> = [
98 ("hello".to_owned(), 0),
99 ("world".to_owned(), 1),
100 ("[UNK]".to_owned(), 2),
101 ("!".to_owned(), 3),
102 ("[CLS]".to_owned(), 4),
103 ("[SEP]".to_owned(), 5),
104 ("[MASK]".to_owned(), 6),
105 ("中".to_owned(), 7),
106 ("华".to_owned(), 8),
107 ("人".to_owned(), 9),
108 ("[PAD]".to_owned(), 10),
109 ("una".to_owned(), 11),
110 ("##ffa".to_owned(), 12),
111 ("##ble".to_owned(), 13)
112 ].iter().cloned().collect();
113
114 let special_values: HashMap<String, i64> = [
115 ("[UNK]".to_owned(), 2),
116 ("[CLS]".to_owned(), 4),
117 ("[SEP]".to_owned(), 5),
118 ("[MASK]".to_owned(), 6),
119 ("[PAD]".to_owned(), 10)
120 ].iter().cloned().collect();
121
122 let indices = swap_key_values(&values);
123 let special_indices = swap_key_values(&special_values);
124
125 BertVocab { values, indices, unknown_value: "[UNK]", special_values, special_indices }
126 }
127
128 #[test]
129 fn test_bert_tokenizer() {
130let vocab = Arc::new(generate_test_vocab());
132 let bert_tokenizer: BertTokenizer = BertTokenizer::from_existing_vocab(vocab);
133 let test_tuples = [
134 (
135 "Hello [MASK] world!",
136 vec!("hello", "[MASK]", "world", "!")
137 ),
138 (
139 "Hello, unaffable world!",
140 vec!("hello", "[UNK]", "una", "##ffa", "##ble", "world", "!")
141 ),
142 (
143 "[UNK]中华人民共和国 [PAD] asdf",
144 vec!("[UNK]", "中", "华", "人", "[UNK]", "[UNK]", "[UNK]", "[UNK]", "[PAD]", "[UNK]")
145 )
146 ];
147 let source_texts: Vec<&str> = test_tuples.iter().map(|v| v.0).collect();
148 let expected_results: Vec<Vec<&str>> = test_tuples.iter().map(|v| v.1.clone()).collect();
149
150for (source_text, expected_result) in test_tuples.iter() {
152 assert_eq!(bert_tokenizer.tokenize(*source_text), *expected_result);
153 }
154
155 assert_eq!(Tokenizer::tokenize_list(&bert_tokenizer, source_texts.clone()), expected_results);
156 assert_eq!(MultiThreadedTokenizer::tokenize_list(&bert_tokenizer, source_texts.clone()), expected_results);
157 }
158
159 #[test]
160 fn test_encode() {
161let vocab = Arc::new(generate_test_vocab());
163 let bert_tokenizer: BertTokenizer = BertTokenizer::from_existing_vocab(vocab);
164 let truncation_strategy = TruncationStrategy::LongestFirst;
165 let test_tuples = [
166 (
167 "hello[MASK] world!",
168 TokenizedInput { token_ids: vec!(4, 0, 6, 1, 3, 5), segment_ids: vec!(0, 0, 0, 0, 0, 0), special_tokens_mask: vec!(1, 0, 0, 0, 0, 1), overflowing_tokens: vec!(), num_truncated_tokens: 0 }
169 ),
170 (
171 "hello, unaffable world!",
172 TokenizedInput { token_ids: vec!(4, 0, 2, 11, 12, 13, 1, 3, 5), segment_ids: vec!(0, 0, 0, 0, 0, 0, 0, 0, 0), special_tokens_mask: vec!(1, 0, 0, 0, 0, 0, 0, 0, 1), overflowing_tokens: vec!(), num_truncated_tokens: 0 }
173 ),
174 (
175 "[UNK]中华人民共和国 [PAD] asdf",
176 TokenizedInput { token_ids: vec!(4, 2, 7, 8, 9, 2, 2, 2, 2, 10, 2, 5), segment_ids: vec!(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), special_tokens_mask: vec!(1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1), overflowing_tokens: vec!(), num_truncated_tokens: 0 }
177 )
178 ];
179 let source_texts: Vec<&str> = test_tuples.iter().map(|v| v.0).collect();
180 let expected_results: Vec<TokenizedInput> = test_tuples.iter().map(|v| v.1.clone()).collect();
181
182for (source_text, expected_result) in test_tuples.iter() {
184 assert_eq!(bert_tokenizer.encode(source_text, None, 128, &truncation_strategy, 0),
185 *expected_result);
186 }
187 assert_eq!(Tokenizer::encode_list(&bert_tokenizer, source_texts.clone(), 128, &truncation_strategy, 0), expected_results);
188 assert_eq!(MultiThreadedTokenizer::encode_list(&bert_tokenizer, source_texts.clone(), 128, &truncation_strategy, 0), expected_results);
189 }
190
191 #[test]
192 fn test_encode_sentence_pair() {
193let vocab = Arc::new(generate_test_vocab());
195 let bert_tokenizer: BertTokenizer = BertTokenizer::from_existing_vocab(vocab);
196 let truncation_strategy = TruncationStrategy::LongestFirst;
197 let test_tuples = [
198(
200 ("hello world", "This is the second sentence"),
201 TokenizedInput { token_ids: vec!(4, 0, 1, 5, 2, 2, 2, 2, 2, 5), segment_ids: vec!(0, 0, 0, 0, 1, 1, 1, 1, 1, 1), special_tokens_mask: vec!(1, 0, 0, 1, 0, 0, 0, 0, 0, 1), overflowing_tokens: vec!(), num_truncated_tokens: 0 }
202 ),
203(
205 ("hello world", "!This is the second sentence!!!"),
206 TokenizedInput { token_ids: vec!(4, 0, 1, 5, 3, 2, 2, 2, 2, 5), segment_ids: vec!(0, 0, 0, 0, 1, 1, 1, 1, 1, 1), special_tokens_mask: vec!(1, 0, 0, 1, 0, 0, 0, 0, 0, 1), overflowing_tokens: vec!(), num_truncated_tokens: 4 }
207 ),
208(
210 ("[UNK] hello hello hello hello hello hello hello hello hello hello hello", "!!!"),
211 TokenizedInput { token_ids: vec!(4, 2, 0, 0, 0, 5, 3, 3, 3, 5), segment_ids: vec!(0, 0, 0, 0, 0, 0, 1, 1, 1, 1), special_tokens_mask: vec!(1, 0, 0, 0, 0, 1, 0, 0, 0, 1), overflowing_tokens: vec!(0, 0, 0, 0, 0, 0, 0, 0), num_truncated_tokens: 8 }
212 ),
213(
215 ("[UNK] hello hello hello hello hello", "!!!!!!!!"),
216 TokenizedInput { token_ids: vec!(4, 2, 0, 0, 5, 3, 3, 3, 3, 5), segment_ids: vec!(0, 0, 0, 0, 0, 1, 1, 1, 1, 1), special_tokens_mask: vec!(1, 0, 0, 0, 1, 0, 0, 0, 0, 1), overflowing_tokens: vec!(0, 0, 0), num_truncated_tokens: 7 }
217 )
218 ];
219 let source_texts: Vec<(&str, &str)> = test_tuples.iter().map(|v| v.0).collect();
220 let expected_results: Vec<TokenizedInput> = test_tuples.iter().map(|v| v.1.clone()).collect();
221
222for (source_text, expected_result) in test_tuples.iter() {
224 assert_eq!(bert_tokenizer.encode(source_text.0, Some(source_text.1), 10, &truncation_strategy, 0),
225 *expected_result);
226 }
227 assert_eq!(Tokenizer::encode_pair_list(&bert_tokenizer, source_texts.clone(), 10, &truncation_strategy, 0), expected_results);
228 assert_eq!(MultiThreadedTokenizer::encode_pair_list(&bert_tokenizer, source_texts.clone(), 10, &truncation_strategy, 0), expected_results);
229 }
230}