use serde::Deserialize;
use std::collections::HashMap;
use std::path::PathBuf;
#[derive(Debug, Clone, Deserialize)]
pub struct ModelAlias {
pub repo: String,
#[serde(default)]
pub default_file: Option<String>,
#[serde(default)]
pub mmproj: Option<String>,
#[serde(default)]
pub family: Option<String>,
#[serde(default)]
pub description: Option<String>,
#[serde(default)]
pub tags: Vec<String>,
}
#[derive(Debug, Clone, Deserialize, Default)]
pub struct QuantizationConfig {
#[serde(flatten)]
pub mappings: HashMap<String, Vec<String>>,
#[serde(default)]
pub default_order: Vec<String>,
}
#[derive(Debug, Clone, Deserialize, Default)]
pub struct RegistryMeta {
pub version: Option<String>,
pub updated: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub struct ModelRegistry {
pub aliases: HashMap<String, ModelAlias>,
pub quantizations: QuantizationConfig,
pub meta: RegistryMeta,
}
#[derive(Debug, Clone)]
pub struct ParsedModelSpec {
pub original: String,
pub name: String,
pub quantization: Option<String>,
pub is_hf_spec: bool,
pub is_local_path: bool,
pub alias: Option<ModelAlias>,
}
impl ModelRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn load_embedded() -> Result<Self, RegistryError> {
let toml_content = include_str!("../../configs/models.toml");
Self::from_toml(toml_content)
}
pub fn from_toml(content: &str) -> Result<Self, RegistryError> {
#[derive(Deserialize)]
struct RawRegistry {
#[serde(default)]
meta: RegistryMeta,
#[serde(default)]
aliases: HashMap<String, ModelAlias>,
#[serde(default)]
quantizations: QuantizationConfig,
}
let raw: RawRegistry =
toml::from_str(content).map_err(|e| RegistryError::ParseError(e.to_string()))?;
Ok(Self {
aliases: raw.aliases,
quantizations: raw.quantizations,
meta: raw.meta,
})
}
pub fn from_file(path: impl AsRef<std::path::Path>) -> Result<Self, RegistryError> {
let content = std::fs::read_to_string(path.as_ref())
.map_err(|e| RegistryError::IoError(e.to_string()))?;
Self::from_toml(&content)
}
pub fn get(&self, name: &str) -> Option<&ModelAlias> {
self.aliases.get(name)
}
pub fn list_aliases(&self) -> Vec<&str> {
self.aliases.keys().map(|s| s.as_str()).collect()
}
pub fn search(&self, query: &str) -> Vec<(&str, &ModelAlias)> {
let query_lower = query.to_lowercase();
self.aliases
.iter()
.filter(|(name, alias)| {
name.to_lowercase().contains(&query_lower)
|| alias
.description
.as_ref()
.map(|d| d.to_lowercase().contains(&query_lower))
.unwrap_or(false)
|| alias
.family
.as_ref()
.map(|f| f.to_lowercase().contains(&query_lower))
.unwrap_or(false)
|| alias
.tags
.iter()
.any(|t| t.to_lowercase().contains(&query_lower))
})
.map(|(name, alias)| (name.as_str(), alias))
.collect()
}
pub fn parse_spec(&self, input: &str) -> ParsedModelSpec {
let input = input.trim();
if input.starts_with("hf:") {
return ParsedModelSpec {
original: input.to_string(),
name: input.to_string(),
quantization: None,
is_hf_spec: true,
is_local_path: false,
alias: None,
};
}
if Self::looks_like_local_path(input) {
return ParsedModelSpec {
original: input.to_string(),
name: input.to_string(),
quantization: None,
is_hf_spec: false,
is_local_path: true,
alias: None,
};
}
let (base_name, quantization) = self.extract_quantization(input);
let alias = self.get(&base_name).cloned();
ParsedModelSpec {
original: input.to_string(),
name: base_name,
quantization,
is_hf_spec: false,
is_local_path: false,
alias,
}
}
fn extract_quantization(&self, input: &str) -> (String, Option<String>) {
let quant_suffixes = ["q2", "q3", "q4", "q5", "q6", "q8", "f16", "f32"];
for suffix in quant_suffixes {
let pattern = format!("-{}", suffix);
if input.ends_with(&pattern) {
let base = input[..input.len() - pattern.len()].to_string();
return (base, Some(suffix.to_string()));
}
}
(input.to_string(), None)
}
fn looks_like_local_path(input: &str) -> bool {
let looks_like_windows_abs = input.len() >= 3
&& input.as_bytes()[0].is_ascii_alphabetic()
&& input.as_bytes()[1] == b':'
&& (input.as_bytes()[2] == b'\\' || input.as_bytes()[2] == b'/');
input.starts_with('/')
|| input.starts_with("./")
|| input.starts_with("../")
|| input.starts_with("~/")
|| looks_like_windows_abs
|| input.ends_with(".gguf")
|| input.contains('\\')
}
pub fn get_quant_patterns(&self, suffix: &str) -> Vec<String> {
self.quantizations
.mappings
.get(suffix)
.cloned()
.unwrap_or_default()
}
pub fn default_quant_order(&self) -> Vec<String> {
if self.quantizations.default_order.is_empty() {
vec![
"Q4_K_M".to_string(),
"Q4_K_S".to_string(),
"Q5_K_M".to_string(),
"Q4_0".to_string(),
"Q8_0".to_string(),
"F16".to_string(),
]
} else {
self.quantizations.default_order.clone()
}
}
pub fn to_hf_spec(&self, spec: &ParsedModelSpec) -> Option<String> {
if spec.is_hf_spec {
return Some(spec.original.clone());
}
if spec.is_local_path {
return None;
}
let alias = spec.alias.as_ref()?;
let mut hf_spec = format!("hf:{}", alias.repo);
if let Some(ref file) = alias.default_file {
hf_spec.push(':');
hf_spec.push_str(file);
}
Some(hf_spec)
}
}
#[derive(Debug)]
pub enum RegistryError {
IoError(String),
ParseError(String),
}
impl std::fmt::Display for RegistryError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RegistryError::IoError(e) => write!(f, "IO error: {}", e),
RegistryError::ParseError(e) => write!(f, "Parse error: {}", e),
}
}
}
impl std::error::Error for RegistryError {}
static REGISTRY: std::sync::OnceLock<ModelRegistry> = std::sync::OnceLock::new();
pub fn registry() -> &'static ModelRegistry {
REGISTRY.get_or_init(|| {
ModelRegistry::load_embedded().unwrap_or_else(|e| {
eprintln!("Warning: Failed to load model registry: {}", e);
ModelRegistry::new()
})
})
}
pub fn resolve_model_name(name: &str) -> ResolvedModel {
let reg = registry();
let spec = reg.parse_spec(name);
if spec.is_local_path {
return ResolvedModel::LocalPath(PathBuf::from(&spec.name));
}
if spec.is_hf_spec {
return ResolvedModel::HuggingFace {
spec: spec.original,
mmproj: None,
};
}
if let Some(ollama_name) = name.strip_prefix("ollama:") {
let (model_name, tag) = ollama_name
.split_once(':')
.unwrap_or((ollama_name, "latest"));
return ResolvedModel::Ollama {
name: model_name.to_string(),
tag: tag.to_string(),
};
}
if let Some(ref alias) = spec.alias {
let hf_spec = reg
.to_hf_spec(&spec)
.unwrap_or_else(|| format!("hf:{}", alias.repo));
return ResolvedModel::HuggingFace {
spec: hf_spec,
mmproj: alias.mmproj.clone(),
};
}
if super::ollama::OllamaClient::is_ollama_ref(name) {
let (model_name, tag) = if name.contains(':') {
let parts: Vec<&str> = name.splitn(2, ':').collect();
(
parts[0].to_string(),
parts.get(1).unwrap_or(&"latest").to_string(),
)
} else {
(name.to_string(), "latest".to_string())
};
return ResolvedModel::Ollama {
name: model_name,
tag,
};
}
ResolvedModel::Unknown(name.to_string())
}
#[derive(Debug, Clone)]
pub enum ResolvedModel {
LocalPath(PathBuf),
HuggingFace {
spec: String,
mmproj: Option<String>,
},
Ollama { name: String, tag: String },
Unknown(String),
}
impl ResolvedModel {
pub fn is_local(&self) -> bool {
matches!(self, ResolvedModel::LocalPath(_))
}
pub fn is_hf(&self) -> bool {
matches!(self, ResolvedModel::HuggingFace { .. })
}
pub fn is_ollama(&self) -> bool {
matches!(self, ResolvedModel::Ollama { .. })
}
pub fn is_unknown(&self) -> bool {
matches!(self, ResolvedModel::Unknown(_))
}
pub fn ollama_name(&self) -> Option<String> {
match self {
ResolvedModel::Ollama { name, tag } => Some(format!("{}:{}", name, tag)),
_ => None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_load_embedded_registry() {
let reg = ModelRegistry::load_embedded().unwrap();
assert!(!reg.aliases.is_empty());
}
#[test]
fn test_parse_alias() {
let reg = ModelRegistry::load_embedded().unwrap();
let spec = reg.parse_spec("llama3.2:1b");
assert_eq!(spec.name, "llama3.2:1b");
assert!(spec.alias.is_some());
assert!(!spec.is_hf_spec);
assert!(!spec.is_local_path);
}
#[test]
fn test_parse_with_quantization() {
let reg = ModelRegistry::load_embedded().unwrap();
let spec = reg.parse_spec("llama3.2:1b-q8");
assert_eq!(spec.name, "llama3.2:1b");
assert_eq!(spec.quantization, Some("q8".to_string()));
}
#[test]
fn test_parse_hf_spec() {
let reg = ModelRegistry::load_embedded().unwrap();
let spec = reg.parse_spec("hf:TheBloke/Llama-2-7B-GGUF");
assert!(spec.is_hf_spec);
assert!(!spec.is_local_path);
assert!(spec.alias.is_none());
}
#[test]
fn test_parse_local_path() {
let reg = ModelRegistry::load_embedded().unwrap();
let spec1 = reg.parse_spec("./model.gguf");
assert!(spec1.is_local_path);
let spec2 = reg.parse_spec("/home/user/model.gguf");
assert!(spec2.is_local_path);
let spec3 = reg.parse_spec("model.gguf");
assert!(spec3.is_local_path);
}
#[test]
fn test_search() {
let reg = ModelRegistry::load_embedded().unwrap();
let results = reg.search("llama");
assert!(!results.is_empty());
let results = reg.search("coding");
assert!(!results.is_empty());
}
#[test]
fn test_resolve_alias() {
let resolved = resolve_model_name("llama3.2:1b");
assert!(resolved.is_hf());
}
#[test]
fn test_resolve_local() {
let resolved = resolve_model_name("./model.gguf");
assert!(resolved.is_local());
}
#[test]
fn test_resolve_ollama_user_model() {
let resolved = resolve_model_name("user/model:tag");
assert!(resolved.is_ollama());
}
}