Skip to main content

ferrum_models/
registry.rs

1//! Model registry and alias management
2
3use ferrum_types::Result;
4use std::collections::HashMap;
5use std::path::{Path, PathBuf};
6use tracing::{debug, info};
7
8/// Model alias entry
9#[derive(Debug, Clone)]
10pub struct ModelAlias {
11    /// Alias name (short name)
12    pub name: String,
13    /// Target model identifier
14    pub target: String,
15    /// Optional description
16    pub description: Option<String>,
17}
18
19/// Architecture types for models
20#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
21pub enum Architecture {
22    Llama,
23    Qwen2,
24    Qwen3,
25    Mistral,
26    Phi,
27    GPT2,
28    Bert,
29    Clip,
30    Unknown,
31}
32
33impl Architecture {
34    pub fn from_str(s: &str) -> Self {
35        match s.to_lowercase().as_str() {
36            "llama" | "llamaforcausallm" => Architecture::Llama,
37            "qwen2" | "qwen2forcausallm" => Architecture::Qwen2,
38            "qwen3" | "qwen3forcausallm" => Architecture::Qwen3,
39            "mistral" | "mistralforcausallm" => Architecture::Mistral,
40            "phi" | "phiforcausallm" => Architecture::Phi,
41            "gpt2" | "gpt2lmheadmodel" => Architecture::GPT2,
42            "bert" | "bertmodel" | "bertformaskedlm" | "bertforsequenceclassification" => {
43                Architecture::Bert
44            }
45            "clip" | "clipmodel" => Architecture::Clip,
46            "chinese_clip" | "chineseclipmodel" => Architecture::Clip,
47            "siglip" | "siglipmodel" => Architecture::Clip,
48            _ => Architecture::Unknown,
49        }
50    }
51}
52
53/// Model format type
54#[derive(Debug, Clone, Copy, PartialEq, Eq)]
55pub enum ModelFormatType {
56    SafeTensors,
57    PyTorch,
58    GGUF,
59    Unknown,
60}
61
62/// Discovered model entry
63#[derive(Debug, Clone)]
64pub struct ModelDiscoveryEntry {
65    /// Model identifier
66    pub id: String,
67    /// Local path to model
68    pub path: PathBuf,
69    /// Model format
70    pub format: ModelFormatType,
71    /// Architecture type (if detected)
72    pub architecture: Option<Architecture>,
73    /// Whether model passes validation
74    pub is_valid: bool,
75}
76
77/// Model registry for managing models and aliases
78#[derive(Debug)]
79pub struct DefaultModelRegistry {
80    /// Model aliases
81    aliases: HashMap<String, String>,
82    /// Discovered models cache
83    discovered_models: Vec<ModelDiscoveryEntry>,
84}
85
86impl DefaultModelRegistry {
87    /// Create new empty registry
88    pub fn new() -> Self {
89        Self {
90            aliases: HashMap::new(),
91            discovered_models: Vec::new(),
92        }
93    }
94
95    /// Create registry with common aliases
96    pub fn with_defaults() -> Self {
97        let mut registry = Self::new();
98
99        // Common model aliases
100        registry.register_alias("tinyllama", "TinyLlama/TinyLlama-1.1B-Chat-v1.0");
101        registry.register_alias("llama2-7b", "meta-llama/Llama-2-7b-hf");
102        registry.register_alias("llama2-7b-chat", "meta-llama/Llama-2-7b-chat-hf");
103        registry.register_alias("llama3-8b", "meta-llama/Meta-Llama-3-8B");
104        registry.register_alias("llama3-8b-instruct", "meta-llama/Meta-Llama-3-8B-Instruct");
105        registry.register_alias("qwen2-7b", "Qwen/Qwen2-7B");
106        registry.register_alias("qwen2-7b-instruct", "Qwen/Qwen2-7B-Instruct");
107        registry.register_alias("qwen3-0.6b", "Qwen/Qwen3-0.6B");
108        registry.register_alias("qwen3-1.7b", "Qwen/Qwen3-1.7B");
109        registry.register_alias("qwen3-4b", "Qwen/Qwen3-4B");
110        registry.register_alias("mistral-7b", "mistralai/Mistral-7B-v0.1");
111        registry.register_alias("mistral-7b-instruct", "mistralai/Mistral-7B-Instruct-v0.2");
112        registry.register_alias("phi3-mini", "microsoft/Phi-3-mini-4k-instruct");
113
114        registry
115    }
116
117    /// Register a model alias
118    pub fn register_alias(&mut self, alias: impl Into<String>, target: impl Into<String>) {
119        let alias_str = alias.into();
120        let target_str = target.into();
121        debug!("Registering alias: {} -> {}", alias_str, target_str);
122        self.aliases.insert(alias_str, target_str);
123    }
124
125    /// Add alias from struct
126    pub fn add_alias(&mut self, alias: ModelAlias) -> Result<()> {
127        self.register_alias(alias.name, alias.target);
128        Ok(())
129    }
130
131    /// Resolve model ID through aliases
132    pub fn resolve_model_id(&self, name: &str) -> String {
133        self.aliases
134            .get(name)
135            .cloned()
136            .unwrap_or_else(|| name.to_string())
137    }
138
139    /// List all registered aliases
140    pub fn list_aliases(&self) -> Vec<ModelAlias> {
141        self.aliases
142            .iter()
143            .map(|(name, target)| ModelAlias {
144                name: name.clone(),
145                target: target.clone(),
146                description: None,
147            })
148            .collect()
149    }
150
151    /// Discover models in a directory
152    pub async fn discover_models(&mut self, root: &Path) -> Result<Vec<ModelDiscoveryEntry>> {
153        if !root.exists() || !root.is_dir() {
154            return Ok(Vec::new());
155        }
156
157        info!("Discovering models in: {:?}", root);
158
159        let mut discovered = Vec::new();
160
161        // First check if root itself is a model directory
162        if let Some(model_entry) = self.inspect_model_dir(root).await {
163            discovered.push(model_entry);
164        } else {
165            // Otherwise scan subdirectories
166            if let Ok(entries) = std::fs::read_dir(root) {
167                for entry in entries.filter_map(|e| e.ok()) {
168                    let path = entry.path();
169                    if path.is_dir() {
170                        if let Some(model_entry) = self.inspect_model_dir(&path).await {
171                            discovered.push(model_entry);
172                        }
173                    }
174                }
175            }
176        }
177
178        self.discovered_models = discovered.clone();
179        Ok(discovered)
180    }
181
182    /// Inspect a directory to see if it contains a model
183    async fn inspect_model_dir(&self, path: &Path) -> Option<ModelDiscoveryEntry> {
184        // Check for config.json
185        let config_path = path.join("config.json");
186        if !config_path.exists() {
187            debug!("No config.json in: {:?}", path);
188            return None;
189        }
190
191        // Detect format
192        let format = self.detect_model_format(path);
193        if format == ModelFormatType::Unknown {
194            debug!("Unknown format in: {:?}", path);
195            return None;
196        }
197
198        debug!("Found valid model at: {:?}, format: {:?}", path, format);
199
200        // Try to read architecture from config
201        let architecture = self.read_architecture(&config_path);
202
203        // Extract model ID from path - try to get friendly name from parent directory
204        let id = if let Some(parent) = path.parent() {
205            if let Some(grandparent) = parent.parent() {
206                // Extract from models--org--name format
207                if let Some(name) = grandparent.file_name().and_then(|n| n.to_str()) {
208                    if name.starts_with("models--") {
209                        name[8..].replace("--", "/")
210                    } else {
211                        path.file_name()
212                            .and_then(|n| n.to_str())
213                            .unwrap_or("unknown")
214                            .to_string()
215                    }
216                } else {
217                    path.file_name()
218                        .and_then(|n| n.to_str())
219                        .unwrap_or("unknown")
220                        .to_string()
221                }
222            } else {
223                path.file_name()
224                    .and_then(|n| n.to_str())
225                    .unwrap_or("unknown")
226                    .to_string()
227            }
228        } else {
229            path.file_name()
230                .and_then(|n| n.to_str())
231                .unwrap_or("unknown")
232                .to_string()
233        };
234
235        Some(ModelDiscoveryEntry {
236            id,
237            path: path.to_path_buf(),
238            format,
239            architecture,
240            is_valid: true,
241        })
242    }
243
244    /// Detect model format in directory
245    fn detect_model_format(&self, path: &Path) -> ModelFormatType {
246        if path.join("model.safetensors").exists()
247            || path.join("model.safetensors.index.json").exists()
248        {
249            ModelFormatType::SafeTensors
250        } else if path.join("pytorch_model.bin").exists()
251            || path.join("pytorch_model.bin.index.json").exists()
252        {
253            ModelFormatType::PyTorch
254        } else if std::fs::read_dir(path)
255            .ok()
256            .and_then(|entries| {
257                entries
258                    .filter_map(|e| e.ok())
259                    .find(|e| e.path().extension().and_then(|s| s.to_str()) == Some("gguf"))
260            })
261            .is_some()
262        {
263            ModelFormatType::GGUF
264        } else {
265            ModelFormatType::Unknown
266        }
267    }
268
269    /// Read architecture type from config.json
270    fn read_architecture(&self, config_path: &Path) -> Option<Architecture> {
271        let content = std::fs::read_to_string(config_path).ok()?;
272        let config: serde_json::Value = serde_json::from_str(&content).ok()?;
273
274        // Try "model_type" field
275        if let Some(model_type) = config.get("model_type").and_then(|v| v.as_str()) {
276            return Some(Architecture::from_str(model_type));
277        }
278
279        // Try "architectures" array
280        if let Some(architectures) = config.get("architectures").and_then(|v| v.as_array()) {
281            if let Some(arch) = architectures.first().and_then(|v| v.as_str()) {
282                return Some(Architecture::from_str(arch));
283            }
284        }
285
286        None
287    }
288}
289
290impl Default for DefaultModelRegistry {
291    fn default() -> Self {
292        Self::new()
293    }
294}
295
296// ============================================================================
297// 内联单元测试
298// ============================================================================
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303
304    #[test]
305    fn test_architecture_from_str() {
306        assert_eq!(Architecture::from_str("llama"), Architecture::Llama);
307        assert_eq!(
308            Architecture::from_str("LlamaForCausalLM"),
309            Architecture::Llama
310        );
311        assert_eq!(Architecture::from_str("qwen2"), Architecture::Qwen2);
312        assert_eq!(Architecture::from_str("mistral"), Architecture::Mistral);
313        assert_eq!(Architecture::from_str("phi"), Architecture::Phi);
314        assert_eq!(Architecture::from_str("gpt2"), Architecture::GPT2);
315        assert_eq!(
316            Architecture::from_str("unknown_arch"),
317            Architecture::Unknown
318        );
319    }
320
321    #[test]
322    fn test_architecture_copy() {
323        let arch = Architecture::Llama;
324        let arch2 = arch;
325        assert_eq!(arch, arch2);
326    }
327
328    #[test]
329    fn test_model_format_type_eq() {
330        assert_eq!(ModelFormatType::SafeTensors, ModelFormatType::SafeTensors);
331        assert_ne!(ModelFormatType::SafeTensors, ModelFormatType::PyTorch);
332    }
333
334    #[test]
335    fn test_model_alias_creation() {
336        let alias = ModelAlias {
337            name: "test".to_string(),
338            target: "test/model".to_string(),
339            description: Some("Test model".to_string()),
340        };
341
342        assert_eq!(alias.name, "test");
343        assert_eq!(alias.target, "test/model");
344        assert!(alias.description.is_some());
345    }
346
347    #[test]
348    fn test_model_alias_clone() {
349        let alias = ModelAlias {
350            name: "test".to_string(),
351            target: "test/model".to_string(),
352            description: None,
353        };
354
355        let cloned = alias.clone();
356        assert_eq!(alias.name, cloned.name);
357        assert_eq!(alias.target, cloned.target);
358    }
359
360    #[test]
361    fn test_model_discovery_entry() {
362        let entry = ModelDiscoveryEntry {
363            id: "test-model".to_string(),
364            path: PathBuf::from("/path/to/model"),
365            format: ModelFormatType::SafeTensors,
366            architecture: Some(Architecture::Llama),
367            is_valid: true,
368        };
369
370        assert_eq!(entry.id, "test-model");
371        assert_eq!(entry.format, ModelFormatType::SafeTensors);
372        assert!(entry.is_valid);
373    }
374
375    #[test]
376    fn test_registry_creation() {
377        let registry = DefaultModelRegistry::new();
378        assert_eq!(registry.aliases.len(), 0);
379        assert_eq!(registry.discovered_models.len(), 0);
380    }
381
382    #[test]
383    fn test_registry_default() {
384        let registry = DefaultModelRegistry::default();
385        assert_eq!(registry.aliases.len(), 0);
386    }
387
388    #[test]
389    fn test_registry_with_defaults() {
390        let registry = DefaultModelRegistry::with_defaults();
391
392        // 应该有一些默认别名
393        assert!(registry.aliases.len() > 0);
394
395        // 测试一些常见别名
396        assert!(registry.aliases.contains_key("tinyllama"));
397        assert!(registry.aliases.contains_key("llama2-7b"));
398    }
399
400    #[test]
401    fn test_registry_register_alias() {
402        let mut registry = DefaultModelRegistry::new();
403
404        registry.register_alias("test", "test/model");
405
406        assert_eq!(
407            registry.aliases.get("test"),
408            Some(&"test/model".to_string())
409        );
410    }
411
412    #[test]
413    fn test_registry_resolve_model_id() {
414        let mut registry = DefaultModelRegistry::new();
415
416        registry.register_alias("mymodel", "org/actual-model");
417
418        let resolved = registry.resolve_model_id("mymodel");
419        assert_eq!(resolved, "org/actual-model");
420
421        // 未注册的别名应该返回原始值
422        let unresolved = registry.resolve_model_id("unknown");
423        assert_eq!(unresolved, "unknown");
424    }
425
426    #[test]
427    fn test_registry_list_aliases() {
428        let mut registry = DefaultModelRegistry::new();
429
430        registry.register_alias("model1", "org/model1");
431        registry.register_alias("model2", "org/model2");
432
433        let aliases = registry.list_aliases();
434        assert_eq!(aliases.len(), 2);
435    }
436
437    #[test]
438    fn test_architecture_debug() {
439        let arch = Architecture::Llama;
440        let debug_str = format!("{:?}", arch);
441        assert!(debug_str.contains("Llama"));
442    }
443
444    #[test]
445    fn test_model_format_debug() {
446        let format = ModelFormatType::SafeTensors;
447        let debug_str = format!("{:?}", format);
448        assert!(debug_str.contains("SafeTensors"));
449    }
450
451    #[test]
452    fn test_model_discovery_entry_clone() {
453        let entry = ModelDiscoveryEntry {
454            id: "test".to_string(),
455            path: PathBuf::from("/path"),
456            format: ModelFormatType::GGUF,
457            architecture: Some(Architecture::Mistral),
458            is_valid: false,
459        };
460
461        let cloned = entry.clone();
462        assert_eq!(entry.id, cloned.id);
463        assert_eq!(entry.format, cloned.format);
464        assert_eq!(entry.is_valid, cloned.is_valid);
465    }
466
467    #[test]
468    fn test_registry_multiple_aliases_same_target() {
469        let mut registry = DefaultModelRegistry::new();
470
471        registry.register_alias("alias1", "org/model");
472        registry.register_alias("alias2", "org/model");
473
474        assert_eq!(registry.resolve_model_id("alias1"), "org/model");
475        assert_eq!(registry.resolve_model_id("alias2"), "org/model");
476    }
477
478    #[test]
479    fn test_architecture_serialization() {
480        let arch = Architecture::Qwen2;
481        let json = serde_json::to_string(&arch).unwrap();
482        assert!(json.contains("Qwen2"));
483
484        let deserialized: Architecture = serde_json::from_str(&json).unwrap();
485        assert_eq!(deserialized, arch);
486    }
487}