multi_llm/internals/tokens.rs
1//! Token counting utilities for LLM providers.
2//!
3//! This module provides token counting implementations for different LLM providers.
4//! Accurate token counting is important for:
5//! - Staying within context window limits
6//! - Estimating API costs
7//! - Optimizing prompts
8//!
9//! # Usage
10//!
11//! Use [`TokenCounterFactory`] to create counters for specific providers:
12//!
13//! ```rust,no_run
14//! use multi_llm::{TokenCounterFactory, TokenCounter};
15//!
16//! // Create counter for OpenAI GPT-4
17//! let counter = TokenCounterFactory::create_counter("openai", "gpt-4")?;
18//!
19//! // Count tokens in text
20//! let tokens = counter.count_tokens("Hello, world!")?;
21//! println!("Token count: {}", tokens);
22//!
23//! // Check context limit
24//! let max = counter.max_context_tokens();
25//! println!("Max context: {} tokens", max);
26//! # Ok::<(), multi_llm::LlmError>(())
27//! ```
28//!
29//! # Provider-Specific Notes
30//!
31//! - **OpenAI**: Uses tiktoken with exact tokenization
32//! - **Anthropic**: Uses cl100k_base with 1.1x approximation factor
33//! - **Ollama/LM Studio**: Uses cl100k_base (may vary by model)
34//!
35//! # Available Types
36//!
37//! - [`TokenCounter`]: Trait for all token counters
38//! - [`OpenAITokenCounter`]: OpenAI GPT model tokenizer
39//! - [`AnthropicTokenCounter`]: Anthropic Claude approximation
40//! - [`TokenCounterFactory`]: Factory for creating counters
41
42use crate::error::{LlmError, LlmResult};
43use crate::logging::{log_debug, log_warn};
44
45use std::sync::Arc;
46use tiktoken_rs::{cl100k_base, o200k_base, CoreBPE};
47
48/// Trait for counting tokens in text and messages.
49///
50/// Implement this trait to add support for new tokenizers.
51/// Use [`TokenCounterFactory`] to create instances for supported providers.
52///
53/// # Example
54///
55/// ```rust,no_run
56/// use multi_llm::{TokenCounter, TokenCounterFactory};
57///
58/// # fn example() -> multi_llm::LlmResult<()> {
59/// let counter = TokenCounterFactory::create_counter("openai", "gpt-4")?;
60///
61/// // Count tokens
62/// let count = counter.count_tokens("Hello, world!")?;
63///
64/// // Validate against limit
65/// counter.validate_token_limit("Some text...")?;
66///
67/// // Truncate if needed
68/// let truncated = counter.truncate_to_limit("Very long text...", 100)?;
69/// # Ok(())
70/// # }
71/// ```
72pub trait TokenCounter: Send + Sync + std::fmt::Debug {
73 /// Count tokens in a text string.
74 ///
75 /// # Errors
76 ///
77 /// Returns [`LlmError::ConfigurationError`] if the tokenizer
78 /// fails to encode the text.
79 fn count_tokens(&self, text: &str) -> LlmResult<u32>;
80
81 /// Count tokens in a list of messages (includes formatting overhead).
82 ///
83 /// The count includes tokens for role markers, message separators,
84 /// and other provider-specific formatting.
85 fn count_message_tokens(&self, messages: &[serde_json::Value]) -> LlmResult<u32>;
86
87 /// Get the maximum context window size for this tokenizer.
88 fn max_context_tokens(&self) -> u32;
89
90 /// Validate that text doesn't exceed the token limit.
91 ///
92 /// # Errors
93 ///
94 /// Returns [`LlmError::TokenLimitExceeded`] if the text exceeds
95 /// the maximum context window.
96 fn validate_token_limit(&self, text: &str) -> LlmResult<()>;
97
98 /// Truncate text to fit within a token limit.
99 ///
100 /// If the text already fits, it's returned unchanged.
101 fn truncate_to_limit(&self, text: &str, max_tokens: u32) -> LlmResult<String>;
102}
103
104/// Token counter for OpenAI GPT models using tiktoken.
105///
106/// Provides exact token counts for OpenAI models. Automatically selects
107/// the correct tokenizer based on the model name.
108///
109/// # Supported Models
110///
111/// | Model | Tokenizer | Context Window |
112/// |-------|-----------|---------------|
113/// | gpt-4-turbo | cl100k_base | 128K |
114/// | gpt-4 | cl100k_base | 8K |
115/// | gpt-3.5-turbo | cl100k_base | 16K |
116/// | o1-* | o200k_base | 200K |
117///
118/// # Example
119///
120/// ```rust,no_run
121/// use multi_llm::{OpenAITokenCounter, TokenCounter};
122///
123/// let counter = OpenAITokenCounter::new("gpt-4")?;
124/// let tokens = counter.count_tokens("Hello, world!")?;
125/// # Ok::<(), multi_llm::LlmError>(())
126/// ```
127pub struct OpenAITokenCounter {
128 tokenizer: CoreBPE,
129 max_tokens: u32,
130 model_name: String,
131}
132
133impl std::fmt::Debug for OpenAITokenCounter {
134 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135 f.debug_struct("OpenAITokenCounter")
136 .field("max_tokens", &self.max_tokens)
137 .field("model_name", &self.model_name)
138 .finish()
139 }
140}
141
142impl OpenAITokenCounter {
143 /// Determine max tokens for GPT-4 model
144 fn gpt4_max_tokens(model: &str) -> u32 {
145 if model.contains("turbo") || model.contains("preview") {
146 128000
147 } else if model.contains("32k") {
148 32768
149 } else {
150 8192
151 }
152 }
153
154 /// Determine max tokens for GPT-3.5 model
155 fn gpt35_max_tokens(model: &str) -> u32 {
156 if model.contains("16k") {
157 16384
158 } else {
159 4096
160 }
161 }
162
163 /// Get tokenizer and max tokens for model
164 fn get_model_config(model: &str) -> LlmResult<(CoreBPE, u32)> {
165 match model {
166 m if m.starts_with("gpt-4") => {
167 let tokenizer = cl100k_base().map_err(|e| {
168 LlmError::configuration_error(format!("Failed to initialize tokenizer: {}", e))
169 })?;
170 Ok((tokenizer, Self::gpt4_max_tokens(m)))
171 }
172 m if m.starts_with("gpt-3.5") => {
173 let tokenizer = cl100k_base().map_err(|e| {
174 LlmError::configuration_error(format!("Failed to initialize tokenizer: {}", e))
175 })?;
176 Ok((tokenizer, Self::gpt35_max_tokens(m)))
177 }
178 m if m.starts_with("o1") => {
179 let tokenizer = o200k_base().map_err(|e| {
180 LlmError::configuration_error(format!("Failed to initialize tokenizer: {}", e))
181 })?;
182 Ok((tokenizer, 200000))
183 }
184 _ => {
185 log_warn!(model = %model, "Unknown model, using cl100k_base tokenizer with 4k context");
186 let tokenizer = cl100k_base().map_err(|e| {
187 LlmError::configuration_error(format!("Failed to initialize tokenizer: {}", e))
188 })?;
189 Ok((tokenizer, 4096))
190 }
191 }
192 }
193
194 /// Create token counter for specific OpenAI model
195 pub fn new(model: &str) -> LlmResult<Self> {
196 let (tokenizer, max_tokens) = Self::get_model_config(model)?;
197
198 Ok(Self {
199 tokenizer,
200 max_tokens,
201 model_name: model.to_string(),
202 })
203 }
204
205 /// Create token counter for LM Studio (uses cl100k_base as default)
206 pub fn for_lm_studio(max_tokens: u32) -> LlmResult<Self> {
207 // log_debug!(max_tokens = max_tokens, "Creating LM Studio token counter");
208
209 let tokenizer = cl100k_base().map_err(|e| {
210 LlmError::configuration_error(format!(
211 "Failed to initialize LM Studio tokenizer: {}",
212 e
213 ))
214 })?;
215
216 Ok(Self {
217 tokenizer,
218 max_tokens,
219 model_name: "lm-studio".to_string(),
220 })
221 }
222}
223
224impl TokenCounter for OpenAITokenCounter {
225 fn count_tokens(&self, text: &str) -> LlmResult<u32> {
226 let tokens = self.tokenizer.encode_with_special_tokens(text);
227 Ok(tokens.len() as u32)
228 }
229
230 fn count_message_tokens(&self, messages: &[serde_json::Value]) -> LlmResult<u32> {
231 let mut total_tokens = 3u32; // Base conversation formatting
232
233 for message in messages {
234 total_tokens += self.count_single_message_tokens(message);
235 }
236
237 total_tokens += 3; // Reply end tokens
238
239 log_debug!(
240 total_tokens = total_tokens,
241 message_count = messages.len(),
242 model = %self.model_name,
243 "Calculated message token count"
244 );
245
246 Ok(total_tokens)
247 }
248
249 fn max_context_tokens(&self) -> u32 {
250 self.max_tokens
251 }
252
253 fn validate_token_limit(&self, text: &str) -> LlmResult<()> {
254 let token_count = self.count_tokens(text)?;
255 if token_count > self.max_tokens {
256 return Err(LlmError::token_limit_exceeded(
257 token_count as usize,
258 self.max_tokens as usize,
259 ));
260 }
261 Ok(())
262 }
263
264 fn truncate_to_limit(&self, text: &str, max_tokens: u32) -> LlmResult<String> {
265 let tokens = self.tokenizer.encode_with_special_tokens(text);
266
267 if tokens.len() <= max_tokens as usize {
268 return Ok(text.to_string());
269 }
270
271 // log_debug!(
272 // original_tokens = tokens.len(),
273 // max_tokens = max_tokens,
274 // "Truncating text to fit token limit"
275 // );
276
277 let truncated_tokens = &tokens[..max_tokens as usize];
278 let truncated_text = self
279 .tokenizer
280 .decode(truncated_tokens.to_vec())
281 .map_err(|e| {
282 LlmError::response_parsing_error(format!(
283 "Failed to decode truncated tokens: {}",
284 e
285 ))
286 })?;
287
288 Ok(truncated_text)
289 }
290}
291
292impl OpenAITokenCounter {
293 fn count_single_message_tokens(&self, message: &serde_json::Value) -> u32 {
294 let role = message
295 .get("role")
296 .and_then(|r| r.as_str())
297 .unwrap_or("user");
298 let content = message
299 .get("content")
300 .and_then(|c| c.as_str())
301 .unwrap_or("");
302
303 let mut tokens = 4u32; // Message formatting tokens
304 tokens += self.tokenizer.encode_with_special_tokens(role).len() as u32;
305 tokens += self.tokenizer.encode_with_special_tokens(content).len() as u32;
306 tokens += self.count_tool_call_tokens(message);
307
308 tokens
309 }
310
311 fn count_tool_call_tokens(&self, message: &serde_json::Value) -> u32 {
312 let Some(tool_calls) = message.get("tool_calls") else {
313 return 0;
314 };
315
316 let Some(calls_array) = tool_calls.as_array() else {
317 return 0;
318 };
319
320 calls_array
321 .iter()
322 .filter_map(|call| {
323 call.get("function")
324 .and_then(|f| f.get("arguments"))
325 .and_then(|a| a.as_str())
326 })
327 .map(|args_str| self.tokenizer.encode_with_special_tokens(args_str).len() as u32)
328 .sum()
329 }
330}
331
332/// Token counter for Anthropic Claude models.
333///
334/// Uses cl100k_base tokenizer with a 1.1x approximation factor, since
335/// Claude's actual tokenizer isn't publicly available. This provides
336/// conservative estimates (slightly over-counting).
337///
338/// # Context Windows
339///
340/// | Model | Context Window |
341/// |-------|---------------|
342/// | claude-3-5-sonnet | 200K |
343/// | claude-3-opus | 200K |
344/// | claude-3-haiku | 200K |
345/// | claude-2.x | 100K |
346///
347/// # Example
348///
349/// ```rust,no_run
350/// use multi_llm::{AnthropicTokenCounter, TokenCounter};
351///
352/// let counter = AnthropicTokenCounter::new("claude-3-5-sonnet-20241022")?;
353/// let tokens = counter.count_tokens("Hello, world!")?;
354/// # Ok::<(), multi_llm::LlmError>(())
355/// ```
356///
357/// # Accuracy Note
358///
359/// Token counts are approximate. The 1.1x factor provides a safety margin
360/// to avoid accidentally exceeding context limits.
361pub struct AnthropicTokenCounter {
362 tokenizer: CoreBPE,
363 max_tokens: u32,
364}
365
366impl std::fmt::Debug for AnthropicTokenCounter {
367 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
368 f.debug_struct("AnthropicTokenCounter")
369 .field("max_tokens", &self.max_tokens)
370 .finish()
371 }
372}
373
374impl AnthropicTokenCounter {
375 /// Create token counter for Anthropic Claude models
376 pub fn new(model: &str) -> LlmResult<Self> {
377 // log_debug!(model = %model, "Creating Anthropic token counter");
378
379 let max_tokens = match model {
380 m if m.contains("claude-3-5-sonnet") => 200000,
381 m if m.contains("claude-3") => 200000,
382 m if m.contains("claude-2") => 100000,
383 _ => {
384 log_warn!(model = %model, "Unknown Anthropic model, using 100k context");
385 100000
386 }
387 };
388
389 // Use cl100k_base as approximation for Claude tokenization
390 let tokenizer = cl100k_base().map_err(|e| {
391 LlmError::configuration_error(format!(
392 "Failed to initialize Anthropic tokenizer: {}",
393 e
394 ))
395 })?;
396
397 Ok(Self {
398 tokenizer,
399 max_tokens,
400 })
401 }
402}
403
404impl TokenCounter for AnthropicTokenCounter {
405 fn count_tokens(&self, text: &str) -> LlmResult<u32> {
406 let tokens = self.tokenizer.encode_with_special_tokens(text);
407 // Apply approximation factor for Claude tokenization differences
408 Ok((tokens.len() as f32 * 1.1) as u32)
409 }
410
411 fn count_message_tokens(&self, messages: &[serde_json::Value]) -> LlmResult<u32> {
412 let mut total_tokens = 0u32;
413
414 for message in messages {
415 let content = message
416 .get("content")
417 .and_then(|c| c.as_str())
418 .unwrap_or("");
419
420 let content_tokens = self.count_tokens(content)?;
421 total_tokens += content_tokens;
422 total_tokens += 10; // Overhead for role and formatting
423 }
424
425 log_debug!(
426 total_tokens = total_tokens,
427 message_count = messages.len(),
428 "Calculated Anthropic message token count"
429 );
430
431 Ok(total_tokens)
432 }
433
434 fn max_context_tokens(&self) -> u32 {
435 self.max_tokens
436 }
437
438 fn validate_token_limit(&self, text: &str) -> LlmResult<()> {
439 let token_count = self.count_tokens(text)?;
440 if token_count > self.max_tokens {
441 return Err(LlmError::token_limit_exceeded(
442 token_count as usize,
443 self.max_tokens as usize,
444 ));
445 }
446 Ok(())
447 }
448
449 fn truncate_to_limit(&self, text: &str, max_tokens: u32) -> LlmResult<String> {
450 let tokens = self.tokenizer.encode_with_special_tokens(text);
451 let adjusted_limit = (max_tokens as f32 / 1.1) as usize; // Account for approximation factor
452
453 if tokens.len() <= adjusted_limit {
454 return Ok(text.to_string());
455 }
456
457 log_debug!(
458 original_tokens = tokens.len(),
459 max_tokens = max_tokens,
460 adjusted_limit = adjusted_limit,
461 "Truncating Anthropic text to fit token limit"
462 );
463
464 let truncated_tokens = &tokens[..adjusted_limit];
465 let truncated_text = self
466 .tokenizer
467 .decode(truncated_tokens.to_vec())
468 .map_err(|e| {
469 LlmError::response_parsing_error(format!(
470 "Failed to decode truncated tokens: {}",
471 e
472 ))
473 })?;
474
475 Ok(truncated_text)
476 }
477}
478
479/// Factory for creating token counters for different providers.
480///
481/// Use this factory to get the appropriate token counter for your provider
482/// and model. The factory handles selecting the correct tokenizer and
483/// context window size.
484///
485/// # Example
486///
487/// ```rust,no_run
488/// use multi_llm::{TokenCounterFactory, TokenCounter};
489///
490/// // Create counter for OpenAI
491/// let openai = TokenCounterFactory::create_counter("openai", "gpt-4")?;
492///
493/// // Create counter for Anthropic
494/// let anthropic = TokenCounterFactory::create_counter("anthropic", "claude-3-5-sonnet")?;
495///
496/// // Create counter with custom limit
497/// let custom = TokenCounterFactory::create_counter_with_limit("openai", "gpt-4", 4096)?;
498/// # Ok::<(), multi_llm::LlmError>(())
499/// ```
500///
501/// # Supported Providers
502///
503/// - `openai`: Uses tiktoken for exact counts
504/// - `anthropic`: Uses approximation with safety margin
505/// - `ollama`: Uses cl100k_base (approximation)
506/// - `lmstudio`: Uses cl100k_base (approximation)
507pub struct TokenCounterFactory;
508
509impl TokenCounterFactory {
510 /// Create token counter for specific provider and model
511 pub fn create_counter(provider: &str, model: &str) -> LlmResult<Arc<dyn TokenCounter>> {
512 match provider.to_lowercase().as_str() {
513 "openai" => {
514 let counter = OpenAITokenCounter::new(model)?;
515 Ok(Arc::new(counter))
516 }
517 "lmstudio" => {
518 // Default to 4k context for local models, but this should be configurable
519 let counter = OpenAITokenCounter::for_lm_studio(4096)?;
520 Ok(Arc::new(counter))
521 }
522 "ollama" => {
523 // Default to 4k context for Ollama models, but this should be configurable
524 let counter = OpenAITokenCounter::for_lm_studio(4096)?;
525 Ok(Arc::new(counter))
526 }
527 "anthropic" => {
528 let counter = AnthropicTokenCounter::new(model)?;
529 Ok(Arc::new(counter))
530 }
531 _ => Err(LlmError::unsupported_provider(provider)),
532 }
533 }
534
535 /// Create counter with custom context window size
536 pub fn create_counter_with_limit(
537 provider: &str,
538 model: &str,
539 max_tokens: u32,
540 ) -> LlmResult<Arc<dyn TokenCounter>> {
541 match provider.to_lowercase().as_str() {
542 "openai" => {
543 let mut counter = OpenAITokenCounter::new(model)?;
544 counter.max_tokens = max_tokens;
545 Ok(Arc::new(counter))
546 }
547 "lmstudio" => {
548 let counter = OpenAITokenCounter::for_lm_studio(max_tokens)?;
549 Ok(Arc::new(counter))
550 }
551 "ollama" => {
552 let counter = OpenAITokenCounter::for_lm_studio(max_tokens)?;
553 Ok(Arc::new(counter))
554 }
555 "anthropic" => {
556 let mut counter = AnthropicTokenCounter::new(model)?;
557 counter.max_tokens = max_tokens;
558 Ok(Arc::new(counter))
559 }
560 _ => Err(LlmError::unsupported_provider(provider)),
561 }
562 }
563}