gpt_model/
tokenizer.rs

1//! Pure Rust implementation of the GPT-2
2//! byte-pair encoder (aka "text tokenizer").
3use std::{collections::HashMap, fs, vec};
4
5use fancy_regex::Regex;
6use serde::Deserialize;
7
8/// Pattern used to match encodable UTF-8 text in unencoded text.
9const ENCODABLE_UTF8_PATTERN: &str =
10    r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+";
11
12/// Token `50256` is used as the padding token, which
13/// corresponds to the `<|endoftext|>` token in the
14/// OpenAI GPT-2 encoder.
15pub const PAD_TOKEN: i32 = 50256;
16
17pub const END_OF_TEXT_TOKEN: i32 = 50256;
18pub const END_OF_TEXT_STRING: &str = "<|endoftext|>";
19
20/// Tokenizer which converts strings into
21/// token sequences consumable by the
22/// GPT-2 model, and vice-versa.
23///
24/// This tokenizer loads its configuration
25/// from the original OpenAI GPT-2 encoder
26/// and vocabulary "byte-pair encoding" (BPE).
27pub struct Tokenizer {
28    /// Byte-pair encoding "ranks",
29    /// obtained from an existing
30    /// `*_vocab.bpe` file.
31    bpe_ranks: HashMap<(String, String), u32>,
32
33    /// Mapping of UTF-8 bytes to tokens,
34    /// obtained via computation on start-up.
35    ///
36    /// A "UTF-8 byte" is a literal `u8`
37    /// representing a unicode-encodable byte,
38    /// and a "token" is a character (like `c`).
39    utf8_bytes_to_char_tokens: HashMap<u8, char>,
40
41    /// Reverse of `bytes_to_tokens`
42    /// for decoding.
43    char_tokens_to_utf8_bytes: HashMap<char, u8>,
44
45    /// Mapping of tokens to their indexes,
46    /// typically obtained from an existing
47    /// `*_encoder.json` file.
48    ///
49    /// A "token" is a character (like `c`)
50    /// or a word fragment (like `ing`), and
51    /// an index is a number that represents
52    /// that token.
53    tokens_to_indexes: HashMap<String, i32>,
54
55    /// Reverse of `tokens_to_indexes`
56    /// for decoding.
57    indexes_to_tokens: HashMap<i32, String>,
58}
59
60impl Tokenizer {
61    /// Creates a new in-memory tokenizer
62    /// from the BPE file at `bpe_path`
63    /// and the character encoding file at `encoder_path`.
64    pub fn new(bpe_path: &str, encoder_path: &str) -> Self {
65        // Parse byte-pair encoding lines into tuples.
66        let bpe_str = fs::read_to_string(bpe_path).expect("wah");
67        let mut bpe_rank_tuples = Vec::new();
68        for line in bpe_str.lines().skip(1) {
69            let mut split = line.split_whitespace();
70            bpe_rank_tuples.push((
71                split.next().expect("k").to_string(),
72                split.next().expect("v").to_string(),
73            ));
74        }
75
76        // Build byte-pair encoding ranks,
77        // where each tuple is mapped to its index
78        // in the original byte-pair encoding file.
79        let mut bpe_ranks = HashMap::new();
80        for (tuple, rank) in bpe_rank_tuples.iter().zip(0..bpe_rank_tuples.len() as u32) {
81            bpe_ranks.insert(tuple.clone(), rank);
82        }
83
84        // Parse encoder JSON.
85        let encoder_str = fs::read_to_string(encoder_path).expect("wah");
86        let encoder_json: EncoderJson = serde_json::from_str(&encoder_str).expect("wah");
87        let tokens_to_indexes = encoder_json.token_indexes;
88        let mut indexes_to_tokens = HashMap::new();
89        for (k, v) in &tokens_to_indexes {
90            indexes_to_tokens.insert(*v, k.clone());
91        }
92
93        // Compute UTF-8 byte-token maps.
94        let (utf8_bytes_to_char_tokens, char_tokens_to_utf8_bytes) = create_utf8_char_maps();
95
96        Self {
97            bpe_ranks,
98            utf8_bytes_to_char_tokens,
99            char_tokens_to_utf8_bytes,
100            tokens_to_indexes,
101            indexes_to_tokens,
102        }
103    }
104
105    /// Encodes `text` into a token sequence,
106    /// truncating and/or "right-padding" the encoded
107    /// token sequence to fit `token_sequence_length`,
108    /// using [PAD_TOKEN] as the padding token.
109    ///
110    /// The returned tuple contains `(token_sequence, padding_length)`,
111    /// where `padding_length` is the number of padding tokens
112    /// in `token_sequence`. If the length of `token_sequence` before
113    /// truncation exceeds `token_sequence_length`, `padding_length`
114    /// will always be zero.
115    ///
116    /// ## Left vs. Right Padding
117    ///
118    /// Common wisdom in the ML community is to "pad-left"
119    /// on natural language models like GPT-2; that is,
120    /// by adding padding tokens to the _front_ of the input
121    /// tokens until they fit a required input length.
122    ///
123    /// However, this method "pads-right" by adding padding
124    /// tokens to the _end_ of the input tokens.
125    ///
126    /// Right-padding works because GPT-2 never looks "ahead"
127    /// (to the right) of its inputs, and so the right-padding
128    /// will not influence the inference results of any tokens
129    /// to the left.
130    ///
131    /// Conversely, left-padding on GPT-2 only works if an
132    /// attention mask is used, which tells the GPT-2 model
133    /// to ignore certain tokens (like the padding tokens).
134    /// However, attention masking is slightly more complicated
135    /// to implement (albeit more efficient); therefore, this
136    /// implementation does not use it.
137    pub fn encode_to_length(&self, text: &str, token_sequence_length: usize) -> (Vec<i32>, usize) {
138        let mut token_sequence = self.encode(text);
139        let padding_length = if token_sequence.len() > token_sequence_length {
140            0
141        } else {
142            token_sequence_length - token_sequence.len()
143        };
144
145        // Truncate to maximum length; no-op if shorter
146        // than max length.
147        token_sequence.truncate(token_sequence_length);
148
149        // Right-pad to maximum length; no-op if at
150        // max length.
151        while token_sequence.len() < token_sequence_length {
152            token_sequence.push(PAD_TOKEN);
153        }
154
155        (token_sequence, padding_length)
156    }
157
158    /// Encodes `text` into a token sequence for
159    /// consumption by the GPT-2 model.
160    pub fn encode(&self, text: &str) -> Vec<i32> {
161        let mut token_sequence = vec![];
162
163        // Strip end of text token.
164        let mut has_eot_token = false;
165        let text = match text.ends_with(END_OF_TEXT_STRING) {
166            true => {
167                has_eot_token = true;
168                text.trim_end_matches(END_OF_TEXT_STRING)
169            }
170            false => text,
171        };
172
173        // Find all encodable UTF-8 text.
174        let utf8_pattern = Regex::new(ENCODABLE_UTF8_PATTERN).unwrap();
175        for utf8_fragment in utf8_pattern.captures_iter(text) {
176            let utf8_fragment = &utf8_fragment.unwrap()[0];
177
178            // Convert token UTF-8 bytes to one or more tokens.
179            // Note: Rust strings are UTF-8 by default.
180            let mut token = String::new();
181            for utf8_byte in utf8_fragment.as_bytes() {
182                token.push(
183                    *self
184                        .utf8_bytes_to_char_tokens
185                        .get(utf8_byte)
186                        .expect("unexpected utf8 byte in input"),
187                )
188            }
189
190            // Encode token into byte-pairs.
191            let encoded_tokens = self.byte_pair_encode(&token);
192            for encoded_token in encoded_tokens.split(' ') {
193                let token_index = self
194                    .tokens_to_indexes
195                    .get(encoded_token)
196                    .unwrap_or_else(|| {
197                        panic!(
198                            "unexpected bpe-token `{:?}` for token `{:?}` in input",
199                            &encoded_token, &token
200                        )
201                    });
202                token_sequence.push(*token_index);
203            }
204        }
205
206        // If the end-of-text token was provided, add it to the text.
207        if has_eot_token {
208            token_sequence.push(END_OF_TEXT_TOKEN);
209        }
210
211        token_sequence
212    }
213
214    /// Decodes `token_sequence` into text.
215    pub fn decode(&self, token_sequence: Vec<i32>) -> String {
216        // Decode each token index into a token.
217        let mut tokens = String::new();
218        for token_index in token_sequence {
219            let token = self
220                .indexes_to_tokens
221                .get(&token_index)
222                .expect("unexpected token index in output");
223            tokens.push_str(token);
224        }
225
226        // Decode tokens into UTF-8 bytes.
227        let mut utf8_bytes = vec![];
228        for token in tokens.chars() {
229            let utf8_byte = self
230                .char_tokens_to_utf8_bytes
231                .get(&token)
232                .expect("unexpected token in output");
233            utf8_bytes.push(*utf8_byte);
234        }
235
236        // Decode UTf-8 bytes into a string.
237        String::from_utf8_lossy(&utf8_bytes).to_string()
238    }
239
240    /// todo: Do...the thing?
241    fn byte_pair_encode(&self, token: &str) -> String {
242        let mut word: Vec<String> = token.chars().map(|c| c.to_string()).collect();
243        let pairs = Self::get_symbol_pairs(&word);
244
245        // If no pairs were generated, there's only
246        // one actual token in the input, and thus
247        // no work to do.
248        if pairs.is_none() {
249            return token.into();
250        }
251        let mut pairs = pairs.unwrap();
252
253        // Perform encoding.
254        loop {
255            // Find the pair with the lowest rank.
256            let min_pair = pairs.iter().min_by_key(|pair| {
257                let pair = (pair.0.to_string(), pair.1.to_string());
258                let rank = self.bpe_ranks.get(&pair).unwrap_or(&u32::MAX);
259                rank
260            });
261
262            // If no known pair was found, we're done.
263            // todo: Why?
264            if min_pair.is_none() {
265                break;
266            }
267            let min_pair = min_pair.unwrap();
268            if !self.bpe_ranks.contains_key(min_pair) {
269                break;
270            }
271            let (first, second) = min_pair;
272
273            // todo: ???
274            let mut new_word = vec![];
275            let mut i = 0;
276            while i < word.len() {
277                // todo: ???
278                if let Some(k) = word.iter().skip(i).position(|c| c == first) {
279                    let k = i + k; // adjust for skip
280                    new_word.extend_from_slice(&word[i..k]);
281                    i = k;
282                } else {
283                    new_word.extend_from_slice(&word[i..]);
284                    break;
285                }
286
287                // todo: ???
288                if &word[i] == first && i < word.len() - 1 && &word[i + 1] == second {
289                    new_word.push(first.clone() + second);
290                    i += 2;
291
292                // todo: ???
293                } else {
294                    new_word.push(word[i].clone());
295                    i += 1;
296                }
297            }
298
299            // todo: ???
300            word = new_word;
301            if word.len() == 1 {
302                break;
303            } else {
304                // todo: different behavior from original
305                //       to gracefully handle errors
306                if let Some(new_pairs) = Self::get_symbol_pairs(&word) {
307                    pairs = new_pairs;
308                } else {
309                    break;
310                }
311            }
312        }
313
314        // Convert word into space-separated tokens.
315        let mut return_word = String::new();
316        for i in 0..word.len() {
317            return_word.push_str(&word[i]);
318            if i + 1 < word.len() {
319                return_word.push(' ');
320            }
321        }
322
323        return_word
324    }
325
326    /// Returns the set of all _pairs_
327    /// of unicode symbols in `word`,
328    /// returning nothing if there are no pairs.
329    fn get_symbol_pairs(word: &Vec<String>) -> Option<Vec<(String, String)>> {
330        if word.len() < 2 {
331            return None;
332        }
333
334        let mut pairs = vec![];
335        let mut prev_char = &word[0];
336        for character in &word[1..] {
337            pairs.push((prev_char.to_string(), character.to_string()));
338            prev_char = character;
339        }
340
341        Some(pairs)
342    }
343}
344
345/// Structure of the data in an encoder JSON file.
346#[derive(Deserialize)]
347struct EncoderJson {
348    #[serde(flatten)]
349    token_indexes: HashMap<String, i32>,
350}
351
352/// Returns a pair of mirrored maps of
353/// bytes to unicode characters, and vice-versa.
354///
355/// todo: The original OpenAI encoder docs
356/// are fairly sparse on what, exactly,
357/// these maps do. E.g., they say things like:
358///       
359/// > To avoid that, we want lookup tables between
360/// > utf-8 bytes and unicode strings. And avoids
361/// > mapping to whitespace/control characters the
362/// > bpe code barfs on.
363///
364/// Additional docs are needed for this function.
365fn create_utf8_char_maps() -> (HashMap<u8, char>, HashMap<char, u8>) {
366    let a = '!' as u32;
367    let b = '~' as u32 + 1;
368    let mut list_one = (a..b).collect::<Vec<_>>();
369    let c = '¡' as u32;
370    let d = '¬' as u32 + 1;
371    let mut list_two = (c..d).collect::<Vec<_>>();
372    let e = '®' as u32;
373    let f = 'ÿ' as u32 + 1;
374    let mut list_three = (e..f).collect::<Vec<_>>();
375
376    // todo: called 'bs' in the OpenAI encoder source
377    list_one.append(&mut list_two);
378    list_one.append(&mut list_three);
379    let mut utf8_bytes: Vec<u32> = Vec::with_capacity(list_one.len());
380    for byte in list_one {
381        utf8_bytes.push(byte);
382    }
383
384    // todo: called 'cs' in the OpenAI encoder source
385    let mut utf8_char_codes = utf8_bytes.clone();
386
387    // todo: ?
388    let mut i = 0;
389    for byte in 0u32..256 {
390        if !utf8_bytes.contains(&byte) {
391            utf8_bytes.push(byte);
392            utf8_char_codes.push(256 + i);
393            i += 1;
394        }
395    }
396
397    // todo: ?
398    let mut bytes_to_chars = HashMap::new();
399    let mut chars_to_bytes = HashMap::new();
400    for (b, c) in utf8_bytes.iter().zip(utf8_char_codes.iter()) {
401        let utf8_byte = u8::try_from(*b).expect("wah");
402        let utf8_char = char::from_u32(*c).expect("wah");
403        bytes_to_chars.insert(utf8_byte, utf8_char);
404        chars_to_bytes.insert(utf8_char, utf8_byte);
405    }
406
407    (bytes_to_chars, chars_to_bytes)
408}
409
410#[cfg(test)]
411mod test {
412
413    use super::*;
414
415    // Paths to OpenAI training data for the 124M (smallest) GPT-2 model.
416    const BPE_PATH: &str = "./gpt-2-model/saved_models/124M_vocab.bpe";
417    const ENCODER_PATH: &str = "./gpt-2-model/saved_models/124M_encoder.json";
418
419    // Sample input text for encoding, along
420    // with expected encoded tokens.
421    const INPUT_TEXT_STR: &str =
422        "GPT-2 is a machine learning model for natural language-processing;";
423    const INPUT_TEXT_TOKENS: &[i32] = &[
424        38, 11571, 12, 17, 318, 257, 4572, 4673, 2746, 329, 3288, 3303, 12, 36948, 26,
425    ];
426
427    // Sample output tokens for decoding,
428    // along with expected decoded text.
429    const OUTPUT_TEXT_STR: &str = " it is a simple, high-performance, and scalable machine learning model that is designed to be used in real-world applications.";
430    const OUTPUT_TEXT_TOKENS: &[i32] = &[
431        340, 318, 257, 2829, 11, 1029, 12, 26585, 11, 290, 43865, 4572, 4673, 2746, 326, 318, 3562,
432        284, 307, 973, 287, 1103, 12, 6894, 5479, 13,
433    ];
434
435    #[test]
436    fn encode() {
437        let tokenizer = Tokenizer::new(BPE_PATH, ENCODER_PATH);
438        let tokens = tokenizer.encode(INPUT_TEXT_STR);
439        assert_eq!(tokens, Vec::from(INPUT_TEXT_TOKENS));
440
441        // Reverse encoding to double-check.
442        let text = tokenizer.decode(tokens);
443        assert_eq!(text, INPUT_TEXT_STR);
444    }
445
446    #[test]
447    fn decode() {
448        let tokenizer = Tokenizer::new(BPE_PATH, ENCODER_PATH);
449        let text = tokenizer.decode(Vec::from(OUTPUT_TEXT_TOKENS));
450        assert_eq!(text, OUTPUT_TEXT_STR);
451
452        // Reverse decoding to double-check.
453        let tokens = tokenizer.encode(&text);
454        assert_eq!(tokens, Vec::from(OUTPUT_TEXT_TOKENS));
455    }
456}