use crate::error::TokenError;
use crate::tokenizers::{
claude::{claude_models, ClaudeTokenizer},
google::{google_models, GoogleTokenizer},
openai::OpenAITokenizer,
ModelInfo, Tokenizer,
};
use anyhow::Result;
use std::collections::HashMap;
use std::sync::OnceLock;
#[derive(Debug, Clone)]
pub struct ModelConfig {
pub name: String,
pub encoding: String,
pub context_window: usize,
pub description: String,
pub aliases: Vec<String>,
}
pub struct ModelRegistry {
models: HashMap<String, ModelConfig>,
aliases: HashMap<String, String>, }
impl ModelRegistry {
pub fn new() -> Self {
let mut registry = Self { models: HashMap::new(), aliases: HashMap::new() };
registry.add_model(ModelConfig {
name: "gpt-3.5-turbo".to_string(),
encoding: "cl100k_base".to_string(),
context_window: 16385,
description: "GPT-3.5 Turbo (16K context)".to_string(),
aliases: vec![
"gpt-3.5".to_string(),
"gpt35".to_string(),
"gpt-35-turbo".to_string(),
"openai/gpt-3.5-turbo".to_string(),
],
});
registry.add_model(ModelConfig {
name: "gpt-4".to_string(),
encoding: "cl100k_base".to_string(),
context_window: 128000,
description: "GPT-4 (128K context)".to_string(),
aliases: vec!["gpt4".to_string(), "openai/gpt-4".to_string()],
});
registry.add_model(ModelConfig {
name: "gpt-4-turbo".to_string(),
encoding: "cl100k_base".to_string(),
context_window: 128000,
description: "GPT-4 Turbo (128K context)".to_string(),
aliases: vec![
"gpt4-turbo".to_string(),
"gpt-4turbo".to_string(),
"openai/gpt-4-turbo".to_string(),
],
});
registry.add_model(ModelConfig {
name: "gpt-4o".to_string(),
encoding: "o200k_base".to_string(),
context_window: 128000,
description: "GPT-4o (128K context)".to_string(),
aliases: vec!["gpt4o".to_string(), "openai/gpt-4o".to_string()],
});
for model in claude_models() {
registry.add_model(model);
}
for model in google_models() {
registry.add_model(model);
}
registry
}
pub fn global() -> &'static Self {
static REGISTRY: OnceLock<ModelRegistry> = OnceLock::new();
REGISTRY.get_or_init(Self::new)
}
fn add_model(&mut self, config: ModelConfig) {
let canonical_name = config.name.clone();
for alias in &config.aliases {
self.aliases.insert(alias.to_lowercase(), canonical_name.clone());
}
self.aliases.insert(canonical_name.to_lowercase(), canonical_name.clone());
self.models.insert(canonical_name, config);
}
pub fn resolve_model_name(&self, name: &str) -> Result<String, TokenError> {
let normalized = name.trim().to_lowercase();
if let Some(canonical) = self.aliases.get(&normalized) {
return Ok(canonical.clone());
}
let suggestion = self.generate_suggestions(&normalized);
Err(TokenError::UnknownModel { model: name.to_string(), suggestion })
}
pub fn get_model(&self, name: &str) -> Result<&ModelConfig, TokenError> {
let canonical = self.resolve_model_name(name)?;
Ok(self.models.get(&canonical).expect("Canonical name must exist"))
}
pub fn get_tokenizer(
&self,
name: &str,
use_accurate: bool,
) -> Result<Box<dyn Tokenizer>, TokenError> {
let config = self.get_model(name)?;
match config.encoding.as_str() {
"anthropic-claude" => {
let tokenizer = ClaudeTokenizer::new(config.clone(), use_accurate)?;
Ok(Box::new(tokenizer))
}
"gemini-gemma3" => {
let tokenizer = GoogleTokenizer::new(config.clone())?;
Ok(Box::new(tokenizer))
}
_ => {
let model_info = ModelInfo {
name: config.name.clone(),
encoding: config.encoding.clone(),
context_window: config.context_window,
description: config.description.clone(),
};
let tokenizer = OpenAITokenizer::new(&config.encoding, model_info)
.map_err(|e| TokenError::Tokenization(e.to_string()))?;
Ok(Box::new(tokenizer))
}
}
}
pub fn list_models(&self) -> Vec<&ModelConfig> {
let mut models: Vec<&ModelConfig> = self.models.values().collect();
models.sort_by(|a, b| a.name.cmp(&b.name));
models
}
fn generate_suggestions(&self, name: &str) -> String {
let mut suggestions: Vec<(&str, usize)> = self
.models
.keys()
.map(|model_name| {
let distance = strsim::levenshtein(name, &model_name.to_lowercase());
(model_name.as_str(), distance)
})
.collect();
suggestions.sort_by_key(|&(_, dist)| dist);
let close_matches: Vec<&str> = suggestions
.iter()
.take(3)
.filter(|&&(_, dist)| dist <= 3)
.map(|&(name, _)| name)
.collect();
if close_matches.is_empty() {
"Use --list-models to see all supported models".to_string()
} else {
format!("Did you mean: {}?", close_matches.join(", "))
}
}
}
impl Default for ModelRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_resolve_canonical_name() {
let registry = ModelRegistry::new();
assert_eq!(registry.resolve_model_name("gpt-4").unwrap(), "gpt-4");
assert_eq!(registry.resolve_model_name("GPT-4").unwrap(), "gpt-4");
}
#[test]
fn test_resolve_alias() {
let registry = ModelRegistry::new();
assert_eq!(registry.resolve_model_name("gpt4").unwrap(), "gpt-4");
assert_eq!(registry.resolve_model_name("gpt35").unwrap(), "gpt-3.5-turbo");
}
#[test]
fn test_unknown_model() {
let registry = ModelRegistry::new();
let result = registry.resolve_model_name("gpt-5");
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("gpt"));
}
#[test]
fn test_list_models() {
let registry = ModelRegistry::new();
let models = registry.list_models();
assert_eq!(models.len(), 11); assert!(models.iter().any(|m| m.name == "gpt-3.5-turbo"));
assert!(models.iter().any(|m| m.name == "gpt-4"));
assert!(models.iter().any(|m| m.name == "gpt-4-turbo"));
assert!(models.iter().any(|m| m.name == "gpt-4o"));
assert!(models.iter().any(|m| m.name == "claude-opus-4-6"));
assert!(models.iter().any(|m| m.name == "claude-sonnet-4-6"));
assert!(models.iter().any(|m| m.name == "claude-haiku-4-5"));
assert!(models.iter().any(|m| m.name == "gemini-2.5-pro"));
assert!(models.iter().any(|m| m.name == "gemini-2.5-flash"));
assert!(models.iter().any(|m| m.name == "gemini-2.5-flash-lite"));
assert!(models.iter().any(|m| m.name == "gemini-3-pro-preview"));
}
#[test]
fn test_get_tokenizer() {
let registry = ModelRegistry::new();
let tokenizer = registry.get_tokenizer("gpt-4", false).unwrap();
let count = tokenizer.count_tokens("Hello world").unwrap();
assert_eq!(count, 2);
}
#[test]
fn test_get_claude_tokenizer() {
let registry = ModelRegistry::new();
let tokenizer = registry.get_tokenizer("claude-sonnet-4-6", false).unwrap();
let count = tokenizer.count_tokens("Hello world").unwrap();
assert_eq!(count, 3);
}
#[test]
fn test_claude_alias_resolution() {
let registry = ModelRegistry::new();
assert_eq!(registry.resolve_model_name("claude").unwrap(), "claude-sonnet-4-6");
assert_eq!(registry.resolve_model_name("sonnet").unwrap(), "claude-sonnet-4-6");
assert_eq!(registry.resolve_model_name("opus").unwrap(), "claude-opus-4-6");
assert_eq!(registry.resolve_model_name("haiku").unwrap(), "claude-haiku-4-5");
}
#[test]
fn test_fuzzy_suggestions() {
let registry = ModelRegistry::new();
let result = registry.resolve_model_name("gpt4-tubro");
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("Did you mean"));
assert!(err.to_string().contains("gpt-4-turbo"));
}
}