1use ferrum_types::Result;
4use std::collections::HashMap;
5use std::path::{Path, PathBuf};
6use tracing::{debug, info};
7
8#[derive(Debug, Clone)]
10pub struct ModelAlias {
11 pub name: String,
13 pub target: String,
15 pub description: Option<String>,
17}
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
21pub enum Architecture {
22 Llama,
23 Qwen2,
24 Qwen3,
25 Mistral,
26 Phi,
27 GPT2,
28 Bert,
29 Unknown,
30}
31
32impl Architecture {
33 pub fn from_str(s: &str) -> Self {
34 match s.to_lowercase().as_str() {
35 "llama" | "llamaforcausallm" => Architecture::Llama,
36 "qwen2" | "qwen2forcausallm" => Architecture::Qwen2,
37 "qwen3" | "qwen3forcausallm" => Architecture::Qwen3,
38 "mistral" | "mistralforcausallm" => Architecture::Mistral,
39 "phi" | "phiforcausallm" => Architecture::Phi,
40 "gpt2" | "gpt2lmheadmodel" => Architecture::GPT2,
41 "bert" | "bertmodel" | "bertformaskedlm" | "bertforsequenceclassification" => {
42 Architecture::Bert
43 }
44 _ => Architecture::Unknown,
45 }
46 }
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
51pub enum ModelFormatType {
52 SafeTensors,
53 PyTorch,
54 GGUF,
55 Unknown,
56}
57
58#[derive(Debug, Clone)]
60pub struct ModelDiscoveryEntry {
61 pub id: String,
63 pub path: PathBuf,
65 pub format: ModelFormatType,
67 pub architecture: Option<Architecture>,
69 pub is_valid: bool,
71}
72
73#[derive(Debug)]
75pub struct DefaultModelRegistry {
76 aliases: HashMap<String, String>,
78 discovered_models: Vec<ModelDiscoveryEntry>,
80}
81
82impl DefaultModelRegistry {
83 pub fn new() -> Self {
85 Self {
86 aliases: HashMap::new(),
87 discovered_models: Vec::new(),
88 }
89 }
90
91 pub fn with_defaults() -> Self {
93 let mut registry = Self::new();
94
95 registry.register_alias("tinyllama", "TinyLlama/TinyLlama-1.1B-Chat-v1.0");
97 registry.register_alias("llama2-7b", "meta-llama/Llama-2-7b-hf");
98 registry.register_alias("llama2-7b-chat", "meta-llama/Llama-2-7b-chat-hf");
99 registry.register_alias("llama3-8b", "meta-llama/Meta-Llama-3-8B");
100 registry.register_alias("llama3-8b-instruct", "meta-llama/Meta-Llama-3-8B-Instruct");
101 registry.register_alias("qwen2-7b", "Qwen/Qwen2-7B");
102 registry.register_alias("qwen2-7b-instruct", "Qwen/Qwen2-7B-Instruct");
103 registry.register_alias("qwen3-0.6b", "Qwen/Qwen3-0.6B");
104 registry.register_alias("qwen3-1.7b", "Qwen/Qwen3-1.7B");
105 registry.register_alias("qwen3-4b", "Qwen/Qwen3-4B");
106 registry.register_alias("mistral-7b", "mistralai/Mistral-7B-v0.1");
107 registry.register_alias("mistral-7b-instruct", "mistralai/Mistral-7B-Instruct-v0.2");
108 registry.register_alias("phi3-mini", "microsoft/Phi-3-mini-4k-instruct");
109
110 registry
111 }
112
113 pub fn register_alias(&mut self, alias: impl Into<String>, target: impl Into<String>) {
115 let alias_str = alias.into();
116 let target_str = target.into();
117 debug!("Registering alias: {} -> {}", alias_str, target_str);
118 self.aliases.insert(alias_str, target_str);
119 }
120
121 pub fn add_alias(&mut self, alias: ModelAlias) -> Result<()> {
123 self.register_alias(alias.name, alias.target);
124 Ok(())
125 }
126
127 pub fn resolve_model_id(&self, name: &str) -> String {
129 self.aliases
130 .get(name)
131 .cloned()
132 .unwrap_or_else(|| name.to_string())
133 }
134
135 pub fn list_aliases(&self) -> Vec<ModelAlias> {
137 self.aliases
138 .iter()
139 .map(|(name, target)| ModelAlias {
140 name: name.clone(),
141 target: target.clone(),
142 description: None,
143 })
144 .collect()
145 }
146
147 pub async fn discover_models(&mut self, root: &Path) -> Result<Vec<ModelDiscoveryEntry>> {
149 if !root.exists() || !root.is_dir() {
150 return Ok(Vec::new());
151 }
152
153 info!("Discovering models in: {:?}", root);
154
155 let mut discovered = Vec::new();
156
157 if let Some(model_entry) = self.inspect_model_dir(root).await {
159 discovered.push(model_entry);
160 } else {
161 if let Ok(entries) = std::fs::read_dir(root) {
163 for entry in entries.filter_map(|e| e.ok()) {
164 let path = entry.path();
165 if path.is_dir() {
166 if let Some(model_entry) = self.inspect_model_dir(&path).await {
167 discovered.push(model_entry);
168 }
169 }
170 }
171 }
172 }
173
174 self.discovered_models = discovered.clone();
175 Ok(discovered)
176 }
177
178 async fn inspect_model_dir(&self, path: &Path) -> Option<ModelDiscoveryEntry> {
180 let config_path = path.join("config.json");
182 if !config_path.exists() {
183 debug!("No config.json in: {:?}", path);
184 return None;
185 }
186
187 let format = self.detect_model_format(path);
189 if format == ModelFormatType::Unknown {
190 debug!("Unknown format in: {:?}", path);
191 return None;
192 }
193
194 debug!("Found valid model at: {:?}, format: {:?}", path, format);
195
196 let architecture = self.read_architecture(&config_path);
198
199 let id = if let Some(parent) = path.parent() {
201 if let Some(grandparent) = parent.parent() {
202 if let Some(name) = grandparent.file_name().and_then(|n| n.to_str()) {
204 if name.starts_with("models--") {
205 name[8..].replace("--", "/")
206 } else {
207 path.file_name()
208 .and_then(|n| n.to_str())
209 .unwrap_or("unknown")
210 .to_string()
211 }
212 } else {
213 path.file_name()
214 .and_then(|n| n.to_str())
215 .unwrap_or("unknown")
216 .to_string()
217 }
218 } else {
219 path.file_name()
220 .and_then(|n| n.to_str())
221 .unwrap_or("unknown")
222 .to_string()
223 }
224 } else {
225 path.file_name()
226 .and_then(|n| n.to_str())
227 .unwrap_or("unknown")
228 .to_string()
229 };
230
231 Some(ModelDiscoveryEntry {
232 id,
233 path: path.to_path_buf(),
234 format,
235 architecture,
236 is_valid: true,
237 })
238 }
239
240 fn detect_model_format(&self, path: &Path) -> ModelFormatType {
242 if path.join("model.safetensors").exists()
243 || path.join("model.safetensors.index.json").exists()
244 {
245 ModelFormatType::SafeTensors
246 } else if path.join("pytorch_model.bin").exists()
247 || path.join("pytorch_model.bin.index.json").exists()
248 {
249 ModelFormatType::PyTorch
250 } else if std::fs::read_dir(path)
251 .ok()
252 .and_then(|entries| {
253 entries
254 .filter_map(|e| e.ok())
255 .find(|e| e.path().extension().and_then(|s| s.to_str()) == Some("gguf"))
256 })
257 .is_some()
258 {
259 ModelFormatType::GGUF
260 } else {
261 ModelFormatType::Unknown
262 }
263 }
264
265 fn read_architecture(&self, config_path: &Path) -> Option<Architecture> {
267 let content = std::fs::read_to_string(config_path).ok()?;
268 let config: serde_json::Value = serde_json::from_str(&content).ok()?;
269
270 if let Some(model_type) = config.get("model_type").and_then(|v| v.as_str()) {
272 return Some(Architecture::from_str(model_type));
273 }
274
275 if let Some(architectures) = config.get("architectures").and_then(|v| v.as_array()) {
277 if let Some(arch) = architectures.first().and_then(|v| v.as_str()) {
278 return Some(Architecture::from_str(arch));
279 }
280 }
281
282 None
283 }
284}
285
286impl Default for DefaultModelRegistry {
287 fn default() -> Self {
288 Self::new()
289 }
290}
291
292#[cfg(test)]
297mod tests {
298 use super::*;
299
300 #[test]
301 fn test_architecture_from_str() {
302 assert_eq!(Architecture::from_str("llama"), Architecture::Llama);
303 assert_eq!(
304 Architecture::from_str("LlamaForCausalLM"),
305 Architecture::Llama
306 );
307 assert_eq!(Architecture::from_str("qwen2"), Architecture::Qwen2);
308 assert_eq!(Architecture::from_str("mistral"), Architecture::Mistral);
309 assert_eq!(Architecture::from_str("phi"), Architecture::Phi);
310 assert_eq!(Architecture::from_str("gpt2"), Architecture::GPT2);
311 assert_eq!(
312 Architecture::from_str("unknown_arch"),
313 Architecture::Unknown
314 );
315 }
316
317 #[test]
318 fn test_architecture_copy() {
319 let arch = Architecture::Llama;
320 let arch2 = arch;
321 assert_eq!(arch, arch2);
322 }
323
324 #[test]
325 fn test_model_format_type_eq() {
326 assert_eq!(ModelFormatType::SafeTensors, ModelFormatType::SafeTensors);
327 assert_ne!(ModelFormatType::SafeTensors, ModelFormatType::PyTorch);
328 }
329
330 #[test]
331 fn test_model_alias_creation() {
332 let alias = ModelAlias {
333 name: "test".to_string(),
334 target: "test/model".to_string(),
335 description: Some("Test model".to_string()),
336 };
337
338 assert_eq!(alias.name, "test");
339 assert_eq!(alias.target, "test/model");
340 assert!(alias.description.is_some());
341 }
342
343 #[test]
344 fn test_model_alias_clone() {
345 let alias = ModelAlias {
346 name: "test".to_string(),
347 target: "test/model".to_string(),
348 description: None,
349 };
350
351 let cloned = alias.clone();
352 assert_eq!(alias.name, cloned.name);
353 assert_eq!(alias.target, cloned.target);
354 }
355
356 #[test]
357 fn test_model_discovery_entry() {
358 let entry = ModelDiscoveryEntry {
359 id: "test-model".to_string(),
360 path: PathBuf::from("/path/to/model"),
361 format: ModelFormatType::SafeTensors,
362 architecture: Some(Architecture::Llama),
363 is_valid: true,
364 };
365
366 assert_eq!(entry.id, "test-model");
367 assert_eq!(entry.format, ModelFormatType::SafeTensors);
368 assert!(entry.is_valid);
369 }
370
371 #[test]
372 fn test_registry_creation() {
373 let registry = DefaultModelRegistry::new();
374 assert_eq!(registry.aliases.len(), 0);
375 assert_eq!(registry.discovered_models.len(), 0);
376 }
377
378 #[test]
379 fn test_registry_default() {
380 let registry = DefaultModelRegistry::default();
381 assert_eq!(registry.aliases.len(), 0);
382 }
383
384 #[test]
385 fn test_registry_with_defaults() {
386 let registry = DefaultModelRegistry::with_defaults();
387
388 assert!(registry.aliases.len() > 0);
390
391 assert!(registry.aliases.contains_key("tinyllama"));
393 assert!(registry.aliases.contains_key("llama2-7b"));
394 }
395
396 #[test]
397 fn test_registry_register_alias() {
398 let mut registry = DefaultModelRegistry::new();
399
400 registry.register_alias("test", "test/model");
401
402 assert_eq!(
403 registry.aliases.get("test"),
404 Some(&"test/model".to_string())
405 );
406 }
407
408 #[test]
409 fn test_registry_resolve_model_id() {
410 let mut registry = DefaultModelRegistry::new();
411
412 registry.register_alias("mymodel", "org/actual-model");
413
414 let resolved = registry.resolve_model_id("mymodel");
415 assert_eq!(resolved, "org/actual-model");
416
417 let unresolved = registry.resolve_model_id("unknown");
419 assert_eq!(unresolved, "unknown");
420 }
421
422 #[test]
423 fn test_registry_list_aliases() {
424 let mut registry = DefaultModelRegistry::new();
425
426 registry.register_alias("model1", "org/model1");
427 registry.register_alias("model2", "org/model2");
428
429 let aliases = registry.list_aliases();
430 assert_eq!(aliases.len(), 2);
431 }
432
433 #[test]
434 fn test_architecture_debug() {
435 let arch = Architecture::Llama;
436 let debug_str = format!("{:?}", arch);
437 assert!(debug_str.contains("Llama"));
438 }
439
440 #[test]
441 fn test_model_format_debug() {
442 let format = ModelFormatType::SafeTensors;
443 let debug_str = format!("{:?}", format);
444 assert!(debug_str.contains("SafeTensors"));
445 }
446
447 #[test]
448 fn test_model_discovery_entry_clone() {
449 let entry = ModelDiscoveryEntry {
450 id: "test".to_string(),
451 path: PathBuf::from("/path"),
452 format: ModelFormatType::GGUF,
453 architecture: Some(Architecture::Mistral),
454 is_valid: false,
455 };
456
457 let cloned = entry.clone();
458 assert_eq!(entry.id, cloned.id);
459 assert_eq!(entry.format, cloned.format);
460 assert_eq!(entry.is_valid, cloned.is_valid);
461 }
462
463 #[test]
464 fn test_registry_multiple_aliases_same_target() {
465 let mut registry = DefaultModelRegistry::new();
466
467 registry.register_alias("alias1", "org/model");
468 registry.register_alias("alias2", "org/model");
469
470 assert_eq!(registry.resolve_model_id("alias1"), "org/model");
471 assert_eq!(registry.resolve_model_id("alias2"), "org/model");
472 }
473
474 #[test]
475 fn test_architecture_serialization() {
476 let arch = Architecture::Qwen2;
477 let json = serde_json::to_string(&arch).unwrap();
478 assert!(json.contains("Qwen2"));
479
480 let deserialized: Architecture = serde_json::from_str(&json).unwrap();
481 assert_eq!(deserialized, arch);
482 }
483}