#![deny(unsafe_code)]
#![warn(missing_docs)]
#![warn(rust_2018_idioms)]
use rayon::prelude::*;
use thiserror::Error;
use tiktoken_rs::CoreBPE;
pub type Result<T> = std::result::Result<T, TokenizerError>;
#[derive(Error, Debug)]
pub enum TokenizerError {
#[error("unknown encoding: {0} (expected cl100k_base or o200k_base)")]
UnknownEncoding(String),
#[error("tiktoken-rs error: {0}")]
Tiktoken(String),
}
pub struct Tokenizer {
bpe: CoreBPE,
encoding_name: String,
}
impl Tokenizer {
pub fn for_model(model: &str) -> Result<Self> {
match tiktoken_rs::get_bpe_from_model(model) {
Ok(bpe) => Ok(Self {
bpe,
encoding_name: encoding_for_model(model).to_string(),
}),
Err(_) => {
let encoding = encoding_for_model(model);
Self::for_encoding(encoding)
}
}
}
pub fn for_encoding(name: &str) -> Result<Self> {
let bpe =
match name {
"cl100k_base" => tiktoken_rs::cl100k_base()
.map_err(|e| TokenizerError::Tiktoken(e.to_string()))?,
"o200k_base" => tiktoken_rs::o200k_base()
.map_err(|e| TokenizerError::Tiktoken(e.to_string()))?,
other => return Err(TokenizerError::UnknownEncoding(other.to_string())),
};
Ok(Self {
bpe,
encoding_name: name.to_string(),
})
}
pub fn encoding_name(&self) -> &str {
&self.encoding_name
}
pub fn count(&self, text: &str) -> usize {
self.bpe.encode_ordinary(text).len()
}
pub fn count_many(&self, texts: &[&str], parallel: bool) -> Vec<usize> {
if parallel {
texts
.par_iter()
.map(|t| self.bpe.encode_ordinary(t).len())
.collect()
} else {
texts
.iter()
.map(|t| self.bpe.encode_ordinary(t).len())
.collect()
}
}
pub fn encode(&self, text: &str) -> Vec<u32> {
self.bpe.encode_ordinary(text)
}
pub fn decode(&self, tokens: &[u32]) -> Result<String> {
self.bpe
.decode(tokens.to_vec())
.map_err(|e| TokenizerError::Tiktoken(e.to_string()))
}
pub fn fits(&self, text: &str, budget: usize) -> bool {
self.count(text) <= budget
}
pub fn truncate_to(&self, text: &str, budget: usize) -> Result<String> {
let mut tokens = self.bpe.encode_ordinary(text);
if tokens.len() <= budget {
return Ok(text.to_string());
}
tokens.truncate(budget);
self.bpe
.decode(tokens)
.map_err(|e| TokenizerError::Tiktoken(e.to_string()))
}
}
fn encoding_for_model(model: &str) -> &'static str {
if model.starts_with("gpt-4o")
|| model.starts_with("gpt-5")
|| model.starts_with("o1")
|| model.starts_with("o3")
|| model.starts_with("o4")
|| model.starts_with("chatgpt-4o")
{
"o200k_base"
} else {
"cl100k_base"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn round_trip_simple_text() {
let tok = Tokenizer::for_encoding("cl100k_base").unwrap();
let text = "hello world";
let toks = tok.encode(text);
let decoded = tok.decode(&toks).unwrap();
assert_eq!(decoded, text);
}
#[test]
fn count_matches_encode_len() {
let tok = Tokenizer::for_encoding("cl100k_base").unwrap();
let text = "the quick brown fox jumps over the lazy dog";
assert_eq!(tok.count(text), tok.encode(text).len());
}
#[test]
fn count_many_serial_and_parallel_agree() {
let tok = Tokenizer::for_encoding("cl100k_base").unwrap();
let texts: Vec<&str> = vec!["hi", "world", "lorem ipsum dolor sit amet"];
let serial = tok.count_many(&texts, false);
let par = tok.count_many(&texts, true);
assert_eq!(serial, par);
}
#[test]
fn for_model_gpt4_is_cl100k() {
let tok = Tokenizer::for_model("gpt-4").unwrap();
assert_eq!(tok.encoding_name(), "cl100k_base");
}
#[test]
fn for_model_gpt5_is_o200k() {
let tok = Tokenizer::for_model("gpt-5").unwrap();
assert_eq!(tok.encoding_name(), "o200k_base");
}
#[test]
fn for_model_o3_is_o200k() {
let tok = Tokenizer::for_model("o3-mini").unwrap();
assert_eq!(tok.encoding_name(), "o200k_base");
}
#[test]
fn for_model_unknown_falls_back_to_cl100k() {
let tok = Tokenizer::for_model("future-model-7b").unwrap();
assert_eq!(tok.encoding_name(), "cl100k_base");
}
#[test]
fn for_model_gpt4o_is_o200k() {
let tok = Tokenizer::for_model("gpt-4o").unwrap();
assert_eq!(tok.encoding_name(), "o200k_base");
}
#[test]
fn unknown_encoding_rejected() {
assert!(Tokenizer::for_encoding("unknown_base").is_err());
}
#[test]
fn fits_and_truncate() {
let tok = Tokenizer::for_encoding("cl100k_base").unwrap();
let text = "the quick brown fox";
let n = tok.count(text);
assert!(tok.fits(text, n));
assert!(tok.fits(text, n + 1));
assert!(!tok.fits(text, n - 1));
let truncated = tok.truncate_to(text, 2).unwrap();
assert!(tok.count(&truncated) <= 2);
assert!(truncated.len() <= text.len());
}
#[test]
fn truncate_returns_input_when_fits() {
let tok = Tokenizer::for_encoding("cl100k_base").unwrap();
let text = "hi";
assert_eq!(tok.truncate_to(text, 100).unwrap(), text);
}
#[test]
fn empty_text_is_zero_tokens() {
let tok = Tokenizer::for_encoding("cl100k_base").unwrap();
assert_eq!(tok.count(""), 0);
assert_eq!(tok.encode(""), Vec::<u32>::new());
}
#[test]
fn unicode_text_round_trips() {
let tok = Tokenizer::for_encoding("cl100k_base").unwrap();
let text = "你好世界 🌍";
let toks = tok.encode(text);
assert_eq!(tok.decode(&toks).unwrap(), text);
}
#[test]
fn count_many_handles_empty_list() {
let tok = Tokenizer::for_encoding("cl100k_base").unwrap();
let empty: Vec<&str> = vec![];
assert!(tok.count_many(&empty, false).is_empty());
assert!(tok.count_many(&empty, true).is_empty());
}
}