1use std::collections::HashMap;
2
3use super::{DecodeError, EncodeError, Model};
4use crate::tokenizer::TokenId;
5
6#[derive(Clone)]
17pub struct WordPiece {
18 token_to_id: HashMap<String, TokenId>,
19 id_to_token: HashMap<TokenId, String>,
20 subword_prefix: String,
21 max_word_len: usize,
22}
23
24#[derive(Debug, Default, Clone)]
26pub struct WordPieceOptions {
27 pub max_word_len: Option<usize>,
32}
33
34impl WordPiece {
35 pub fn from_vocab(vocab: HashMap<String, TokenId>, options: WordPieceOptions) -> WordPiece {
39 let id_to_token: HashMap<TokenId, String> =
40 vocab.iter().map(|(k, v)| (*v, k.to_string())).collect();
41
42 let subword_prefix = "##".to_string();
43
44 WordPiece {
45 token_to_id: vocab,
46 subword_prefix,
47 max_word_len: options.max_word_len.unwrap_or(100),
48 id_to_token,
49 }
50 }
51}
52
53impl Model for WordPiece {
54 fn encode_with_offsets(
55 &self,
56 word: &str,
57 on_token: &mut dyn FnMut(usize, TokenId),
58 ) -> Result<(), EncodeError> {
59 let mut tmp_buf = String::with_capacity(self.max_word_len);
60 let mut offset = 0;
61
62 macro_rules! add_unknown_token {
63 () => {
64 let unknown_token = "[UNK]";
65 let unknown_token_id = self
66 .get_token_id(unknown_token)
67 .ok_or_else(|| EncodeError::TokenIdNotFound(unknown_token.to_string()))?;
68 on_token(offset, unknown_token_id);
69 };
70 }
71
72 if word.trim().is_empty() {
73 return Ok(());
74 }
75
76 if word.chars().count() > self.max_word_len {
77 add_unknown_token!();
78 return Ok(());
79 }
80
81 let mut remainder = word;
82 let mut word_tokens = 0;
83 while !remainder.is_empty() {
84 let mut len = remainder.len();
86 while len > 0 {
87 let prefix = if word_tokens > 0 {
88 tmp_buf.clear();
89 tmp_buf.push_str(&self.subword_prefix);
90 tmp_buf.push_str(&remainder[..len]);
91 &tmp_buf[..]
92 } else {
93 &remainder[..len]
94 };
95
96 if let Some(id) = self.token_to_id.get(prefix) {
97 on_token(offset, *id);
98 offset += prefix.len();
99 remainder = remainder.split_at(len).1;
100 word_tokens += 1;
101 break;
102 } else {
103 let last_char_bytes = prefix.chars().next_back().unwrap().len_utf8();
104 len -= last_char_bytes;
105 }
106 }
107
108 if len == 0 {
109 add_unknown_token!();
110 break;
111 }
112 }
113
114 Ok(())
115 }
116
117 fn get_token_str(&self, id: TokenId) -> Option<String> {
118 self.id_to_token.get(&id).cloned()
119 }
120
121 fn get_token_id(&self, tok: &str) -> Option<TokenId> {
122 self.token_to_id.get(tok).copied()
123 }
124
125 fn decode(&self, ids: &[TokenId]) -> Result<String, DecodeError> {
126 let token_strings = self.get_tokens(ids)?;
127 Ok(token_strings.join(" "))
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use std::collections::HashMap;
134
135 use rten_testing::TestCases;
136
137 use crate::models::{WordPiece, WordPieceOptions};
138 use crate::normalizers::Normalizer;
139 use crate::tokenizer::{Tokenizer, TokenizerOptions};
140 use crate::{normalizers, pre_tokenizers};
141
142 fn create_tokenizer(
143 vocab: &[&str],
144 normalizer: Option<Box<dyn Normalizer>>,
145 options: WordPieceOptions,
146 ) -> Tokenizer {
147 let vocab: HashMap<_, _> = vocab
148 .iter()
149 .enumerate()
150 .map(|(i, token)| (token.to_string(), i as u32))
151 .collect();
152 let model = WordPiece::from_vocab(vocab, options);
153 let mut tokenizer = Tokenizer::new(
154 model,
155 TokenizerOptions {
156 cls_token: Some("[CLS]"),
157 sep_token: Some("[SEP]"),
158 },
159 )
160 .with_pre_tokenizer(Box::new(pre_tokenizers::Bert::new()));
161
162 if let Some(normalizer) = normalizer {
163 tokenizer = tokenizer.with_normalizer(normalizer);
164 }
165
166 tokenizer
167 }
168
169 #[test]
170 fn test_wordpiece_model() {
171 #[derive(Debug)]
172 struct Case<'a> {
173 text: &'a str,
174 tokens: &'a [&'a str],
175 }
176
177 let vocab = &[
178 "[CLS]", "[SEP]", "[UNK]", "This", "is", "a", "test", "sequence", "Word", "##Piece",
179 "Piece", "of", "pie", ".", "!", "?", "Hey", "Hello", "the", "game", "is", "set", "in",
180 "Faerûn",
181 ];
182
183 let cases = [
184 Case {
186 text: "This is a test sequence",
187 tokens: &["[CLS]", "This", "is", "a", "test", "sequence", "[SEP]"],
188 },
189 Case {
190 text: "Piece of pie",
191 tokens: &["[CLS]", "Piece", "of", "pie", "[SEP]"],
192 },
193 Case {
195 text: "This is unknown sequence",
196 tokens: &["[CLS]", "This", "is", "[UNK]", "sequence", "[SEP]"],
197 },
198 Case {
200 text: "WordPiece",
201 tokens: &["[CLS]", "Word", "##Piece", "[SEP]"],
202 },
203 Case {
205 text: "",
206 tokens: &["[CLS]", "[SEP]"],
207 },
208 Case {
210 text: "Hey! Hello?",
211 tokens: &["[CLS]", "Hey", "!", "Hello", "?", "[SEP]"],
212 },
213 Case {
215 text: &"a".repeat(101),
217 tokens: &["[CLS]", "[UNK]", "[SEP]"],
218 },
219 Case {
221 text: "the game is set in Faerûn",
222 tokens: &["[CLS]", "the", "game", "is", "set", "in", "Faerûn", "[SEP]"],
223 },
224 ];
225
226 cases.test_each(|case| {
227 let &Case { text, tokens } = case;
228
229 let tokenizer = create_tokenizer(vocab, None, Default::default());
230 let encoded = tokenizer.encode(text, None).unwrap();
231 assert_eq!(
232 tokenizer.model().get_tokens(encoded.token_ids()).unwrap(),
233 tokens
234 );
235 assert!(encoded.token_type_ids().all(|ttid| ttid == 0));
236 });
237 }
238
239 #[test]
240 fn test_wordpiece_max_word_len() {
241 let vocab = &["[CLS]", "[SEP]", "[UNK]", "foo", "##bar", "##foo"];
242 let opts = WordPieceOptions {
243 max_word_len: Some(6),
244 ..Default::default()
245 };
246 let tokenizer = create_tokenizer(vocab, None, opts);
247
248 let text = "foobar foofoo foobarfoo";
251 let encoded = tokenizer.encode(text, None).unwrap();
252
253 assert_eq!(
254 tokenizer.model().get_tokens(encoded.token_ids()).unwrap(),
255 &["[CLS]", "foo", "##bar", "foo", "##foo", "[UNK]", "[SEP]"]
256 );
257 }
258
259 #[test]
260 fn test_wordpiece_model_lowercase() {
261 #[derive(Debug)]
262 struct Case<'a> {
263 text: &'a str,
264 tokens: &'a [&'a str],
265 }
266
267 let vocab = &[
268 "[CLS]", "[SEP]", "[UNK]", "this", "is", "a", "test", "sequence",
269 ];
270
271 let cases = [
272 Case {
274 text: "this is a test sequence",
275 tokens: &["[CLS]", "this", "is", "a", "test", "sequence", "[SEP]"],
276 },
277 Case {
278 text: "THIS IS A TEST SEQUENCE",
279 tokens: &["[CLS]", "this", "is", "a", "test", "sequence", "[SEP]"],
280 },
281 ];
282
283 cases.test_each(|case| {
284 let &Case { text, tokens } = case;
285
286 let normalizer = normalizers::Bert::new(normalizers::BertOptions {
287 lowercase: true,
288 ..Default::default()
289 });
290 let tokenizer = create_tokenizer(vocab, Some(Box::new(normalizer)), Default::default());
291
292 let encoded = tokenizer.encode(text, None).unwrap();
293 assert_eq!(
294 tokenizer.model().get_tokens(encoded.token_ids()).unwrap(),
295 tokens
296 );
297 assert!(encoded.token_type_ids().all(|ttid| ttid == 0));
298 })
299 }
300
301 #[test]
302 fn test_decode() {
303 #[derive(Debug)]
304 struct Case<'a> {
305 input: &'a str,
306 expected: &'a str,
307 }
308
309 let cases = [
310 Case {
311 input: "",
312 expected: "[CLS] [SEP]",
313 },
314 Case {
315 input: "this is a test sequence",
316 expected: "[CLS] this is a test sequence [SEP]",
317 },
318 Case {
319 input: "THIS IS A TEST SEQUENCE",
320 expected: "[CLS] this is a test sequence [SEP]",
321 },
322 ];
323
324 let vocab = &[
325 "[CLS]", "[SEP]", "[UNK]", "this", "is", "a", "test", "sequence",
326 ];
327
328 cases.test_each(|case| {
329 let &Case { input, expected } = case;
330
331 let normalizer = normalizers::Bert::new(normalizers::BertOptions {
332 lowercase: true,
333 ..Default::default()
334 });
335 let tokenizer = create_tokenizer(vocab, Some(Box::new(normalizer)), Default::default());
336
337 let encoded = tokenizer.encode(input, None).unwrap();
338 let decoded = tokenizer.decode(encoded.token_ids()).unwrap();
339 assert_eq!(decoded, expected);
340 })
341 }
342}