Skip to main content

ferrum_models/
registry.rs

1//! Model registry and alias management
2
3use ferrum_types::Result;
4use std::collections::HashMap;
5use std::path::{Path, PathBuf};
6use tracing::{debug, info};
7
8/// Model alias entry
9#[derive(Debug, Clone)]
10pub struct ModelAlias {
11    /// Alias name (short name)
12    pub name: String,
13    /// Target model identifier
14    pub target: String,
15    /// Optional description
16    pub description: Option<String>,
17}
18
19/// Architecture types for models
20#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
21pub enum Architecture {
22    Llama,
23    Qwen2,
24    Qwen3,
25    /// Qwen3-MoE family (Qwen3-30B-A3B and friends). Distinct from Qwen3
26    /// because the FFN per layer is replaced by a router + N experts.
27    Qwen3Moe,
28    Mistral,
29    Phi,
30    GPT2,
31    Bert,
32    Clip,
33    Whisper,
34    Qwen3TTS,
35    Unknown,
36}
37
38impl Architecture {
39    pub fn from_str(s: &str) -> Self {
40        match s.to_lowercase().as_str() {
41            "llama" | "llamaforcausallm" => Architecture::Llama,
42            "qwen2" | "qwen2forcausallm" => Architecture::Qwen2,
43            "qwen3" | "qwen3forcausallm" => Architecture::Qwen3,
44            "qwen3_moe" | "qwen3moe" | "qwen3moeforcausallm" => Architecture::Qwen3Moe,
45            "mistral" | "mistralforcausallm" => Architecture::Mistral,
46            "phi" | "phiforcausallm" => Architecture::Phi,
47            "gpt2" | "gpt2lmheadmodel" => Architecture::GPT2,
48            "bert" | "bertmodel" | "bertformaskedlm" | "bertforsequenceclassification" => {
49                Architecture::Bert
50            }
51            "clip" | "clipmodel" => Architecture::Clip,
52            "chinese_clip" | "chineseclipmodel" => Architecture::Clip,
53            "siglip" | "siglipmodel" => Architecture::Clip,
54            "whisper" | "whisperforconditionalgeneration" => Architecture::Whisper,
55            "qwen3_tts" | "qwen3ttsforconditionalgeneration" => Architecture::Qwen3TTS,
56            _ => Architecture::Unknown,
57        }
58    }
59}
60
61/// Model format type
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63pub enum ModelFormatType {
64    SafeTensors,
65    PyTorch,
66    GGUF,
67    Unknown,
68}
69
70/// Discovered model entry
71#[derive(Debug, Clone)]
72pub struct ModelDiscoveryEntry {
73    /// Model identifier
74    pub id: String,
75    /// Local path to model
76    pub path: PathBuf,
77    /// Model format
78    pub format: ModelFormatType,
79    /// Architecture type (if detected)
80    pub architecture: Option<Architecture>,
81    /// Whether model passes validation
82    pub is_valid: bool,
83}
84
85/// Model registry for managing models and aliases
86#[derive(Debug)]
87pub struct DefaultModelRegistry {
88    /// Model aliases
89    aliases: HashMap<String, String>,
90    /// Discovered models cache
91    discovered_models: Vec<ModelDiscoveryEntry>,
92}
93
94impl DefaultModelRegistry {
95    /// Create new empty registry
96    pub fn new() -> Self {
97        Self {
98            aliases: HashMap::new(),
99            discovered_models: Vec::new(),
100        }
101    }
102
103    /// Create registry with common aliases
104    pub fn with_defaults() -> Self {
105        let mut registry = Self::new();
106
107        // Common model aliases
108        registry.register_alias("tinyllama", "TinyLlama/TinyLlama-1.1B-Chat-v1.0");
109        registry.register_alias("llama2-7b", "meta-llama/Llama-2-7b-hf");
110        registry.register_alias("llama2-7b-chat", "meta-llama/Llama-2-7b-chat-hf");
111        registry.register_alias("llama3-8b", "meta-llama/Meta-Llama-3-8B");
112        registry.register_alias("llama3-8b-instruct", "meta-llama/Meta-Llama-3-8B-Instruct");
113        registry.register_alias("qwen2-7b", "Qwen/Qwen2-7B");
114        registry.register_alias("qwen2-7b-instruct", "Qwen/Qwen2-7B-Instruct");
115        registry.register_alias("qwen3-0.6b", "Qwen/Qwen3-0.6B");
116        registry.register_alias("qwen3-1.7b", "Qwen/Qwen3-1.7B");
117        registry.register_alias("qwen3-4b", "Qwen/Qwen3-4B");
118        registry.register_alias("mistral-7b", "mistralai/Mistral-7B-v0.1");
119        registry.register_alias("mistral-7b-instruct", "mistralai/Mistral-7B-Instruct-v0.2");
120        registry.register_alias("phi3-mini", "microsoft/Phi-3-mini-4k-instruct");
121
122        // Whisper ASR models
123        registry.register_alias("whisper-tiny", "openai/whisper-tiny");
124        registry.register_alias("whisper-base", "openai/whisper-base");
125        registry.register_alias("whisper-small", "openai/whisper-small");
126        registry.register_alias("whisper-medium", "openai/whisper-medium");
127        registry.register_alias("whisper-large-v3", "openai/whisper-large-v3");
128        registry.register_alias("whisper-turbo", "openai/whisper-large-v3-turbo");
129        registry.register_alias("whisper-large-v3-turbo", "openai/whisper-large-v3-turbo");
130
131        registry
132    }
133
134    /// Register a model alias
135    pub fn register_alias(&mut self, alias: impl Into<String>, target: impl Into<String>) {
136        let alias_str = alias.into();
137        let target_str = target.into();
138        debug!("Registering alias: {} -> {}", alias_str, target_str);
139        self.aliases.insert(alias_str, target_str);
140    }
141
142    /// Add alias from struct
143    pub fn add_alias(&mut self, alias: ModelAlias) -> Result<()> {
144        self.register_alias(alias.name, alias.target);
145        Ok(())
146    }
147
148    /// Resolve model ID through aliases
149    pub fn resolve_model_id(&self, name: &str) -> String {
150        self.aliases
151            .get(name)
152            .cloned()
153            .unwrap_or_else(|| name.to_string())
154    }
155
156    /// List all registered aliases
157    pub fn list_aliases(&self) -> Vec<ModelAlias> {
158        self.aliases
159            .iter()
160            .map(|(name, target)| ModelAlias {
161                name: name.clone(),
162                target: target.clone(),
163                description: None,
164            })
165            .collect()
166    }
167
168    /// Discover models in a directory
169    pub async fn discover_models(&mut self, root: &Path) -> Result<Vec<ModelDiscoveryEntry>> {
170        if !root.exists() || !root.is_dir() {
171            return Ok(Vec::new());
172        }
173
174        info!("Discovering models in: {:?}", root);
175
176        let mut discovered = Vec::new();
177
178        // First check if root itself is a model directory
179        if let Some(model_entry) = self.inspect_model_dir(root).await {
180            discovered.push(model_entry);
181        } else {
182            // Otherwise scan subdirectories
183            if let Ok(entries) = std::fs::read_dir(root) {
184                for entry in entries.filter_map(|e| e.ok()) {
185                    let path = entry.path();
186                    if path.is_dir() {
187                        if let Some(model_entry) = self.inspect_model_dir(&path).await {
188                            discovered.push(model_entry);
189                        }
190                    }
191                }
192            }
193        }
194
195        self.discovered_models = discovered.clone();
196        Ok(discovered)
197    }
198
199    /// Inspect a directory to see if it contains a model
200    async fn inspect_model_dir(&self, path: &Path) -> Option<ModelDiscoveryEntry> {
201        // Check for config.json
202        let config_path = path.join("config.json");
203        if !config_path.exists() {
204            debug!("No config.json in: {:?}", path);
205            return None;
206        }
207
208        // Detect format
209        let format = self.detect_model_format(path);
210        if format == ModelFormatType::Unknown {
211            debug!("Unknown format in: {:?}", path);
212            return None;
213        }
214
215        debug!("Found valid model at: {:?}, format: {:?}", path, format);
216
217        // Try to read architecture from config
218        let architecture = self.read_architecture(&config_path);
219
220        // Extract model ID from path - try to get friendly name from parent directory
221        let id = if let Some(parent) = path.parent() {
222            if let Some(grandparent) = parent.parent() {
223                // Extract from models--org--name format
224                if let Some(name) = grandparent.file_name().and_then(|n| n.to_str()) {
225                    if name.starts_with("models--") {
226                        name[8..].replace("--", "/")
227                    } else {
228                        path.file_name()
229                            .and_then(|n| n.to_str())
230                            .unwrap_or("unknown")
231                            .to_string()
232                    }
233                } else {
234                    path.file_name()
235                        .and_then(|n| n.to_str())
236                        .unwrap_or("unknown")
237                        .to_string()
238                }
239            } else {
240                path.file_name()
241                    .and_then(|n| n.to_str())
242                    .unwrap_or("unknown")
243                    .to_string()
244            }
245        } else {
246            path.file_name()
247                .and_then(|n| n.to_str())
248                .unwrap_or("unknown")
249                .to_string()
250        };
251
252        Some(ModelDiscoveryEntry {
253            id,
254            path: path.to_path_buf(),
255            format,
256            architecture,
257            is_valid: true,
258        })
259    }
260
261    /// Detect model format in directory
262    fn detect_model_format(&self, path: &Path) -> ModelFormatType {
263        if path.join("model.safetensors").exists()
264            || path.join("model.safetensors.index.json").exists()
265        {
266            ModelFormatType::SafeTensors
267        } else if path.join("pytorch_model.bin").exists()
268            || path.join("pytorch_model.bin.index.json").exists()
269        {
270            ModelFormatType::PyTorch
271        } else if std::fs::read_dir(path)
272            .ok()
273            .and_then(|entries| {
274                entries
275                    .filter_map(|e| e.ok())
276                    .find(|e| e.path().extension().and_then(|s| s.to_str()) == Some("gguf"))
277            })
278            .is_some()
279        {
280            ModelFormatType::GGUF
281        } else {
282            ModelFormatType::Unknown
283        }
284    }
285
286    /// Read architecture type from config.json
287    fn read_architecture(&self, config_path: &Path) -> Option<Architecture> {
288        let content = std::fs::read_to_string(config_path).ok()?;
289        let config: serde_json::Value = serde_json::from_str(&content).ok()?;
290
291        // Try "model_type" field
292        if let Some(model_type) = config.get("model_type").and_then(|v| v.as_str()) {
293            return Some(Architecture::from_str(model_type));
294        }
295
296        // Try "architectures" array
297        if let Some(architectures) = config.get("architectures").and_then(|v| v.as_array()) {
298            if let Some(arch) = architectures.first().and_then(|v| v.as_str()) {
299                return Some(Architecture::from_str(arch));
300            }
301        }
302
303        None
304    }
305}
306
307impl Default for DefaultModelRegistry {
308    fn default() -> Self {
309        Self::new()
310    }
311}
312
313// ============================================================================
314// 内联单元测试
315// ============================================================================
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320
321    #[test]
322    fn test_architecture_from_str() {
323        assert_eq!(Architecture::from_str("llama"), Architecture::Llama);
324        assert_eq!(
325            Architecture::from_str("LlamaForCausalLM"),
326            Architecture::Llama
327        );
328        assert_eq!(Architecture::from_str("qwen2"), Architecture::Qwen2);
329        assert_eq!(Architecture::from_str("mistral"), Architecture::Mistral);
330        assert_eq!(Architecture::from_str("phi"), Architecture::Phi);
331        assert_eq!(Architecture::from_str("gpt2"), Architecture::GPT2);
332        assert_eq!(
333            Architecture::from_str("unknown_arch"),
334            Architecture::Unknown
335        );
336    }
337
338    #[test]
339    fn test_architecture_copy() {
340        let arch = Architecture::Llama;
341        let arch2 = arch;
342        assert_eq!(arch, arch2);
343    }
344
345    #[test]
346    fn test_model_format_type_eq() {
347        assert_eq!(ModelFormatType::SafeTensors, ModelFormatType::SafeTensors);
348        assert_ne!(ModelFormatType::SafeTensors, ModelFormatType::PyTorch);
349    }
350
351    #[test]
352    fn test_model_alias_creation() {
353        let alias = ModelAlias {
354            name: "test".to_string(),
355            target: "test/model".to_string(),
356            description: Some("Test model".to_string()),
357        };
358
359        assert_eq!(alias.name, "test");
360        assert_eq!(alias.target, "test/model");
361        assert!(alias.description.is_some());
362    }
363
364    #[test]
365    fn test_model_alias_clone() {
366        let alias = ModelAlias {
367            name: "test".to_string(),
368            target: "test/model".to_string(),
369            description: None,
370        };
371
372        let cloned = alias.clone();
373        assert_eq!(alias.name, cloned.name);
374        assert_eq!(alias.target, cloned.target);
375    }
376
377    #[test]
378    fn test_model_discovery_entry() {
379        let entry = ModelDiscoveryEntry {
380            id: "test-model".to_string(),
381            path: PathBuf::from("/path/to/model"),
382            format: ModelFormatType::SafeTensors,
383            architecture: Some(Architecture::Llama),
384            is_valid: true,
385        };
386
387        assert_eq!(entry.id, "test-model");
388        assert_eq!(entry.format, ModelFormatType::SafeTensors);
389        assert!(entry.is_valid);
390    }
391
392    #[test]
393    fn test_registry_creation() {
394        let registry = DefaultModelRegistry::new();
395        assert_eq!(registry.aliases.len(), 0);
396        assert_eq!(registry.discovered_models.len(), 0);
397    }
398
399    #[test]
400    fn test_registry_default() {
401        let registry = DefaultModelRegistry::default();
402        assert_eq!(registry.aliases.len(), 0);
403    }
404
405    #[test]
406    fn test_registry_with_defaults() {
407        let registry = DefaultModelRegistry::with_defaults();
408
409        // 应该有一些默认别名
410        assert!(registry.aliases.len() > 0);
411
412        // 测试一些常见别名
413        assert!(registry.aliases.contains_key("tinyllama"));
414        assert!(registry.aliases.contains_key("llama2-7b"));
415    }
416
417    #[test]
418    fn test_registry_register_alias() {
419        let mut registry = DefaultModelRegistry::new();
420
421        registry.register_alias("test", "test/model");
422
423        assert_eq!(
424            registry.aliases.get("test"),
425            Some(&"test/model".to_string())
426        );
427    }
428
429    #[test]
430    fn test_registry_resolve_model_id() {
431        let mut registry = DefaultModelRegistry::new();
432
433        registry.register_alias("mymodel", "org/actual-model");
434
435        let resolved = registry.resolve_model_id("mymodel");
436        assert_eq!(resolved, "org/actual-model");
437
438        // 未注册的别名应该返回原始值
439        let unresolved = registry.resolve_model_id("unknown");
440        assert_eq!(unresolved, "unknown");
441    }
442
443    #[test]
444    fn test_registry_list_aliases() {
445        let mut registry = DefaultModelRegistry::new();
446
447        registry.register_alias("model1", "org/model1");
448        registry.register_alias("model2", "org/model2");
449
450        let aliases = registry.list_aliases();
451        assert_eq!(aliases.len(), 2);
452    }
453
454    #[test]
455    fn test_architecture_debug() {
456        let arch = Architecture::Llama;
457        let debug_str = format!("{:?}", arch);
458        assert!(debug_str.contains("Llama"));
459    }
460
461    #[test]
462    fn test_model_format_debug() {
463        let format = ModelFormatType::SafeTensors;
464        let debug_str = format!("{:?}", format);
465        assert!(debug_str.contains("SafeTensors"));
466    }
467
468    #[test]
469    fn test_model_discovery_entry_clone() {
470        let entry = ModelDiscoveryEntry {
471            id: "test".to_string(),
472            path: PathBuf::from("/path"),
473            format: ModelFormatType::GGUF,
474            architecture: Some(Architecture::Mistral),
475            is_valid: false,
476        };
477
478        let cloned = entry.clone();
479        assert_eq!(entry.id, cloned.id);
480        assert_eq!(entry.format, cloned.format);
481        assert_eq!(entry.is_valid, cloned.is_valid);
482    }
483
484    #[test]
485    fn test_registry_multiple_aliases_same_target() {
486        let mut registry = DefaultModelRegistry::new();
487
488        registry.register_alias("alias1", "org/model");
489        registry.register_alias("alias2", "org/model");
490
491        assert_eq!(registry.resolve_model_id("alias1"), "org/model");
492        assert_eq!(registry.resolve_model_id("alias2"), "org/model");
493    }
494
495    #[test]
496    fn test_architecture_serialization() {
497        let arch = Architecture::Qwen2;
498        let json = serde_json::to_string(&arch).unwrap();
499        assert!(json.contains("Qwen2"));
500
501        let deserialized: Architecture = serde_json::from_str(&json).unwrap();
502        assert_eq!(deserialized, arch);
503    }
504}