1use std::{collections::HashMap, fs, vec};
4
5use fancy_regex::Regex;
6use serde::Deserialize;
7
8const ENCODABLE_UTF8_PATTERN: &str =
10 r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+";
11
12pub const PAD_TOKEN: i32 = 50256;
16
17pub const END_OF_TEXT_TOKEN: i32 = 50256;
18pub const END_OF_TEXT_STRING: &str = "<|endoftext|>";
19
20pub struct Tokenizer {
28 bpe_ranks: HashMap<(String, String), u32>,
32
33 utf8_bytes_to_char_tokens: HashMap<u8, char>,
40
41 char_tokens_to_utf8_bytes: HashMap<char, u8>,
44
45 tokens_to_indexes: HashMap<String, i32>,
54
55 indexes_to_tokens: HashMap<i32, String>,
58}
59
60impl Tokenizer {
61 pub fn new(bpe_path: &str, encoder_path: &str) -> Self {
65 let bpe_str = fs::read_to_string(bpe_path).expect("wah");
67 let mut bpe_rank_tuples = Vec::new();
68 for line in bpe_str.lines().skip(1) {
69 let mut split = line.split_whitespace();
70 bpe_rank_tuples.push((
71 split.next().expect("k").to_string(),
72 split.next().expect("v").to_string(),
73 ));
74 }
75
76 let mut bpe_ranks = HashMap::new();
80 for (tuple, rank) in bpe_rank_tuples.iter().zip(0..bpe_rank_tuples.len() as u32) {
81 bpe_ranks.insert(tuple.clone(), rank);
82 }
83
84 let encoder_str = fs::read_to_string(encoder_path).expect("wah");
86 let encoder_json: EncoderJson = serde_json::from_str(&encoder_str).expect("wah");
87 let tokens_to_indexes = encoder_json.token_indexes;
88 let mut indexes_to_tokens = HashMap::new();
89 for (k, v) in &tokens_to_indexes {
90 indexes_to_tokens.insert(*v, k.clone());
91 }
92
93 let (utf8_bytes_to_char_tokens, char_tokens_to_utf8_bytes) = create_utf8_char_maps();
95
96 Self {
97 bpe_ranks,
98 utf8_bytes_to_char_tokens,
99 char_tokens_to_utf8_bytes,
100 tokens_to_indexes,
101 indexes_to_tokens,
102 }
103 }
104
105 pub fn encode_to_length(&self, text: &str, token_sequence_length: usize) -> (Vec<i32>, usize) {
138 let mut token_sequence = self.encode(text);
139 let padding_length = if token_sequence.len() > token_sequence_length {
140 0
141 } else {
142 token_sequence_length - token_sequence.len()
143 };
144
145 token_sequence.truncate(token_sequence_length);
148
149 while token_sequence.len() < token_sequence_length {
152 token_sequence.push(PAD_TOKEN);
153 }
154
155 (token_sequence, padding_length)
156 }
157
158 pub fn encode(&self, text: &str) -> Vec<i32> {
161 let mut token_sequence = vec![];
162
163 let mut has_eot_token = false;
165 let text = match text.ends_with(END_OF_TEXT_STRING) {
166 true => {
167 has_eot_token = true;
168 text.trim_end_matches(END_OF_TEXT_STRING)
169 }
170 false => text,
171 };
172
173 let utf8_pattern = Regex::new(ENCODABLE_UTF8_PATTERN).unwrap();
175 for utf8_fragment in utf8_pattern.captures_iter(text) {
176 let utf8_fragment = &utf8_fragment.unwrap()[0];
177
178 let mut token = String::new();
181 for utf8_byte in utf8_fragment.as_bytes() {
182 token.push(
183 *self
184 .utf8_bytes_to_char_tokens
185 .get(utf8_byte)
186 .expect("unexpected utf8 byte in input"),
187 )
188 }
189
190 let encoded_tokens = self.byte_pair_encode(&token);
192 for encoded_token in encoded_tokens.split(' ') {
193 let token_index = self
194 .tokens_to_indexes
195 .get(encoded_token)
196 .unwrap_or_else(|| {
197 panic!(
198 "unexpected bpe-token `{:?}` for token `{:?}` in input",
199 &encoded_token, &token
200 )
201 });
202 token_sequence.push(*token_index);
203 }
204 }
205
206 if has_eot_token {
208 token_sequence.push(END_OF_TEXT_TOKEN);
209 }
210
211 token_sequence
212 }
213
214 pub fn decode(&self, token_sequence: Vec<i32>) -> String {
216 let mut tokens = String::new();
218 for token_index in token_sequence {
219 let token = self
220 .indexes_to_tokens
221 .get(&token_index)
222 .expect("unexpected token index in output");
223 tokens.push_str(token);
224 }
225
226 let mut utf8_bytes = vec![];
228 for token in tokens.chars() {
229 let utf8_byte = self
230 .char_tokens_to_utf8_bytes
231 .get(&token)
232 .expect("unexpected token in output");
233 utf8_bytes.push(*utf8_byte);
234 }
235
236 String::from_utf8_lossy(&utf8_bytes).to_string()
238 }
239
240 fn byte_pair_encode(&self, token: &str) -> String {
242 let mut word: Vec<String> = token.chars().map(|c| c.to_string()).collect();
243 let pairs = Self::get_symbol_pairs(&word);
244
245 if pairs.is_none() {
249 return token.into();
250 }
251 let mut pairs = pairs.unwrap();
252
253 loop {
255 let min_pair = pairs.iter().min_by_key(|pair| {
257 let pair = (pair.0.to_string(), pair.1.to_string());
258 let rank = self.bpe_ranks.get(&pair).unwrap_or(&u32::MAX);
259 rank
260 });
261
262 if min_pair.is_none() {
265 break;
266 }
267 let min_pair = min_pair.unwrap();
268 if !self.bpe_ranks.contains_key(min_pair) {
269 break;
270 }
271 let (first, second) = min_pair;
272
273 let mut new_word = vec![];
275 let mut i = 0;
276 while i < word.len() {
277 if let Some(k) = word.iter().skip(i).position(|c| c == first) {
279 let k = i + k; new_word.extend_from_slice(&word[i..k]);
281 i = k;
282 } else {
283 new_word.extend_from_slice(&word[i..]);
284 break;
285 }
286
287 if &word[i] == first && i < word.len() - 1 && &word[i + 1] == second {
289 new_word.push(first.clone() + second);
290 i += 2;
291
292 } else {
294 new_word.push(word[i].clone());
295 i += 1;
296 }
297 }
298
299 word = new_word;
301 if word.len() == 1 {
302 break;
303 } else {
304 if let Some(new_pairs) = Self::get_symbol_pairs(&word) {
307 pairs = new_pairs;
308 } else {
309 break;
310 }
311 }
312 }
313
314 let mut return_word = String::new();
316 for i in 0..word.len() {
317 return_word.push_str(&word[i]);
318 if i + 1 < word.len() {
319 return_word.push(' ');
320 }
321 }
322
323 return_word
324 }
325
326 fn get_symbol_pairs(word: &Vec<String>) -> Option<Vec<(String, String)>> {
330 if word.len() < 2 {
331 return None;
332 }
333
334 let mut pairs = vec![];
335 let mut prev_char = &word[0];
336 for character in &word[1..] {
337 pairs.push((prev_char.to_string(), character.to_string()));
338 prev_char = character;
339 }
340
341 Some(pairs)
342 }
343}
344
345#[derive(Deserialize)]
347struct EncoderJson {
348 #[serde(flatten)]
349 token_indexes: HashMap<String, i32>,
350}
351
352fn create_utf8_char_maps() -> (HashMap<u8, char>, HashMap<char, u8>) {
366 let a = '!' as u32;
367 let b = '~' as u32 + 1;
368 let mut list_one = (a..b).collect::<Vec<_>>();
369 let c = '¡' as u32;
370 let d = '¬' as u32 + 1;
371 let mut list_two = (c..d).collect::<Vec<_>>();
372 let e = '®' as u32;
373 let f = 'ÿ' as u32 + 1;
374 let mut list_three = (e..f).collect::<Vec<_>>();
375
376 list_one.append(&mut list_two);
378 list_one.append(&mut list_three);
379 let mut utf8_bytes: Vec<u32> = Vec::with_capacity(list_one.len());
380 for byte in list_one {
381 utf8_bytes.push(byte);
382 }
383
384 let mut utf8_char_codes = utf8_bytes.clone();
386
387 let mut i = 0;
389 for byte in 0u32..256 {
390 if !utf8_bytes.contains(&byte) {
391 utf8_bytes.push(byte);
392 utf8_char_codes.push(256 + i);
393 i += 1;
394 }
395 }
396
397 let mut bytes_to_chars = HashMap::new();
399 let mut chars_to_bytes = HashMap::new();
400 for (b, c) in utf8_bytes.iter().zip(utf8_char_codes.iter()) {
401 let utf8_byte = u8::try_from(*b).expect("wah");
402 let utf8_char = char::from_u32(*c).expect("wah");
403 bytes_to_chars.insert(utf8_byte, utf8_char);
404 chars_to_bytes.insert(utf8_char, utf8_byte);
405 }
406
407 (bytes_to_chars, chars_to_bytes)
408}
409
410#[cfg(test)]
411mod test {
412
413 use super::*;
414
415 const BPE_PATH: &str = "./gpt-2-model/saved_models/124M_vocab.bpe";
417 const ENCODER_PATH: &str = "./gpt-2-model/saved_models/124M_encoder.json";
418
419 const INPUT_TEXT_STR: &str =
422 "GPT-2 is a machine learning model for natural language-processing;";
423 const INPUT_TEXT_TOKENS: &[i32] = &[
424 38, 11571, 12, 17, 318, 257, 4572, 4673, 2746, 329, 3288, 3303, 12, 36948, 26,
425 ];
426
427 const OUTPUT_TEXT_STR: &str = " it is a simple, high-performance, and scalable machine learning model that is designed to be used in real-world applications.";
430 const OUTPUT_TEXT_TOKENS: &[i32] = &[
431 340, 318, 257, 2829, 11, 1029, 12, 26585, 11, 290, 43865, 4572, 4673, 2746, 326, 318, 3562,
432 284, 307, 973, 287, 1103, 12, 6894, 5479, 13,
433 ];
434
435 #[test]
436 fn encode() {
437 let tokenizer = Tokenizer::new(BPE_PATH, ENCODER_PATH);
438 let tokens = tokenizer.encode(INPUT_TEXT_STR);
439 assert_eq!(tokens, Vec::from(INPUT_TEXT_TOKENS));
440
441 let text = tokenizer.decode(tokens);
443 assert_eq!(text, INPUT_TEXT_STR);
444 }
445
446 #[test]
447 fn decode() {
448 let tokenizer = Tokenizer::new(BPE_PATH, ENCODER_PATH);
449 let text = tokenizer.decode(Vec::from(OUTPUT_TEXT_TOKENS));
450 assert_eq!(text, OUTPUT_TEXT_STR);
451
452 let tokens = tokenizer.encode(&text);
454 assert_eq!(tokens, Vec::from(OUTPUT_TEXT_TOKENS));
455 }
456}