use crate::error::TokenizerError;
use crate::tokenizer::tokenization_utils::{openai_gpt_bpe, split_on_bpe_pairs, BpeCache};
use crate::tokenizer::{BaseTokenizer, MultiThreadedTokenizer, Tokenizer};
use crate::vocab::bpe_vocab::BpePairVocab;
use crate::vocab::{OpenAiGptVocab, Vocab};
use crate::{Mask, Token, TokenRef};
use std::collections::HashMap;
use std::path::Path;
use std::sync::RwLock;
pub struct OpenAiGptTokenizer {
vocab: OpenAiGptVocab,
base_tokenizer: BaseTokenizer<OpenAiGptVocab>,
bpe_ranks: BpePairVocab,
cache: BpeCache,
}
impl OpenAiGptTokenizer {
pub fn from_file<P: AsRef<Path>, M: AsRef<Path>>(
vocab_path: P,
merges_path: M,
lower_case: bool,
) -> Result<OpenAiGptTokenizer, TokenizerError> {
let vocab = OpenAiGptVocab::from_file(vocab_path)?;
let base_tokenizer = BaseTokenizer::from_existing_vocab(vocab.clone(), lower_case, true);
let bpe_ranks = BpePairVocab::from_file(merges_path)?;
let cache = RwLock::new(HashMap::new());
Ok(OpenAiGptTokenizer {
vocab,
base_tokenizer,
bpe_ranks,
cache,
})
}
pub fn from_file_with_special_token_mapping<V: AsRef<Path>, M: AsRef<Path>, S: AsRef<Path>>(
vocab_path: V,
merges_path: M,
lower_case: bool,
special_token_mapping_path: S,
) -> Result<OpenAiGptTokenizer, TokenizerError> {
let vocab = OpenAiGptVocab::from_file_with_special_token_mapping(
vocab_path,
special_token_mapping_path,
)?;
let base_tokenizer = BaseTokenizer::from_existing_vocab(vocab.clone(), lower_case, true);
let bpe_ranks = BpePairVocab::from_file(merges_path)?;
let cache = RwLock::new(HashMap::new());
Ok(OpenAiGptTokenizer {
vocab,
base_tokenizer,
bpe_ranks,
cache,
})
}
pub fn from_existing_vocab_and_merges(
vocab: OpenAiGptVocab,
merges: BpePairVocab,
lower_case: bool,
) -> OpenAiGptTokenizer {
let base_tokenizer = BaseTokenizer::from_existing_vocab(vocab.clone(), lower_case, true);
let cache = RwLock::new(HashMap::new());
OpenAiGptTokenizer {
vocab,
base_tokenizer,
bpe_ranks: merges,
cache,
}
}
}
impl Tokenizer<OpenAiGptVocab> for OpenAiGptTokenizer {
fn vocab(&self) -> &OpenAiGptVocab {
&self.vocab
}
fn tokenize_to_tokens(&self, initial_token: TokenRef) -> Vec<Token> {
let tokens: Vec<Token> = self
.base_tokenizer
.tokenize_to_tokens(initial_token)
.into_iter()
.flat_map(|token| {
if token.mask != Mask::Special && token.mask != Mask::Unknown {
split_on_bpe_pairs(
token.as_ref(),
openai_gpt_bpe,
&self.bpe_ranks,
&self.cache,
false,
)
} else {
vec![token]
}
})
.collect();
tokens
}
fn convert_tokens_to_string(&self, tokens: Vec<String>) -> String {
tokens.join("").replace("</w>", " ").trim().to_owned()
}
}
impl MultiThreadedTokenizer<OpenAiGptVocab> for OpenAiGptTokenizer {}
#[cfg(test)]
mod tests {
use super::*;
use crate::tokenizer::base_tokenizer::{Offset, TokenizedInput, TruncationStrategy};
use crate::vocab::base_vocab::{swap_key_values, SpecialTokenMap};
use crate::vocab::OpenAiGptVocab;
use itertools::Itertools;
use std::collections::HashMap;
fn generate_test_vocab() -> OpenAiGptVocab {
let values: HashMap<String, i64> = [
("t".to_owned(), 0),
("h".to_owned(), 1),
("a</w>".to_owned(), 2),
("n".to_owned(), 3),
("the".to_owned(), 4),
("Ä ".to_owned(), 5),
("<unk>".to_owned(), 6),
("o</w>".to_owned(), 7),
("the</w>".to_owned(), 8),
("rth</w>".to_owned(), 9),
("ea".to_owned(), 10),
]
.iter()
.cloned()
.collect();
let special_token_map = SpecialTokenMap {
unk_token: "<unk>".to_string(),
pad_token: None,
bos_token: None,
sep_token: None,
cls_token: None,
eos_token: None,
mask_token: None,
additional_special_tokens: None,
};
let special_values: HashMap<String, i64> =
[("<unk>".to_owned(), 6)].iter().cloned().collect();
let indices = swap_key_values(&values);
let special_indices = swap_key_values(&special_values);
OpenAiGptVocab {
values,
indices,
special_token_map,
special_values,
special_indices,
}
}
fn generate_test_merges() -> BpePairVocab {
let values: HashMap<(String, String), i64> = [
(("4".to_owned(), "t".to_owned()), 0),
(("2".to_owned(), "n".to_owned()), 1),
(("r".to_owned(), "th</w>".to_owned()), 2),
(("t".to_owned(), "he</w>".to_owned()), 3),
(("h".to_owned(), "e".to_owned()), 4),
(("t".to_owned(), "h</w>".to_owned()), 5),
(("t".to_owned(), "h".to_owned()), 6),
(("th".to_owned(), "e</w>".to_owned()), 7),
(("e".to_owned(), "a".to_owned()), 8),
]
.iter()
.cloned()
.collect();
BpePairVocab { values }
}
#[test]
fn test_openai_gpt_tokenizer() {
let vocab = generate_test_vocab();
let merges = generate_test_merges();
let openai_gpt_tokenizer: OpenAiGptTokenizer =
OpenAiGptTokenizer::from_existing_vocab_and_merges(vocab, merges, true);
let test_tuples = [
("The earth", vec!["the</w>", "ea", "rth</w>"]),
("", vec![]),
(" ", vec![]),
(" \n ", vec![]),
];
let source_texts: Vec<&str> = test_tuples.iter().map(|v| v.0).collect();
let expected_results: Vec<Vec<&str>> = test_tuples.iter().map(|v| v.1.clone()).collect();
for (source_text, expected_result) in test_tuples.iter() {
assert_eq!(openai_gpt_tokenizer.tokenize(source_text), *expected_result);
}
assert_eq!(
MultiThreadedTokenizer::tokenize_list(&openai_gpt_tokenizer, &source_texts),
expected_results
);
}
#[test]
fn test_openai_gpt_tokenizer_no_lower_casing() {
let vocab = generate_test_vocab();
let merges = generate_test_merges();
let openai_gpt_tokenizer: OpenAiGptTokenizer =
OpenAiGptTokenizer::from_existing_vocab_and_merges(vocab, merges, false);
let test_tuples = [
("The Earth", vec!["T", "h", "e</w>", "E", "a", "rth</w>"]),
("", vec![]),
(" ", vec![]),
(" \n ", vec![]),
];
let source_texts: Vec<&str> = test_tuples.iter().map(|v| v.0).collect();
let expected_results: Vec<Vec<&str>> = test_tuples.iter().map(|v| v.1.clone()).collect();
for (source_text, expected_result) in test_tuples.iter() {
assert_eq!(openai_gpt_tokenizer.tokenize(source_text), *expected_result);
}
assert_eq!(
MultiThreadedTokenizer::tokenize_list(&openai_gpt_tokenizer, &source_texts),
expected_results
);
}
#[test]
fn test_encode() {
let vocab = generate_test_vocab();
let merges = generate_test_merges();
let openai_gpt_tokenizer: OpenAiGptTokenizer =
OpenAiGptTokenizer::from_existing_vocab_and_merges(vocab, merges, true);
let truncation_strategy = TruncationStrategy::LongestFirst;
let test_tuples = [
(
"the earth",
TokenizedInput {
token_ids: vec![8, 10, 9],
segment_ids: vec![0, 0, 0],
special_tokens_mask: vec![0, 0, 0],
overflowing_tokens: vec![],
num_truncated_tokens: 0,
token_offsets: vec![
Some(Offset { begin: 0, end: 3 }),
Some(Offset { begin: 4, end: 6 }),
Some(Offset { begin: 6, end: 9 }),
],
reference_offsets: vec![vec![0, 1, 2], vec![4, 5], vec![6, 7, 8]],
mask: vec![Mask::None, Mask::Begin, Mask::Continuation],
},
),
(
" ",
TokenizedInput {
token_ids: vec![],
segment_ids: vec![],
special_tokens_mask: vec![],
overflowing_tokens: vec![],
num_truncated_tokens: 0,
token_offsets: vec![],
reference_offsets: vec![],
mask: vec![],
},
),
(
"",
TokenizedInput {
token_ids: vec![],
segment_ids: vec![],
special_tokens_mask: vec![],
overflowing_tokens: vec![],
num_truncated_tokens: 0,
token_offsets: vec![],
reference_offsets: vec![],
mask: vec![],
},
),
];
let source_texts: Vec<&str> = test_tuples.iter().map(|v| v.0).collect();
let expected_results: Vec<TokenizedInput> =
test_tuples.iter().map(|v| v.1.clone()).collect();
for (source_text, expected_result) in test_tuples.iter() {
assert_eq!(
openai_gpt_tokenizer.encode(source_text, None, 128, &truncation_strategy, 0),
*expected_result
);
}
assert_eq!(
MultiThreadedTokenizer::encode_list(
&openai_gpt_tokenizer,
&source_texts,
128,
&truncation_strategy,
0
),
expected_results
);
}
#[test]
fn test_decode() {
let vocab = generate_test_vocab();
let merges = generate_test_merges();
let openai_gpt_tokenizer: OpenAiGptTokenizer =
OpenAiGptTokenizer::from_existing_vocab_and_merges(vocab, merges, true);
let skip_special_tokens = false;
let clean_up_tokenization_spaces = false;
let test_tuples = [(vec![8, 10, 9], "the earth")];
let source_ids: Vec<Vec<i64>> = test_tuples.iter().map(|v| v.0.clone()).collect_vec();
let expected_results: Vec<&str> = test_tuples.iter().map(|v| v.1).collect_vec();
for (source_ids, expected_result) in test_tuples.iter() {
assert_eq!(
openai_gpt_tokenizer.decode(
source_ids,
skip_special_tokens,
clean_up_tokenization_spaces
),
*expected_result
);
}
assert_eq!(
Tokenizer::decode_list(
&openai_gpt_tokenizer,
&source_ids,
skip_special_tokens,
clean_up_tokenization_spaces
),
expected_results
);
}
}