multi_llm/internals/
tokens.rs

1//! Token counting utilities for LLM providers.
2//!
3//! This module provides token counting implementations for different LLM providers.
4//! Accurate token counting is important for:
5//! - Staying within context window limits
6//! - Estimating API costs
7//! - Optimizing prompts
8//!
9//! # Usage
10//!
11//! Use [`TokenCounterFactory`] to create counters for specific providers:
12//!
13//! ```rust,no_run
14//! use multi_llm::{TokenCounterFactory, TokenCounter};
15//!
16//! // Create counter for OpenAI GPT-4
17//! let counter = TokenCounterFactory::create_counter("openai", "gpt-4")?;
18//!
19//! // Count tokens in text
20//! let tokens = counter.count_tokens("Hello, world!")?;
21//! println!("Token count: {}", tokens);
22//!
23//! // Check context limit
24//! let max = counter.max_context_tokens();
25//! println!("Max context: {} tokens", max);
26//! # Ok::<(), multi_llm::LlmError>(())
27//! ```
28//!
29//! # Provider-Specific Notes
30//!
31//! - **OpenAI**: Uses tiktoken with exact tokenization
32//! - **Anthropic**: Uses cl100k_base with 1.1x approximation factor
33//! - **Ollama/LM Studio**: Uses cl100k_base (may vary by model)
34//!
35//! # Available Types
36//!
37//! - [`TokenCounter`]: Trait for all token counters
38//! - [`OpenAITokenCounter`]: OpenAI GPT model tokenizer
39//! - [`AnthropicTokenCounter`]: Anthropic Claude approximation
40//! - [`TokenCounterFactory`]: Factory for creating counters
41
42use crate::error::{LlmError, LlmResult};
43use crate::logging::{log_debug, log_warn};
44
45use std::sync::Arc;
46use tiktoken_rs::{cl100k_base, o200k_base, CoreBPE};
47
48/// Trait for counting tokens in text and messages.
49///
50/// Implement this trait to add support for new tokenizers.
51/// Use [`TokenCounterFactory`] to create instances for supported providers.
52///
53/// # Example
54///
55/// ```rust,no_run
56/// use multi_llm::{TokenCounter, TokenCounterFactory};
57///
58/// # fn example() -> multi_llm::LlmResult<()> {
59/// let counter = TokenCounterFactory::create_counter("openai", "gpt-4")?;
60///
61/// // Count tokens
62/// let count = counter.count_tokens("Hello, world!")?;
63///
64/// // Validate against limit
65/// counter.validate_token_limit("Some text...")?;
66///
67/// // Truncate if needed
68/// let truncated = counter.truncate_to_limit("Very long text...", 100)?;
69/// # Ok(())
70/// # }
71/// ```
72pub trait TokenCounter: Send + Sync + std::fmt::Debug {
73    /// Count tokens in a text string.
74    ///
75    /// # Errors
76    ///
77    /// Returns [`LlmError::ConfigurationError`] if the tokenizer
78    /// fails to encode the text.
79    fn count_tokens(&self, text: &str) -> LlmResult<u32>;
80
81    /// Count tokens in a list of messages (includes formatting overhead).
82    ///
83    /// The count includes tokens for role markers, message separators,
84    /// and other provider-specific formatting.
85    fn count_message_tokens(&self, messages: &[serde_json::Value]) -> LlmResult<u32>;
86
87    /// Get the maximum context window size for this tokenizer.
88    fn max_context_tokens(&self) -> u32;
89
90    /// Validate that text doesn't exceed the token limit.
91    ///
92    /// # Errors
93    ///
94    /// Returns [`LlmError::TokenLimitExceeded`] if the text exceeds
95    /// the maximum context window.
96    fn validate_token_limit(&self, text: &str) -> LlmResult<()>;
97
98    /// Truncate text to fit within a token limit.
99    ///
100    /// If the text already fits, it's returned unchanged.
101    fn truncate_to_limit(&self, text: &str, max_tokens: u32) -> LlmResult<String>;
102}
103
104/// Token counter for OpenAI GPT models using tiktoken.
105///
106/// Provides exact token counts for OpenAI models. Automatically selects
107/// the correct tokenizer based on the model name.
108///
109/// # Supported Models
110///
111/// | Model | Tokenizer | Context Window |
112/// |-------|-----------|---------------|
113/// | gpt-4-turbo | cl100k_base | 128K |
114/// | gpt-4 | cl100k_base | 8K |
115/// | gpt-3.5-turbo | cl100k_base | 16K |
116/// | o1-* | o200k_base | 200K |
117///
118/// # Example
119///
120/// ```rust,no_run
121/// use multi_llm::{OpenAITokenCounter, TokenCounter};
122///
123/// let counter = OpenAITokenCounter::new("gpt-4")?;
124/// let tokens = counter.count_tokens("Hello, world!")?;
125/// # Ok::<(), multi_llm::LlmError>(())
126/// ```
127pub struct OpenAITokenCounter {
128    tokenizer: CoreBPE,
129    max_tokens: u32,
130    model_name: String,
131}
132
133impl std::fmt::Debug for OpenAITokenCounter {
134    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135        f.debug_struct("OpenAITokenCounter")
136            .field("max_tokens", &self.max_tokens)
137            .field("model_name", &self.model_name)
138            .finish()
139    }
140}
141
142impl OpenAITokenCounter {
143    /// Determine max tokens for GPT-4 model
144    fn gpt4_max_tokens(model: &str) -> u32 {
145        if model.contains("turbo") || model.contains("preview") {
146            128000
147        } else if model.contains("32k") {
148            32768
149        } else {
150            8192
151        }
152    }
153
154    /// Determine max tokens for GPT-3.5 model
155    fn gpt35_max_tokens(model: &str) -> u32 {
156        if model.contains("16k") {
157            16384
158        } else {
159            4096
160        }
161    }
162
163    /// Get tokenizer and max tokens for model
164    fn get_model_config(model: &str) -> LlmResult<(CoreBPE, u32)> {
165        match model {
166            m if m.starts_with("gpt-4") => {
167                let tokenizer = cl100k_base().map_err(|e| {
168                    LlmError::configuration_error(format!("Failed to initialize tokenizer: {}", e))
169                })?;
170                Ok((tokenizer, Self::gpt4_max_tokens(m)))
171            }
172            m if m.starts_with("gpt-3.5") => {
173                let tokenizer = cl100k_base().map_err(|e| {
174                    LlmError::configuration_error(format!("Failed to initialize tokenizer: {}", e))
175                })?;
176                Ok((tokenizer, Self::gpt35_max_tokens(m)))
177            }
178            m if m.starts_with("o1") => {
179                let tokenizer = o200k_base().map_err(|e| {
180                    LlmError::configuration_error(format!("Failed to initialize tokenizer: {}", e))
181                })?;
182                Ok((tokenizer, 200000))
183            }
184            _ => {
185                log_warn!(model = %model, "Unknown model, using cl100k_base tokenizer with 4k context");
186                let tokenizer = cl100k_base().map_err(|e| {
187                    LlmError::configuration_error(format!("Failed to initialize tokenizer: {}", e))
188                })?;
189                Ok((tokenizer, 4096))
190            }
191        }
192    }
193
194    /// Create token counter for specific OpenAI model
195    pub fn new(model: &str) -> LlmResult<Self> {
196        let (tokenizer, max_tokens) = Self::get_model_config(model)?;
197
198        Ok(Self {
199            tokenizer,
200            max_tokens,
201            model_name: model.to_string(),
202        })
203    }
204
205    /// Create token counter for LM Studio (uses cl100k_base as default)
206    pub fn for_lm_studio(max_tokens: u32) -> LlmResult<Self> {
207        // log_debug!(max_tokens = max_tokens, "Creating LM Studio token counter");
208
209        let tokenizer = cl100k_base().map_err(|e| {
210            LlmError::configuration_error(format!(
211                "Failed to initialize LM Studio tokenizer: {}",
212                e
213            ))
214        })?;
215
216        Ok(Self {
217            tokenizer,
218            max_tokens,
219            model_name: "lm-studio".to_string(),
220        })
221    }
222}
223
224impl TokenCounter for OpenAITokenCounter {
225    fn count_tokens(&self, text: &str) -> LlmResult<u32> {
226        let tokens = self.tokenizer.encode_with_special_tokens(text);
227        Ok(tokens.len() as u32)
228    }
229
230    fn count_message_tokens(&self, messages: &[serde_json::Value]) -> LlmResult<u32> {
231        let mut total_tokens = 3u32; // Base conversation formatting
232
233        for message in messages {
234            total_tokens += self.count_single_message_tokens(message);
235        }
236
237        total_tokens += 3; // Reply end tokens
238
239        log_debug!(
240            total_tokens = total_tokens,
241            message_count = messages.len(),
242            model = %self.model_name,
243            "Calculated message token count"
244        );
245
246        Ok(total_tokens)
247    }
248
249    fn max_context_tokens(&self) -> u32 {
250        self.max_tokens
251    }
252
253    fn validate_token_limit(&self, text: &str) -> LlmResult<()> {
254        let token_count = self.count_tokens(text)?;
255        if token_count > self.max_tokens {
256            return Err(LlmError::token_limit_exceeded(
257                token_count as usize,
258                self.max_tokens as usize,
259            ));
260        }
261        Ok(())
262    }
263
264    fn truncate_to_limit(&self, text: &str, max_tokens: u32) -> LlmResult<String> {
265        let tokens = self.tokenizer.encode_with_special_tokens(text);
266
267        if tokens.len() <= max_tokens as usize {
268            return Ok(text.to_string());
269        }
270
271        // log_debug!(
272        //     original_tokens = tokens.len(),
273        //     max_tokens = max_tokens,
274        //     "Truncating text to fit token limit"
275        // );
276
277        let truncated_tokens = &tokens[..max_tokens as usize];
278        let truncated_text = self
279            .tokenizer
280            .decode(truncated_tokens.to_vec())
281            .map_err(|e| {
282                LlmError::response_parsing_error(format!(
283                    "Failed to decode truncated tokens: {}",
284                    e
285                ))
286            })?;
287
288        Ok(truncated_text)
289    }
290}
291
292impl OpenAITokenCounter {
293    fn count_single_message_tokens(&self, message: &serde_json::Value) -> u32 {
294        let role = message
295            .get("role")
296            .and_then(|r| r.as_str())
297            .unwrap_or("user");
298        let content = message
299            .get("content")
300            .and_then(|c| c.as_str())
301            .unwrap_or("");
302
303        let mut tokens = 4u32; // Message formatting tokens
304        tokens += self.tokenizer.encode_with_special_tokens(role).len() as u32;
305        tokens += self.tokenizer.encode_with_special_tokens(content).len() as u32;
306        tokens += self.count_tool_call_tokens(message);
307
308        tokens
309    }
310
311    fn count_tool_call_tokens(&self, message: &serde_json::Value) -> u32 {
312        let Some(tool_calls) = message.get("tool_calls") else {
313            return 0;
314        };
315
316        let Some(calls_array) = tool_calls.as_array() else {
317            return 0;
318        };
319
320        calls_array
321            .iter()
322            .filter_map(|call| {
323                call.get("function")
324                    .and_then(|f| f.get("arguments"))
325                    .and_then(|a| a.as_str())
326            })
327            .map(|args_str| self.tokenizer.encode_with_special_tokens(args_str).len() as u32)
328            .sum()
329    }
330}
331
332/// Token counter for Anthropic Claude models.
333///
334/// Uses cl100k_base tokenizer with a 1.1x approximation factor, since
335/// Claude's actual tokenizer isn't publicly available. This provides
336/// conservative estimates (slightly over-counting).
337///
338/// # Context Windows
339///
340/// | Model | Context Window |
341/// |-------|---------------|
342/// | claude-3-5-sonnet | 200K |
343/// | claude-3-opus | 200K |
344/// | claude-3-haiku | 200K |
345/// | claude-2.x | 100K |
346///
347/// # Example
348///
349/// ```rust,no_run
350/// use multi_llm::{AnthropicTokenCounter, TokenCounter};
351///
352/// let counter = AnthropicTokenCounter::new("claude-3-5-sonnet-20241022")?;
353/// let tokens = counter.count_tokens("Hello, world!")?;
354/// # Ok::<(), multi_llm::LlmError>(())
355/// ```
356///
357/// # Accuracy Note
358///
359/// Token counts are approximate. The 1.1x factor provides a safety margin
360/// to avoid accidentally exceeding context limits.
361pub struct AnthropicTokenCounter {
362    tokenizer: CoreBPE,
363    max_tokens: u32,
364}
365
366impl std::fmt::Debug for AnthropicTokenCounter {
367    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
368        f.debug_struct("AnthropicTokenCounter")
369            .field("max_tokens", &self.max_tokens)
370            .finish()
371    }
372}
373
374impl AnthropicTokenCounter {
375    /// Create token counter for Anthropic Claude models
376    pub fn new(model: &str) -> LlmResult<Self> {
377        // log_debug!(model = %model, "Creating Anthropic token counter");
378
379        let max_tokens = match model {
380            m if m.contains("claude-3-5-sonnet") => 200000,
381            m if m.contains("claude-3") => 200000,
382            m if m.contains("claude-2") => 100000,
383            _ => {
384                log_warn!(model = %model, "Unknown Anthropic model, using 100k context");
385                100000
386            }
387        };
388
389        // Use cl100k_base as approximation for Claude tokenization
390        let tokenizer = cl100k_base().map_err(|e| {
391            LlmError::configuration_error(format!(
392                "Failed to initialize Anthropic tokenizer: {}",
393                e
394            ))
395        })?;
396
397        Ok(Self {
398            tokenizer,
399            max_tokens,
400        })
401    }
402}
403
404impl TokenCounter for AnthropicTokenCounter {
405    fn count_tokens(&self, text: &str) -> LlmResult<u32> {
406        let tokens = self.tokenizer.encode_with_special_tokens(text);
407        // Apply approximation factor for Claude tokenization differences
408        Ok((tokens.len() as f32 * 1.1) as u32)
409    }
410
411    fn count_message_tokens(&self, messages: &[serde_json::Value]) -> LlmResult<u32> {
412        let mut total_tokens = 0u32;
413
414        for message in messages {
415            let content = message
416                .get("content")
417                .and_then(|c| c.as_str())
418                .unwrap_or("");
419
420            let content_tokens = self.count_tokens(content)?;
421            total_tokens += content_tokens;
422            total_tokens += 10; // Overhead for role and formatting
423        }
424
425        log_debug!(
426            total_tokens = total_tokens,
427            message_count = messages.len(),
428            "Calculated Anthropic message token count"
429        );
430
431        Ok(total_tokens)
432    }
433
434    fn max_context_tokens(&self) -> u32 {
435        self.max_tokens
436    }
437
438    fn validate_token_limit(&self, text: &str) -> LlmResult<()> {
439        let token_count = self.count_tokens(text)?;
440        if token_count > self.max_tokens {
441            return Err(LlmError::token_limit_exceeded(
442                token_count as usize,
443                self.max_tokens as usize,
444            ));
445        }
446        Ok(())
447    }
448
449    fn truncate_to_limit(&self, text: &str, max_tokens: u32) -> LlmResult<String> {
450        let tokens = self.tokenizer.encode_with_special_tokens(text);
451        let adjusted_limit = (max_tokens as f32 / 1.1) as usize; // Account for approximation factor
452
453        if tokens.len() <= adjusted_limit {
454            return Ok(text.to_string());
455        }
456
457        log_debug!(
458            original_tokens = tokens.len(),
459            max_tokens = max_tokens,
460            adjusted_limit = adjusted_limit,
461            "Truncating Anthropic text to fit token limit"
462        );
463
464        let truncated_tokens = &tokens[..adjusted_limit];
465        let truncated_text = self
466            .tokenizer
467            .decode(truncated_tokens.to_vec())
468            .map_err(|e| {
469                LlmError::response_parsing_error(format!(
470                    "Failed to decode truncated tokens: {}",
471                    e
472                ))
473            })?;
474
475        Ok(truncated_text)
476    }
477}
478
479/// Factory for creating token counters for different providers.
480///
481/// Use this factory to get the appropriate token counter for your provider
482/// and model. The factory handles selecting the correct tokenizer and
483/// context window size.
484///
485/// # Example
486///
487/// ```rust,no_run
488/// use multi_llm::{TokenCounterFactory, TokenCounter};
489///
490/// // Create counter for OpenAI
491/// let openai = TokenCounterFactory::create_counter("openai", "gpt-4")?;
492///
493/// // Create counter for Anthropic
494/// let anthropic = TokenCounterFactory::create_counter("anthropic", "claude-3-5-sonnet")?;
495///
496/// // Create counter with custom limit
497/// let custom = TokenCounterFactory::create_counter_with_limit("openai", "gpt-4", 4096)?;
498/// # Ok::<(), multi_llm::LlmError>(())
499/// ```
500///
501/// # Supported Providers
502///
503/// - `openai`: Uses tiktoken for exact counts
504/// - `anthropic`: Uses approximation with safety margin
505/// - `ollama`: Uses cl100k_base (approximation)
506/// - `lmstudio`: Uses cl100k_base (approximation)
507pub struct TokenCounterFactory;
508
509impl TokenCounterFactory {
510    /// Create token counter for specific provider and model
511    pub fn create_counter(provider: &str, model: &str) -> LlmResult<Arc<dyn TokenCounter>> {
512        match provider.to_lowercase().as_str() {
513            "openai" => {
514                let counter = OpenAITokenCounter::new(model)?;
515                Ok(Arc::new(counter))
516            }
517            "lmstudio" => {
518                // Default to 4k context for local models, but this should be configurable
519                let counter = OpenAITokenCounter::for_lm_studio(4096)?;
520                Ok(Arc::new(counter))
521            }
522            "ollama" => {
523                // Default to 4k context for Ollama models, but this should be configurable
524                let counter = OpenAITokenCounter::for_lm_studio(4096)?;
525                Ok(Arc::new(counter))
526            }
527            "anthropic" => {
528                let counter = AnthropicTokenCounter::new(model)?;
529                Ok(Arc::new(counter))
530            }
531            _ => Err(LlmError::unsupported_provider(provider)),
532        }
533    }
534
535    /// Create counter with custom context window size
536    pub fn create_counter_with_limit(
537        provider: &str,
538        model: &str,
539        max_tokens: u32,
540    ) -> LlmResult<Arc<dyn TokenCounter>> {
541        match provider.to_lowercase().as_str() {
542            "openai" => {
543                let mut counter = OpenAITokenCounter::new(model)?;
544                counter.max_tokens = max_tokens;
545                Ok(Arc::new(counter))
546            }
547            "lmstudio" => {
548                let counter = OpenAITokenCounter::for_lm_studio(max_tokens)?;
549                Ok(Arc::new(counter))
550            }
551            "ollama" => {
552                let counter = OpenAITokenCounter::for_lm_studio(max_tokens)?;
553                Ok(Arc::new(counter))
554            }
555            "anthropic" => {
556                let mut counter = AnthropicTokenCounter::new(model)?;
557                counter.max_tokens = max_tokens;
558                Ok(Arc::new(counter))
559            }
560            _ => Err(LlmError::unsupported_provider(provider)),
561        }
562    }
563}