1use std::collections::HashMap;
6
7use lazy_static::lazy_static;
8
9#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)]
22pub enum Tokenizer {
23 Cl100kBase,
24 P50kBase,
25 R50kBase,
26 P50kEdit,
27 Gpt2,
28}
29
30const MODEL_PREFIX_TO_TOKENIZER: &[(&str, Tokenizer)] = &[
31 ("gpt-4-", Tokenizer::Cl100kBase),
33 ("gpt-3.5-turbo-", Tokenizer::Cl100kBase),
34];
35
36const MODEL_TO_TOKENIZER: &[(&str, Tokenizer)] = &[
37 ("gpt-4-32k", Tokenizer::Cl100kBase),
39 ("gpt-4", Tokenizer::Cl100kBase),
40 ("gpt-3.5-turbo", Tokenizer::Cl100kBase),
41 ("text-davinci-003", Tokenizer::P50kBase),
43 ("text-davinci-002", Tokenizer::P50kBase),
44 ("text-davinci-001", Tokenizer::R50kBase),
45 ("text-curie-001", Tokenizer::R50kBase),
46 ("text-babbage-001", Tokenizer::R50kBase),
47 ("text-ada-001", Tokenizer::R50kBase),
48 ("davinci", Tokenizer::R50kBase),
49 ("curie", Tokenizer::R50kBase),
50 ("babbage", Tokenizer::R50kBase),
51 ("ada", Tokenizer::R50kBase),
52 ("code-davinci-002", Tokenizer::P50kBase),
54 ("code-davinci-001", Tokenizer::P50kBase),
55 ("code-cushman-002", Tokenizer::P50kBase),
56 ("code-cushman-001", Tokenizer::P50kBase),
57 ("davinci-codex", Tokenizer::P50kBase),
58 ("cushman-codex", Tokenizer::P50kBase),
59 ("text-davinci-edit-001", Tokenizer::P50kEdit),
61 ("code-davinci-edit-001", Tokenizer::P50kEdit),
62 ("text-embedding-ada-002", Tokenizer::Cl100kBase),
64 ("text-similarity-davinci-001", Tokenizer::R50kBase),
66 ("text-similarity-curie-001", Tokenizer::R50kBase),
67 ("text-similarity-babbage-001", Tokenizer::R50kBase),
68 ("text-similarity-ada-001", Tokenizer::R50kBase),
69 ("text-search-davinci-doc-001", Tokenizer::R50kBase),
70 ("text-search-curie-doc-001", Tokenizer::R50kBase),
71 ("text-search-babbage-doc-001", Tokenizer::R50kBase),
72 ("text-search-ada-doc-001", Tokenizer::R50kBase),
73 ("code-search-babbage-code-001", Tokenizer::R50kBase),
74 ("code-search-ada-code-001", Tokenizer::R50kBase),
75 ("gpt2", Tokenizer::Gpt2),
77];
78
79lazy_static! {
80 static ref MODEL_TO_TOKENIZER_MAP: HashMap<&'static str, Tokenizer> = {
81 let mut map = HashMap::new();
82 MODEL_TO_TOKENIZER.iter().for_each(|&(model, tokenizer)| {
83 map.insert(model, tokenizer);
84 });
85 map
86 };
87}
88
89pub fn get_tokenizer(model_name: &str) -> Option<Tokenizer> {
113 if let Some(tokenizer) = MODEL_TO_TOKENIZER_MAP.get(model_name) {
114 return Some(*tokenizer);
115 }
116 if let Some(tokenizer) = MODEL_PREFIX_TO_TOKENIZER
117 .iter()
118 .find(|(model_prefix, _)| model_name.starts_with(*model_prefix))
119 {
120 return Some(tokenizer.1);
121 }
122
123 None
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129
130 #[test]
131 fn test_get_tokenizer() {
132 assert_eq!(get_tokenizer("gpt-4-32k-0314"), Some(Tokenizer::Cl100kBase));
133 assert_eq!(get_tokenizer("gpt-3.5-turbo"), Some(Tokenizer::Cl100kBase));
134 assert_eq!(
135 get_tokenizer("gpt-3.5-turbo-0301"),
136 Some(Tokenizer::Cl100kBase)
137 );
138 assert_eq!(get_tokenizer("text-davinci-003"), Some(Tokenizer::P50kBase));
139 assert_eq!(
140 get_tokenizer("code-search-ada-code-001"),
141 Some(Tokenizer::R50kBase)
142 );
143 assert_eq!(get_tokenizer("foo"), None);
144 }
145}