use tiktoken_rs::{cl100k_base, o200k_base, CoreBPE};
pub struct Tokenizer {
encoder: CoreBPE,
model: String,
}
impl Tokenizer {
pub fn for_model(model: &str) -> Self {
let encoder = match model {
m if m.contains("gpt-4o") || m.contains("o1") || m.contains("o3") => {
o200k_base().expect("Failed to load o200k tokenizer")
}
m if m.contains("gpt-4") || m.contains("gpt-3.5") => {
cl100k_base().expect("Failed to load cl100k tokenizer")
}
m if m.contains("text-embedding") => {
cl100k_base().expect("Failed to load cl100k tokenizer")
}
_ => cl100k_base().expect("Failed to load cl100k tokenizer"),
};
Self {
encoder,
model: model.to_string(),
}
}
pub fn default_tokenizer() -> Self {
Self {
encoder: cl100k_base().expect("Failed to load cl100k tokenizer"),
model: "default".to_string(),
}
}
pub fn count_tokens(&self, text: &str) -> usize {
self.encoder.encode_with_special_tokens(text).len()
}
pub fn encode(&self, text: &str) -> Vec<u32> {
self.encoder.encode_with_special_tokens(text)
}
pub fn decode(&self, tokens: &[u32]) -> String {
self.encoder.decode(tokens.to_vec()).unwrap_or_default()
}
pub fn truncate(&self, text: &str, max_tokens: usize) -> String {
let tokens = self.encode(text);
if tokens.len() <= max_tokens {
return text.to_string();
}
self.decode(&tokens[..max_tokens])
}
pub fn chunk(&self, text: &str, max_tokens: usize, overlap_tokens: usize) -> Vec<String> {
let tokens = self.encode(text);
let mut chunks = Vec::new();
let mut start = 0;
while start < tokens.len() {
let end = (start + max_tokens).min(tokens.len());
let chunk_tokens = &tokens[start..end];
chunks.push(self.decode(chunk_tokens));
if end >= tokens.len() {
break;
}
start = end.saturating_sub(overlap_tokens);
}
chunks
}
pub fn model(&self) -> &str {
&self.model
}
}
impl Default for Tokenizer {
fn default() -> Self {
Self::default_tokenizer()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_token_counting() {
let tokenizer = Tokenizer::default_tokenizer();
let text = "Hello, world!";
let count = tokenizer.count_tokens(text);
assert!(count > 0);
assert!(count < text.len()); }
#[test]
fn test_encode_decode_roundtrip() {
let tokenizer = Tokenizer::default_tokenizer();
let text = "This is a test sentence.";
let tokens = tokenizer.encode(text);
let decoded = tokenizer.decode(&tokens);
assert_eq!(decoded, text);
}
#[test]
fn test_truncate() {
let tokenizer = Tokenizer::default_tokenizer();
let text = "This is a longer sentence that should be truncated.";
let truncated = tokenizer.truncate(text, 5);
let token_count = tokenizer.count_tokens(&truncated);
assert!(token_count <= 5);
}
#[test]
fn test_chunking() {
let tokenizer = Tokenizer::default_tokenizer();
let text = "One two three four five six seven eight nine ten.";
let chunks = tokenizer.chunk(text, 3, 1);
assert!(chunks.len() > 1);
}
#[test]
fn test_model_specific_tokenizer() {
let gpt4 = Tokenizer::for_model("gpt-4");
let gpt4o = Tokenizer::for_model("gpt-4o");
let text = "Hello, world!";
assert!(gpt4.count_tokens(text) > 0);
assert!(gpt4o.count_tokens(text) > 0);
}
#[test]
fn test_for_model_gpt35() {
let t = Tokenizer::for_model("gpt-3.5-turbo");
assert_eq!(t.model(), "gpt-3.5-turbo");
assert!(t.count_tokens("Hello") > 0);
}
#[test]
fn test_for_model_o1() {
let t = Tokenizer::for_model("o1-mini");
assert_eq!(t.model(), "o1-mini");
assert!(t.count_tokens("Hello") > 0);
}
#[test]
fn test_for_model_o3() {
let t = Tokenizer::for_model("o3-mini");
assert_eq!(t.model(), "o3-mini");
assert!(t.count_tokens("Hello") > 0);
}
#[test]
fn test_for_model_embedding() {
let t = Tokenizer::for_model("text-embedding-ada-002");
assert_eq!(t.model(), "text-embedding-ada-002");
assert!(t.count_tokens("Hello") > 0);
}
#[test]
fn test_for_model_unknown_falls_back() {
let t = Tokenizer::for_model("some-unknown-model");
assert_eq!(t.model(), "some-unknown-model");
assert!(t.count_tokens("Hello") > 0);
}
#[test]
fn test_default_impl() {
let t = Tokenizer::default();
assert_eq!(t.model(), "default");
assert!(t.count_tokens("Hello") > 0);
}
#[test]
fn test_truncate_within_limit() {
let tokenizer = Tokenizer::default_tokenizer();
let text = "Hello";
let truncated = tokenizer.truncate(text, 100);
assert_eq!(truncated, text);
}
#[test]
fn test_chunk_within_limit() {
let tokenizer = Tokenizer::default_tokenizer();
let text = "Short";
let chunks = tokenizer.chunk(text, 100, 0);
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0], text);
}
#[test]
fn test_chunk_no_overlap() {
let tokenizer = Tokenizer::default_tokenizer();
let text = "One two three four five six seven eight nine ten eleven twelve.";
let chunks = tokenizer.chunk(text, 3, 0);
assert!(chunks.len() > 1);
}
#[test]
fn test_model_accessor() {
let tokenizer = Tokenizer::for_model("gpt-4o-mini");
assert_eq!(tokenizer.model(), "gpt-4o-mini");
}
#[test]
fn test_empty_string() {
let tokenizer = Tokenizer::default_tokenizer();
assert_eq!(tokenizer.count_tokens(""), 0);
assert!(tokenizer.encode("").is_empty());
}
#[test]
fn test_decode_empty() {
let tokenizer = Tokenizer::default_tokenizer();
let decoded = tokenizer.decode(&[]);
assert_eq!(decoded, "");
}
}