1use crate::error::TokenError;
4use crate::tokenizers::{
5 claude::{claude_models, ClaudeTokenizer},
6 google::{google_models, GoogleTokenizer},
7 openai::OpenAITokenizer,
8 ModelInfo, Tokenizer,
9};
10use anyhow::Result;
11use std::collections::HashMap;
12use std::sync::OnceLock;
13
14#[derive(Debug, Clone)]
16pub struct ModelConfig {
17 pub name: String,
18 pub encoding: String,
19 pub context_window: usize,
20 pub description: String,
21 pub aliases: Vec<String>,
22}
23
24pub struct ModelRegistry {
26 models: HashMap<String, ModelConfig>,
27 aliases: HashMap<String, String>, }
29
30impl ModelRegistry {
31 pub fn new() -> Self {
33 let mut registry = Self { models: HashMap::new(), aliases: HashMap::new() };
34
35 registry.add_model(ModelConfig {
37 name: "gpt-3.5-turbo".to_string(),
38 encoding: "cl100k_base".to_string(),
39 context_window: 16385,
40 description: "GPT-3.5 Turbo (16K context)".to_string(),
41 aliases: vec![
42 "gpt-3.5".to_string(),
43 "gpt35".to_string(),
44 "gpt-35-turbo".to_string(),
45 "openai/gpt-3.5-turbo".to_string(),
46 ],
47 });
48
49 registry.add_model(ModelConfig {
51 name: "gpt-4".to_string(),
52 encoding: "cl100k_base".to_string(),
53 context_window: 128000,
54 description: "GPT-4 (128K context)".to_string(),
55 aliases: vec!["gpt4".to_string(), "openai/gpt-4".to_string()],
56 });
57
58 registry.add_model(ModelConfig {
60 name: "gpt-4-turbo".to_string(),
61 encoding: "cl100k_base".to_string(),
62 context_window: 128000,
63 description: "GPT-4 Turbo (128K context)".to_string(),
64 aliases: vec![
65 "gpt4-turbo".to_string(),
66 "gpt-4turbo".to_string(),
67 "openai/gpt-4-turbo".to_string(),
68 ],
69 });
70
71 registry.add_model(ModelConfig {
73 name: "gpt-4o".to_string(),
74 encoding: "o200k_base".to_string(),
75 context_window: 128000,
76 description: "GPT-4o (128K context)".to_string(),
77 aliases: vec!["gpt4o".to_string(), "openai/gpt-4o".to_string()],
78 });
79
80 for model in claude_models() {
82 registry.add_model(model);
83 }
84
85 for model in google_models() {
87 registry.add_model(model);
88 }
89
90 registry
91 }
92
93 pub fn global() -> &'static Self {
95 static REGISTRY: OnceLock<ModelRegistry> = OnceLock::new();
96 REGISTRY.get_or_init(Self::new)
97 }
98
99 fn add_model(&mut self, config: ModelConfig) {
101 let canonical_name = config.name.clone();
102
103 for alias in &config.aliases {
105 self.aliases.insert(alias.to_lowercase(), canonical_name.clone());
106 }
107
108 self.aliases.insert(canonical_name.to_lowercase(), canonical_name.clone());
110
111 self.models.insert(canonical_name, config);
112 }
113
114 pub fn resolve_model_name(&self, name: &str) -> Result<String, TokenError> {
116 let normalized = name.trim().to_lowercase();
117
118 if let Some(canonical) = self.aliases.get(&normalized) {
119 return Ok(canonical.clone());
120 }
121
122 let suggestion = self.generate_suggestions(&normalized);
124 Err(TokenError::UnknownModel { model: name.to_string(), suggestion })
125 }
126
127 pub fn get_model(&self, name: &str) -> Result<&ModelConfig, TokenError> {
129 let canonical = self.resolve_model_name(name)?;
130 Ok(self.models.get(&canonical).expect("Canonical name must exist"))
131 }
132
133 pub fn get_tokenizer(
143 &self,
144 name: &str,
145 use_accurate: bool,
146 ) -> Result<Box<dyn Tokenizer>, TokenError> {
147 let config = self.get_model(name)?;
148
149 match config.encoding.as_str() {
151 "anthropic-claude" => {
152 let tokenizer = ClaudeTokenizer::new(config.clone(), use_accurate)?;
154 Ok(Box::new(tokenizer))
155 }
156 "gemini-gemma3" => {
157 let tokenizer = GoogleTokenizer::new(config.clone())?;
159 Ok(Box::new(tokenizer))
160 }
161 _ => {
162 let model_info = ModelInfo {
164 name: config.name.clone(),
165 encoding: config.encoding.clone(),
166 context_window: config.context_window,
167 description: config.description.clone(),
168 };
169
170 let tokenizer = OpenAITokenizer::new(&config.encoding, model_info)
171 .map_err(|e| TokenError::Tokenization(e.to_string()))?;
172
173 Ok(Box::new(tokenizer))
174 }
175 }
176 }
177
178 pub fn list_models(&self) -> Vec<&ModelConfig> {
180 let mut models: Vec<&ModelConfig> = self.models.values().collect();
181 models.sort_by(|a, b| a.name.cmp(&b.name));
182 models
183 }
184
185 fn generate_suggestions(&self, name: &str) -> String {
187 let mut suggestions: Vec<(&str, usize)> = self
188 .models
189 .keys()
190 .map(|model_name| {
191 let distance = strsim::levenshtein(name, &model_name.to_lowercase());
192 (model_name.as_str(), distance)
193 })
194 .collect();
195
196 suggestions.sort_by_key(|&(_, dist)| dist);
197
198 let close_matches: Vec<&str> = suggestions
199 .iter()
200 .take(3)
201 .filter(|&&(_, dist)| dist <= 3)
202 .map(|&(name, _)| name)
203 .collect();
204
205 if close_matches.is_empty() {
206 "Use --list-models to see all supported models".to_string()
207 } else {
208 format!("Did you mean: {}?", close_matches.join(", "))
209 }
210 }
211}
212
213impl Default for ModelRegistry {
214 fn default() -> Self {
215 Self::new()
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222
223 #[test]
224 fn test_resolve_canonical_name() {
225 let registry = ModelRegistry::new();
226 assert_eq!(registry.resolve_model_name("gpt-4").unwrap(), "gpt-4");
227 assert_eq!(registry.resolve_model_name("GPT-4").unwrap(), "gpt-4");
228 }
229
230 #[test]
231 fn test_resolve_alias() {
232 let registry = ModelRegistry::new();
233 assert_eq!(registry.resolve_model_name("gpt4").unwrap(), "gpt-4");
234 assert_eq!(registry.resolve_model_name("gpt35").unwrap(), "gpt-3.5-turbo");
235 }
236
237 #[test]
238 fn test_unknown_model() {
239 let registry = ModelRegistry::new();
240 let result = registry.resolve_model_name("gpt-5");
241 assert!(result.is_err());
242 assert!(result.unwrap_err().to_string().contains("gpt"));
243 }
244
245 #[test]
246 fn test_list_models() {
247 let registry = ModelRegistry::new();
248 let models = registry.list_models();
249 assert_eq!(models.len(), 11); assert!(models.iter().any(|m| m.name == "gpt-3.5-turbo"));
251 assert!(models.iter().any(|m| m.name == "gpt-4"));
252 assert!(models.iter().any(|m| m.name == "gpt-4-turbo"));
253 assert!(models.iter().any(|m| m.name == "gpt-4o"));
254 assert!(models.iter().any(|m| m.name == "claude-opus-4-6"));
255 assert!(models.iter().any(|m| m.name == "claude-sonnet-4-6"));
256 assert!(models.iter().any(|m| m.name == "claude-haiku-4-5"));
257 assert!(models.iter().any(|m| m.name == "gemini-2.5-pro"));
258 assert!(models.iter().any(|m| m.name == "gemini-2.5-flash"));
259 assert!(models.iter().any(|m| m.name == "gemini-2.5-flash-lite"));
260 assert!(models.iter().any(|m| m.name == "gemini-3-pro-preview"));
261 }
262
263 #[test]
264 fn test_get_tokenizer() {
265 let registry = ModelRegistry::new();
266 let tokenizer = registry.get_tokenizer("gpt-4", false).unwrap();
267 let count = tokenizer.count_tokens("Hello world").unwrap();
268 assert_eq!(count, 2);
269 }
270
271 #[test]
272 fn test_get_claude_tokenizer() {
273 let registry = ModelRegistry::new();
274 let tokenizer = registry.get_tokenizer("claude-sonnet-4-6", false).unwrap();
275 let count = tokenizer.count_tokens("Hello world").unwrap();
276 assert_eq!(count, 3);
279 }
280
281 #[test]
282 fn test_claude_alias_resolution() {
283 let registry = ModelRegistry::new();
284 assert_eq!(registry.resolve_model_name("claude").unwrap(), "claude-sonnet-4-6");
285 assert_eq!(registry.resolve_model_name("sonnet").unwrap(), "claude-sonnet-4-6");
286 assert_eq!(registry.resolve_model_name("opus").unwrap(), "claude-opus-4-6");
287 assert_eq!(registry.resolve_model_name("haiku").unwrap(), "claude-haiku-4-5");
288 }
289
290 #[test]
291 fn test_fuzzy_suggestions() {
292 let registry = ModelRegistry::new();
293 let result = registry.resolve_model_name("gpt4-tubro");
294 assert!(result.is_err());
295 let err = result.unwrap_err();
296 assert!(err.to_string().contains("Did you mean"));
297 assert!(err.to_string().contains("gpt-4-turbo"));
298 }
299}