llm_tokenizer/
tiktoken.rs

1use anyhow::{Error, Result};
2use tiktoken_rs::{cl100k_base, p50k_base, p50k_edit, r50k_base, CoreBPE};
3
4use crate::traits::{
5    Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait,
6};
7
8/// Tiktoken tokenizer wrapper for OpenAI GPT models
9pub(crate) struct TiktokenTokenizer {
10    tokenizer: CoreBPE,
11    #[allow(dead_code)]
12    model: TiktokenModel,
13    special_tokens: SpecialTokens,
14    vocab_size: usize,
15}
16
17/// Supported Tiktoken models
18#[derive(Debug, Clone, Copy)]
19pub enum TiktokenModel {
20    /// GPT-4, GPT-3.5-turbo, text-embedding-ada-002
21    Cl100kBase,
22    /// Codex models, text-davinci-002, text-davinci-003
23    P50kBase,
24    /// Use for edit models like text-davinci-edit-001, code-davinci-edit-001
25    P50kEdit,
26    /// GPT-3 models like davinci
27    R50kBase,
28}
29
30impl TiktokenTokenizer {
31    /// Create a new Tiktoken tokenizer for the specified model
32    pub fn new(model: TiktokenModel) -> Result<Self> {
33        let tokenizer =
34            match model {
35                TiktokenModel::Cl100kBase => cl100k_base()
36                    .map_err(|e| Error::msg(format!("Failed to load cl100k_base: {}", e)))?,
37                TiktokenModel::P50kBase => p50k_base()
38                    .map_err(|e| Error::msg(format!("Failed to load p50k_base: {}", e)))?,
39                TiktokenModel::P50kEdit => p50k_edit()
40                    .map_err(|e| Error::msg(format!("Failed to load p50k_edit: {}", e)))?,
41                TiktokenModel::R50kBase => r50k_base()
42                    .map_err(|e| Error::msg(format!("Failed to load r50k_base: {}", e)))?,
43            };
44
45        // Extract special tokens (tiktoken-rs doesn't expose them directly)
46        // We'll use common ones for GPT models
47        let special_tokens = Self::get_special_tokens_for_model(model);
48
49        // Get vocabulary size (this is an approximation)
50        let vocab_size = match model {
51            TiktokenModel::Cl100kBase => 100256, // cl100k has ~100k tokens
52            TiktokenModel::P50kBase | TiktokenModel::P50kEdit => 50281, // p50k has ~50k tokens
53            TiktokenModel::R50kBase => 50257,    // r50k has ~50k tokens
54        };
55
56        Ok(TiktokenTokenizer {
57            tokenizer,
58            model,
59            special_tokens,
60            vocab_size,
61        })
62    }
63
64    /// Create a tokenizer from a model string (e.g., "gpt-4", "gpt-3.5-turbo")
65    pub fn from_model_name(model_name: &str) -> Result<Self> {
66        let model = Self::model_from_name(model_name)?;
67        Self::new(model)
68    }
69
70    /// Determine the appropriate model from a model name
71    fn model_from_name(model_name: &str) -> Result<TiktokenModel> {
72        // Based on OpenAI's model-to-encoding mapping
73        if model_name.contains("gpt-4")
74            || model_name.contains("gpt-3.5")
75            || model_name.contains("turbo")
76        {
77            Ok(TiktokenModel::Cl100kBase)
78        } else if model_name.contains("davinci-002")
79            || model_name.contains("davinci-003")
80            || model_name.contains("codex")
81        {
82            Ok(TiktokenModel::P50kBase)
83        } else if model_name.contains("edit") {
84            Ok(TiktokenModel::P50kEdit)
85        } else if model_name.contains("davinci")
86            || model_name.contains("curie")
87            || model_name.contains("babbage")
88            || model_name.contains("ada")
89        {
90            Ok(TiktokenModel::R50kBase)
91        } else {
92            // Return an error for unrecognized model names to prevent silent failures
93            Err(anyhow::anyhow!(
94                "Unrecognized OpenAI model name: '{}'. Expected GPT-3, GPT-3.5, GPT-4, or related model names",
95                model_name
96            ))
97        }
98    }
99
100    /// Get special tokens for a specific model
101    fn get_special_tokens_for_model(model: TiktokenModel) -> SpecialTokens {
102        // These are common special tokens for GPT models
103        // The actual token IDs might vary by model
104        match model {
105            TiktokenModel::Cl100kBase => SpecialTokens {
106                bos_token: Some("<|endoftext|>".to_string()),
107                eos_token: Some("<|endoftext|>".to_string()),
108                unk_token: None,
109                sep_token: None,
110                pad_token: Some("<|endoftext|>".to_string()),
111                cls_token: None,
112                mask_token: None,
113                additional_special_tokens: vec![
114                    "<|fim_prefix|>".to_string(),
115                    "<|fim_middle|>".to_string(),
116                    "<|fim_suffix|>".to_string(),
117                    "<|endofprompt|>".to_string(),
118                ],
119            },
120            _ => SpecialTokens {
121                bos_token: Some("<|endoftext|>".to_string()),
122                eos_token: Some("<|endoftext|>".to_string()),
123                unk_token: None,
124                sep_token: None,
125                pad_token: Some("<|endoftext|>".to_string()),
126                cls_token: None,
127                mask_token: None,
128                additional_special_tokens: vec![],
129            },
130        }
131    }
132}
133
134impl Encoder for TiktokenTokenizer {
135    fn encode(&self, input: &str, _add_special_tokens: bool) -> Result<Encoding> {
136        // tiktoken uses encode_ordinary which doesn't add special tokens
137        let tokens = self.tokenizer.encode_ordinary(input);
138        Ok(Encoding::Tiktoken(tokens))
139    }
140
141    fn encode_batch(&self, inputs: &[&str], add_special_tokens: bool) -> Result<Vec<Encoding>> {
142        inputs
143            .iter()
144            .map(|input| self.encode(input, add_special_tokens))
145            .collect()
146    }
147}
148
149impl Decoder for TiktokenTokenizer {
150    fn decode(&self, token_ids: &[TokenIdType], _skip_special_tokens: bool) -> Result<String> {
151        // tiktoken-rs 0.7.0 now uses u32 (Rank type)
152        self.tokenizer
153            .decode(token_ids.to_vec())
154            .map_err(|e| Error::msg(format!("Decoding failed: {}", e)))
155    }
156}
157
158impl TokenizerTrait for TiktokenTokenizer {
159    fn vocab_size(&self) -> usize {
160        self.vocab_size
161    }
162
163    fn get_special_tokens(&self) -> &SpecialTokens {
164        &self.special_tokens
165    }
166
167    fn token_to_id(&self, _token: &str) -> Option<TokenIdType> {
168        // Tiktoken doesn't provide direct token-to-id mapping
169        // We'd need to encode the token and check if it produces a single ID
170        None
171    }
172
173    fn id_to_token(&self, _id: TokenIdType) -> Option<String> {
174        // Tiktoken doesn't provide direct id-to-token mapping
175        // We can only decode IDs to text
176        None
177    }
178
179    fn as_any(&self) -> &dyn std::any::Any {
180        self
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use super::{TiktokenModel, TiktokenTokenizer};
187    use crate::traits::{Decoder, Encoder, Tokenizer};
188
189    #[test]
190    fn test_tiktoken_creation() {
191        let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
192        assert_eq!(tokenizer.vocab_size(), 100256);
193    }
194
195    #[test]
196    fn test_model_from_name() {
197        assert!(matches!(
198            TiktokenTokenizer::model_from_name("gpt-4").unwrap(),
199            TiktokenModel::Cl100kBase
200        ));
201        assert!(matches!(
202            TiktokenTokenizer::model_from_name("gpt-3.5-turbo").unwrap(),
203            TiktokenModel::Cl100kBase
204        ));
205        assert!(matches!(
206            TiktokenTokenizer::model_from_name("text-davinci-003").unwrap(),
207            TiktokenModel::P50kBase
208        ));
209        assert!(matches!(
210            TiktokenTokenizer::model_from_name("text-davinci-edit-001").unwrap(),
211            TiktokenModel::P50kEdit
212        ));
213        assert!(matches!(
214            TiktokenTokenizer::model_from_name("davinci").unwrap(),
215            TiktokenModel::R50kBase
216        ));
217    }
218
219    #[test]
220    fn test_encode_decode() {
221        let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
222
223        let text = "Hello, world!";
224        let encoding = tokenizer.encode(text, false).unwrap();
225
226        let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
227        assert_eq!(decoded, text);
228    }
229
230    #[test]
231    fn test_batch_encode() {
232        let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
233
234        let texts = vec!["Hello", "World", "Test"];
235        let encodings = tokenizer.encode_batch(&texts, false).unwrap();
236
237        assert_eq!(encodings.len(), 3);
238        for (i, encoding) in encodings.iter().enumerate() {
239            let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
240            assert_eq!(decoded, texts[i]);
241        }
242    }
243
244    #[test]
245    fn test_special_tokens() {
246        let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
247        let special_tokens = tokenizer.get_special_tokens();
248
249        assert!(special_tokens.eos_token.is_some());
250        assert_eq!(special_tokens.eos_token.as_ref().unwrap(), "<|endoftext|>");
251    }
252
253    #[test]
254    fn test_unrecognized_model_name_returns_error() {
255        let result = TiktokenTokenizer::from_model_name("distilgpt-2");
256        assert!(result.is_err());
257        if let Err(e) = result {
258            assert!(e.to_string().contains("Unrecognized OpenAI model name"));
259        }
260
261        let result = TiktokenTokenizer::from_model_name("bert-base-uncased");
262        assert!(result.is_err());
263        if let Err(e) = result {
264            assert!(e.to_string().contains("Unrecognized OpenAI model name"));
265        }
266
267        let result = TiktokenTokenizer::from_model_name("llama-7b");
268        assert!(result.is_err());
269        if let Err(e) = result {
270            assert!(e.to_string().contains("Unrecognized OpenAI model name"));
271        }
272    }
273
274    #[test]
275    fn test_recognized_model_names() {
276        assert!(TiktokenTokenizer::from_model_name("gpt-4").is_ok());
277        assert!(TiktokenTokenizer::from_model_name("gpt-3.5-turbo").is_ok());
278        assert!(TiktokenTokenizer::from_model_name("text-davinci-003").is_ok());
279        assert!(TiktokenTokenizer::from_model_name("code-davinci-002").is_ok());
280        assert!(TiktokenTokenizer::from_model_name("text-curie-001").is_ok());
281        assert!(TiktokenTokenizer::from_model_name("text-babbage-001").is_ok());
282        assert!(TiktokenTokenizer::from_model_name("text-ada-001").is_ok());
283    }
284}