rten_text/models/
wordpiece.rs

1use std::collections::HashMap;
2
3use super::{DecodeError, EncodeError, Model};
4use crate::tokenizer::TokenId;
5
6/// WordPiece tokenizer [^1] used by BERT [^2] models.
7///
8/// [^1]: Schuster, Mike, and Kaisuke Nakajima. "Japanese and korean voice
9///       search." 2012 IEEE international conference on acoustics, speech and signal
10///       processing (ICASSP). IEEE, 2012. Accessed at
11///       <https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/37842.pdf>
12///
13/// [^2]: Devlin, Jacob, et al. "Bert: Pre-training of deep bidirectional
14///       transformers for language understanding." arXiv preprint arXiv:1810.04805
15///       (2018). <https://arxiv.org/abs/1810.04805>
16#[derive(Clone)]
17pub struct WordPiece {
18    token_to_id: HashMap<String, TokenId>,
19    id_to_token: HashMap<TokenId, String>,
20    subword_prefix: String,
21    max_word_len: usize,
22}
23
24/// Configuration for a [`WordPiece`] tokenizer.
25#[derive(Debug, Default, Clone)]
26pub struct WordPieceOptions {
27    /// The maximum length of words that can be tokenized. Any words longer than
28    /// this are tokenized as `[UNK]`.
29    ///
30    /// Defaults to 100.
31    pub max_word_len: Option<usize>,
32}
33
34impl WordPiece {
35    /// Construct a WordPiece tokenizer from a vocabulary.
36    ///
37    /// `vocab` is a mapping from word piece to token ID.
38    pub fn from_vocab(vocab: HashMap<String, TokenId>, options: WordPieceOptions) -> WordPiece {
39        let id_to_token: HashMap<TokenId, String> =
40            vocab.iter().map(|(k, v)| (*v, k.to_string())).collect();
41
42        let subword_prefix = "##".to_string();
43
44        WordPiece {
45            token_to_id: vocab,
46            subword_prefix,
47            max_word_len: options.max_word_len.unwrap_or(100),
48            id_to_token,
49        }
50    }
51}
52
53impl Model for WordPiece {
54    fn encode_with_offsets(
55        &self,
56        word: &str,
57        on_token: &mut dyn FnMut(usize, TokenId),
58    ) -> Result<(), EncodeError> {
59        let mut tmp_buf = String::with_capacity(self.max_word_len);
60        let mut offset = 0;
61
62        macro_rules! add_unknown_token {
63            () => {
64                let unknown_token = "[UNK]";
65                let unknown_token_id = self
66                    .get_token_id(unknown_token)
67                    .ok_or_else(|| EncodeError::TokenIdNotFound(unknown_token.to_string()))?;
68                on_token(offset, unknown_token_id);
69            };
70        }
71
72        if word.trim().is_empty() {
73            return Ok(());
74        }
75
76        if word.chars().count() > self.max_word_len {
77            add_unknown_token!();
78            return Ok(());
79        }
80
81        let mut remainder = word;
82        let mut word_tokens = 0;
83        while !remainder.is_empty() {
84            // Find longest prefix of `remainder` that is in the vocab.
85            let mut len = remainder.len();
86            while len > 0 {
87                let prefix = if word_tokens > 0 {
88                    tmp_buf.clear();
89                    tmp_buf.push_str(&self.subword_prefix);
90                    tmp_buf.push_str(&remainder[..len]);
91                    &tmp_buf[..]
92                } else {
93                    &remainder[..len]
94                };
95
96                if let Some(id) = self.token_to_id.get(prefix) {
97                    on_token(offset, *id);
98                    offset += prefix.len();
99                    remainder = remainder.split_at(len).1;
100                    word_tokens += 1;
101                    break;
102                } else {
103                    let last_char_bytes = prefix.chars().next_back().unwrap().len_utf8();
104                    len -= last_char_bytes;
105                }
106            }
107
108            if len == 0 {
109                add_unknown_token!();
110                break;
111            }
112        }
113
114        Ok(())
115    }
116
117    fn get_token_str(&self, id: TokenId) -> Option<String> {
118        self.id_to_token.get(&id).cloned()
119    }
120
121    fn get_token_id(&self, tok: &str) -> Option<TokenId> {
122        self.token_to_id.get(tok).copied()
123    }
124
125    fn decode(&self, ids: &[TokenId]) -> Result<String, DecodeError> {
126        let token_strings = self.get_tokens(ids)?;
127        Ok(token_strings.join(" "))
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use std::collections::HashMap;
134
135    use rten_testing::TestCases;
136
137    use crate::models::{WordPiece, WordPieceOptions};
138    use crate::normalizers::Normalizer;
139    use crate::tokenizer::{Tokenizer, TokenizerOptions};
140    use crate::{normalizers, pre_tokenizers};
141
142    fn create_tokenizer(
143        vocab: &[&str],
144        normalizer: Option<Box<dyn Normalizer>>,
145        options: WordPieceOptions,
146    ) -> Tokenizer {
147        let vocab: HashMap<_, _> = vocab
148            .iter()
149            .enumerate()
150            .map(|(i, token)| (token.to_string(), i as u32))
151            .collect();
152        let model = WordPiece::from_vocab(vocab, options);
153        let mut tokenizer = Tokenizer::new(
154            model,
155            TokenizerOptions {
156                cls_token: Some("[CLS]"),
157                sep_token: Some("[SEP]"),
158            },
159        )
160        .with_pre_tokenizer(Box::new(pre_tokenizers::Bert::new()));
161
162        if let Some(normalizer) = normalizer {
163            tokenizer = tokenizer.with_normalizer(normalizer);
164        }
165
166        tokenizer
167    }
168
169    #[test]
170    fn test_wordpiece_model() {
171        #[derive(Debug)]
172        struct Case<'a> {
173            text: &'a str,
174            tokens: &'a [&'a str],
175        }
176
177        let vocab = &[
178            "[CLS]", "[SEP]", "[UNK]", "This", "is", "a", "test", "sequence", "Word", "##Piece",
179            "Piece", "of", "pie", ".", "!", "?", "Hey", "Hello", "the", "game", "is", "set", "in",
180            "Faerûn",
181        ];
182
183        let cases = [
184            // Single sequence, no subwords.
185            Case {
186                text: "This is a test sequence",
187                tokens: &["[CLS]", "This", "is", "a", "test", "sequence", "[SEP]"],
188            },
189            Case {
190                text: "Piece of pie",
191                tokens: &["[CLS]", "Piece", "of", "pie", "[SEP]"],
192            },
193            // Sequence with unknown word.
194            Case {
195                text: "This is unknown sequence",
196                tokens: &["[CLS]", "This", "is", "[UNK]", "sequence", "[SEP]"],
197            },
198            // Sequence with subwords.
199            Case {
200                text: "WordPiece",
201                tokens: &["[CLS]", "Word", "##Piece", "[SEP]"],
202            },
203            // Empty sequence.
204            Case {
205                text: "",
206                tokens: &["[CLS]", "[SEP]"],
207            },
208            // Punctuation
209            Case {
210                text: "Hey! Hello?",
211                tokens: &["[CLS]", "Hey", "!", "Hello", "?", "[SEP]"],
212            },
213            // Word that exceeds length limit.
214            Case {
215                // note that "a" on its own is in the vocab
216                text: &"a".repeat(101),
217                tokens: &["[CLS]", "[UNK]", "[SEP]"],
218            },
219            // Chars requiring multiple bytes in UTF-8
220            Case {
221                text: "the game is set in Faerûn",
222                tokens: &["[CLS]", "the", "game", "is", "set", "in", "Faerûn", "[SEP]"],
223            },
224        ];
225
226        cases.test_each(|case| {
227            let &Case { text, tokens } = case;
228
229            let tokenizer = create_tokenizer(vocab, None, Default::default());
230            let encoded = tokenizer.encode(text, None).unwrap();
231            assert_eq!(
232                tokenizer.model().get_tokens(encoded.token_ids()).unwrap(),
233                tokens
234            );
235            assert!(encoded.token_type_ids().all(|ttid| ttid == 0));
236        });
237    }
238
239    #[test]
240    fn test_wordpiece_max_word_len() {
241        let vocab = &["[CLS]", "[SEP]", "[UNK]", "foo", "##bar", "##foo"];
242        let opts = WordPieceOptions {
243            max_word_len: Some(6),
244            ..Default::default()
245        };
246        let tokenizer = create_tokenizer(vocab, None, opts);
247
248        // The third word should be tokenized to `[UNK]` because it exceeds
249        // `max_word_len`.
250        let text = "foobar foofoo foobarfoo";
251        let encoded = tokenizer.encode(text, None).unwrap();
252
253        assert_eq!(
254            tokenizer.model().get_tokens(encoded.token_ids()).unwrap(),
255            &["[CLS]", "foo", "##bar", "foo", "##foo", "[UNK]", "[SEP]"]
256        );
257    }
258
259    #[test]
260    fn test_wordpiece_model_lowercase() {
261        #[derive(Debug)]
262        struct Case<'a> {
263            text: &'a str,
264            tokens: &'a [&'a str],
265        }
266
267        let vocab = &[
268            "[CLS]", "[SEP]", "[UNK]", "this", "is", "a", "test", "sequence",
269        ];
270
271        let cases = [
272            // Single sequence, no subwords.
273            Case {
274                text: "this is a test sequence",
275                tokens: &["[CLS]", "this", "is", "a", "test", "sequence", "[SEP]"],
276            },
277            Case {
278                text: "THIS IS A TEST SEQUENCE",
279                tokens: &["[CLS]", "this", "is", "a", "test", "sequence", "[SEP]"],
280            },
281        ];
282
283        cases.test_each(|case| {
284            let &Case { text, tokens } = case;
285
286            let normalizer = normalizers::Bert::new(normalizers::BertOptions {
287                lowercase: true,
288                ..Default::default()
289            });
290            let tokenizer = create_tokenizer(vocab, Some(Box::new(normalizer)), Default::default());
291
292            let encoded = tokenizer.encode(text, None).unwrap();
293            assert_eq!(
294                tokenizer.model().get_tokens(encoded.token_ids()).unwrap(),
295                tokens
296            );
297            assert!(encoded.token_type_ids().all(|ttid| ttid == 0));
298        })
299    }
300
301    #[test]
302    fn test_decode() {
303        #[derive(Debug)]
304        struct Case<'a> {
305            input: &'a str,
306            expected: &'a str,
307        }
308
309        let cases = [
310            Case {
311                input: "",
312                expected: "[CLS] [SEP]",
313            },
314            Case {
315                input: "this is a test sequence",
316                expected: "[CLS] this is a test sequence [SEP]",
317            },
318            Case {
319                input: "THIS IS A TEST SEQUENCE",
320                expected: "[CLS] this is a test sequence [SEP]",
321            },
322        ];
323
324        let vocab = &[
325            "[CLS]", "[SEP]", "[UNK]", "this", "is", "a", "test", "sequence",
326        ];
327
328        cases.test_each(|case| {
329            let &Case { input, expected } = case;
330
331            let normalizer = normalizers::Bert::new(normalizers::BertOptions {
332                lowercase: true,
333                ..Default::default()
334            });
335            let tokenizer = create_tokenizer(vocab, Some(Box::new(normalizer)), Default::default());
336
337            let encoded = tokenizer.encode(input, None).unwrap();
338            let decoded = tokenizer.decode(encoded.token_ids()).unwrap();
339            assert_eq!(decoded, expected);
340        })
341    }
342}