1use anyhow::{Error, Result};
2use tiktoken_rs::{cl100k_base, p50k_base, p50k_edit, r50k_base, CoreBPE};
3
4use crate::traits::{
5 Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait,
6};
7
8pub(crate) struct TiktokenTokenizer {
10 tokenizer: CoreBPE,
11 #[allow(dead_code)]
12 model: TiktokenModel,
13 special_tokens: SpecialTokens,
14 vocab_size: usize,
15}
16
17#[derive(Debug, Clone, Copy)]
19pub enum TiktokenModel {
20 Cl100kBase,
22 P50kBase,
24 P50kEdit,
26 R50kBase,
28}
29
30impl TiktokenTokenizer {
31 pub fn new(model: TiktokenModel) -> Result<Self> {
33 let tokenizer =
34 match model {
35 TiktokenModel::Cl100kBase => cl100k_base()
36 .map_err(|e| Error::msg(format!("Failed to load cl100k_base: {}", e)))?,
37 TiktokenModel::P50kBase => p50k_base()
38 .map_err(|e| Error::msg(format!("Failed to load p50k_base: {}", e)))?,
39 TiktokenModel::P50kEdit => p50k_edit()
40 .map_err(|e| Error::msg(format!("Failed to load p50k_edit: {}", e)))?,
41 TiktokenModel::R50kBase => r50k_base()
42 .map_err(|e| Error::msg(format!("Failed to load r50k_base: {}", e)))?,
43 };
44
45 let special_tokens = Self::get_special_tokens_for_model(model);
48
49 let vocab_size = match model {
51 TiktokenModel::Cl100kBase => 100256, TiktokenModel::P50kBase | TiktokenModel::P50kEdit => 50281, TiktokenModel::R50kBase => 50257, };
55
56 Ok(TiktokenTokenizer {
57 tokenizer,
58 model,
59 special_tokens,
60 vocab_size,
61 })
62 }
63
64 pub fn from_model_name(model_name: &str) -> Result<Self> {
66 let model = Self::model_from_name(model_name)?;
67 Self::new(model)
68 }
69
70 fn model_from_name(model_name: &str) -> Result<TiktokenModel> {
72 if model_name.contains("gpt-4")
74 || model_name.contains("gpt-3.5")
75 || model_name.contains("turbo")
76 {
77 Ok(TiktokenModel::Cl100kBase)
78 } else if model_name.contains("davinci-002")
79 || model_name.contains("davinci-003")
80 || model_name.contains("codex")
81 {
82 Ok(TiktokenModel::P50kBase)
83 } else if model_name.contains("edit") {
84 Ok(TiktokenModel::P50kEdit)
85 } else if model_name.contains("davinci")
86 || model_name.contains("curie")
87 || model_name.contains("babbage")
88 || model_name.contains("ada")
89 {
90 Ok(TiktokenModel::R50kBase)
91 } else {
92 Err(anyhow::anyhow!(
94 "Unrecognized OpenAI model name: '{}'. Expected GPT-3, GPT-3.5, GPT-4, or related model names",
95 model_name
96 ))
97 }
98 }
99
100 fn get_special_tokens_for_model(model: TiktokenModel) -> SpecialTokens {
102 match model {
105 TiktokenModel::Cl100kBase => SpecialTokens {
106 bos_token: Some("<|endoftext|>".to_string()),
107 eos_token: Some("<|endoftext|>".to_string()),
108 unk_token: None,
109 sep_token: None,
110 pad_token: Some("<|endoftext|>".to_string()),
111 cls_token: None,
112 mask_token: None,
113 additional_special_tokens: vec![
114 "<|fim_prefix|>".to_string(),
115 "<|fim_middle|>".to_string(),
116 "<|fim_suffix|>".to_string(),
117 "<|endofprompt|>".to_string(),
118 ],
119 },
120 _ => SpecialTokens {
121 bos_token: Some("<|endoftext|>".to_string()),
122 eos_token: Some("<|endoftext|>".to_string()),
123 unk_token: None,
124 sep_token: None,
125 pad_token: Some("<|endoftext|>".to_string()),
126 cls_token: None,
127 mask_token: None,
128 additional_special_tokens: vec![],
129 },
130 }
131 }
132}
133
134impl Encoder for TiktokenTokenizer {
135 fn encode(&self, input: &str, _add_special_tokens: bool) -> Result<Encoding> {
136 let tokens = self.tokenizer.encode_ordinary(input);
138 Ok(Encoding::Tiktoken(tokens))
139 }
140
141 fn encode_batch(&self, inputs: &[&str], add_special_tokens: bool) -> Result<Vec<Encoding>> {
142 inputs
143 .iter()
144 .map(|input| self.encode(input, add_special_tokens))
145 .collect()
146 }
147}
148
149impl Decoder for TiktokenTokenizer {
150 fn decode(&self, token_ids: &[TokenIdType], _skip_special_tokens: bool) -> Result<String> {
151 self.tokenizer
153 .decode(token_ids.to_vec())
154 .map_err(|e| Error::msg(format!("Decoding failed: {}", e)))
155 }
156}
157
158impl TokenizerTrait for TiktokenTokenizer {
159 fn vocab_size(&self) -> usize {
160 self.vocab_size
161 }
162
163 fn get_special_tokens(&self) -> &SpecialTokens {
164 &self.special_tokens
165 }
166
167 fn token_to_id(&self, _token: &str) -> Option<TokenIdType> {
168 None
171 }
172
173 fn id_to_token(&self, _id: TokenIdType) -> Option<String> {
174 None
177 }
178
179 fn as_any(&self) -> &dyn std::any::Any {
180 self
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use super::{TiktokenModel, TiktokenTokenizer};
187 use crate::traits::{Decoder, Encoder, Tokenizer};
188
189 #[test]
190 fn test_tiktoken_creation() {
191 let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
192 assert_eq!(tokenizer.vocab_size(), 100256);
193 }
194
195 #[test]
196 fn test_model_from_name() {
197 assert!(matches!(
198 TiktokenTokenizer::model_from_name("gpt-4").unwrap(),
199 TiktokenModel::Cl100kBase
200 ));
201 assert!(matches!(
202 TiktokenTokenizer::model_from_name("gpt-3.5-turbo").unwrap(),
203 TiktokenModel::Cl100kBase
204 ));
205 assert!(matches!(
206 TiktokenTokenizer::model_from_name("text-davinci-003").unwrap(),
207 TiktokenModel::P50kBase
208 ));
209 assert!(matches!(
210 TiktokenTokenizer::model_from_name("text-davinci-edit-001").unwrap(),
211 TiktokenModel::P50kEdit
212 ));
213 assert!(matches!(
214 TiktokenTokenizer::model_from_name("davinci").unwrap(),
215 TiktokenModel::R50kBase
216 ));
217 }
218
219 #[test]
220 fn test_encode_decode() {
221 let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
222
223 let text = "Hello, world!";
224 let encoding = tokenizer.encode(text, false).unwrap();
225
226 let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
227 assert_eq!(decoded, text);
228 }
229
230 #[test]
231 fn test_batch_encode() {
232 let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
233
234 let texts = vec!["Hello", "World", "Test"];
235 let encodings = tokenizer.encode_batch(&texts, false).unwrap();
236
237 assert_eq!(encodings.len(), 3);
238 for (i, encoding) in encodings.iter().enumerate() {
239 let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
240 assert_eq!(decoded, texts[i]);
241 }
242 }
243
244 #[test]
245 fn test_special_tokens() {
246 let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
247 let special_tokens = tokenizer.get_special_tokens();
248
249 assert!(special_tokens.eos_token.is_some());
250 assert_eq!(special_tokens.eos_token.as_ref().unwrap(), "<|endoftext|>");
251 }
252
253 #[test]
254 fn test_unrecognized_model_name_returns_error() {
255 let result = TiktokenTokenizer::from_model_name("distilgpt-2");
256 assert!(result.is_err());
257 if let Err(e) = result {
258 assert!(e.to_string().contains("Unrecognized OpenAI model name"));
259 }
260
261 let result = TiktokenTokenizer::from_model_name("bert-base-uncased");
262 assert!(result.is_err());
263 if let Err(e) = result {
264 assert!(e.to_string().contains("Unrecognized OpenAI model name"));
265 }
266
267 let result = TiktokenTokenizer::from_model_name("llama-7b");
268 assert!(result.is_err());
269 if let Err(e) = result {
270 assert!(e.to_string().contains("Unrecognized OpenAI model name"));
271 }
272 }
273
274 #[test]
275 fn test_recognized_model_names() {
276 assert!(TiktokenTokenizer::from_model_name("gpt-4").is_ok());
277 assert!(TiktokenTokenizer::from_model_name("gpt-3.5-turbo").is_ok());
278 assert!(TiktokenTokenizer::from_model_name("text-davinci-003").is_ok());
279 assert!(TiktokenTokenizer::from_model_name("code-davinci-002").is_ok());
280 assert!(TiktokenTokenizer::from_model_name("text-curie-001").is_ok());
281 assert!(TiktokenTokenizer::from_model_name("text-babbage-001").is_ok());
282 assert!(TiktokenTokenizer::from_model_name("text-ada-001").is_ok());
283 }
284}