Skip to main content

token_count/tokenizers/
registry.rs

1//! Model registry for managing supported models
2
3use crate::error::TokenError;
4use crate::tokenizers::{
5    claude::{claude_models, ClaudeTokenizer},
6    openai::OpenAITokenizer,
7    ModelInfo, Tokenizer,
8};
9use anyhow::Result;
10use std::collections::HashMap;
11use std::sync::OnceLock;
12
13/// Configuration for a specific model
14#[derive(Debug, Clone)]
15pub struct ModelConfig {
16    pub name: String,
17    pub encoding: String,
18    pub context_window: usize,
19    pub description: String,
20    pub aliases: Vec<String>,
21}
22
23/// Registry of all supported models
24pub struct ModelRegistry {
25    models: HashMap<String, ModelConfig>,
26    aliases: HashMap<String, String>, // alias → canonical name
27}
28
29impl ModelRegistry {
30    /// Create a new model registry with all supported models
31    pub fn new() -> Self {
32        let mut registry = Self { models: HashMap::new(), aliases: HashMap::new() };
33
34        // GPT-3.5-turbo (default)
35        registry.add_model(ModelConfig {
36            name: "gpt-3.5-turbo".to_string(),
37            encoding: "cl100k_base".to_string(),
38            context_window: 16385,
39            description: "GPT-3.5 Turbo (16K context)".to_string(),
40            aliases: vec![
41                "gpt-3.5".to_string(),
42                "gpt35".to_string(),
43                "gpt-35-turbo".to_string(),
44                "openai/gpt-3.5-turbo".to_string(),
45            ],
46        });
47
48        // GPT-4
49        registry.add_model(ModelConfig {
50            name: "gpt-4".to_string(),
51            encoding: "cl100k_base".to_string(),
52            context_window: 128000,
53            description: "GPT-4 (128K context)".to_string(),
54            aliases: vec!["gpt4".to_string(), "openai/gpt-4".to_string()],
55        });
56
57        // GPT-4-turbo
58        registry.add_model(ModelConfig {
59            name: "gpt-4-turbo".to_string(),
60            encoding: "cl100k_base".to_string(),
61            context_window: 128000,
62            description: "GPT-4 Turbo (128K context)".to_string(),
63            aliases: vec![
64                "gpt4-turbo".to_string(),
65                "gpt-4turbo".to_string(),
66                "openai/gpt-4-turbo".to_string(),
67            ],
68        });
69
70        // GPT-4o
71        registry.add_model(ModelConfig {
72            name: "gpt-4o".to_string(),
73            encoding: "o200k_base".to_string(),
74            context_window: 128000,
75            description: "GPT-4o (128K context)".to_string(),
76            aliases: vec!["gpt4o".to_string(), "openai/gpt-4o".to_string()],
77        });
78
79        // Claude models
80        for model in claude_models() {
81            registry.add_model(model);
82        }
83
84        registry
85    }
86
87    /// Get the global registry instance
88    pub fn global() -> &'static Self {
89        static REGISTRY: OnceLock<ModelRegistry> = OnceLock::new();
90        REGISTRY.get_or_init(Self::new)
91    }
92
93    /// Add a model to the registry
94    fn add_model(&mut self, config: ModelConfig) {
95        let canonical_name = config.name.clone();
96
97        // Add aliases
98        for alias in &config.aliases {
99            self.aliases.insert(alias.to_lowercase(), canonical_name.clone());
100        }
101
102        // Add model itself as an alias (case-insensitive)
103        self.aliases.insert(canonical_name.to_lowercase(), canonical_name.clone());
104
105        self.models.insert(canonical_name, config);
106    }
107
108    /// Resolve a model name (canonical or alias) to its canonical name
109    pub fn resolve_model_name(&self, name: &str) -> Result<String, TokenError> {
110        let normalized = name.trim().to_lowercase();
111
112        if let Some(canonical) = self.aliases.get(&normalized) {
113            return Ok(canonical.clone());
114        }
115
116        // Model not found - generate suggestions
117        let suggestion = self.generate_suggestions(&normalized);
118        Err(TokenError::UnknownModel { model: name.to_string(), suggestion })
119    }
120
121    /// Get a model configuration by name (canonical or alias)
122    pub fn get_model(&self, name: &str) -> Result<&ModelConfig, TokenError> {
123        let canonical = self.resolve_model_name(name)?;
124        Ok(self.models.get(&canonical).expect("Canonical name must exist"))
125    }
126
127    /// Create a tokenizer for the given model
128    ///
129    /// # Arguments
130    /// * `name` - Model name (canonical or alias)
131    /// * `use_accurate` - Whether to use accurate mode for models that support it (Claude API)
132    ///
133    /// # Returns
134    /// * `Ok(Box<dyn Tokenizer>)` - Tokenizer instance for the model
135    /// * `Err(TokenError)` - Model not found or tokenizer creation failed
136    pub fn get_tokenizer(
137        &self,
138        name: &str,
139        use_accurate: bool,
140    ) -> Result<Box<dyn Tokenizer>, TokenError> {
141        let config = self.get_model(name)?;
142
143        // Detect tokenizer type based on encoding
144        if config.encoding == "anthropic-claude" {
145            // Claude tokenizer (estimation or API)
146            let tokenizer = ClaudeTokenizer::new(config.clone(), use_accurate)?;
147            Ok(Box::new(tokenizer))
148        } else {
149            // OpenAI tokenizer (tiktoken)
150            let model_info = ModelInfo {
151                name: config.name.clone(),
152                encoding: config.encoding.clone(),
153                context_window: config.context_window,
154                description: config.description.clone(),
155            };
156
157            let tokenizer = OpenAITokenizer::new(&config.encoding, model_info)
158                .map_err(|e| TokenError::Tokenization(e.to_string()))?;
159
160            Ok(Box::new(tokenizer))
161        }
162    }
163
164    /// List all supported models
165    pub fn list_models(&self) -> Vec<&ModelConfig> {
166        let mut models: Vec<&ModelConfig> = self.models.values().collect();
167        models.sort_by(|a, b| a.name.cmp(&b.name));
168        models
169    }
170
171    /// Generate fuzzy suggestions for unknown model names
172    fn generate_suggestions(&self, name: &str) -> String {
173        let mut suggestions: Vec<(&str, usize)> = self
174            .models
175            .keys()
176            .map(|model_name| {
177                let distance = strsim::levenshtein(name, &model_name.to_lowercase());
178                (model_name.as_str(), distance)
179            })
180            .collect();
181
182        suggestions.sort_by_key(|&(_, dist)| dist);
183
184        let close_matches: Vec<&str> = suggestions
185            .iter()
186            .take(3)
187            .filter(|&&(_, dist)| dist <= 3)
188            .map(|&(name, _)| name)
189            .collect();
190
191        if close_matches.is_empty() {
192            "Use --list-models to see all supported models".to_string()
193        } else {
194            format!("Did you mean: {}?", close_matches.join(", "))
195        }
196    }
197}
198
199impl Default for ModelRegistry {
200    fn default() -> Self {
201        Self::new()
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208
209    #[test]
210    fn test_resolve_canonical_name() {
211        let registry = ModelRegistry::new();
212        assert_eq!(registry.resolve_model_name("gpt-4").unwrap(), "gpt-4");
213        assert_eq!(registry.resolve_model_name("GPT-4").unwrap(), "gpt-4");
214    }
215
216    #[test]
217    fn test_resolve_alias() {
218        let registry = ModelRegistry::new();
219        assert_eq!(registry.resolve_model_name("gpt4").unwrap(), "gpt-4");
220        assert_eq!(registry.resolve_model_name("gpt35").unwrap(), "gpt-3.5-turbo");
221    }
222
223    #[test]
224    fn test_unknown_model() {
225        let registry = ModelRegistry::new();
226        let result = registry.resolve_model_name("gpt-5");
227        assert!(result.is_err());
228        assert!(result.unwrap_err().to_string().contains("gpt"));
229    }
230
231    #[test]
232    fn test_list_models() {
233        let registry = ModelRegistry::new();
234        let models = registry.list_models();
235        assert_eq!(models.len(), 7); // 4 OpenAI + 3 Claude
236        assert!(models.iter().any(|m| m.name == "gpt-3.5-turbo"));
237        assert!(models.iter().any(|m| m.name == "gpt-4"));
238        assert!(models.iter().any(|m| m.name == "gpt-4-turbo"));
239        assert!(models.iter().any(|m| m.name == "gpt-4o"));
240        assert!(models.iter().any(|m| m.name == "claude-opus-4-6"));
241        assert!(models.iter().any(|m| m.name == "claude-sonnet-4-6"));
242        assert!(models.iter().any(|m| m.name == "claude-haiku-4-5"));
243    }
244
245    #[test]
246    fn test_get_tokenizer() {
247        let registry = ModelRegistry::new();
248        let tokenizer = registry.get_tokenizer("gpt-4", false).unwrap();
249        let count = tokenizer.count_tokens("Hello world").unwrap();
250        assert_eq!(count, 2);
251    }
252
253    #[test]
254    fn test_get_claude_tokenizer() {
255        let registry = ModelRegistry::new();
256        let tokenizer = registry.get_tokenizer("claude-sonnet-4-6", false).unwrap();
257        let count = tokenizer.count_tokens("Hello world").unwrap();
258        // Estimation mode: "Hello world" = 11 chars
259        // Prose detection: ratio < 5% → 4.5 chars/token → 11/4.5 = 2.44 → ceil = 3 tokens
260        assert_eq!(count, 3);
261    }
262
263    #[test]
264    fn test_claude_alias_resolution() {
265        let registry = ModelRegistry::new();
266        assert_eq!(registry.resolve_model_name("claude").unwrap(), "claude-sonnet-4-6");
267        assert_eq!(registry.resolve_model_name("sonnet").unwrap(), "claude-sonnet-4-6");
268        assert_eq!(registry.resolve_model_name("opus").unwrap(), "claude-opus-4-6");
269        assert_eq!(registry.resolve_model_name("haiku").unwrap(), "claude-haiku-4-5");
270    }
271
272    #[test]
273    fn test_fuzzy_suggestions() {
274        let registry = ModelRegistry::new();
275        let result = registry.resolve_model_name("gpt4-tubro");
276        assert!(result.is_err());
277        let err = result.unwrap_err();
278        assert!(err.to_string().contains("Did you mean"));
279        assert!(err.to_string().contains("gpt-4-turbo"));
280    }
281}