Skip to main content

token_count/tokenizers/
registry.rs

1//! Model registry for managing supported models
2
3use crate::error::TokenError;
4use crate::tokenizers::{
5    claude::{claude_models, ClaudeTokenizer},
6    google::{google_models, GoogleTokenizer},
7    openai::OpenAITokenizer,
8    ModelInfo, Tokenizer,
9};
10use anyhow::Result;
11use std::collections::HashMap;
12use std::sync::OnceLock;
13
14/// Configuration for a specific model
15#[derive(Debug, Clone)]
16pub struct ModelConfig {
17    pub name: String,
18    pub encoding: String,
19    pub context_window: usize,
20    pub description: String,
21    pub aliases: Vec<String>,
22}
23
24/// Registry of all supported models
25pub struct ModelRegistry {
26    models: HashMap<String, ModelConfig>,
27    aliases: HashMap<String, String>, // alias → canonical name
28}
29
30impl ModelRegistry {
31    /// Create a new model registry with all supported models
32    pub fn new() -> Self {
33        let mut registry = Self { models: HashMap::new(), aliases: HashMap::new() };
34
35        // GPT-3.5-turbo (default)
36        registry.add_model(ModelConfig {
37            name: "gpt-3.5-turbo".to_string(),
38            encoding: "cl100k_base".to_string(),
39            context_window: 16385,
40            description: "GPT-3.5 Turbo (16K context)".to_string(),
41            aliases: vec![
42                "gpt-3.5".to_string(),
43                "gpt35".to_string(),
44                "gpt-35-turbo".to_string(),
45                "openai/gpt-3.5-turbo".to_string(),
46            ],
47        });
48
49        // GPT-4
50        registry.add_model(ModelConfig {
51            name: "gpt-4".to_string(),
52            encoding: "cl100k_base".to_string(),
53            context_window: 128000,
54            description: "GPT-4 (128K context)".to_string(),
55            aliases: vec!["gpt4".to_string(), "openai/gpt-4".to_string()],
56        });
57
58        // GPT-4-turbo
59        registry.add_model(ModelConfig {
60            name: "gpt-4-turbo".to_string(),
61            encoding: "cl100k_base".to_string(),
62            context_window: 128000,
63            description: "GPT-4 Turbo (128K context)".to_string(),
64            aliases: vec![
65                "gpt4-turbo".to_string(),
66                "gpt-4turbo".to_string(),
67                "openai/gpt-4-turbo".to_string(),
68            ],
69        });
70
71        // GPT-4o
72        registry.add_model(ModelConfig {
73            name: "gpt-4o".to_string(),
74            encoding: "o200k_base".to_string(),
75            context_window: 128000,
76            description: "GPT-4o (128K context)".to_string(),
77            aliases: vec!["gpt4o".to_string(), "openai/gpt-4o".to_string()],
78        });
79
80        // Claude models
81        for model in claude_models() {
82            registry.add_model(model);
83        }
84
85        // Google Gemini models
86        for model in google_models() {
87            registry.add_model(model);
88        }
89
90        registry
91    }
92
93    /// Get the global registry instance
94    pub fn global() -> &'static Self {
95        static REGISTRY: OnceLock<ModelRegistry> = OnceLock::new();
96        REGISTRY.get_or_init(Self::new)
97    }
98
99    /// Add a model to the registry
100    fn add_model(&mut self, config: ModelConfig) {
101        let canonical_name = config.name.clone();
102
103        // Add aliases
104        for alias in &config.aliases {
105            self.aliases.insert(alias.to_lowercase(), canonical_name.clone());
106        }
107
108        // Add model itself as an alias (case-insensitive)
109        self.aliases.insert(canonical_name.to_lowercase(), canonical_name.clone());
110
111        self.models.insert(canonical_name, config);
112    }
113
114    /// Resolve a model name (canonical or alias) to its canonical name
115    pub fn resolve_model_name(&self, name: &str) -> Result<String, TokenError> {
116        let normalized = name.trim().to_lowercase();
117
118        if let Some(canonical) = self.aliases.get(&normalized) {
119            return Ok(canonical.clone());
120        }
121
122        // Model not found - generate suggestions
123        let suggestion = self.generate_suggestions(&normalized);
124        Err(TokenError::UnknownModel { model: name.to_string(), suggestion })
125    }
126
127    /// Get a model configuration by name (canonical or alias)
128    pub fn get_model(&self, name: &str) -> Result<&ModelConfig, TokenError> {
129        let canonical = self.resolve_model_name(name)?;
130        Ok(self.models.get(&canonical).expect("Canonical name must exist"))
131    }
132
133    /// Create a tokenizer for the given model
134    ///
135    /// # Arguments
136    /// * `name` - Model name (canonical or alias)
137    /// * `use_accurate` - Whether to use accurate mode for models that support it (Claude API)
138    ///
139    /// # Returns
140    /// * `Ok(Box<dyn Tokenizer>)` - Tokenizer instance for the model
141    /// * `Err(TokenError)` - Model not found or tokenizer creation failed
142    pub fn get_tokenizer(
143        &self,
144        name: &str,
145        use_accurate: bool,
146    ) -> Result<Box<dyn Tokenizer>, TokenError> {
147        let config = self.get_model(name)?;
148
149        // Detect tokenizer type based on encoding
150        match config.encoding.as_str() {
151            "anthropic-claude" => {
152                // Claude tokenizer (estimation or API)
153                let tokenizer = ClaudeTokenizer::new(config.clone(), use_accurate)?;
154                Ok(Box::new(tokenizer))
155            }
156            "gemini-gemma3" => {
157                // Google Gemini tokenizer
158                let tokenizer = GoogleTokenizer::new(config.clone())?;
159                Ok(Box::new(tokenizer))
160            }
161            _ => {
162                // OpenAI tokenizer (tiktoken)
163                let model_info = ModelInfo {
164                    name: config.name.clone(),
165                    encoding: config.encoding.clone(),
166                    context_window: config.context_window,
167                    description: config.description.clone(),
168                };
169
170                let tokenizer = OpenAITokenizer::new(&config.encoding, model_info)
171                    .map_err(|e| TokenError::Tokenization(e.to_string()))?;
172
173                Ok(Box::new(tokenizer))
174            }
175        }
176    }
177
178    /// List all supported models
179    pub fn list_models(&self) -> Vec<&ModelConfig> {
180        let mut models: Vec<&ModelConfig> = self.models.values().collect();
181        models.sort_by(|a, b| a.name.cmp(&b.name));
182        models
183    }
184
185    /// Generate fuzzy suggestions for unknown model names
186    fn generate_suggestions(&self, name: &str) -> String {
187        let mut suggestions: Vec<(&str, usize)> = self
188            .models
189            .keys()
190            .map(|model_name| {
191                let distance = strsim::levenshtein(name, &model_name.to_lowercase());
192                (model_name.as_str(), distance)
193            })
194            .collect();
195
196        suggestions.sort_by_key(|&(_, dist)| dist);
197
198        let close_matches: Vec<&str> = suggestions
199            .iter()
200            .take(3)
201            .filter(|&&(_, dist)| dist <= 3)
202            .map(|&(name, _)| name)
203            .collect();
204
205        if close_matches.is_empty() {
206            "Use --list-models to see all supported models".to_string()
207        } else {
208            format!("Did you mean: {}?", close_matches.join(", "))
209        }
210    }
211}
212
213impl Default for ModelRegistry {
214    fn default() -> Self {
215        Self::new()
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222
223    #[test]
224    fn test_resolve_canonical_name() {
225        let registry = ModelRegistry::new();
226        assert_eq!(registry.resolve_model_name("gpt-4").unwrap(), "gpt-4");
227        assert_eq!(registry.resolve_model_name("GPT-4").unwrap(), "gpt-4");
228    }
229
230    #[test]
231    fn test_resolve_alias() {
232        let registry = ModelRegistry::new();
233        assert_eq!(registry.resolve_model_name("gpt4").unwrap(), "gpt-4");
234        assert_eq!(registry.resolve_model_name("gpt35").unwrap(), "gpt-3.5-turbo");
235    }
236
237    #[test]
238    fn test_unknown_model() {
239        let registry = ModelRegistry::new();
240        let result = registry.resolve_model_name("gpt-5");
241        assert!(result.is_err());
242        assert!(result.unwrap_err().to_string().contains("gpt"));
243    }
244
245    #[test]
246    fn test_list_models() {
247        let registry = ModelRegistry::new();
248        let models = registry.list_models();
249        assert_eq!(models.len(), 11); // 4 OpenAI + 3 Claude + 4 Gemini
250        assert!(models.iter().any(|m| m.name == "gpt-3.5-turbo"));
251        assert!(models.iter().any(|m| m.name == "gpt-4"));
252        assert!(models.iter().any(|m| m.name == "gpt-4-turbo"));
253        assert!(models.iter().any(|m| m.name == "gpt-4o"));
254        assert!(models.iter().any(|m| m.name == "claude-opus-4-6"));
255        assert!(models.iter().any(|m| m.name == "claude-sonnet-4-6"));
256        assert!(models.iter().any(|m| m.name == "claude-haiku-4-5"));
257        assert!(models.iter().any(|m| m.name == "gemini-2.5-pro"));
258        assert!(models.iter().any(|m| m.name == "gemini-2.5-flash"));
259        assert!(models.iter().any(|m| m.name == "gemini-2.5-flash-lite"));
260        assert!(models.iter().any(|m| m.name == "gemini-3-pro-preview"));
261    }
262
263    #[test]
264    fn test_get_tokenizer() {
265        let registry = ModelRegistry::new();
266        let tokenizer = registry.get_tokenizer("gpt-4", false).unwrap();
267        let count = tokenizer.count_tokens("Hello world").unwrap();
268        assert_eq!(count, 2);
269    }
270
271    #[test]
272    fn test_get_claude_tokenizer() {
273        let registry = ModelRegistry::new();
274        let tokenizer = registry.get_tokenizer("claude-sonnet-4-6", false).unwrap();
275        let count = tokenizer.count_tokens("Hello world").unwrap();
276        // Estimation mode: "Hello world" = 11 chars
277        // Prose detection: ratio < 5% → 4.5 chars/token → 11/4.5 = 2.44 → ceil = 3 tokens
278        assert_eq!(count, 3);
279    }
280
281    #[test]
282    fn test_claude_alias_resolution() {
283        let registry = ModelRegistry::new();
284        assert_eq!(registry.resolve_model_name("claude").unwrap(), "claude-sonnet-4-6");
285        assert_eq!(registry.resolve_model_name("sonnet").unwrap(), "claude-sonnet-4-6");
286        assert_eq!(registry.resolve_model_name("opus").unwrap(), "claude-opus-4-6");
287        assert_eq!(registry.resolve_model_name("haiku").unwrap(), "claude-haiku-4-5");
288    }
289
290    #[test]
291    fn test_fuzzy_suggestions() {
292        let registry = ModelRegistry::new();
293        let result = registry.resolve_model_name("gpt4-tubro");
294        assert!(result.is_err());
295        let err = result.unwrap_err();
296        assert!(err.to_string().contains("Did you mean"));
297        assert!(err.to_string().contains("gpt-4-turbo"));
298    }
299}