use std::collections::HashMap;
use lazy_static::lazy_static;
#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)]
pub enum Tokenizer {
Cl100kBase,
P50kBase,
R50kBase,
P50kEdit,
Gpt2,
}
const MODEL_PREFIX_TO_TOKENIZER: &[(&str, Tokenizer)] = &[
("gpt-4-", Tokenizer::Cl100kBase),
("gpt-3.5-turbo-", Tokenizer::Cl100kBase),
];
const MODEL_TO_TOKENIZER: &[(&str, Tokenizer)] = &[
("gpt-4-32k", Tokenizer::Cl100kBase),
("gpt-4", Tokenizer::Cl100kBase),
("gpt-3.5-turbo", Tokenizer::Cl100kBase),
("text-davinci-003", Tokenizer::P50kBase),
("text-davinci-002", Tokenizer::P50kBase),
("text-davinci-001", Tokenizer::R50kBase),
("text-curie-001", Tokenizer::R50kBase),
("text-babbage-001", Tokenizer::R50kBase),
("text-ada-001", Tokenizer::R50kBase),
("davinci", Tokenizer::R50kBase),
("curie", Tokenizer::R50kBase),
("babbage", Tokenizer::R50kBase),
("ada", Tokenizer::R50kBase),
("code-davinci-002", Tokenizer::P50kBase),
("code-davinci-001", Tokenizer::P50kBase),
("code-cushman-002", Tokenizer::P50kBase),
("code-cushman-001", Tokenizer::P50kBase),
("davinci-codex", Tokenizer::P50kBase),
("cushman-codex", Tokenizer::P50kBase),
("text-davinci-edit-001", Tokenizer::P50kEdit),
("code-davinci-edit-001", Tokenizer::P50kEdit),
("text-embedding-ada-002", Tokenizer::Cl100kBase),
("text-similarity-davinci-001", Tokenizer::R50kBase),
("text-similarity-curie-001", Tokenizer::R50kBase),
("text-similarity-babbage-001", Tokenizer::R50kBase),
("text-similarity-ada-001", Tokenizer::R50kBase),
("text-search-davinci-doc-001", Tokenizer::R50kBase),
("text-search-curie-doc-001", Tokenizer::R50kBase),
("text-search-babbage-doc-001", Tokenizer::R50kBase),
("text-search-ada-doc-001", Tokenizer::R50kBase),
("code-search-babbage-code-001", Tokenizer::R50kBase),
("code-search-ada-code-001", Tokenizer::R50kBase),
("gpt2", Tokenizer::Gpt2),
];
lazy_static! {
static ref MODEL_TO_TOKENIZER_MAP: HashMap<&'static str, Tokenizer> = {
let mut map = HashMap::new();
MODEL_TO_TOKENIZER.iter().for_each(|&(model, tokenizer)| {
map.insert(model, tokenizer);
});
map
};
}
pub fn get_tokenizer(model_name: &str) -> Option<Tokenizer> {
if let Some(tokenizer) = MODEL_TO_TOKENIZER_MAP.get(model_name) {
return Some(*tokenizer);
}
if let Some(tokenizer) = MODEL_PREFIX_TO_TOKENIZER
.iter()
.find(|(model_prefix, _)| model_name.starts_with(*model_prefix))
{
return Some(tokenizer.1);
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_tokenizer() {
assert_eq!(get_tokenizer("gpt-4-32k-0314"), Some(Tokenizer::Cl100kBase));
assert_eq!(get_tokenizer("gpt-3.5-turbo"), Some(Tokenizer::Cl100kBase));
assert_eq!(
get_tokenizer("gpt-3.5-turbo-0301"),
Some(Tokenizer::Cl100kBase)
);
assert_eq!(get_tokenizer("text-davinci-003"), Some(Tokenizer::P50kBase));
assert_eq!(
get_tokenizer("code-search-ada-code-001"),
Some(Tokenizer::R50kBase)
);
assert_eq!(get_tokenizer("foo"), None);
}
}