1use ferrum_types::Result;
4use std::collections::HashMap;
5use std::path::{Path, PathBuf};
6use tracing::{debug, info};
7
8#[derive(Debug, Clone)]
10pub struct ModelAlias {
11 pub name: String,
13 pub target: String,
15 pub description: Option<String>,
17}
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
21pub enum Architecture {
22 Llama,
23 Qwen2,
24 Qwen3,
25 Mistral,
26 Phi,
27 GPT2,
28 Bert,
29 Clip,
30 Whisper,
31 Qwen3TTS,
32 Unknown,
33}
34
35impl Architecture {
36 pub fn from_str(s: &str) -> Self {
37 match s.to_lowercase().as_str() {
38 "llama" | "llamaforcausallm" => Architecture::Llama,
39 "qwen2" | "qwen2forcausallm" => Architecture::Qwen2,
40 "qwen3" | "qwen3forcausallm" => Architecture::Qwen3,
41 "mistral" | "mistralforcausallm" => Architecture::Mistral,
42 "phi" | "phiforcausallm" => Architecture::Phi,
43 "gpt2" | "gpt2lmheadmodel" => Architecture::GPT2,
44 "bert" | "bertmodel" | "bertformaskedlm" | "bertforsequenceclassification" => {
45 Architecture::Bert
46 }
47 "clip" | "clipmodel" => Architecture::Clip,
48 "chinese_clip" | "chineseclipmodel" => Architecture::Clip,
49 "siglip" | "siglipmodel" => Architecture::Clip,
50 "whisper" | "whisperforconditionalgeneration" => Architecture::Whisper,
51 "qwen3_tts" | "qwen3ttsforconditionalgeneration" => Architecture::Qwen3TTS,
52 _ => Architecture::Unknown,
53 }
54 }
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
59pub enum ModelFormatType {
60 SafeTensors,
61 PyTorch,
62 GGUF,
63 Unknown,
64}
65
66#[derive(Debug, Clone)]
68pub struct ModelDiscoveryEntry {
69 pub id: String,
71 pub path: PathBuf,
73 pub format: ModelFormatType,
75 pub architecture: Option<Architecture>,
77 pub is_valid: bool,
79}
80
81#[derive(Debug)]
83pub struct DefaultModelRegistry {
84 aliases: HashMap<String, String>,
86 discovered_models: Vec<ModelDiscoveryEntry>,
88}
89
90impl DefaultModelRegistry {
91 pub fn new() -> Self {
93 Self {
94 aliases: HashMap::new(),
95 discovered_models: Vec::new(),
96 }
97 }
98
99 pub fn with_defaults() -> Self {
101 let mut registry = Self::new();
102
103 registry.register_alias("tinyllama", "TinyLlama/TinyLlama-1.1B-Chat-v1.0");
105 registry.register_alias("llama2-7b", "meta-llama/Llama-2-7b-hf");
106 registry.register_alias("llama2-7b-chat", "meta-llama/Llama-2-7b-chat-hf");
107 registry.register_alias("llama3-8b", "meta-llama/Meta-Llama-3-8B");
108 registry.register_alias("llama3-8b-instruct", "meta-llama/Meta-Llama-3-8B-Instruct");
109 registry.register_alias("qwen2-7b", "Qwen/Qwen2-7B");
110 registry.register_alias("qwen2-7b-instruct", "Qwen/Qwen2-7B-Instruct");
111 registry.register_alias("qwen3-0.6b", "Qwen/Qwen3-0.6B");
112 registry.register_alias("qwen3-1.7b", "Qwen/Qwen3-1.7B");
113 registry.register_alias("qwen3-4b", "Qwen/Qwen3-4B");
114 registry.register_alias("mistral-7b", "mistralai/Mistral-7B-v0.1");
115 registry.register_alias("mistral-7b-instruct", "mistralai/Mistral-7B-Instruct-v0.2");
116 registry.register_alias("phi3-mini", "microsoft/Phi-3-mini-4k-instruct");
117
118 registry.register_alias("whisper-tiny", "openai/whisper-tiny");
120 registry.register_alias("whisper-base", "openai/whisper-base");
121 registry.register_alias("whisper-small", "openai/whisper-small");
122 registry.register_alias("whisper-medium", "openai/whisper-medium");
123 registry.register_alias("whisper-large-v3", "openai/whisper-large-v3");
124 registry.register_alias("whisper-turbo", "openai/whisper-large-v3-turbo");
125 registry.register_alias("whisper-large-v3-turbo", "openai/whisper-large-v3-turbo");
126
127 registry
128 }
129
130 pub fn register_alias(&mut self, alias: impl Into<String>, target: impl Into<String>) {
132 let alias_str = alias.into();
133 let target_str = target.into();
134 debug!("Registering alias: {} -> {}", alias_str, target_str);
135 self.aliases.insert(alias_str, target_str);
136 }
137
138 pub fn add_alias(&mut self, alias: ModelAlias) -> Result<()> {
140 self.register_alias(alias.name, alias.target);
141 Ok(())
142 }
143
144 pub fn resolve_model_id(&self, name: &str) -> String {
146 self.aliases
147 .get(name)
148 .cloned()
149 .unwrap_or_else(|| name.to_string())
150 }
151
152 pub fn list_aliases(&self) -> Vec<ModelAlias> {
154 self.aliases
155 .iter()
156 .map(|(name, target)| ModelAlias {
157 name: name.clone(),
158 target: target.clone(),
159 description: None,
160 })
161 .collect()
162 }
163
164 pub async fn discover_models(&mut self, root: &Path) -> Result<Vec<ModelDiscoveryEntry>> {
166 if !root.exists() || !root.is_dir() {
167 return Ok(Vec::new());
168 }
169
170 info!("Discovering models in: {:?}", root);
171
172 let mut discovered = Vec::new();
173
174 if let Some(model_entry) = self.inspect_model_dir(root).await {
176 discovered.push(model_entry);
177 } else {
178 if let Ok(entries) = std::fs::read_dir(root) {
180 for entry in entries.filter_map(|e| e.ok()) {
181 let path = entry.path();
182 if path.is_dir() {
183 if let Some(model_entry) = self.inspect_model_dir(&path).await {
184 discovered.push(model_entry);
185 }
186 }
187 }
188 }
189 }
190
191 self.discovered_models = discovered.clone();
192 Ok(discovered)
193 }
194
195 async fn inspect_model_dir(&self, path: &Path) -> Option<ModelDiscoveryEntry> {
197 let config_path = path.join("config.json");
199 if !config_path.exists() {
200 debug!("No config.json in: {:?}", path);
201 return None;
202 }
203
204 let format = self.detect_model_format(path);
206 if format == ModelFormatType::Unknown {
207 debug!("Unknown format in: {:?}", path);
208 return None;
209 }
210
211 debug!("Found valid model at: {:?}, format: {:?}", path, format);
212
213 let architecture = self.read_architecture(&config_path);
215
216 let id = if let Some(parent) = path.parent() {
218 if let Some(grandparent) = parent.parent() {
219 if let Some(name) = grandparent.file_name().and_then(|n| n.to_str()) {
221 if name.starts_with("models--") {
222 name[8..].replace("--", "/")
223 } else {
224 path.file_name()
225 .and_then(|n| n.to_str())
226 .unwrap_or("unknown")
227 .to_string()
228 }
229 } else {
230 path.file_name()
231 .and_then(|n| n.to_str())
232 .unwrap_or("unknown")
233 .to_string()
234 }
235 } else {
236 path.file_name()
237 .and_then(|n| n.to_str())
238 .unwrap_or("unknown")
239 .to_string()
240 }
241 } else {
242 path.file_name()
243 .and_then(|n| n.to_str())
244 .unwrap_or("unknown")
245 .to_string()
246 };
247
248 Some(ModelDiscoveryEntry {
249 id,
250 path: path.to_path_buf(),
251 format,
252 architecture,
253 is_valid: true,
254 })
255 }
256
257 fn detect_model_format(&self, path: &Path) -> ModelFormatType {
259 if path.join("model.safetensors").exists()
260 || path.join("model.safetensors.index.json").exists()
261 {
262 ModelFormatType::SafeTensors
263 } else if path.join("pytorch_model.bin").exists()
264 || path.join("pytorch_model.bin.index.json").exists()
265 {
266 ModelFormatType::PyTorch
267 } else if std::fs::read_dir(path)
268 .ok()
269 .and_then(|entries| {
270 entries
271 .filter_map(|e| e.ok())
272 .find(|e| e.path().extension().and_then(|s| s.to_str()) == Some("gguf"))
273 })
274 .is_some()
275 {
276 ModelFormatType::GGUF
277 } else {
278 ModelFormatType::Unknown
279 }
280 }
281
282 fn read_architecture(&self, config_path: &Path) -> Option<Architecture> {
284 let content = std::fs::read_to_string(config_path).ok()?;
285 let config: serde_json::Value = serde_json::from_str(&content).ok()?;
286
287 if let Some(model_type) = config.get("model_type").and_then(|v| v.as_str()) {
289 return Some(Architecture::from_str(model_type));
290 }
291
292 if let Some(architectures) = config.get("architectures").and_then(|v| v.as_array()) {
294 if let Some(arch) = architectures.first().and_then(|v| v.as_str()) {
295 return Some(Architecture::from_str(arch));
296 }
297 }
298
299 None
300 }
301}
302
303impl Default for DefaultModelRegistry {
304 fn default() -> Self {
305 Self::new()
306 }
307}
308
309#[cfg(test)]
314mod tests {
315 use super::*;
316
317 #[test]
318 fn test_architecture_from_str() {
319 assert_eq!(Architecture::from_str("llama"), Architecture::Llama);
320 assert_eq!(
321 Architecture::from_str("LlamaForCausalLM"),
322 Architecture::Llama
323 );
324 assert_eq!(Architecture::from_str("qwen2"), Architecture::Qwen2);
325 assert_eq!(Architecture::from_str("mistral"), Architecture::Mistral);
326 assert_eq!(Architecture::from_str("phi"), Architecture::Phi);
327 assert_eq!(Architecture::from_str("gpt2"), Architecture::GPT2);
328 assert_eq!(
329 Architecture::from_str("unknown_arch"),
330 Architecture::Unknown
331 );
332 }
333
334 #[test]
335 fn test_architecture_copy() {
336 let arch = Architecture::Llama;
337 let arch2 = arch;
338 assert_eq!(arch, arch2);
339 }
340
341 #[test]
342 fn test_model_format_type_eq() {
343 assert_eq!(ModelFormatType::SafeTensors, ModelFormatType::SafeTensors);
344 assert_ne!(ModelFormatType::SafeTensors, ModelFormatType::PyTorch);
345 }
346
347 #[test]
348 fn test_model_alias_creation() {
349 let alias = ModelAlias {
350 name: "test".to_string(),
351 target: "test/model".to_string(),
352 description: Some("Test model".to_string()),
353 };
354
355 assert_eq!(alias.name, "test");
356 assert_eq!(alias.target, "test/model");
357 assert!(alias.description.is_some());
358 }
359
360 #[test]
361 fn test_model_alias_clone() {
362 let alias = ModelAlias {
363 name: "test".to_string(),
364 target: "test/model".to_string(),
365 description: None,
366 };
367
368 let cloned = alias.clone();
369 assert_eq!(alias.name, cloned.name);
370 assert_eq!(alias.target, cloned.target);
371 }
372
373 #[test]
374 fn test_model_discovery_entry() {
375 let entry = ModelDiscoveryEntry {
376 id: "test-model".to_string(),
377 path: PathBuf::from("/path/to/model"),
378 format: ModelFormatType::SafeTensors,
379 architecture: Some(Architecture::Llama),
380 is_valid: true,
381 };
382
383 assert_eq!(entry.id, "test-model");
384 assert_eq!(entry.format, ModelFormatType::SafeTensors);
385 assert!(entry.is_valid);
386 }
387
388 #[test]
389 fn test_registry_creation() {
390 let registry = DefaultModelRegistry::new();
391 assert_eq!(registry.aliases.len(), 0);
392 assert_eq!(registry.discovered_models.len(), 0);
393 }
394
395 #[test]
396 fn test_registry_default() {
397 let registry = DefaultModelRegistry::default();
398 assert_eq!(registry.aliases.len(), 0);
399 }
400
401 #[test]
402 fn test_registry_with_defaults() {
403 let registry = DefaultModelRegistry::with_defaults();
404
405 assert!(registry.aliases.len() > 0);
407
408 assert!(registry.aliases.contains_key("tinyllama"));
410 assert!(registry.aliases.contains_key("llama2-7b"));
411 }
412
413 #[test]
414 fn test_registry_register_alias() {
415 let mut registry = DefaultModelRegistry::new();
416
417 registry.register_alias("test", "test/model");
418
419 assert_eq!(
420 registry.aliases.get("test"),
421 Some(&"test/model".to_string())
422 );
423 }
424
425 #[test]
426 fn test_registry_resolve_model_id() {
427 let mut registry = DefaultModelRegistry::new();
428
429 registry.register_alias("mymodel", "org/actual-model");
430
431 let resolved = registry.resolve_model_id("mymodel");
432 assert_eq!(resolved, "org/actual-model");
433
434 let unresolved = registry.resolve_model_id("unknown");
436 assert_eq!(unresolved, "unknown");
437 }
438
439 #[test]
440 fn test_registry_list_aliases() {
441 let mut registry = DefaultModelRegistry::new();
442
443 registry.register_alias("model1", "org/model1");
444 registry.register_alias("model2", "org/model2");
445
446 let aliases = registry.list_aliases();
447 assert_eq!(aliases.len(), 2);
448 }
449
450 #[test]
451 fn test_architecture_debug() {
452 let arch = Architecture::Llama;
453 let debug_str = format!("{:?}", arch);
454 assert!(debug_str.contains("Llama"));
455 }
456
457 #[test]
458 fn test_model_format_debug() {
459 let format = ModelFormatType::SafeTensors;
460 let debug_str = format!("{:?}", format);
461 assert!(debug_str.contains("SafeTensors"));
462 }
463
464 #[test]
465 fn test_model_discovery_entry_clone() {
466 let entry = ModelDiscoveryEntry {
467 id: "test".to_string(),
468 path: PathBuf::from("/path"),
469 format: ModelFormatType::GGUF,
470 architecture: Some(Architecture::Mistral),
471 is_valid: false,
472 };
473
474 let cloned = entry.clone();
475 assert_eq!(entry.id, cloned.id);
476 assert_eq!(entry.format, cloned.format);
477 assert_eq!(entry.is_valid, cloned.is_valid);
478 }
479
480 #[test]
481 fn test_registry_multiple_aliases_same_target() {
482 let mut registry = DefaultModelRegistry::new();
483
484 registry.register_alias("alias1", "org/model");
485 registry.register_alias("alias2", "org/model");
486
487 assert_eq!(registry.resolve_model_id("alias1"), "org/model");
488 assert_eq!(registry.resolve_model_id("alias2"), "org/model");
489 }
490
491 #[test]
492 fn test_architecture_serialization() {
493 let arch = Architecture::Qwen2;
494 let json = serde_json::to_string(&arch).unwrap();
495 assert!(json.contains("Qwen2"));
496
497 let deserialized: Architecture = serde_json::from_str(&json).unwrap();
498 assert_eq!(deserialized, arch);
499 }
500}