#![cfg_attr(docsrs, feature(doc_cfg))]
#![doc(html_root_url = "https://docs.rs/entelix-tokenizer-tiktoken/0.5.3")]
#![deny(missing_docs)]
#![allow(
// Vendor-name proper nouns (`OpenAI`, `OTel`, `BPE`, `GPT-4o`)
// appear throughout the docs; backtick-quoting every occurrence
// hurts readability without adding signal.
clippy::doc_markdown
)]
use std::fmt;
use std::sync::Arc;
use entelix_core::TokenCounter;
use thiserror::Error;
use tiktoken_rs::CoreBPE;
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
#[non_exhaustive]
pub enum TiktokenEncoding {
Cl100kBase,
O200kBase,
P50kBase,
R50kBase,
}
impl TiktokenEncoding {
#[must_use]
pub const fn name(self) -> &'static str {
match self {
Self::Cl100kBase => "cl100k_base",
Self::O200kBase => "o200k_base",
Self::P50kBase => "p50k_base",
Self::R50kBase => "r50k_base",
}
}
}
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum TiktokenError {
#[error("tiktoken BPE load failed for {encoding_name}: {message}")]
Load {
encoding_name: &'static str,
message: String,
},
}
#[derive(Clone)]
pub struct TiktokenCounter {
bpe: Arc<CoreBPE>,
encoding: TiktokenEncoding,
}
impl fmt::Debug for TiktokenCounter {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TiktokenCounter")
.field("encoding", &self.encoding)
.finish_non_exhaustive()
}
}
impl TiktokenCounter {
pub fn for_encoding(encoding: TiktokenEncoding) -> Result<Self, TiktokenError> {
let bpe = match encoding {
TiktokenEncoding::Cl100kBase => tiktoken_rs::cl100k_base(),
TiktokenEncoding::O200kBase => tiktoken_rs::o200k_base(),
TiktokenEncoding::P50kBase => tiktoken_rs::p50k_base(),
TiktokenEncoding::R50kBase => tiktoken_rs::r50k_base(),
}
.map_err(|e| TiktokenError::Load {
encoding_name: encoding.name(),
message: e.to_string(),
})?;
Ok(Self {
bpe: Arc::new(bpe),
encoding,
})
}
#[must_use]
pub const fn encoding(&self) -> TiktokenEncoding {
self.encoding
}
}
impl TokenCounter for TiktokenCounter {
fn count(&self, text: &str) -> u64 {
u64::try_from(self.bpe.encode_ordinary(text).len()).unwrap_or(u64::MAX)
}
fn encoding_name(&self) -> &'static str {
self.encoding.name()
}
}
#[cfg(test)]
mod tests {
use super::*;
use entelix_core::ir::{ContentPart, Message, Role};
type TestResult = Result<(), TiktokenError>;
#[test]
fn each_encoding_loads_successfully() -> TestResult {
for encoding in [
TiktokenEncoding::Cl100kBase,
TiktokenEncoding::O200kBase,
TiktokenEncoding::P50kBase,
TiktokenEncoding::R50kBase,
] {
let counter = TiktokenCounter::for_encoding(encoding)?;
assert_eq!(counter.encoding(), encoding);
assert_eq!(counter.encoding_name(), encoding.name());
}
Ok(())
}
#[test]
fn empty_string_counts_zero() -> TestResult {
let counter = TiktokenCounter::for_encoding(TiktokenEncoding::Cl100kBase)?;
assert_eq!(counter.count(""), 0);
Ok(())
}
#[test]
fn cl100k_base_counts_match_known_tiktoken_values() -> TestResult {
let counter = TiktokenCounter::for_encoding(TiktokenEncoding::Cl100kBase)?;
assert_eq!(counter.count("Hello world"), 2);
assert_eq!(counter.count("tiktoken is great!"), 6);
Ok(())
}
#[test]
fn o200k_base_handles_multibyte_utf8() -> TestResult {
let counter = TiktokenCounter::for_encoding(TiktokenEncoding::O200kBase)?;
let count = counter.count("안녕 세계");
assert!(count > 0, "non-empty CJK text must count above zero");
assert!(
count < 20,
"five-grapheme CJK should not bloat past 20 tokens"
);
Ok(())
}
#[test]
fn longer_text_produces_more_tokens() -> TestResult {
let counter = TiktokenCounter::for_encoding(TiktokenEncoding::Cl100kBase)?;
let short = counter.count("hello");
let long = counter.count("hello world this is a longer sentence with more tokens");
assert!(long > short, "monotonicity: longer input → more tokens");
Ok(())
}
#[test]
fn count_messages_default_walks_text_parts() -> TestResult {
let counter = TiktokenCounter::for_encoding(TiktokenEncoding::Cl100kBase)?;
let msg = Message::new(
Role::User,
vec![
ContentPart::text("Hello world"), ContentPart::text("tiktoken is great!"), ],
);
assert_eq!(counter.count_messages(std::slice::from_ref(&msg)), 8);
Ok(())
}
#[test]
fn count_messages_skips_non_text_parts() -> TestResult {
let counter = TiktokenCounter::for_encoding(TiktokenEncoding::Cl100kBase)?;
let msg = Message::new(
Role::Assistant,
vec![
ContentPart::text("Hello world"), ContentPart::ToolUse {
id: "call_1".into(),
name: "search".into(),
input: serde_json::json!({"q": "rust"}),
provider_echoes: Vec::new(),
},
],
);
assert_eq!(counter.count_messages(std::slice::from_ref(&msg)), 2);
Ok(())
}
#[test]
fn arc_dyn_dispatch_forwards_through_blanket_impl() -> TestResult {
let counter: Arc<dyn TokenCounter> =
Arc::new(TiktokenCounter::for_encoding(TiktokenEncoding::Cl100kBase)?);
assert_eq!(counter.count("Hello world"), 2);
assert_eq!(counter.encoding_name(), "cl100k_base");
Ok(())
}
#[test]
fn clone_shares_bpe_and_keeps_encoding() -> TestResult {
let original = TiktokenCounter::for_encoding(TiktokenEncoding::O200kBase)?;
let cloned = original.clone();
assert_eq!(cloned.encoding(), TiktokenEncoding::O200kBase);
assert_eq!(cloned.count("hello"), original.count("hello"));
assert!(Arc::ptr_eq(&original.bpe, &cloned.bpe));
Ok(())
}
#[test]
fn debug_includes_encoding_not_bpe_table() -> TestResult {
let counter = TiktokenCounter::for_encoding(TiktokenEncoding::Cl100kBase)?;
let debug = format!("{counter:?}");
assert!(debug.contains("Cl100kBase"));
assert!(
!debug.contains("CoreBPE"),
"Debug must not dump the BPE tables: {debug}"
);
Ok(())
}
#[test]
fn encoding_name_round_trips() {
assert_eq!(TiktokenEncoding::Cl100kBase.name(), "cl100k_base");
assert_eq!(TiktokenEncoding::O200kBase.name(), "o200k_base");
assert_eq!(TiktokenEncoding::P50kBase.name(), "p50k_base");
assert_eq!(TiktokenEncoding::R50kBase.name(), "r50k_base");
}
}