use std::sync::Arc;
use crate::ir::Message;
pub trait TokenCounter: Send + Sync + std::fmt::Debug {
fn count(&self, text: &str) -> u64;
fn count_messages(&self, msgs: &[Message]) -> u64 {
msgs.iter()
.flat_map(|m| m.content.iter())
.filter_map(|part| match part {
crate::ir::ContentPart::Text { text, .. } => Some(text.as_str()),
_ => None,
})
.map(|t| self.count(t))
.sum()
}
fn encoding_name(&self) -> &'static str;
}
impl<T: TokenCounter + ?Sized> TokenCounter for Arc<T> {
fn count(&self, text: &str) -> u64 {
(**self).count(text)
}
fn count_messages(&self, msgs: &[Message]) -> u64 {
(**self).count_messages(msgs)
}
fn encoding_name(&self) -> &'static str {
(**self).encoding_name()
}
}
#[derive(Clone, Copy, Debug, Default)]
pub struct ByteCountTokenCounter;
impl ByteCountTokenCounter {
#[must_use]
pub const fn new() -> Self {
Self
}
}
impl TokenCounter for ByteCountTokenCounter {
fn count(&self, text: &str) -> u64 {
u64::from(u32::try_from(text.len().div_ceil(4)).unwrap_or(u32::MAX))
}
fn encoding_name(&self) -> &'static str {
"byte-count-naive"
}
}
pub struct TokenCounterRegistry {
entries: Vec<RegistryEntry>,
fallback: Arc<dyn TokenCounter>,
}
struct RegistryEntry {
provider: &'static str,
model_prefix: &'static str,
counter: Arc<dyn TokenCounter>,
}
impl TokenCounterRegistry {
#[must_use]
pub fn new() -> Self {
Self {
entries: Vec::new(),
fallback: Arc::new(ByteCountTokenCounter::new()),
}
}
#[must_use]
pub fn with_default(mut self, fallback: Arc<dyn TokenCounter>) -> Self {
self.fallback = fallback;
self
}
#[must_use]
pub fn register(
mut self,
provider: &'static str,
model_prefix: &'static str,
counter: Arc<dyn TokenCounter>,
) -> Self {
self.entries.push(RegistryEntry {
provider,
model_prefix,
counter,
});
self
}
#[must_use]
pub fn resolve(&self, provider: &str, model: &str) -> Resolution {
let mut best: Option<&RegistryEntry> = None;
for entry in &self.entries {
if entry.provider != provider {
continue;
}
if !model.starts_with(entry.model_prefix) {
continue;
}
match best {
Some(prev) if prev.model_prefix.len() > entry.model_prefix.len() => {}
_ => best = Some(entry),
}
}
match best {
Some(entry) => Resolution::Matched(Arc::clone(&entry.counter)),
None => Resolution::Fallback(Arc::clone(&self.fallback)),
}
}
#[must_use]
pub fn len(&self) -> usize {
self.entries.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
}
impl Default for TokenCounterRegistry {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone)]
#[non_exhaustive]
pub enum Resolution {
Matched(Arc<dyn TokenCounter>),
Fallback(Arc<dyn TokenCounter>),
}
impl Resolution {
#[must_use]
pub fn counter(&self) -> &Arc<dyn TokenCounter> {
match self {
Self::Matched(c) | Self::Fallback(c) => c,
}
}
#[must_use]
pub const fn is_match(&self) -> bool {
matches!(self, Self::Matched(_))
}
#[must_use]
pub const fn is_fallback(&self) -> bool {
matches!(self, Self::Fallback(_))
}
}
impl std::fmt::Debug for Resolution {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let (kind, encoding) = match self {
Self::Matched(c) => ("Matched", c.encoding_name()),
Self::Fallback(c) => ("Fallback", c.encoding_name()),
};
f.debug_struct(kind).field("encoding", &encoding).finish()
}
}
impl std::fmt::Debug for TokenCounterRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let entries: Vec<(&'static str, &'static str, &'static str)> = self
.entries
.iter()
.map(|e| (e.provider, e.model_prefix, e.counter.encoding_name()))
.collect();
f.debug_struct("TokenCounterRegistry")
.field("entries", &entries)
.field("fallback", &self.fallback.encoding_name())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{ContentPart, Role};
#[test]
fn byte_count_rounds_up() {
let c = ByteCountTokenCounter::new();
assert_eq!(c.count(""), 0, "empty string is zero");
assert_eq!(c.count("a"), 1, "one byte rounds up to one token");
assert_eq!(c.count("abcd"), 1, "exactly four bytes is one token");
assert_eq!(c.count("abcde"), 2, "five bytes rounds up to two");
assert_eq!(c.count("abcdefgh"), 2, "exactly eight bytes is two");
}
#[test]
fn byte_count_handles_multibyte_utf8_at_byte_granularity() {
let c = ByteCountTokenCounter::new();
assert_eq!(c.count("안녕"), 2);
}
#[test]
fn count_messages_sums_text_parts_only() {
let counter = ByteCountTokenCounter::new();
let msg = Message::new(
Role::User,
vec![
ContentPart::text("hello world!"), ContentPart::text("xyz"), ],
);
assert_eq!(counter.count_messages(std::slice::from_ref(&msg)), 4);
}
#[test]
fn count_messages_default_impl_skips_non_text_parts() {
let counter = ByteCountTokenCounter::new();
let msg = Message::new(
Role::Assistant,
vec![
ContentPart::text("hi"), ContentPart::ToolUse {
id: "call_1".into(),
name: "tool".into(),
input: serde_json::json!({}),
provider_echoes: Vec::new(),
},
],
);
assert_eq!(counter.count_messages(std::slice::from_ref(&msg)), 1);
}
#[test]
fn encoding_name_surfaces_for_otel_attribute() {
assert_eq!(
ByteCountTokenCounter::new().encoding_name(),
"byte-count-naive"
);
}
#[test]
fn arc_blanket_impl_forwards() {
let c: Arc<dyn TokenCounter> = Arc::new(ByteCountTokenCounter::new());
assert_eq!(c.count("abcd"), 1);
assert_eq!(c.encoding_name(), "byte-count-naive");
}
#[derive(Debug)]
struct LabelledCounter(&'static str, u64);
impl TokenCounter for LabelledCounter {
fn count(&self, _text: &str) -> u64 {
self.1
}
fn encoding_name(&self) -> &'static str {
self.0
}
}
fn labelled(name: &'static str, fixed: u64) -> Arc<dyn TokenCounter> {
Arc::new(LabelledCounter(name, fixed))
}
#[test]
fn registry_returns_fallback_when_empty() {
let reg = TokenCounterRegistry::new();
let resolution = reg.resolve("openai", "gpt-5");
assert!(
resolution.is_fallback(),
"empty registry should fall through"
);
assert_eq!(resolution.counter().encoding_name(), "byte-count-naive");
}
#[test]
fn registry_resolves_exact_provider_and_prefix() {
let reg = TokenCounterRegistry::new().register("openai", "gpt-4o", labelled("o200k", 1));
let resolution = reg.resolve("openai", "gpt-4o-mini");
assert!(resolution.is_match(), "registered prefix should match");
assert_eq!(resolution.counter().encoding_name(), "o200k");
}
#[test]
fn registry_ignores_wrong_provider() {
let reg =
TokenCounterRegistry::new().register("anthropic", "claude", labelled("anthropic", 2));
let resolution = reg.resolve("openai", "claude-clone");
assert!(resolution.is_fallback());
assert_eq!(resolution.counter().encoding_name(), "byte-count-naive");
}
#[test]
fn registry_longest_prefix_wins_regardless_of_registration_order() {
let reg = TokenCounterRegistry::new()
.register("openai", "gpt-4", labelled("cl100k", 1))
.register("openai", "gpt-4o", labelled("o200k", 1));
assert_eq!(
reg.resolve("openai", "gpt-4o-mini")
.counter()
.encoding_name(),
"o200k"
);
let reg = TokenCounterRegistry::new()
.register("openai", "gpt-4o", labelled("o200k", 1))
.register("openai", "gpt-4", labelled("cl100k", 1));
assert_eq!(
reg.resolve("openai", "gpt-4o-mini")
.counter()
.encoding_name(),
"o200k"
);
}
#[test]
fn registry_falls_through_to_fallback_on_non_matching_model() {
let reg = TokenCounterRegistry::new().register("openai", "gpt-4o", labelled("o200k", 1));
let resolution = reg.resolve("openai", "davinci");
assert!(resolution.is_fallback());
assert_eq!(resolution.counter().encoding_name(), "byte-count-naive");
}
#[test]
fn registry_last_wins_on_tie() {
let reg = TokenCounterRegistry::new()
.register("openai", "gpt-4", labelled("first", 1))
.register("openai", "gpt-4", labelled("second", 1));
assert_eq!(
reg.resolve("openai", "gpt-4-turbo")
.counter()
.encoding_name(),
"second"
);
}
#[test]
fn registry_with_default_replaces_fallback() {
let reg = TokenCounterRegistry::new().with_default(labelled("custom-fb", 0));
let resolution = reg.resolve("any", "x");
assert!(resolution.is_fallback());
assert_eq!(resolution.counter().encoding_name(), "custom-fb");
}
#[test]
fn registry_len_excludes_fallback() {
let reg = TokenCounterRegistry::new()
.register("openai", "gpt-4", labelled("a", 1))
.register("openai", "gpt-4o", labelled("b", 1));
assert_eq!(reg.len(), 2);
assert!(!reg.is_empty());
}
#[test]
fn resolution_pattern_match_yields_counter() {
let reg = TokenCounterRegistry::new().register("openai", "gpt-4o", labelled("o200k", 1));
let counter = match reg.resolve("openai", "gpt-4o") {
Resolution::Matched(c) | Resolution::Fallback(c) => c,
};
assert_eq!(counter.encoding_name(), "o200k");
}
#[test]
fn resolution_match_and_fallback_are_distinguishable() {
let reg = TokenCounterRegistry::new().register("openai", "gpt-4o", labelled("o200k", 1));
let matched = reg.resolve("openai", "gpt-4o-mini");
let fallback = reg.resolve("openai", "davinci-002");
assert!(matched.is_match() && !matched.is_fallback());
assert!(fallback.is_fallback() && !fallback.is_match());
}
}