use std::sync::OnceLock;
use serde::{Deserialize, Serialize};
use tiktoken_rs::CoreBPE;
const CHARS_PER_TOKEN: f64 = 3.5;
pub fn estimate_tokens(text: &str) -> usize {
Tokenizer::Heuristic.count(text)
}
pub fn tokens_to_chars(tokens: usize) -> usize {
(tokens as f64 * CHARS_PER_TOKEN).floor() as usize
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum Tokenizer {
#[default]
Heuristic,
Cl100kBase,
O200kBase,
}
impl Tokenizer {
pub fn count(&self, text: &str) -> usize {
if text.is_empty() {
return 0;
}
match self {
Self::Heuristic => (text.len() as f64 / CHARS_PER_TOKEN).ceil() as usize,
Self::Cl100kBase => match cl100k_bpe() {
Some(bpe) => bpe.encode_with_special_tokens(text).len(),
None => Self::Heuristic.count(text),
},
Self::O200kBase => match o200k_bpe() {
Some(bpe) => bpe.encode_with_special_tokens(text).len(),
None => Self::Heuristic.count(text),
},
}
}
pub fn as_str(&self) -> &'static str {
match self {
Self::Heuristic => "heuristic",
Self::Cl100kBase => "cl100k_base",
Self::O200kBase => "o200k_base",
}
}
pub fn from_str_lossy(s: &str) -> Self {
match s.to_ascii_lowercase().as_str() {
"cl100k_base" | "cl100k" => Self::Cl100kBase,
"o200k_base" | "o200k" => Self::O200kBase,
_ => Self::Heuristic,
}
}
}
fn cl100k_bpe() -> Option<&'static CoreBPE> {
static BPE: OnceLock<Option<CoreBPE>> = OnceLock::new();
BPE.get_or_init(|| match tiktoken_rs::cl100k_base() {
Ok(b) => Some(b),
Err(e) => {
tracing::warn!(
target: "devboy_format_pipeline::tokenizer",
"cl100k_base BPE table failed to load: {e} — \
falling back to chars/3.5 heuristic"
);
None
}
})
.as_ref()
}
fn o200k_bpe() -> Option<&'static CoreBPE> {
static BPE: OnceLock<Option<CoreBPE>> = OnceLock::new();
BPE.get_or_init(|| match tiktoken_rs::o200k_base() {
Ok(b) => Some(b),
Err(e) => {
tracing::warn!(
target: "devboy_format_pipeline::tokenizer",
"o200k_base BPE table failed to load: {e} — \
falling back to chars/3.5 heuristic"
);
None
}
})
.as_ref()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_string() {
assert_eq!(estimate_tokens(""), 0);
assert_eq!(Tokenizer::Cl100kBase.count(""), 0);
assert_eq!(Tokenizer::O200kBase.count(""), 0);
}
#[test]
fn test_short_text() {
assert_eq!(estimate_tokens("hello"), 2);
}
#[test]
fn test_structured_data() {
let toon = "key: gh#1\ntitle: Fix bug\nstate: open";
let tokens = estimate_tokens(toon);
assert_eq!(tokens, 11);
}
#[test]
fn test_round_trip() {
let budget = 8000;
let chars = tokens_to_chars(budget);
let back = estimate_tokens(&"x".repeat(chars));
assert!((back as i64 - budget as i64).unsigned_abs() <= 1);
}
#[test]
fn test_tokens_to_chars() {
assert_eq!(tokens_to_chars(8000), 28000);
}
#[test]
fn cl100k_and_o200k_produce_positive_counts_on_simple_input() {
let phrase = "hello world";
let heuristic = Tokenizer::Heuristic.count(phrase);
for tk in [Tokenizer::Cl100kBase, Tokenizer::O200kBase] {
let n = tk.count(phrase);
assert!(n > 0, "{tk:?} returned zero on `{phrase}`");
assert!(
n <= heuristic,
"{tk:?} reported {n} tokens, worse than the {heuristic}-token heuristic prior"
);
}
}
#[test]
fn cl100k_and_o200k_agree_on_hello_world() {
let cl = Tokenizer::Cl100kBase.count("hello world");
let o2 = Tokenizer::O200kBase.count("hello world");
assert_eq!(cl, o2, "cl100k and o200k should agree on `hello world`");
}
#[test]
fn cl100k_and_o200k_disagree_on_jsonish() {
let json = "{\"id\":42,\"name\":\"alpha\",\"tags\":[\"x\",\"y\",\"z\"]}";
let cl = Tokenizer::Cl100kBase.count(json);
let o2 = Tokenizer::O200kBase.count(json);
assert!(cl > 0 && o2 > 0);
assert_ne!(cl, o2);
}
#[test]
fn heuristic_default_is_heuristic() {
assert_eq!(Tokenizer::default(), Tokenizer::Heuristic);
assert_eq!(Tokenizer::default().as_str(), "heuristic");
}
#[test]
fn from_str_lossy_known_and_unknown() {
assert_eq!(
Tokenizer::from_str_lossy("cl100k_base"),
Tokenizer::Cl100kBase
);
assert_eq!(Tokenizer::from_str_lossy("CL100K"), Tokenizer::Cl100kBase);
assert_eq!(
Tokenizer::from_str_lossy("o200k_base"),
Tokenizer::O200kBase
);
assert_eq!(Tokenizer::from_str_lossy("o200k"), Tokenizer::O200kBase);
assert_eq!(Tokenizer::from_str_lossy("nonsense"), Tokenizer::Heuristic);
assert_eq!(Tokenizer::from_str_lossy(""), Tokenizer::Heuristic);
}
#[test]
fn round_trip_serde() {
for tk in [
Tokenizer::Heuristic,
Tokenizer::Cl100kBase,
Tokenizer::O200kBase,
] {
let json = serde_json::to_string(&tk).unwrap();
let back: Tokenizer = serde_json::from_str(&json).unwrap();
assert_eq!(tk, back);
assert_eq!(Tokenizer::from_str_lossy(tk.as_str()), tk);
}
}
}