use std::{collections::HashMap, fs, vec};
use fancy_regex::Regex;
use serde::Deserialize;
const ENCODABLE_UTF8_PATTERN: &str =
r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+";
pub const PAD_TOKEN: i32 = 50256;
pub const END_OF_TEXT_TOKEN: i32 = 50256;
pub const END_OF_TEXT_STRING: &str = "<|endoftext|>";
pub struct Tokenizer {
bpe_ranks: HashMap<(String, String), u32>,
utf8_bytes_to_char_tokens: HashMap<u8, char>,
char_tokens_to_utf8_bytes: HashMap<char, u8>,
tokens_to_indexes: HashMap<String, i32>,
indexes_to_tokens: HashMap<i32, String>,
}
impl Tokenizer {
pub fn new(bpe_path: &str, encoder_path: &str) -> Self {
let bpe_str = fs::read_to_string(bpe_path).expect("wah");
let mut bpe_rank_tuples = Vec::new();
for line in bpe_str.lines().skip(1) {
let mut split = line.split_whitespace();
bpe_rank_tuples.push((
split.next().expect("k").to_string(),
split.next().expect("v").to_string(),
));
}
let mut bpe_ranks = HashMap::new();
for (tuple, rank) in bpe_rank_tuples.iter().zip(0..bpe_rank_tuples.len() as u32) {
bpe_ranks.insert(tuple.clone(), rank);
}
let encoder_str = fs::read_to_string(encoder_path).expect("wah");
let encoder_json: EncoderJson = serde_json::from_str(&encoder_str).expect("wah");
let tokens_to_indexes = encoder_json.token_indexes;
let mut indexes_to_tokens = HashMap::new();
for (k, v) in &tokens_to_indexes {
indexes_to_tokens.insert(*v, k.clone());
}
let (utf8_bytes_to_char_tokens, char_tokens_to_utf8_bytes) = create_utf8_char_maps();
Self {
bpe_ranks,
utf8_bytes_to_char_tokens,
char_tokens_to_utf8_bytes,
tokens_to_indexes,
indexes_to_tokens,
}
}
pub fn encode_to_length(&self, text: &str, token_sequence_length: usize) -> (Vec<i32>, usize) {
let mut token_sequence = self.encode(text);
let padding_length = if token_sequence.len() > token_sequence_length {
0
} else {
token_sequence_length - token_sequence.len()
};
token_sequence.truncate(token_sequence_length);
while token_sequence.len() < token_sequence_length {
token_sequence.push(PAD_TOKEN);
}
(token_sequence, padding_length)
}
pub fn encode(&self, text: &str) -> Vec<i32> {
let mut token_sequence = vec![];
let mut has_eot_token = false;
let text = match text.ends_with(END_OF_TEXT_STRING) {
true => {
has_eot_token = true;
text.trim_end_matches(END_OF_TEXT_STRING)
}
false => text,
};
let utf8_pattern = Regex::new(ENCODABLE_UTF8_PATTERN).unwrap();
for utf8_fragment in utf8_pattern.captures_iter(text) {
let utf8_fragment = &utf8_fragment.unwrap()[0];
let mut token = String::new();
for utf8_byte in utf8_fragment.as_bytes() {
token.push(
*self
.utf8_bytes_to_char_tokens
.get(utf8_byte)
.expect("unexpected utf8 byte in input"),
)
}
let encoded_tokens = self.byte_pair_encode(&token);
for encoded_token in encoded_tokens.split(' ') {
let token_index = self
.tokens_to_indexes
.get(encoded_token)
.unwrap_or_else(|| {
panic!(
"unexpected bpe-token `{:?}` for token `{:?}` in input",
&encoded_token, &token
)
});
token_sequence.push(*token_index);
}
}
if has_eot_token {
token_sequence.push(END_OF_TEXT_TOKEN);
}
token_sequence
}
pub fn decode(&self, token_sequence: Vec<i32>) -> String {
let mut tokens = String::new();
for token_index in token_sequence {
let token = self
.indexes_to_tokens
.get(&token_index)
.expect("unexpected token index in output");
tokens.push_str(token);
}
let mut utf8_bytes = vec![];
for token in tokens.chars() {
let utf8_byte = self
.char_tokens_to_utf8_bytes
.get(&token)
.expect("unexpected token in output");
utf8_bytes.push(*utf8_byte);
}
String::from_utf8_lossy(&utf8_bytes).to_string()
}
fn byte_pair_encode(&self, token: &str) -> String {
let mut word: Vec<String> = token.chars().map(|c| c.to_string()).collect();
let pairs = Self::get_symbol_pairs(&word);
if pairs.is_none() {
return token.into();
}
let mut pairs = pairs.unwrap();
loop {
let min_pair = pairs.iter().min_by_key(|pair| {
let pair = (pair.0.to_string(), pair.1.to_string());
let rank = self.bpe_ranks.get(&pair).unwrap_or(&u32::MAX);
rank
});
if min_pair.is_none() {
break;
}
let min_pair = min_pair.unwrap();
if !self.bpe_ranks.contains_key(min_pair) {
break;
}
let (first, second) = min_pair;
let mut new_word = vec![];
let mut i = 0;
while i < word.len() {
if let Some(k) = word.iter().skip(i).position(|c| c == first) {
let k = i + k; new_word.extend_from_slice(&word[i..k]);
i = k;
} else {
new_word.extend_from_slice(&word[i..]);
break;
}
if &word[i] == first && i < word.len() - 1 && &word[i + 1] == second {
new_word.push(first.clone() + second);
i += 2;
} else {
new_word.push(word[i].clone());
i += 1;
}
}
word = new_word;
if word.len() == 1 {
break;
} else {
if let Some(new_pairs) = Self::get_symbol_pairs(&word) {
pairs = new_pairs;
} else {
break;
}
}
}
let mut return_word = String::new();
for i in 0..word.len() {
return_word.push_str(&word[i]);
if i + 1 < word.len() {
return_word.push(' ');
}
}
return_word
}
fn get_symbol_pairs(word: &Vec<String>) -> Option<Vec<(String, String)>> {
if word.len() < 2 {
return None;
}
let mut pairs = vec![];
let mut prev_char = &word[0];
for character in &word[1..] {
pairs.push((prev_char.to_string(), character.to_string()));
prev_char = character;
}
Some(pairs)
}
}
#[derive(Deserialize)]
struct EncoderJson {
#[serde(flatten)]
token_indexes: HashMap<String, i32>,
}
fn create_utf8_char_maps() -> (HashMap<u8, char>, HashMap<char, u8>) {
let a = '!' as u32;
let b = '~' as u32 + 1;
let mut list_one = (a..b).collect::<Vec<_>>();
let c = '¡' as u32;
let d = '¬' as u32 + 1;
let mut list_two = (c..d).collect::<Vec<_>>();
let e = '®' as u32;
let f = 'ÿ' as u32 + 1;
let mut list_three = (e..f).collect::<Vec<_>>();
list_one.append(&mut list_two);
list_one.append(&mut list_three);
let mut utf8_bytes: Vec<u32> = Vec::with_capacity(list_one.len());
for byte in list_one {
utf8_bytes.push(byte);
}
let mut utf8_char_codes = utf8_bytes.clone();
let mut i = 0;
for byte in 0u32..256 {
if !utf8_bytes.contains(&byte) {
utf8_bytes.push(byte);
utf8_char_codes.push(256 + i);
i += 1;
}
}
let mut bytes_to_chars = HashMap::new();
let mut chars_to_bytes = HashMap::new();
for (b, c) in utf8_bytes.iter().zip(utf8_char_codes.iter()) {
let utf8_byte = u8::try_from(*b).expect("wah");
let utf8_char = char::from_u32(*c).expect("wah");
bytes_to_chars.insert(utf8_byte, utf8_char);
chars_to_bytes.insert(utf8_char, utf8_byte);
}
(bytes_to_chars, chars_to_bytes)
}
#[cfg(test)]
mod test {
use super::*;
const BPE_PATH: &str = "./gpt-2-model/saved_models/124M_vocab.bpe";
const ENCODER_PATH: &str = "./gpt-2-model/saved_models/124M_encoder.json";
const INPUT_TEXT_STR: &str =
"GPT-2 is a machine learning model for natural language-processing;";
const INPUT_TEXT_TOKENS: &[i32] = &[
38, 11571, 12, 17, 318, 257, 4572, 4673, 2746, 329, 3288, 3303, 12, 36948, 26,
];
const OUTPUT_TEXT_STR: &str = " it is a simple, high-performance, and scalable machine learning model that is designed to be used in real-world applications.";
const OUTPUT_TEXT_TOKENS: &[i32] = &[
340, 318, 257, 2829, 11, 1029, 12, 26585, 11, 290, 43865, 4572, 4673, 2746, 326, 318, 3562,
284, 307, 973, 287, 1103, 12, 6894, 5479, 13,
];
#[test]
fn encode() {
let tokenizer = Tokenizer::new(BPE_PATH, ENCODER_PATH);
let tokens = tokenizer.encode(INPUT_TEXT_STR);
assert_eq!(tokens, Vec::from(INPUT_TEXT_TOKENS));
let text = tokenizer.decode(tokens);
assert_eq!(text, INPUT_TEXT_STR);
}
#[test]
fn decode() {
let tokenizer = Tokenizer::new(BPE_PATH, ENCODER_PATH);
let text = tokenizer.decode(Vec::from(OUTPUT_TEXT_TOKENS));
assert_eq!(text, OUTPUT_TEXT_STR);
let tokens = tokenizer.encode(&text);
assert_eq!(tokens, Vec::from(OUTPUT_TEXT_TOKENS));
}
}