Skip to main content

token_count/tokenizers/
openai.rs

1//! OpenAI tokenization using tiktoken-rs
2
3use crate::tokenizers::{ModelInfo, TokenDetail, Tokenizer};
4use anyhow::{Context, Result};
5use tiktoken_rs::CoreBPE;
6
7/// OpenAI tokenizer using tiktoken-rs
8pub struct OpenAITokenizer {
9    bpe: CoreBPE,
10    model_info: ModelInfo,
11}
12
13impl OpenAITokenizer {
14    /// Create a new OpenAI tokenizer for the given encoding
15    pub fn new(encoding_name: &str, model_info: ModelInfo) -> Result<Self> {
16        let tokenizer_enum = match encoding_name {
17            "o200k_base" => tiktoken_rs::tokenizer::Tokenizer::O200kBase,
18            "cl100k_base" => tiktoken_rs::tokenizer::Tokenizer::Cl100kBase,
19            "p50k_base" => tiktoken_rs::tokenizer::Tokenizer::P50kBase,
20            "r50k_base" => tiktoken_rs::tokenizer::Tokenizer::R50kBase,
21            "gpt2" => tiktoken_rs::tokenizer::Tokenizer::Gpt2,
22            _ => anyhow::bail!("Unsupported encoding: {}", encoding_name),
23        };
24
25        let bpe = tiktoken_rs::get_bpe_from_tokenizer(tokenizer_enum)
26            .with_context(|| format!("Failed to load encoding: {}", encoding_name))?;
27
28        Ok(Self { bpe, model_info })
29    }
30}
31
32impl Tokenizer for OpenAITokenizer {
33    fn count_tokens(&self, text: &str) -> Result<usize> {
34        let tokens = self.bpe.encode_with_special_tokens(text);
35        Ok(tokens.len())
36    }
37
38    fn get_model_info(&self) -> ModelInfo {
39        self.model_info.clone()
40    }
41
42    fn encode_with_details(&self, text: &str) -> Result<Option<Vec<TokenDetail>>> {
43        // Skip detailed tokenization for inputs >50KB to prevent stack overflow
44        // tiktoken-rs has known recursion depth issues with large inputs
45        // See: https://github.com/zurawiki/tiktoken-rs/issues/327
46        const MAX_DEBUG_INPUT_SIZE: usize = 50 * 1024; // 50KB safety limit
47
48        if text.len() > MAX_DEBUG_INPUT_SIZE {
49            eprintln!(
50                "Warning: Input size ({} bytes) exceeds debug mode limit ({} bytes). \
51                 Showing token count only. For token IDs, provide smaller input.",
52                text.len(),
53                MAX_DEBUG_INPUT_SIZE
54            );
55            return Ok(None);
56        }
57
58        let token_ids = self.bpe.encode_with_special_tokens(text);
59
60        // Limit to first 10 tokens to avoid overwhelming output
61        let mut details = Vec::new();
62        for token_id in token_ids.iter().take(10) {
63            // Decode individual token
64            let decoded = self.bpe.decode(vec![*token_id])?;
65            details.push(TokenDetail { id: *token_id, text: decoded });
66        }
67
68        Ok(Some(details))
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use super::*;
75
76    #[test]
77    fn test_basic_tokenization() {
78        let model_info = ModelInfo {
79            name: "gpt-4".to_string(),
80            encoding: "cl100k_base".to_string(),
81            context_window: 128000,
82            description: "GPT-4 model".to_string(),
83        };
84
85        let tokenizer = OpenAITokenizer::new("cl100k_base", model_info).unwrap();
86        let count = tokenizer.count_tokens("Hello world").unwrap();
87        assert_eq!(count, 2);
88    }
89
90    #[test]
91    fn test_empty_string() {
92        let model_info = ModelInfo {
93            name: "gpt-4".to_string(),
94            encoding: "cl100k_base".to_string(),
95            context_window: 128000,
96            description: "GPT-4 model".to_string(),
97        };
98
99        let tokenizer = OpenAITokenizer::new("cl100k_base", model_info).unwrap();
100        let count = tokenizer.count_tokens("").unwrap();
101        assert_eq!(count, 0);
102    }
103
104    #[test]
105    fn test_encode_with_details_large_input() {
106        let model_info = ModelInfo {
107            name: "gpt-4".to_string(),
108            encoding: "cl100k_base".to_string(),
109            context_window: 128000,
110            description: "GPT-4 model".to_string(),
111        };
112
113        let tokenizer = OpenAITokenizer::new("cl100k_base", model_info).unwrap();
114
115        // Create a 60KB input (exceeds 50KB limit)
116        let large_input = "a".repeat(60 * 1024);
117        let result = tokenizer.encode_with_details(&large_input);
118
119        // Should return None gracefully, not panic
120        assert!(result.is_ok());
121        assert_eq!(result.unwrap(), None);
122    }
123
124    #[test]
125    fn test_encode_with_details_normal_input() {
126        let model_info = ModelInfo {
127            name: "gpt-4".to_string(),
128            encoding: "cl100k_base".to_string(),
129            context_window: 128000,
130            description: "GPT-4 model".to_string(),
131        };
132
133        let tokenizer = OpenAITokenizer::new("cl100k_base", model_info).unwrap();
134
135        // Normal input should work fine
136        let result = tokenizer.encode_with_details("Hello world");
137        assert!(result.is_ok());
138
139        let details = result.unwrap();
140        assert!(details.is_some());
141
142        let details = details.unwrap();
143        assert_eq!(details.len(), 2); // "Hello" + " world"
144        assert_eq!(details[0].id, 9906); // "Hello"
145        assert_eq!(details[1].id, 1917); // " world"
146    }
147}