token_count/tokenizers/
openai.rs1use crate::tokenizers::{ModelInfo, Tokenizer};
4use anyhow::{Context, Result};
5use tiktoken_rs::CoreBPE;
6
7pub struct OpenAITokenizer {
9 bpe: CoreBPE,
10 model_info: ModelInfo,
11}
12
13impl OpenAITokenizer {
14 pub fn new(encoding_name: &str, model_info: ModelInfo) -> Result<Self> {
16 let tokenizer_enum = match encoding_name {
17 "o200k_base" => tiktoken_rs::tokenizer::Tokenizer::O200kBase,
18 "cl100k_base" => tiktoken_rs::tokenizer::Tokenizer::Cl100kBase,
19 "p50k_base" => tiktoken_rs::tokenizer::Tokenizer::P50kBase,
20 "r50k_base" => tiktoken_rs::tokenizer::Tokenizer::R50kBase,
21 "gpt2" => tiktoken_rs::tokenizer::Tokenizer::Gpt2,
22 _ => anyhow::bail!("Unsupported encoding: {}", encoding_name),
23 };
24
25 let bpe = tiktoken_rs::get_bpe_from_tokenizer(tokenizer_enum)
26 .with_context(|| format!("Failed to load encoding: {}", encoding_name))?;
27
28 Ok(Self { bpe, model_info })
29 }
30}
31
32impl Tokenizer for OpenAITokenizer {
33 fn count_tokens(&self, text: &str) -> Result<usize> {
34 let tokens = self.bpe.encode_with_special_tokens(text);
35 Ok(tokens.len())
36 }
37
38 fn get_model_info(&self) -> ModelInfo {
39 self.model_info.clone()
40 }
41}
42
43#[cfg(test)]
44mod tests {
45 use super::*;
46
47 #[test]
48 fn test_basic_tokenization() {
49 let model_info = ModelInfo {
50 name: "gpt-4".to_string(),
51 encoding: "cl100k_base".to_string(),
52 context_window: 128000,
53 description: "GPT-4 model".to_string(),
54 };
55
56 let tokenizer = OpenAITokenizer::new("cl100k_base", model_info).unwrap();
57 let count = tokenizer.count_tokens("Hello world").unwrap();
58 assert_eq!(count, 2);
59 }
60
61 #[test]
62 fn test_empty_string() {
63 let model_info = ModelInfo {
64 name: "gpt-4".to_string(),
65 encoding: "cl100k_base".to_string(),
66 context_window: 128000,
67 description: "GPT-4 model".to_string(),
68 };
69
70 let tokenizer = OpenAITokenizer::new("cl100k_base", model_info).unwrap();
71 let count = tokenizer.count_tokens("").unwrap();
72 assert_eq!(count, 0);
73 }
74}