1use ferrum_types::Result;
4use std::collections::HashMap;
5use std::path::{Path, PathBuf};
6use tracing::{debug, info};
7
8#[derive(Debug, Clone)]
10pub struct ModelAlias {
11 pub name: String,
13 pub target: String,
15 pub description: Option<String>,
17}
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
21pub enum Architecture {
22 Llama,
23 Qwen2,
24 Qwen3,
25 Qwen3Moe,
28 Mistral,
29 Phi,
30 GPT2,
31 Bert,
32 Clip,
33 Whisper,
34 Qwen3TTS,
35 Unknown,
36}
37
38impl Architecture {
39 pub fn from_str(s: &str) -> Self {
40 match s.to_lowercase().as_str() {
41 "llama" | "llamaforcausallm" => Architecture::Llama,
42 "qwen2" | "qwen2forcausallm" => Architecture::Qwen2,
43 "qwen3" | "qwen3forcausallm" => Architecture::Qwen3,
44 "qwen3_moe" | "qwen3moe" | "qwen3moeforcausallm" => Architecture::Qwen3Moe,
45 "mistral" | "mistralforcausallm" => Architecture::Mistral,
46 "phi" | "phiforcausallm" => Architecture::Phi,
47 "gpt2" | "gpt2lmheadmodel" => Architecture::GPT2,
48 "bert" | "bertmodel" | "bertformaskedlm" | "bertforsequenceclassification" => {
49 Architecture::Bert
50 }
51 "clip" | "clipmodel" => Architecture::Clip,
52 "chinese_clip" | "chineseclipmodel" => Architecture::Clip,
53 "siglip" | "siglipmodel" => Architecture::Clip,
54 "whisper" | "whisperforconditionalgeneration" => Architecture::Whisper,
55 "qwen3_tts" | "qwen3ttsforconditionalgeneration" => Architecture::Qwen3TTS,
56 _ => Architecture::Unknown,
57 }
58 }
59}
60
61#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63pub enum ModelFormatType {
64 SafeTensors,
65 PyTorch,
66 GGUF,
67 Unknown,
68}
69
70#[derive(Debug, Clone)]
72pub struct ModelDiscoveryEntry {
73 pub id: String,
75 pub path: PathBuf,
77 pub format: ModelFormatType,
79 pub architecture: Option<Architecture>,
81 pub is_valid: bool,
83}
84
85#[derive(Debug)]
87pub struct DefaultModelRegistry {
88 aliases: HashMap<String, String>,
90 discovered_models: Vec<ModelDiscoveryEntry>,
92}
93
94impl DefaultModelRegistry {
95 pub fn new() -> Self {
97 Self {
98 aliases: HashMap::new(),
99 discovered_models: Vec::new(),
100 }
101 }
102
103 pub fn with_defaults() -> Self {
105 let mut registry = Self::new();
106
107 registry.register_alias("tinyllama", "TinyLlama/TinyLlama-1.1B-Chat-v1.0");
109 registry.register_alias("llama2-7b", "meta-llama/Llama-2-7b-hf");
110 registry.register_alias("llama2-7b-chat", "meta-llama/Llama-2-7b-chat-hf");
111 registry.register_alias("llama3-8b", "meta-llama/Meta-Llama-3-8B");
112 registry.register_alias("llama3-8b-instruct", "meta-llama/Meta-Llama-3-8B-Instruct");
113 registry.register_alias("qwen2-7b", "Qwen/Qwen2-7B");
114 registry.register_alias("qwen2-7b-instruct", "Qwen/Qwen2-7B-Instruct");
115 registry.register_alias("qwen3-0.6b", "Qwen/Qwen3-0.6B");
116 registry.register_alias("qwen3-1.7b", "Qwen/Qwen3-1.7B");
117 registry.register_alias("qwen3-4b", "Qwen/Qwen3-4B");
118 registry.register_alias("mistral-7b", "mistralai/Mistral-7B-v0.1");
119 registry.register_alias("mistral-7b-instruct", "mistralai/Mistral-7B-Instruct-v0.2");
120 registry.register_alias("phi3-mini", "microsoft/Phi-3-mini-4k-instruct");
121
122 registry.register_alias("whisper-tiny", "openai/whisper-tiny");
124 registry.register_alias("whisper-base", "openai/whisper-base");
125 registry.register_alias("whisper-small", "openai/whisper-small");
126 registry.register_alias("whisper-medium", "openai/whisper-medium");
127 registry.register_alias("whisper-large-v3", "openai/whisper-large-v3");
128 registry.register_alias("whisper-turbo", "openai/whisper-large-v3-turbo");
129 registry.register_alias("whisper-large-v3-turbo", "openai/whisper-large-v3-turbo");
130
131 registry
132 }
133
134 pub fn register_alias(&mut self, alias: impl Into<String>, target: impl Into<String>) {
136 let alias_str = alias.into();
137 let target_str = target.into();
138 debug!("Registering alias: {} -> {}", alias_str, target_str);
139 self.aliases.insert(alias_str, target_str);
140 }
141
142 pub fn add_alias(&mut self, alias: ModelAlias) -> Result<()> {
144 self.register_alias(alias.name, alias.target);
145 Ok(())
146 }
147
148 pub fn resolve_model_id(&self, name: &str) -> String {
150 self.aliases
151 .get(name)
152 .cloned()
153 .unwrap_or_else(|| name.to_string())
154 }
155
156 pub fn list_aliases(&self) -> Vec<ModelAlias> {
158 self.aliases
159 .iter()
160 .map(|(name, target)| ModelAlias {
161 name: name.clone(),
162 target: target.clone(),
163 description: None,
164 })
165 .collect()
166 }
167
168 pub async fn discover_models(&mut self, root: &Path) -> Result<Vec<ModelDiscoveryEntry>> {
170 if !root.exists() || !root.is_dir() {
171 return Ok(Vec::new());
172 }
173
174 info!("Discovering models in: {:?}", root);
175
176 let mut discovered = Vec::new();
177
178 if let Some(model_entry) = self.inspect_model_dir(root).await {
180 discovered.push(model_entry);
181 } else {
182 if let Ok(entries) = std::fs::read_dir(root) {
184 for entry in entries.filter_map(|e| e.ok()) {
185 let path = entry.path();
186 if path.is_dir() {
187 if let Some(model_entry) = self.inspect_model_dir(&path).await {
188 discovered.push(model_entry);
189 }
190 }
191 }
192 }
193 }
194
195 self.discovered_models = discovered.clone();
196 Ok(discovered)
197 }
198
199 async fn inspect_model_dir(&self, path: &Path) -> Option<ModelDiscoveryEntry> {
201 let config_path = path.join("config.json");
203 if !config_path.exists() {
204 debug!("No config.json in: {:?}", path);
205 return None;
206 }
207
208 let format = self.detect_model_format(path);
210 if format == ModelFormatType::Unknown {
211 debug!("Unknown format in: {:?}", path);
212 return None;
213 }
214
215 debug!("Found valid model at: {:?}, format: {:?}", path, format);
216
217 let architecture = self.read_architecture(&config_path);
219
220 let id = if let Some(parent) = path.parent() {
222 if let Some(grandparent) = parent.parent() {
223 if let Some(name) = grandparent.file_name().and_then(|n| n.to_str()) {
225 if name.starts_with("models--") {
226 name[8..].replace("--", "/")
227 } else {
228 path.file_name()
229 .and_then(|n| n.to_str())
230 .unwrap_or("unknown")
231 .to_string()
232 }
233 } else {
234 path.file_name()
235 .and_then(|n| n.to_str())
236 .unwrap_or("unknown")
237 .to_string()
238 }
239 } else {
240 path.file_name()
241 .and_then(|n| n.to_str())
242 .unwrap_or("unknown")
243 .to_string()
244 }
245 } else {
246 path.file_name()
247 .and_then(|n| n.to_str())
248 .unwrap_or("unknown")
249 .to_string()
250 };
251
252 Some(ModelDiscoveryEntry {
253 id,
254 path: path.to_path_buf(),
255 format,
256 architecture,
257 is_valid: true,
258 })
259 }
260
261 fn detect_model_format(&self, path: &Path) -> ModelFormatType {
263 if path.join("model.safetensors").exists()
264 || path.join("model.safetensors.index.json").exists()
265 {
266 ModelFormatType::SafeTensors
267 } else if path.join("pytorch_model.bin").exists()
268 || path.join("pytorch_model.bin.index.json").exists()
269 {
270 ModelFormatType::PyTorch
271 } else if std::fs::read_dir(path)
272 .ok()
273 .and_then(|entries| {
274 entries
275 .filter_map(|e| e.ok())
276 .find(|e| e.path().extension().and_then(|s| s.to_str()) == Some("gguf"))
277 })
278 .is_some()
279 {
280 ModelFormatType::GGUF
281 } else {
282 ModelFormatType::Unknown
283 }
284 }
285
286 fn read_architecture(&self, config_path: &Path) -> Option<Architecture> {
288 let content = std::fs::read_to_string(config_path).ok()?;
289 let config: serde_json::Value = serde_json::from_str(&content).ok()?;
290
291 if let Some(model_type) = config.get("model_type").and_then(|v| v.as_str()) {
293 return Some(Architecture::from_str(model_type));
294 }
295
296 if let Some(architectures) = config.get("architectures").and_then(|v| v.as_array()) {
298 if let Some(arch) = architectures.first().and_then(|v| v.as_str()) {
299 return Some(Architecture::from_str(arch));
300 }
301 }
302
303 None
304 }
305}
306
307impl Default for DefaultModelRegistry {
308 fn default() -> Self {
309 Self::new()
310 }
311}
312
313#[cfg(test)]
318mod tests {
319 use super::*;
320
321 #[test]
322 fn test_architecture_from_str() {
323 assert_eq!(Architecture::from_str("llama"), Architecture::Llama);
324 assert_eq!(
325 Architecture::from_str("LlamaForCausalLM"),
326 Architecture::Llama
327 );
328 assert_eq!(Architecture::from_str("qwen2"), Architecture::Qwen2);
329 assert_eq!(Architecture::from_str("mistral"), Architecture::Mistral);
330 assert_eq!(Architecture::from_str("phi"), Architecture::Phi);
331 assert_eq!(Architecture::from_str("gpt2"), Architecture::GPT2);
332 assert_eq!(
333 Architecture::from_str("unknown_arch"),
334 Architecture::Unknown
335 );
336 }
337
338 #[test]
339 fn test_architecture_copy() {
340 let arch = Architecture::Llama;
341 let arch2 = arch;
342 assert_eq!(arch, arch2);
343 }
344
345 #[test]
346 fn test_model_format_type_eq() {
347 assert_eq!(ModelFormatType::SafeTensors, ModelFormatType::SafeTensors);
348 assert_ne!(ModelFormatType::SafeTensors, ModelFormatType::PyTorch);
349 }
350
351 #[test]
352 fn test_model_alias_creation() {
353 let alias = ModelAlias {
354 name: "test".to_string(),
355 target: "test/model".to_string(),
356 description: Some("Test model".to_string()),
357 };
358
359 assert_eq!(alias.name, "test");
360 assert_eq!(alias.target, "test/model");
361 assert!(alias.description.is_some());
362 }
363
364 #[test]
365 fn test_model_alias_clone() {
366 let alias = ModelAlias {
367 name: "test".to_string(),
368 target: "test/model".to_string(),
369 description: None,
370 };
371
372 let cloned = alias.clone();
373 assert_eq!(alias.name, cloned.name);
374 assert_eq!(alias.target, cloned.target);
375 }
376
377 #[test]
378 fn test_model_discovery_entry() {
379 let entry = ModelDiscoveryEntry {
380 id: "test-model".to_string(),
381 path: PathBuf::from("/path/to/model"),
382 format: ModelFormatType::SafeTensors,
383 architecture: Some(Architecture::Llama),
384 is_valid: true,
385 };
386
387 assert_eq!(entry.id, "test-model");
388 assert_eq!(entry.format, ModelFormatType::SafeTensors);
389 assert!(entry.is_valid);
390 }
391
392 #[test]
393 fn test_registry_creation() {
394 let registry = DefaultModelRegistry::new();
395 assert_eq!(registry.aliases.len(), 0);
396 assert_eq!(registry.discovered_models.len(), 0);
397 }
398
399 #[test]
400 fn test_registry_default() {
401 let registry = DefaultModelRegistry::default();
402 assert_eq!(registry.aliases.len(), 0);
403 }
404
405 #[test]
406 fn test_registry_with_defaults() {
407 let registry = DefaultModelRegistry::with_defaults();
408
409 assert!(registry.aliases.len() > 0);
411
412 assert!(registry.aliases.contains_key("tinyllama"));
414 assert!(registry.aliases.contains_key("llama2-7b"));
415 }
416
417 #[test]
418 fn test_registry_register_alias() {
419 let mut registry = DefaultModelRegistry::new();
420
421 registry.register_alias("test", "test/model");
422
423 assert_eq!(
424 registry.aliases.get("test"),
425 Some(&"test/model".to_string())
426 );
427 }
428
429 #[test]
430 fn test_registry_resolve_model_id() {
431 let mut registry = DefaultModelRegistry::new();
432
433 registry.register_alias("mymodel", "org/actual-model");
434
435 let resolved = registry.resolve_model_id("mymodel");
436 assert_eq!(resolved, "org/actual-model");
437
438 let unresolved = registry.resolve_model_id("unknown");
440 assert_eq!(unresolved, "unknown");
441 }
442
443 #[test]
444 fn test_registry_list_aliases() {
445 let mut registry = DefaultModelRegistry::new();
446
447 registry.register_alias("model1", "org/model1");
448 registry.register_alias("model2", "org/model2");
449
450 let aliases = registry.list_aliases();
451 assert_eq!(aliases.len(), 2);
452 }
453
454 #[test]
455 fn test_architecture_debug() {
456 let arch = Architecture::Llama;
457 let debug_str = format!("{:?}", arch);
458 assert!(debug_str.contains("Llama"));
459 }
460
461 #[test]
462 fn test_model_format_debug() {
463 let format = ModelFormatType::SafeTensors;
464 let debug_str = format!("{:?}", format);
465 assert!(debug_str.contains("SafeTensors"));
466 }
467
468 #[test]
469 fn test_model_discovery_entry_clone() {
470 let entry = ModelDiscoveryEntry {
471 id: "test".to_string(),
472 path: PathBuf::from("/path"),
473 format: ModelFormatType::GGUF,
474 architecture: Some(Architecture::Mistral),
475 is_valid: false,
476 };
477
478 let cloned = entry.clone();
479 assert_eq!(entry.id, cloned.id);
480 assert_eq!(entry.format, cloned.format);
481 assert_eq!(entry.is_valid, cloned.is_valid);
482 }
483
484 #[test]
485 fn test_registry_multiple_aliases_same_target() {
486 let mut registry = DefaultModelRegistry::new();
487
488 registry.register_alias("alias1", "org/model");
489 registry.register_alias("alias2", "org/model");
490
491 assert_eq!(registry.resolve_model_id("alias1"), "org/model");
492 assert_eq!(registry.resolve_model_id("alias2"), "org/model");
493 }
494
495 #[test]
496 fn test_architecture_serialization() {
497 let arch = Architecture::Qwen2;
498 let json = serde_json::to_string(&arch).unwrap();
499 assert!(json.contains("Qwen2"));
500
501 let deserialized: Architecture = serde_json::from_str(&json).unwrap();
502 assert_eq!(deserialized, arch);
503 }
504}