Skip to main content

entelix_core/
tokens.rs

1//! `TokenCounter` — operator-supplied token-count surface.
2//!
3//! Tokens are the unit budget caps and chunk boundaries are denominated
4//! in. The vendor's wire-level usage report (the `Usage` block on a
5//! `ModelResponse`) gives the *post-flight* count. `TokenCounter` is
6//! the *pre-flight* counterpart — what operators reach for when:
7//!
8//! - **`RunBudget` pre-flight check** — refuse a request whose
9//!   estimated input would already exceed the configured input-token
10//!   cap, before paying the round-trip cost.
11//! - **RAG chunking** — `entelix-rag::TokenCountSplitter` slices a
12//!   document so each chunk lands under the model's per-message
13//!   ceiling.
14//! - **Content-economy budgeting** — system prompt + tools + history
15//!   sum estimation when assembling a request.
16//!
17//! Vendor-accurate counters live in companion crates —
18//! [`entelix-tokenizer-tiktoken`](https://docs.rs/entelix-tokenizer-tiktoken)
19//! for the OpenAI BPE family (`cl100k_base`, `o200k_base`,
20//! `p50k_base`, `r50k_base`) and
21//! [`entelix-tokenizer-hf`](https://docs.rs/entelix-tokenizer-hf)
22//! for HuggingFace tokenizer.json sources (Llama, Qwen, Mistral,
23//! DeepSeek, Gemma, Phi, …). Future companions cover locale-aware
24//! morphological accuracy for Korean and Japanese.
25//! [`ByteCountTokenCounter`] ships in core as a zero-dependency
26//! conservative default — accurate enough for development scaffolding,
27//! never for production budgeting on non-English content.
28
29use std::sync::Arc;
30
31use crate::ir::Message;
32
33/// Counts tokens for budget enforcement, splitter sizing, and
34/// content-economy estimation.
35///
36/// Synchronous by contract — counters that need IO (remote tokenizer
37/// service, lazy file-backed model) should pre-load eagerly at
38/// construction or wrap the slow path in
39/// `tokio::task::spawn_blocking` at the *call* site rather than
40/// hiding `.await` inside the trait. Mirrors the
41/// [`crate::time::Clock`] discipline: low-level primitives stay
42/// pure so they compose with locks and hot paths.
43pub trait TokenCounter: Send + Sync + std::fmt::Debug {
44    /// Count the tokens in `text` under this counter's encoding.
45    fn count(&self, text: &str) -> u64;
46
47    /// Sum the token count across every text-bearing content part
48    /// of a message slice. The default impl walks
49    /// [`crate::ir::ContentPart::Text`] parts; non-text parts (image,
50    /// tool-use, tool-result blocks) are vendor-specific in their
51    /// token cost — counters that need an exact tally for those
52    /// shapes override this method.
53    fn count_messages(&self, msgs: &[Message]) -> u64 {
54        msgs.iter()
55            .flat_map(|m| m.content.iter())
56            .filter_map(|part| match part {
57                crate::ir::ContentPart::Text { text, .. } => Some(text.as_str()),
58                _ => None,
59            })
60            .map(|t| self.count(t))
61            .sum()
62    }
63
64    /// Vendor-published encoding name (`"cl100k_base"`,
65    /// `"o200k_base"`, `"claude"`, `"gemini-tokenizer"`, …) — surfaced
66    /// on OTel `gen_ai.tokenizer.name` and operator diagnostics.
67    fn encoding_name(&self) -> &'static str;
68}
69
70impl<T: TokenCounter + ?Sized> TokenCounter for Arc<T> {
71    fn count(&self, text: &str) -> u64 {
72        (**self).count(text)
73    }
74    fn count_messages(&self, msgs: &[Message]) -> u64 {
75        (**self).count_messages(msgs)
76    }
77    fn encoding_name(&self) -> &'static str {
78        (**self).encoding_name()
79    }
80}
81
82/// Zero-dependency conservative counter — `bytes.div_ceil(4)`.
83///
84/// Approximates English at the ~4-bytes-per-token rule of thumb that
85/// tiktoken's `cl100k_base` is built around. **Systematically
86/// inaccurate** for CJK, Devanagari, Arabic, and other scripts whose
87/// UTF-8 byte cost diverges from typical token boundaries — operators
88/// shipping multilingual workloads inject a vendor-accurate counter
89/// (`entelix-tokenizer-tiktoken`, `entelix-tokenizer-hf`, locale-aware
90/// companions) at `ChatModel::with_token_counter(...)` time.
91///
92/// The bias direction is deliberate: `div_ceil` rounds up, so the
93/// estimate skews *over* the real count on average. Pre-flight
94/// `RunBudget` checks built on top remain conservative — a
95/// near-budget call is more likely refused than admitted, which is
96/// the correct error direction for budget enforcement.
97#[derive(Clone, Copy, Debug, Default)]
98pub struct ByteCountTokenCounter;
99
100impl ByteCountTokenCounter {
101    /// Construct the counter. Stateless — every call to [`Self::new`]
102    /// returns the same logical instance.
103    #[must_use]
104    pub const fn new() -> Self {
105        Self
106    }
107}
108
109impl TokenCounter for ByteCountTokenCounter {
110    fn count(&self, text: &str) -> u64 {
111        // `usize::div_ceil` only stabilised in 1.73 — the workspace
112        // pins 1.95 so the direct call is fine. `u64::from` over an
113        // `as` cast keeps the lossless-conversion lint happy.
114        u64::from(u32::try_from(text.len().div_ceil(4)).unwrap_or(u32::MAX))
115    }
116
117    fn encoding_name(&self) -> &'static str {
118        "byte-count-naive"
119    }
120}
121
122/// Routing table from `(provider, model)` pairs to a vendor-accurate
123/// [`TokenCounter`].
124///
125/// Operators registering one counter per `(provider, model_prefix)`
126/// pair drive the [`Self::resolve`] lookup; gateways (one process
127/// fronting many model handles) read every chat-side dispatch
128/// through one shared registry. The facade ships
129/// `default_token_counter_registry()` factory (feature-gated on
130/// `tokenizer-tiktoken` / `tokenizer-hf`) with the OpenAI BPE
131/// pre-populated; operators extend it with their HuggingFace
132/// `tokenizer.json` bytes.
133///
134/// ## Matching algorithm
135///
136/// Entries match when the provider name is **exactly equal** and the
137/// model name **starts with** the registered prefix. Among all
138/// matching entries, the one with the longest prefix wins — so
139/// registering both `"gpt-4"` and `"gpt-4o"` routes
140/// `"gpt-4o-mini"` to the `"gpt-4o"` entry without depending on
141/// registration order. Ties on prefix length resolve to the
142/// last-registered entry (operator-overridable). Misses fall through
143/// to the registry's fallback counter.
144///
145/// ## Why prefix matching
146///
147/// Vendors version models with stable family prefixes (`gpt-4o-*`,
148/// `claude-sonnet-*`, `gemini-1.5-*`). Exact-name matching forces
149/// the operator to update the registry on every minor model release;
150/// prefix matching absorbs new patch revisions silently into the
151/// same tokenizer mapping the family uses. Regex was considered and
152/// rejected — too expressive, and the typical mistake is *missing* a
153/// model, which prefix matching handles by falling through to the
154/// fallback rather than misrouting silently.
155pub struct TokenCounterRegistry {
156    entries: Vec<RegistryEntry>,
157    fallback: Arc<dyn TokenCounter>,
158}
159
160struct RegistryEntry {
161    provider: &'static str,
162    model_prefix: &'static str,
163    counter: Arc<dyn TokenCounter>,
164}
165
166impl TokenCounterRegistry {
167    /// Construct an empty registry with [`ByteCountTokenCounter`] as
168    /// the fallback. Add entries with [`Self::register`]; replace
169    /// the fallback with [`Self::with_default`].
170    #[must_use]
171    pub fn new() -> Self {
172        Self {
173            entries: Vec::new(),
174            fallback: Arc::new(ByteCountTokenCounter::new()),
175        }
176    }
177
178    /// Replace the fallback counter used when no entry matches.
179    /// Default is [`ByteCountTokenCounter`] (conservative — biased
180    /// to over-count so pre-flight budget checks fail closed).
181    #[must_use]
182    pub fn with_default(mut self, fallback: Arc<dyn TokenCounter>) -> Self {
183        self.fallback = fallback;
184        self
185    }
186
187    /// Append a `(provider, model_prefix) → counter` entry. Multiple
188    /// entries for the same provider partition the model space by
189    /// prefix; the longest-prefix match wins at lookup time.
190    /// Ties on prefix length resolve to the last-registered entry,
191    /// so a later `register` call overrides an earlier one for the
192    /// same `(provider, model_prefix)` pair.
193    #[must_use]
194    pub fn register(
195        mut self,
196        provider: &'static str,
197        model_prefix: &'static str,
198        counter: Arc<dyn TokenCounter>,
199    ) -> Self {
200        self.entries.push(RegistryEntry {
201            provider,
202            model_prefix,
203            counter,
204        });
205        self
206    }
207
208    /// Resolve `(provider, model)` to a counter. The returned
209    /// [`Resolution`] makes match vs. fallback visible at the type
210    /// level so callers branch on the miss without inspecting
211    /// `encoding_name()` after the fact (invariant 15 — silent
212    /// fallback is not an option). See [`Resolution`] for the
213    /// pattern-match shape and the deliberate "accept fallback"
214    /// idiom.
215    #[must_use]
216    pub fn resolve(&self, provider: &str, model: &str) -> Resolution {
217        let mut best: Option<&RegistryEntry> = None;
218        for entry in &self.entries {
219            if entry.provider != provider {
220                continue;
221            }
222            if !model.starts_with(entry.model_prefix) {
223                continue;
224            }
225            match best {
226                Some(prev) if prev.model_prefix.len() > entry.model_prefix.len() => {}
227                _ => best = Some(entry),
228            }
229        }
230        match best {
231            Some(entry) => Resolution::Matched(Arc::clone(&entry.counter)),
232            None => Resolution::Fallback(Arc::clone(&self.fallback)),
233        }
234    }
235
236    /// Number of registered `(provider, model_prefix)` entries.
237    /// Excludes the fallback. Operators wiring a `tracing::info!` on
238    /// boot read this to confirm the table is the expected size.
239    #[must_use]
240    pub fn len(&self) -> usize {
241        self.entries.len()
242    }
243
244    /// Whether the registry has no registered entries (the fallback
245    /// is always present).
246    #[must_use]
247    pub fn is_empty(&self) -> bool {
248        self.entries.is_empty()
249    }
250}
251
252impl Default for TokenCounterRegistry {
253    fn default() -> Self {
254        Self::new()
255    }
256}
257
258/// Outcome of [`TokenCounterRegistry::resolve`] — surfaces whether a
259/// registered entry actually matched `(provider, model)` or the
260/// registry's fallback counter was returned because no entry matched.
261///
262/// Making fallback visible at the type level lets production cost
263/// estimation paths log / alert on unknown models without inspecting
264/// `encoding_name()` after the fact (invariant 15 — silent fallback is
265/// not an option). Callers consume the resolution via pattern match —
266/// there is no "give me the counter regardless" shortcut, because
267/// such a shortcut would re-enable silent absorption of the fallback.
268///
269/// ```ignore
270/// let counter = match registry.resolve("openai", "gpt-9") {
271///     Resolution::Matched(c) => c,
272///     Resolution::Fallback(c) => {
273///         tracing::warn!("unknown model — using fallback counter");
274///         c
275///     }
276/// };
277/// ```
278///
279/// Operators that intentionally accept the fallback bind both arms to
280/// the same name (`Resolution::Matched(c) | Resolution::Fallback(c) => c`).
281/// The pattern is short, and the call site signals the deliberate
282/// choice in source.
283///
284/// ## Future variants
285///
286/// `Resolution` is `#[non_exhaustive]`. The two predicates
287/// [`Self::is_match`] and [`Self::is_fallback`] each return `true`
288/// only for their own variant — `!is_match()` and `is_fallback()`
289/// are **not** equivalent under future variant additions. Operators
290/// that want to alert on the no-match branch should match on
291/// `Self::Fallback` explicitly.
292#[derive(Clone)]
293#[non_exhaustive]
294pub enum Resolution {
295    /// A registered `(provider, model_prefix)` entry matched.
296    Matched(Arc<dyn TokenCounter>),
297    /// No entry matched; the registry's fallback counter is returned.
298    Fallback(Arc<dyn TokenCounter>),
299}
300
301impl Resolution {
302    /// Borrow the resolved counter regardless of which branch matched.
303    /// Operators that need the owned `Arc` consume the enum through a
304    /// pattern match.
305    #[must_use]
306    pub fn counter(&self) -> &Arc<dyn TokenCounter> {
307        match self {
308            Self::Matched(c) | Self::Fallback(c) => c,
309        }
310    }
311
312    /// `true` when a registered entry matched.
313    #[must_use]
314    pub const fn is_match(&self) -> bool {
315        matches!(self, Self::Matched(_))
316    }
317
318    /// `true` when no entry matched and the fallback was returned.
319    /// Distinct from `!is_match()` — `Resolution` is `non_exhaustive`,
320    /// so a future variant could make both predicates return `false`.
321    #[must_use]
322    pub const fn is_fallback(&self) -> bool {
323        matches!(self, Self::Fallback(_))
324    }
325}
326
327impl std::fmt::Debug for Resolution {
328    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
329        let (kind, encoding) = match self {
330            Self::Matched(c) => ("Matched", c.encoding_name()),
331            Self::Fallback(c) => ("Fallback", c.encoding_name()),
332        };
333        f.debug_struct(kind).field("encoding", &encoding).finish()
334    }
335}
336
337impl std::fmt::Debug for TokenCounterRegistry {
338    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
339        let entries: Vec<(&'static str, &'static str, &'static str)> = self
340            .entries
341            .iter()
342            .map(|e| (e.provider, e.model_prefix, e.counter.encoding_name()))
343            .collect();
344        f.debug_struct("TokenCounterRegistry")
345            .field("entries", &entries)
346            .field("fallback", &self.fallback.encoding_name())
347            .finish()
348    }
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354    use crate::ir::{ContentPart, Role};
355
356    #[test]
357    fn byte_count_rounds_up() {
358        let c = ByteCountTokenCounter::new();
359        assert_eq!(c.count(""), 0, "empty string is zero");
360        assert_eq!(c.count("a"), 1, "one byte rounds up to one token");
361        assert_eq!(c.count("abcd"), 1, "exactly four bytes is one token");
362        assert_eq!(c.count("abcde"), 2, "five bytes rounds up to two");
363        assert_eq!(c.count("abcdefgh"), 2, "exactly eight bytes is two");
364    }
365
366    #[test]
367    fn byte_count_handles_multibyte_utf8_at_byte_granularity() {
368        // Korean "안녕" — 2 chars, 6 UTF-8 bytes → 2 tokens by the
369        // div_ceil(4) heuristic. Real cl100k_base would tokenise
370        // these as 2-3 tokens; the naive counter is documented as
371        // approximate, not exact.
372        let c = ByteCountTokenCounter::new();
373        assert_eq!(c.count("안녕"), 2);
374    }
375
376    #[test]
377    fn count_messages_sums_text_parts_only() {
378        let counter = ByteCountTokenCounter::new();
379        let msg = Message::new(
380            Role::User,
381            vec![
382                ContentPart::text("hello world!"), // 12 bytes → 3 tokens
383                ContentPart::text("xyz"),          // 3 bytes  → 1 token
384            ],
385        );
386        assert_eq!(counter.count_messages(std::slice::from_ref(&msg)), 4);
387    }
388
389    #[test]
390    fn count_messages_default_impl_skips_non_text_parts() {
391        // Tool-use blocks etc. carry vendor-specific tokens; the
392        // default counter walks Text parts only. A counter that
393        // wants exact counting for tool-use shapes overrides
394        // count_messages.
395        let counter = ByteCountTokenCounter::new();
396        let msg = Message::new(
397            Role::Assistant,
398            vec![
399                ContentPart::text("hi"), // 2 bytes → 1 token
400                ContentPart::ToolUse {
401                    id: "call_1".into(),
402                    name: "tool".into(),
403                    input: serde_json::json!({}),
404                    provider_echoes: Vec::new(),
405                },
406            ],
407        );
408        assert_eq!(counter.count_messages(std::slice::from_ref(&msg)), 1);
409    }
410
411    #[test]
412    fn encoding_name_surfaces_for_otel_attribute() {
413        assert_eq!(
414            ByteCountTokenCounter::new().encoding_name(),
415            "byte-count-naive"
416        );
417    }
418
419    #[test]
420    fn arc_blanket_impl_forwards() {
421        let c: Arc<dyn TokenCounter> = Arc::new(ByteCountTokenCounter::new());
422        assert_eq!(c.count("abcd"), 1);
423        assert_eq!(c.encoding_name(), "byte-count-naive");
424    }
425
426    #[derive(Debug)]
427    struct LabelledCounter(&'static str, u64);
428    impl TokenCounter for LabelledCounter {
429        fn count(&self, _text: &str) -> u64 {
430            self.1
431        }
432        fn encoding_name(&self) -> &'static str {
433            self.0
434        }
435    }
436
437    fn labelled(name: &'static str, fixed: u64) -> Arc<dyn TokenCounter> {
438        Arc::new(LabelledCounter(name, fixed))
439    }
440
441    #[test]
442    fn registry_returns_fallback_when_empty() {
443        let reg = TokenCounterRegistry::new();
444        let resolution = reg.resolve("openai", "gpt-5");
445        assert!(
446            resolution.is_fallback(),
447            "empty registry should fall through"
448        );
449        assert_eq!(resolution.counter().encoding_name(), "byte-count-naive");
450    }
451
452    #[test]
453    fn registry_resolves_exact_provider_and_prefix() {
454        let reg = TokenCounterRegistry::new().register("openai", "gpt-4o", labelled("o200k", 1));
455        let resolution = reg.resolve("openai", "gpt-4o-mini");
456        assert!(resolution.is_match(), "registered prefix should match");
457        assert_eq!(resolution.counter().encoding_name(), "o200k");
458    }
459
460    #[test]
461    fn registry_ignores_wrong_provider() {
462        let reg =
463            TokenCounterRegistry::new().register("anthropic", "claude", labelled("anthropic", 2));
464        let resolution = reg.resolve("openai", "claude-clone");
465        // Provider mismatch — fall through to fallback.
466        assert!(resolution.is_fallback());
467        assert_eq!(resolution.counter().encoding_name(), "byte-count-naive");
468    }
469
470    #[test]
471    fn registry_longest_prefix_wins_regardless_of_registration_order() {
472        // Register "gpt-4" first, then "gpt-4o" — longest-prefix wins
473        // on "gpt-4o-mini" regardless of order.
474        let reg = TokenCounterRegistry::new()
475            .register("openai", "gpt-4", labelled("cl100k", 1))
476            .register("openai", "gpt-4o", labelled("o200k", 1));
477        assert_eq!(
478            reg.resolve("openai", "gpt-4o-mini")
479                .counter()
480                .encoding_name(),
481            "o200k"
482        );
483
484        // Reverse registration order — same outcome.
485        let reg = TokenCounterRegistry::new()
486            .register("openai", "gpt-4o", labelled("o200k", 1))
487            .register("openai", "gpt-4", labelled("cl100k", 1));
488        assert_eq!(
489            reg.resolve("openai", "gpt-4o-mini")
490                .counter()
491                .encoding_name(),
492            "o200k"
493        );
494    }
495
496    #[test]
497    fn registry_falls_through_to_fallback_on_non_matching_model() {
498        let reg = TokenCounterRegistry::new().register("openai", "gpt-4o", labelled("o200k", 1));
499        // Same provider, prefix doesn't match — fall through.
500        let resolution = reg.resolve("openai", "davinci");
501        assert!(resolution.is_fallback());
502        assert_eq!(resolution.counter().encoding_name(), "byte-count-naive");
503    }
504
505    #[test]
506    fn registry_last_wins_on_tie() {
507        let reg = TokenCounterRegistry::new()
508            .register("openai", "gpt-4", labelled("first", 1))
509            .register("openai", "gpt-4", labelled("second", 1));
510        assert_eq!(
511            reg.resolve("openai", "gpt-4-turbo")
512                .counter()
513                .encoding_name(),
514            "second"
515        );
516    }
517
518    #[test]
519    fn registry_with_default_replaces_fallback() {
520        let reg = TokenCounterRegistry::new().with_default(labelled("custom-fb", 0));
521        let resolution = reg.resolve("any", "x");
522        assert!(resolution.is_fallback());
523        assert_eq!(resolution.counter().encoding_name(), "custom-fb");
524    }
525
526    #[test]
527    fn registry_len_excludes_fallback() {
528        let reg = TokenCounterRegistry::new()
529            .register("openai", "gpt-4", labelled("a", 1))
530            .register("openai", "gpt-4o", labelled("b", 1));
531        assert_eq!(reg.len(), 2);
532        assert!(!reg.is_empty());
533    }
534
535    #[test]
536    fn resolution_pattern_match_yields_counter() {
537        let reg = TokenCounterRegistry::new().register("openai", "gpt-4o", labelled("o200k", 1));
538        let counter = match reg.resolve("openai", "gpt-4o") {
539            Resolution::Matched(c) | Resolution::Fallback(c) => c,
540        };
541        assert_eq!(counter.encoding_name(), "o200k");
542    }
543
544    #[test]
545    fn resolution_match_and_fallback_are_distinguishable() {
546        let reg = TokenCounterRegistry::new().register("openai", "gpt-4o", labelled("o200k", 1));
547        let matched = reg.resolve("openai", "gpt-4o-mini");
548        let fallback = reg.resolve("openai", "davinci-002");
549        assert!(matched.is_match() && !matched.is_fallback());
550        assert!(fallback.is_fallback() && !fallback.is_match());
551    }
552}