Skip to main content

ferrum_models/
registry.rs

1//! Model registry and alias management
2
3use ferrum_types::Result;
4use std::collections::HashMap;
5use std::path::{Path, PathBuf};
6use tracing::{debug, info};
7
8/// Model alias entry
9#[derive(Debug, Clone)]
10pub struct ModelAlias {
11    /// Alias name (short name)
12    pub name: String,
13    /// Target model identifier
14    pub target: String,
15    /// Optional description
16    pub description: Option<String>,
17}
18
19/// Architecture types for models
20#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
21pub enum Architecture {
22    Llama,
23    Qwen2,
24    Qwen3,
25    Mistral,
26    Phi,
27    GPT2,
28    Bert,
29    Clip,
30    Whisper,
31    Qwen3TTS,
32    Unknown,
33}
34
35impl Architecture {
36    pub fn from_str(s: &str) -> Self {
37        match s.to_lowercase().as_str() {
38            "llama" | "llamaforcausallm" => Architecture::Llama,
39            "qwen2" | "qwen2forcausallm" => Architecture::Qwen2,
40            "qwen3" | "qwen3forcausallm" => Architecture::Qwen3,
41            "mistral" | "mistralforcausallm" => Architecture::Mistral,
42            "phi" | "phiforcausallm" => Architecture::Phi,
43            "gpt2" | "gpt2lmheadmodel" => Architecture::GPT2,
44            "bert" | "bertmodel" | "bertformaskedlm" | "bertforsequenceclassification" => {
45                Architecture::Bert
46            }
47            "clip" | "clipmodel" => Architecture::Clip,
48            "chinese_clip" | "chineseclipmodel" => Architecture::Clip,
49            "siglip" | "siglipmodel" => Architecture::Clip,
50            "whisper" | "whisperforconditionalgeneration" => Architecture::Whisper,
51            "qwen3_tts" | "qwen3ttsforconditionalgeneration" => Architecture::Qwen3TTS,
52            _ => Architecture::Unknown,
53        }
54    }
55}
56
57/// Model format type
58#[derive(Debug, Clone, Copy, PartialEq, Eq)]
59pub enum ModelFormatType {
60    SafeTensors,
61    PyTorch,
62    GGUF,
63    Unknown,
64}
65
66/// Discovered model entry
67#[derive(Debug, Clone)]
68pub struct ModelDiscoveryEntry {
69    /// Model identifier
70    pub id: String,
71    /// Local path to model
72    pub path: PathBuf,
73    /// Model format
74    pub format: ModelFormatType,
75    /// Architecture type (if detected)
76    pub architecture: Option<Architecture>,
77    /// Whether model passes validation
78    pub is_valid: bool,
79}
80
81/// Model registry for managing models and aliases
82#[derive(Debug)]
83pub struct DefaultModelRegistry {
84    /// Model aliases
85    aliases: HashMap<String, String>,
86    /// Discovered models cache
87    discovered_models: Vec<ModelDiscoveryEntry>,
88}
89
90impl DefaultModelRegistry {
91    /// Create new empty registry
92    pub fn new() -> Self {
93        Self {
94            aliases: HashMap::new(),
95            discovered_models: Vec::new(),
96        }
97    }
98
99    /// Create registry with common aliases
100    pub fn with_defaults() -> Self {
101        let mut registry = Self::new();
102
103        // Common model aliases
104        registry.register_alias("tinyllama", "TinyLlama/TinyLlama-1.1B-Chat-v1.0");
105        registry.register_alias("llama2-7b", "meta-llama/Llama-2-7b-hf");
106        registry.register_alias("llama2-7b-chat", "meta-llama/Llama-2-7b-chat-hf");
107        registry.register_alias("llama3-8b", "meta-llama/Meta-Llama-3-8B");
108        registry.register_alias("llama3-8b-instruct", "meta-llama/Meta-Llama-3-8B-Instruct");
109        registry.register_alias("qwen2-7b", "Qwen/Qwen2-7B");
110        registry.register_alias("qwen2-7b-instruct", "Qwen/Qwen2-7B-Instruct");
111        registry.register_alias("qwen3-0.6b", "Qwen/Qwen3-0.6B");
112        registry.register_alias("qwen3-1.7b", "Qwen/Qwen3-1.7B");
113        registry.register_alias("qwen3-4b", "Qwen/Qwen3-4B");
114        registry.register_alias("mistral-7b", "mistralai/Mistral-7B-v0.1");
115        registry.register_alias("mistral-7b-instruct", "mistralai/Mistral-7B-Instruct-v0.2");
116        registry.register_alias("phi3-mini", "microsoft/Phi-3-mini-4k-instruct");
117
118        // Whisper ASR models
119        registry.register_alias("whisper-tiny", "openai/whisper-tiny");
120        registry.register_alias("whisper-base", "openai/whisper-base");
121        registry.register_alias("whisper-small", "openai/whisper-small");
122        registry.register_alias("whisper-medium", "openai/whisper-medium");
123        registry.register_alias("whisper-large-v3", "openai/whisper-large-v3");
124        registry.register_alias("whisper-turbo", "openai/whisper-large-v3-turbo");
125        registry.register_alias("whisper-large-v3-turbo", "openai/whisper-large-v3-turbo");
126
127        registry
128    }
129
130    /// Register a model alias
131    pub fn register_alias(&mut self, alias: impl Into<String>, target: impl Into<String>) {
132        let alias_str = alias.into();
133        let target_str = target.into();
134        debug!("Registering alias: {} -> {}", alias_str, target_str);
135        self.aliases.insert(alias_str, target_str);
136    }
137
138    /// Add alias from struct
139    pub fn add_alias(&mut self, alias: ModelAlias) -> Result<()> {
140        self.register_alias(alias.name, alias.target);
141        Ok(())
142    }
143
144    /// Resolve model ID through aliases
145    pub fn resolve_model_id(&self, name: &str) -> String {
146        self.aliases
147            .get(name)
148            .cloned()
149            .unwrap_or_else(|| name.to_string())
150    }
151
152    /// List all registered aliases
153    pub fn list_aliases(&self) -> Vec<ModelAlias> {
154        self.aliases
155            .iter()
156            .map(|(name, target)| ModelAlias {
157                name: name.clone(),
158                target: target.clone(),
159                description: None,
160            })
161            .collect()
162    }
163
164    /// Discover models in a directory
165    pub async fn discover_models(&mut self, root: &Path) -> Result<Vec<ModelDiscoveryEntry>> {
166        if !root.exists() || !root.is_dir() {
167            return Ok(Vec::new());
168        }
169
170        info!("Discovering models in: {:?}", root);
171
172        let mut discovered = Vec::new();
173
174        // First check if root itself is a model directory
175        if let Some(model_entry) = self.inspect_model_dir(root).await {
176            discovered.push(model_entry);
177        } else {
178            // Otherwise scan subdirectories
179            if let Ok(entries) = std::fs::read_dir(root) {
180                for entry in entries.filter_map(|e| e.ok()) {
181                    let path = entry.path();
182                    if path.is_dir() {
183                        if let Some(model_entry) = self.inspect_model_dir(&path).await {
184                            discovered.push(model_entry);
185                        }
186                    }
187                }
188            }
189        }
190
191        self.discovered_models = discovered.clone();
192        Ok(discovered)
193    }
194
195    /// Inspect a directory to see if it contains a model
196    async fn inspect_model_dir(&self, path: &Path) -> Option<ModelDiscoveryEntry> {
197        // Check for config.json
198        let config_path = path.join("config.json");
199        if !config_path.exists() {
200            debug!("No config.json in: {:?}", path);
201            return None;
202        }
203
204        // Detect format
205        let format = self.detect_model_format(path);
206        if format == ModelFormatType::Unknown {
207            debug!("Unknown format in: {:?}", path);
208            return None;
209        }
210
211        debug!("Found valid model at: {:?}, format: {:?}", path, format);
212
213        // Try to read architecture from config
214        let architecture = self.read_architecture(&config_path);
215
216        // Extract model ID from path - try to get friendly name from parent directory
217        let id = if let Some(parent) = path.parent() {
218            if let Some(grandparent) = parent.parent() {
219                // Extract from models--org--name format
220                if let Some(name) = grandparent.file_name().and_then(|n| n.to_str()) {
221                    if name.starts_with("models--") {
222                        name[8..].replace("--", "/")
223                    } else {
224                        path.file_name()
225                            .and_then(|n| n.to_str())
226                            .unwrap_or("unknown")
227                            .to_string()
228                    }
229                } else {
230                    path.file_name()
231                        .and_then(|n| n.to_str())
232                        .unwrap_or("unknown")
233                        .to_string()
234                }
235            } else {
236                path.file_name()
237                    .and_then(|n| n.to_str())
238                    .unwrap_or("unknown")
239                    .to_string()
240            }
241        } else {
242            path.file_name()
243                .and_then(|n| n.to_str())
244                .unwrap_or("unknown")
245                .to_string()
246        };
247
248        Some(ModelDiscoveryEntry {
249            id,
250            path: path.to_path_buf(),
251            format,
252            architecture,
253            is_valid: true,
254        })
255    }
256
257    /// Detect model format in directory
258    fn detect_model_format(&self, path: &Path) -> ModelFormatType {
259        if path.join("model.safetensors").exists()
260            || path.join("model.safetensors.index.json").exists()
261        {
262            ModelFormatType::SafeTensors
263        } else if path.join("pytorch_model.bin").exists()
264            || path.join("pytorch_model.bin.index.json").exists()
265        {
266            ModelFormatType::PyTorch
267        } else if std::fs::read_dir(path)
268            .ok()
269            .and_then(|entries| {
270                entries
271                    .filter_map(|e| e.ok())
272                    .find(|e| e.path().extension().and_then(|s| s.to_str()) == Some("gguf"))
273            })
274            .is_some()
275        {
276            ModelFormatType::GGUF
277        } else {
278            ModelFormatType::Unknown
279        }
280    }
281
282    /// Read architecture type from config.json
283    fn read_architecture(&self, config_path: &Path) -> Option<Architecture> {
284        let content = std::fs::read_to_string(config_path).ok()?;
285        let config: serde_json::Value = serde_json::from_str(&content).ok()?;
286
287        // Try "model_type" field
288        if let Some(model_type) = config.get("model_type").and_then(|v| v.as_str()) {
289            return Some(Architecture::from_str(model_type));
290        }
291
292        // Try "architectures" array
293        if let Some(architectures) = config.get("architectures").and_then(|v| v.as_array()) {
294            if let Some(arch) = architectures.first().and_then(|v| v.as_str()) {
295                return Some(Architecture::from_str(arch));
296            }
297        }
298
299        None
300    }
301}
302
303impl Default for DefaultModelRegistry {
304    fn default() -> Self {
305        Self::new()
306    }
307}
308
309// ============================================================================
310// 内联单元测试
311// ============================================================================
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316
317    #[test]
318    fn test_architecture_from_str() {
319        assert_eq!(Architecture::from_str("llama"), Architecture::Llama);
320        assert_eq!(
321            Architecture::from_str("LlamaForCausalLM"),
322            Architecture::Llama
323        );
324        assert_eq!(Architecture::from_str("qwen2"), Architecture::Qwen2);
325        assert_eq!(Architecture::from_str("mistral"), Architecture::Mistral);
326        assert_eq!(Architecture::from_str("phi"), Architecture::Phi);
327        assert_eq!(Architecture::from_str("gpt2"), Architecture::GPT2);
328        assert_eq!(
329            Architecture::from_str("unknown_arch"),
330            Architecture::Unknown
331        );
332    }
333
334    #[test]
335    fn test_architecture_copy() {
336        let arch = Architecture::Llama;
337        let arch2 = arch;
338        assert_eq!(arch, arch2);
339    }
340
341    #[test]
342    fn test_model_format_type_eq() {
343        assert_eq!(ModelFormatType::SafeTensors, ModelFormatType::SafeTensors);
344        assert_ne!(ModelFormatType::SafeTensors, ModelFormatType::PyTorch);
345    }
346
347    #[test]
348    fn test_model_alias_creation() {
349        let alias = ModelAlias {
350            name: "test".to_string(),
351            target: "test/model".to_string(),
352            description: Some("Test model".to_string()),
353        };
354
355        assert_eq!(alias.name, "test");
356        assert_eq!(alias.target, "test/model");
357        assert!(alias.description.is_some());
358    }
359
360    #[test]
361    fn test_model_alias_clone() {
362        let alias = ModelAlias {
363            name: "test".to_string(),
364            target: "test/model".to_string(),
365            description: None,
366        };
367
368        let cloned = alias.clone();
369        assert_eq!(alias.name, cloned.name);
370        assert_eq!(alias.target, cloned.target);
371    }
372
373    #[test]
374    fn test_model_discovery_entry() {
375        let entry = ModelDiscoveryEntry {
376            id: "test-model".to_string(),
377            path: PathBuf::from("/path/to/model"),
378            format: ModelFormatType::SafeTensors,
379            architecture: Some(Architecture::Llama),
380            is_valid: true,
381        };
382
383        assert_eq!(entry.id, "test-model");
384        assert_eq!(entry.format, ModelFormatType::SafeTensors);
385        assert!(entry.is_valid);
386    }
387
388    #[test]
389    fn test_registry_creation() {
390        let registry = DefaultModelRegistry::new();
391        assert_eq!(registry.aliases.len(), 0);
392        assert_eq!(registry.discovered_models.len(), 0);
393    }
394
395    #[test]
396    fn test_registry_default() {
397        let registry = DefaultModelRegistry::default();
398        assert_eq!(registry.aliases.len(), 0);
399    }
400
401    #[test]
402    fn test_registry_with_defaults() {
403        let registry = DefaultModelRegistry::with_defaults();
404
405        // 应该有一些默认别名
406        assert!(registry.aliases.len() > 0);
407
408        // 测试一些常见别名
409        assert!(registry.aliases.contains_key("tinyllama"));
410        assert!(registry.aliases.contains_key("llama2-7b"));
411    }
412
413    #[test]
414    fn test_registry_register_alias() {
415        let mut registry = DefaultModelRegistry::new();
416
417        registry.register_alias("test", "test/model");
418
419        assert_eq!(
420            registry.aliases.get("test"),
421            Some(&"test/model".to_string())
422        );
423    }
424
425    #[test]
426    fn test_registry_resolve_model_id() {
427        let mut registry = DefaultModelRegistry::new();
428
429        registry.register_alias("mymodel", "org/actual-model");
430
431        let resolved = registry.resolve_model_id("mymodel");
432        assert_eq!(resolved, "org/actual-model");
433
434        // 未注册的别名应该返回原始值
435        let unresolved = registry.resolve_model_id("unknown");
436        assert_eq!(unresolved, "unknown");
437    }
438
439    #[test]
440    fn test_registry_list_aliases() {
441        let mut registry = DefaultModelRegistry::new();
442
443        registry.register_alias("model1", "org/model1");
444        registry.register_alias("model2", "org/model2");
445
446        let aliases = registry.list_aliases();
447        assert_eq!(aliases.len(), 2);
448    }
449
450    #[test]
451    fn test_architecture_debug() {
452        let arch = Architecture::Llama;
453        let debug_str = format!("{:?}", arch);
454        assert!(debug_str.contains("Llama"));
455    }
456
457    #[test]
458    fn test_model_format_debug() {
459        let format = ModelFormatType::SafeTensors;
460        let debug_str = format!("{:?}", format);
461        assert!(debug_str.contains("SafeTensors"));
462    }
463
464    #[test]
465    fn test_model_discovery_entry_clone() {
466        let entry = ModelDiscoveryEntry {
467            id: "test".to_string(),
468            path: PathBuf::from("/path"),
469            format: ModelFormatType::GGUF,
470            architecture: Some(Architecture::Mistral),
471            is_valid: false,
472        };
473
474        let cloned = entry.clone();
475        assert_eq!(entry.id, cloned.id);
476        assert_eq!(entry.format, cloned.format);
477        assert_eq!(entry.is_valid, cloned.is_valid);
478    }
479
480    #[test]
481    fn test_registry_multiple_aliases_same_target() {
482        let mut registry = DefaultModelRegistry::new();
483
484        registry.register_alias("alias1", "org/model");
485        registry.register_alias("alias2", "org/model");
486
487        assert_eq!(registry.resolve_model_id("alias1"), "org/model");
488        assert_eq!(registry.resolve_model_id("alias2"), "org/model");
489    }
490
491    #[test]
492    fn test_architecture_serialization() {
493        let arch = Architecture::Qwen2;
494        let json = serde_json::to_string(&arch).unwrap();
495        assert!(json.contains("Qwen2"));
496
497        let deserialized: Architecture = serde_json::from_str(&json).unwrap();
498        assert_eq!(deserialized, arch);
499    }
500}