1use crate::error::TokenError;
4use crate::tokenizers::{
5 claude::{claude_models, ClaudeTokenizer},
6 openai::OpenAITokenizer,
7 ModelInfo, Tokenizer,
8};
9use anyhow::Result;
10use std::collections::HashMap;
11use std::sync::OnceLock;
12
13#[derive(Debug, Clone)]
15pub struct ModelConfig {
16 pub name: String,
17 pub encoding: String,
18 pub context_window: usize,
19 pub description: String,
20 pub aliases: Vec<String>,
21}
22
23pub struct ModelRegistry {
25 models: HashMap<String, ModelConfig>,
26 aliases: HashMap<String, String>, }
28
29impl ModelRegistry {
30 pub fn new() -> Self {
32 let mut registry = Self { models: HashMap::new(), aliases: HashMap::new() };
33
34 registry.add_model(ModelConfig {
36 name: "gpt-3.5-turbo".to_string(),
37 encoding: "cl100k_base".to_string(),
38 context_window: 16385,
39 description: "GPT-3.5 Turbo (16K context)".to_string(),
40 aliases: vec![
41 "gpt-3.5".to_string(),
42 "gpt35".to_string(),
43 "gpt-35-turbo".to_string(),
44 "openai/gpt-3.5-turbo".to_string(),
45 ],
46 });
47
48 registry.add_model(ModelConfig {
50 name: "gpt-4".to_string(),
51 encoding: "cl100k_base".to_string(),
52 context_window: 128000,
53 description: "GPT-4 (128K context)".to_string(),
54 aliases: vec!["gpt4".to_string(), "openai/gpt-4".to_string()],
55 });
56
57 registry.add_model(ModelConfig {
59 name: "gpt-4-turbo".to_string(),
60 encoding: "cl100k_base".to_string(),
61 context_window: 128000,
62 description: "GPT-4 Turbo (128K context)".to_string(),
63 aliases: vec![
64 "gpt4-turbo".to_string(),
65 "gpt-4turbo".to_string(),
66 "openai/gpt-4-turbo".to_string(),
67 ],
68 });
69
70 registry.add_model(ModelConfig {
72 name: "gpt-4o".to_string(),
73 encoding: "o200k_base".to_string(),
74 context_window: 128000,
75 description: "GPT-4o (128K context)".to_string(),
76 aliases: vec!["gpt4o".to_string(), "openai/gpt-4o".to_string()],
77 });
78
79 for model in claude_models() {
81 registry.add_model(model);
82 }
83
84 registry
85 }
86
87 pub fn global() -> &'static Self {
89 static REGISTRY: OnceLock<ModelRegistry> = OnceLock::new();
90 REGISTRY.get_or_init(Self::new)
91 }
92
93 fn add_model(&mut self, config: ModelConfig) {
95 let canonical_name = config.name.clone();
96
97 for alias in &config.aliases {
99 self.aliases.insert(alias.to_lowercase(), canonical_name.clone());
100 }
101
102 self.aliases.insert(canonical_name.to_lowercase(), canonical_name.clone());
104
105 self.models.insert(canonical_name, config);
106 }
107
108 pub fn resolve_model_name(&self, name: &str) -> Result<String, TokenError> {
110 let normalized = name.trim().to_lowercase();
111
112 if let Some(canonical) = self.aliases.get(&normalized) {
113 return Ok(canonical.clone());
114 }
115
116 let suggestion = self.generate_suggestions(&normalized);
118 Err(TokenError::UnknownModel { model: name.to_string(), suggestion })
119 }
120
121 pub fn get_model(&self, name: &str) -> Result<&ModelConfig, TokenError> {
123 let canonical = self.resolve_model_name(name)?;
124 Ok(self.models.get(&canonical).expect("Canonical name must exist"))
125 }
126
127 pub fn get_tokenizer(
137 &self,
138 name: &str,
139 use_accurate: bool,
140 ) -> Result<Box<dyn Tokenizer>, TokenError> {
141 let config = self.get_model(name)?;
142
143 if config.encoding == "anthropic-claude" {
145 let tokenizer = ClaudeTokenizer::new(config.clone(), use_accurate)?;
147 Ok(Box::new(tokenizer))
148 } else {
149 let model_info = ModelInfo {
151 name: config.name.clone(),
152 encoding: config.encoding.clone(),
153 context_window: config.context_window,
154 description: config.description.clone(),
155 };
156
157 let tokenizer = OpenAITokenizer::new(&config.encoding, model_info)
158 .map_err(|e| TokenError::Tokenization(e.to_string()))?;
159
160 Ok(Box::new(tokenizer))
161 }
162 }
163
164 pub fn list_models(&self) -> Vec<&ModelConfig> {
166 let mut models: Vec<&ModelConfig> = self.models.values().collect();
167 models.sort_by(|a, b| a.name.cmp(&b.name));
168 models
169 }
170
171 fn generate_suggestions(&self, name: &str) -> String {
173 let mut suggestions: Vec<(&str, usize)> = self
174 .models
175 .keys()
176 .map(|model_name| {
177 let distance = strsim::levenshtein(name, &model_name.to_lowercase());
178 (model_name.as_str(), distance)
179 })
180 .collect();
181
182 suggestions.sort_by_key(|&(_, dist)| dist);
183
184 let close_matches: Vec<&str> = suggestions
185 .iter()
186 .take(3)
187 .filter(|&&(_, dist)| dist <= 3)
188 .map(|&(name, _)| name)
189 .collect();
190
191 if close_matches.is_empty() {
192 "Use --list-models to see all supported models".to_string()
193 } else {
194 format!("Did you mean: {}?", close_matches.join(", "))
195 }
196 }
197}
198
199impl Default for ModelRegistry {
200 fn default() -> Self {
201 Self::new()
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208
209 #[test]
210 fn test_resolve_canonical_name() {
211 let registry = ModelRegistry::new();
212 assert_eq!(registry.resolve_model_name("gpt-4").unwrap(), "gpt-4");
213 assert_eq!(registry.resolve_model_name("GPT-4").unwrap(), "gpt-4");
214 }
215
216 #[test]
217 fn test_resolve_alias() {
218 let registry = ModelRegistry::new();
219 assert_eq!(registry.resolve_model_name("gpt4").unwrap(), "gpt-4");
220 assert_eq!(registry.resolve_model_name("gpt35").unwrap(), "gpt-3.5-turbo");
221 }
222
223 #[test]
224 fn test_unknown_model() {
225 let registry = ModelRegistry::new();
226 let result = registry.resolve_model_name("gpt-5");
227 assert!(result.is_err());
228 assert!(result.unwrap_err().to_string().contains("gpt"));
229 }
230
231 #[test]
232 fn test_list_models() {
233 let registry = ModelRegistry::new();
234 let models = registry.list_models();
235 assert_eq!(models.len(), 7); assert!(models.iter().any(|m| m.name == "gpt-3.5-turbo"));
237 assert!(models.iter().any(|m| m.name == "gpt-4"));
238 assert!(models.iter().any(|m| m.name == "gpt-4-turbo"));
239 assert!(models.iter().any(|m| m.name == "gpt-4o"));
240 assert!(models.iter().any(|m| m.name == "claude-opus-4-6"));
241 assert!(models.iter().any(|m| m.name == "claude-sonnet-4-6"));
242 assert!(models.iter().any(|m| m.name == "claude-haiku-4-5"));
243 }
244
245 #[test]
246 fn test_get_tokenizer() {
247 let registry = ModelRegistry::new();
248 let tokenizer = registry.get_tokenizer("gpt-4", false).unwrap();
249 let count = tokenizer.count_tokens("Hello world").unwrap();
250 assert_eq!(count, 2);
251 }
252
253 #[test]
254 fn test_get_claude_tokenizer() {
255 let registry = ModelRegistry::new();
256 let tokenizer = registry.get_tokenizer("claude-sonnet-4-6", false).unwrap();
257 let count = tokenizer.count_tokens("Hello world").unwrap();
258 assert_eq!(count, 3);
261 }
262
263 #[test]
264 fn test_claude_alias_resolution() {
265 let registry = ModelRegistry::new();
266 assert_eq!(registry.resolve_model_name("claude").unwrap(), "claude-sonnet-4-6");
267 assert_eq!(registry.resolve_model_name("sonnet").unwrap(), "claude-sonnet-4-6");
268 assert_eq!(registry.resolve_model_name("opus").unwrap(), "claude-opus-4-6");
269 assert_eq!(registry.resolve_model_name("haiku").unwrap(), "claude-haiku-4-5");
270 }
271
272 #[test]
273 fn test_fuzzy_suggestions() {
274 let registry = ModelRegistry::new();
275 let result = registry.resolve_model_name("gpt4-tubro");
276 assert!(result.is_err());
277 let err = result.unwrap_err();
278 assert!(err.to_string().contains("Did you mean"));
279 assert!(err.to_string().contains("gpt-4-turbo"));
280 }
281}