Skip to main content

ferrum_models/
registry.rs

1//! Model registry and alias management
2
3use ferrum_types::Result;
4use std::collections::HashMap;
5use std::path::{Path, PathBuf};
6use tracing::{debug, info};
7
8/// Model alias entry
9#[derive(Debug, Clone)]
10pub struct ModelAlias {
11    /// Alias name (short name)
12    pub name: String,
13    /// Target model identifier
14    pub target: String,
15    /// Optional description
16    pub description: Option<String>,
17}
18
19/// Architecture types for models
20#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
21pub enum Architecture {
22    Llama,
23    Qwen2,
24    Qwen3,
25    Mistral,
26    Phi,
27    GPT2,
28    Bert,
29    Unknown,
30}
31
32impl Architecture {
33    pub fn from_str(s: &str) -> Self {
34        match s.to_lowercase().as_str() {
35            "llama" | "llamaforcausallm" => Architecture::Llama,
36            "qwen2" | "qwen2forcausallm" => Architecture::Qwen2,
37            "qwen3" | "qwen3forcausallm" => Architecture::Qwen3,
38            "mistral" | "mistralforcausallm" => Architecture::Mistral,
39            "phi" | "phiforcausallm" => Architecture::Phi,
40            "gpt2" | "gpt2lmheadmodel" => Architecture::GPT2,
41            "bert" | "bertmodel" | "bertformaskedlm" | "bertforsequenceclassification" => {
42                Architecture::Bert
43            }
44            _ => Architecture::Unknown,
45        }
46    }
47}
48
49/// Model format type
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
51pub enum ModelFormatType {
52    SafeTensors,
53    PyTorch,
54    GGUF,
55    Unknown,
56}
57
58/// Discovered model entry
59#[derive(Debug, Clone)]
60pub struct ModelDiscoveryEntry {
61    /// Model identifier
62    pub id: String,
63    /// Local path to model
64    pub path: PathBuf,
65    /// Model format
66    pub format: ModelFormatType,
67    /// Architecture type (if detected)
68    pub architecture: Option<Architecture>,
69    /// Whether model passes validation
70    pub is_valid: bool,
71}
72
73/// Model registry for managing models and aliases
74#[derive(Debug)]
75pub struct DefaultModelRegistry {
76    /// Model aliases
77    aliases: HashMap<String, String>,
78    /// Discovered models cache
79    discovered_models: Vec<ModelDiscoveryEntry>,
80}
81
82impl DefaultModelRegistry {
83    /// Create new empty registry
84    pub fn new() -> Self {
85        Self {
86            aliases: HashMap::new(),
87            discovered_models: Vec::new(),
88        }
89    }
90
91    /// Create registry with common aliases
92    pub fn with_defaults() -> Self {
93        let mut registry = Self::new();
94
95        // Common model aliases
96        registry.register_alias("tinyllama", "TinyLlama/TinyLlama-1.1B-Chat-v1.0");
97        registry.register_alias("llama2-7b", "meta-llama/Llama-2-7b-hf");
98        registry.register_alias("llama2-7b-chat", "meta-llama/Llama-2-7b-chat-hf");
99        registry.register_alias("llama3-8b", "meta-llama/Meta-Llama-3-8B");
100        registry.register_alias("llama3-8b-instruct", "meta-llama/Meta-Llama-3-8B-Instruct");
101        registry.register_alias("qwen2-7b", "Qwen/Qwen2-7B");
102        registry.register_alias("qwen2-7b-instruct", "Qwen/Qwen2-7B-Instruct");
103        registry.register_alias("qwen3-0.6b", "Qwen/Qwen3-0.6B");
104        registry.register_alias("qwen3-1.7b", "Qwen/Qwen3-1.7B");
105        registry.register_alias("qwen3-4b", "Qwen/Qwen3-4B");
106        registry.register_alias("mistral-7b", "mistralai/Mistral-7B-v0.1");
107        registry.register_alias("mistral-7b-instruct", "mistralai/Mistral-7B-Instruct-v0.2");
108        registry.register_alias("phi3-mini", "microsoft/Phi-3-mini-4k-instruct");
109
110        registry
111    }
112
113    /// Register a model alias
114    pub fn register_alias(&mut self, alias: impl Into<String>, target: impl Into<String>) {
115        let alias_str = alias.into();
116        let target_str = target.into();
117        debug!("Registering alias: {} -> {}", alias_str, target_str);
118        self.aliases.insert(alias_str, target_str);
119    }
120
121    /// Add alias from struct
122    pub fn add_alias(&mut self, alias: ModelAlias) -> Result<()> {
123        self.register_alias(alias.name, alias.target);
124        Ok(())
125    }
126
127    /// Resolve model ID through aliases
128    pub fn resolve_model_id(&self, name: &str) -> String {
129        self.aliases
130            .get(name)
131            .cloned()
132            .unwrap_or_else(|| name.to_string())
133    }
134
135    /// List all registered aliases
136    pub fn list_aliases(&self) -> Vec<ModelAlias> {
137        self.aliases
138            .iter()
139            .map(|(name, target)| ModelAlias {
140                name: name.clone(),
141                target: target.clone(),
142                description: None,
143            })
144            .collect()
145    }
146
147    /// Discover models in a directory
148    pub async fn discover_models(&mut self, root: &Path) -> Result<Vec<ModelDiscoveryEntry>> {
149        if !root.exists() || !root.is_dir() {
150            return Ok(Vec::new());
151        }
152
153        info!("Discovering models in: {:?}", root);
154
155        let mut discovered = Vec::new();
156
157        // First check if root itself is a model directory
158        if let Some(model_entry) = self.inspect_model_dir(root).await {
159            discovered.push(model_entry);
160        } else {
161            // Otherwise scan subdirectories
162            if let Ok(entries) = std::fs::read_dir(root) {
163                for entry in entries.filter_map(|e| e.ok()) {
164                    let path = entry.path();
165                    if path.is_dir() {
166                        if let Some(model_entry) = self.inspect_model_dir(&path).await {
167                            discovered.push(model_entry);
168                        }
169                    }
170                }
171            }
172        }
173
174        self.discovered_models = discovered.clone();
175        Ok(discovered)
176    }
177
178    /// Inspect a directory to see if it contains a model
179    async fn inspect_model_dir(&self, path: &Path) -> Option<ModelDiscoveryEntry> {
180        // Check for config.json
181        let config_path = path.join("config.json");
182        if !config_path.exists() {
183            debug!("No config.json in: {:?}", path);
184            return None;
185        }
186
187        // Detect format
188        let format = self.detect_model_format(path);
189        if format == ModelFormatType::Unknown {
190            debug!("Unknown format in: {:?}", path);
191            return None;
192        }
193
194        debug!("Found valid model at: {:?}, format: {:?}", path, format);
195
196        // Try to read architecture from config
197        let architecture = self.read_architecture(&config_path);
198
199        // Extract model ID from path - try to get friendly name from parent directory
200        let id = if let Some(parent) = path.parent() {
201            if let Some(grandparent) = parent.parent() {
202                // Extract from models--org--name format
203                if let Some(name) = grandparent.file_name().and_then(|n| n.to_str()) {
204                    if name.starts_with("models--") {
205                        name[8..].replace("--", "/")
206                    } else {
207                        path.file_name()
208                            .and_then(|n| n.to_str())
209                            .unwrap_or("unknown")
210                            .to_string()
211                    }
212                } else {
213                    path.file_name()
214                        .and_then(|n| n.to_str())
215                        .unwrap_or("unknown")
216                        .to_string()
217                }
218            } else {
219                path.file_name()
220                    .and_then(|n| n.to_str())
221                    .unwrap_or("unknown")
222                    .to_string()
223            }
224        } else {
225            path.file_name()
226                .and_then(|n| n.to_str())
227                .unwrap_or("unknown")
228                .to_string()
229        };
230
231        Some(ModelDiscoveryEntry {
232            id,
233            path: path.to_path_buf(),
234            format,
235            architecture,
236            is_valid: true,
237        })
238    }
239
240    /// Detect model format in directory
241    fn detect_model_format(&self, path: &Path) -> ModelFormatType {
242        if path.join("model.safetensors").exists()
243            || path.join("model.safetensors.index.json").exists()
244        {
245            ModelFormatType::SafeTensors
246        } else if path.join("pytorch_model.bin").exists()
247            || path.join("pytorch_model.bin.index.json").exists()
248        {
249            ModelFormatType::PyTorch
250        } else if std::fs::read_dir(path)
251            .ok()
252            .and_then(|entries| {
253                entries
254                    .filter_map(|e| e.ok())
255                    .find(|e| e.path().extension().and_then(|s| s.to_str()) == Some("gguf"))
256            })
257            .is_some()
258        {
259            ModelFormatType::GGUF
260        } else {
261            ModelFormatType::Unknown
262        }
263    }
264
265    /// Read architecture type from config.json
266    fn read_architecture(&self, config_path: &Path) -> Option<Architecture> {
267        let content = std::fs::read_to_string(config_path).ok()?;
268        let config: serde_json::Value = serde_json::from_str(&content).ok()?;
269
270        // Try "model_type" field
271        if let Some(model_type) = config.get("model_type").and_then(|v| v.as_str()) {
272            return Some(Architecture::from_str(model_type));
273        }
274
275        // Try "architectures" array
276        if let Some(architectures) = config.get("architectures").and_then(|v| v.as_array()) {
277            if let Some(arch) = architectures.first().and_then(|v| v.as_str()) {
278                return Some(Architecture::from_str(arch));
279            }
280        }
281
282        None
283    }
284}
285
286impl Default for DefaultModelRegistry {
287    fn default() -> Self {
288        Self::new()
289    }
290}
291
292// ============================================================================
293// 内联单元测试
294// ============================================================================
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299
300    #[test]
301    fn test_architecture_from_str() {
302        assert_eq!(Architecture::from_str("llama"), Architecture::Llama);
303        assert_eq!(
304            Architecture::from_str("LlamaForCausalLM"),
305            Architecture::Llama
306        );
307        assert_eq!(Architecture::from_str("qwen2"), Architecture::Qwen2);
308        assert_eq!(Architecture::from_str("mistral"), Architecture::Mistral);
309        assert_eq!(Architecture::from_str("phi"), Architecture::Phi);
310        assert_eq!(Architecture::from_str("gpt2"), Architecture::GPT2);
311        assert_eq!(
312            Architecture::from_str("unknown_arch"),
313            Architecture::Unknown
314        );
315    }
316
317    #[test]
318    fn test_architecture_copy() {
319        let arch = Architecture::Llama;
320        let arch2 = arch;
321        assert_eq!(arch, arch2);
322    }
323
324    #[test]
325    fn test_model_format_type_eq() {
326        assert_eq!(ModelFormatType::SafeTensors, ModelFormatType::SafeTensors);
327        assert_ne!(ModelFormatType::SafeTensors, ModelFormatType::PyTorch);
328    }
329
330    #[test]
331    fn test_model_alias_creation() {
332        let alias = ModelAlias {
333            name: "test".to_string(),
334            target: "test/model".to_string(),
335            description: Some("Test model".to_string()),
336        };
337
338        assert_eq!(alias.name, "test");
339        assert_eq!(alias.target, "test/model");
340        assert!(alias.description.is_some());
341    }
342
343    #[test]
344    fn test_model_alias_clone() {
345        let alias = ModelAlias {
346            name: "test".to_string(),
347            target: "test/model".to_string(),
348            description: None,
349        };
350
351        let cloned = alias.clone();
352        assert_eq!(alias.name, cloned.name);
353        assert_eq!(alias.target, cloned.target);
354    }
355
356    #[test]
357    fn test_model_discovery_entry() {
358        let entry = ModelDiscoveryEntry {
359            id: "test-model".to_string(),
360            path: PathBuf::from("/path/to/model"),
361            format: ModelFormatType::SafeTensors,
362            architecture: Some(Architecture::Llama),
363            is_valid: true,
364        };
365
366        assert_eq!(entry.id, "test-model");
367        assert_eq!(entry.format, ModelFormatType::SafeTensors);
368        assert!(entry.is_valid);
369    }
370
371    #[test]
372    fn test_registry_creation() {
373        let registry = DefaultModelRegistry::new();
374        assert_eq!(registry.aliases.len(), 0);
375        assert_eq!(registry.discovered_models.len(), 0);
376    }
377
378    #[test]
379    fn test_registry_default() {
380        let registry = DefaultModelRegistry::default();
381        assert_eq!(registry.aliases.len(), 0);
382    }
383
384    #[test]
385    fn test_registry_with_defaults() {
386        let registry = DefaultModelRegistry::with_defaults();
387
388        // 应该有一些默认别名
389        assert!(registry.aliases.len() > 0);
390
391        // 测试一些常见别名
392        assert!(registry.aliases.contains_key("tinyllama"));
393        assert!(registry.aliases.contains_key("llama2-7b"));
394    }
395
396    #[test]
397    fn test_registry_register_alias() {
398        let mut registry = DefaultModelRegistry::new();
399
400        registry.register_alias("test", "test/model");
401
402        assert_eq!(
403            registry.aliases.get("test"),
404            Some(&"test/model".to_string())
405        );
406    }
407
408    #[test]
409    fn test_registry_resolve_model_id() {
410        let mut registry = DefaultModelRegistry::new();
411
412        registry.register_alias("mymodel", "org/actual-model");
413
414        let resolved = registry.resolve_model_id("mymodel");
415        assert_eq!(resolved, "org/actual-model");
416
417        // 未注册的别名应该返回原始值
418        let unresolved = registry.resolve_model_id("unknown");
419        assert_eq!(unresolved, "unknown");
420    }
421
422    #[test]
423    fn test_registry_list_aliases() {
424        let mut registry = DefaultModelRegistry::new();
425
426        registry.register_alias("model1", "org/model1");
427        registry.register_alias("model2", "org/model2");
428
429        let aliases = registry.list_aliases();
430        assert_eq!(aliases.len(), 2);
431    }
432
433    #[test]
434    fn test_architecture_debug() {
435        let arch = Architecture::Llama;
436        let debug_str = format!("{:?}", arch);
437        assert!(debug_str.contains("Llama"));
438    }
439
440    #[test]
441    fn test_model_format_debug() {
442        let format = ModelFormatType::SafeTensors;
443        let debug_str = format!("{:?}", format);
444        assert!(debug_str.contains("SafeTensors"));
445    }
446
447    #[test]
448    fn test_model_discovery_entry_clone() {
449        let entry = ModelDiscoveryEntry {
450            id: "test".to_string(),
451            path: PathBuf::from("/path"),
452            format: ModelFormatType::GGUF,
453            architecture: Some(Architecture::Mistral),
454            is_valid: false,
455        };
456
457        let cloned = entry.clone();
458        assert_eq!(entry.id, cloned.id);
459        assert_eq!(entry.format, cloned.format);
460        assert_eq!(entry.is_valid, cloned.is_valid);
461    }
462
463    #[test]
464    fn test_registry_multiple_aliases_same_target() {
465        let mut registry = DefaultModelRegistry::new();
466
467        registry.register_alias("alias1", "org/model");
468        registry.register_alias("alias2", "org/model");
469
470        assert_eq!(registry.resolve_model_id("alias1"), "org/model");
471        assert_eq!(registry.resolve_model_id("alias2"), "org/model");
472    }
473
474    #[test]
475    fn test_architecture_serialization() {
476        let arch = Architecture::Qwen2;
477        let json = serde_json::to_string(&arch).unwrap();
478        assert!(json.contains("Qwen2"));
479
480        let deserialized: Architecture = serde_json::from_str(&json).unwrap();
481        assert_eq!(deserialized, arch);
482    }
483}