token_count/tokenizers/
registry.rs1use crate::error::TokenError;
4use crate::tokenizers::{openai::OpenAITokenizer, ModelInfo, Tokenizer};
5use anyhow::Result;
6use std::collections::HashMap;
7use std::sync::OnceLock;
8
9#[derive(Debug, Clone)]
11pub struct ModelConfig {
12 pub name: String,
13 pub encoding: String,
14 pub context_window: usize,
15 pub description: String,
16 pub aliases: Vec<String>,
17}
18
19pub struct ModelRegistry {
21 models: HashMap<String, ModelConfig>,
22 aliases: HashMap<String, String>, }
24
25impl ModelRegistry {
26 pub fn new() -> Self {
28 let mut registry = Self { models: HashMap::new(), aliases: HashMap::new() };
29
30 registry.add_model(ModelConfig {
32 name: "gpt-3.5-turbo".to_string(),
33 encoding: "cl100k_base".to_string(),
34 context_window: 16385,
35 description: "GPT-3.5 Turbo (16K context)".to_string(),
36 aliases: vec![
37 "gpt-3.5".to_string(),
38 "gpt35".to_string(),
39 "gpt-35-turbo".to_string(),
40 "openai/gpt-3.5-turbo".to_string(),
41 ],
42 });
43
44 registry.add_model(ModelConfig {
46 name: "gpt-4".to_string(),
47 encoding: "cl100k_base".to_string(),
48 context_window: 128000,
49 description: "GPT-4 (128K context)".to_string(),
50 aliases: vec!["gpt4".to_string(), "openai/gpt-4".to_string()],
51 });
52
53 registry.add_model(ModelConfig {
55 name: "gpt-4-turbo".to_string(),
56 encoding: "cl100k_base".to_string(),
57 context_window: 128000,
58 description: "GPT-4 Turbo (128K context)".to_string(),
59 aliases: vec![
60 "gpt4-turbo".to_string(),
61 "gpt-4turbo".to_string(),
62 "openai/gpt-4-turbo".to_string(),
63 ],
64 });
65
66 registry.add_model(ModelConfig {
68 name: "gpt-4o".to_string(),
69 encoding: "o200k_base".to_string(),
70 context_window: 128000,
71 description: "GPT-4o (128K context)".to_string(),
72 aliases: vec!["gpt4o".to_string(), "openai/gpt-4o".to_string()],
73 });
74
75 registry
76 }
77
78 pub fn global() -> &'static Self {
80 static REGISTRY: OnceLock<ModelRegistry> = OnceLock::new();
81 REGISTRY.get_or_init(Self::new)
82 }
83
84 fn add_model(&mut self, config: ModelConfig) {
86 let canonical_name = config.name.clone();
87
88 for alias in &config.aliases {
90 self.aliases.insert(alias.to_lowercase(), canonical_name.clone());
91 }
92
93 self.aliases.insert(canonical_name.to_lowercase(), canonical_name.clone());
95
96 self.models.insert(canonical_name, config);
97 }
98
99 pub fn resolve_model_name(&self, name: &str) -> Result<String, TokenError> {
101 let normalized = name.trim().to_lowercase();
102
103 if let Some(canonical) = self.aliases.get(&normalized) {
104 return Ok(canonical.clone());
105 }
106
107 let suggestion = self.generate_suggestions(&normalized);
109 Err(TokenError::UnknownModel { model: name.to_string(), suggestion })
110 }
111
112 pub fn get_model(&self, name: &str) -> Result<&ModelConfig, TokenError> {
114 let canonical = self.resolve_model_name(name)?;
115 Ok(self.models.get(&canonical).expect("Canonical name must exist"))
116 }
117
118 pub fn get_tokenizer(&self, name: &str) -> Result<Box<dyn Tokenizer>, TokenError> {
120 let config = self.get_model(name)?;
121
122 let model_info = ModelInfo {
123 name: config.name.clone(),
124 encoding: config.encoding.clone(),
125 context_window: config.context_window,
126 description: config.description.clone(),
127 };
128
129 let tokenizer = OpenAITokenizer::new(&config.encoding, model_info)
130 .map_err(|e| TokenError::Tokenization(e.to_string()))?;
131
132 Ok(Box::new(tokenizer))
133 }
134
135 pub fn list_models(&self) -> Vec<&ModelConfig> {
137 let mut models: Vec<&ModelConfig> = self.models.values().collect();
138 models.sort_by(|a, b| a.name.cmp(&b.name));
139 models
140 }
141
142 fn generate_suggestions(&self, name: &str) -> String {
144 let mut suggestions: Vec<(&str, usize)> = self
145 .models
146 .keys()
147 .map(|model_name| {
148 let distance = strsim::levenshtein(name, &model_name.to_lowercase());
149 (model_name.as_str(), distance)
150 })
151 .collect();
152
153 suggestions.sort_by_key(|&(_, dist)| dist);
154
155 let close_matches: Vec<&str> = suggestions
156 .iter()
157 .take(3)
158 .filter(|&&(_, dist)| dist <= 3)
159 .map(|&(name, _)| name)
160 .collect();
161
162 if close_matches.is_empty() {
163 "Use --list-models to see all supported models".to_string()
164 } else {
165 format!("Did you mean: {}?", close_matches.join(", "))
166 }
167 }
168}
169
170impl Default for ModelRegistry {
171 fn default() -> Self {
172 Self::new()
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179
180 #[test]
181 fn test_resolve_canonical_name() {
182 let registry = ModelRegistry::new();
183 assert_eq!(registry.resolve_model_name("gpt-4").unwrap(), "gpt-4");
184 assert_eq!(registry.resolve_model_name("GPT-4").unwrap(), "gpt-4");
185 }
186
187 #[test]
188 fn test_resolve_alias() {
189 let registry = ModelRegistry::new();
190 assert_eq!(registry.resolve_model_name("gpt4").unwrap(), "gpt-4");
191 assert_eq!(registry.resolve_model_name("gpt35").unwrap(), "gpt-3.5-turbo");
192 }
193
194 #[test]
195 fn test_unknown_model() {
196 let registry = ModelRegistry::new();
197 let result = registry.resolve_model_name("gpt-5");
198 assert!(result.is_err());
199 assert!(result.unwrap_err().to_string().contains("gpt"));
200 }
201
202 #[test]
203 fn test_list_models() {
204 let registry = ModelRegistry::new();
205 let models = registry.list_models();
206 assert_eq!(models.len(), 4);
207 assert!(models.iter().any(|m| m.name == "gpt-3.5-turbo"));
208 assert!(models.iter().any(|m| m.name == "gpt-4"));
209 assert!(models.iter().any(|m| m.name == "gpt-4-turbo"));
210 assert!(models.iter().any(|m| m.name == "gpt-4o"));
211 }
212
213 #[test]
214 fn test_get_tokenizer() {
215 let registry = ModelRegistry::new();
216 let tokenizer = registry.get_tokenizer("gpt-4").unwrap();
217 let count = tokenizer.count_tokens("Hello world").unwrap();
218 assert_eq!(count, 2);
219 }
220
221 #[test]
222 fn test_fuzzy_suggestions() {
223 let registry = ModelRegistry::new();
224 let result = registry.resolve_model_name("gpt4-tubro");
225 assert!(result.is_err());
226 let err = result.unwrap_err();
227 assert!(err.to_string().contains("Did you mean"));
228 assert!(err.to_string().contains("gpt-4-turbo"));
229 }
230}