1use ferrum_types::Result;
4use std::collections::HashMap;
5use std::path::{Path, PathBuf};
6use tracing::{debug, info};
7
8#[derive(Debug, Clone)]
10pub struct ModelAlias {
11 pub name: String,
13 pub target: String,
15 pub description: Option<String>,
17}
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
21pub enum Architecture {
22 Llama,
23 Qwen2,
24 Qwen3,
25 Mistral,
26 Phi,
27 GPT2,
28 Bert,
29 Clip,
30 Unknown,
31}
32
33impl Architecture {
34 pub fn from_str(s: &str) -> Self {
35 match s.to_lowercase().as_str() {
36 "llama" | "llamaforcausallm" => Architecture::Llama,
37 "qwen2" | "qwen2forcausallm" => Architecture::Qwen2,
38 "qwen3" | "qwen3forcausallm" => Architecture::Qwen3,
39 "mistral" | "mistralforcausallm" => Architecture::Mistral,
40 "phi" | "phiforcausallm" => Architecture::Phi,
41 "gpt2" | "gpt2lmheadmodel" => Architecture::GPT2,
42 "bert" | "bertmodel" | "bertformaskedlm" | "bertforsequenceclassification" => {
43 Architecture::Bert
44 }
45 "clip" | "clipmodel" => Architecture::Clip,
46 "chinese_clip" | "chineseclipmodel" => Architecture::Clip,
47 "siglip" | "siglipmodel" => Architecture::Clip,
48 _ => Architecture::Unknown,
49 }
50 }
51}
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
55pub enum ModelFormatType {
56 SafeTensors,
57 PyTorch,
58 GGUF,
59 Unknown,
60}
61
62#[derive(Debug, Clone)]
64pub struct ModelDiscoveryEntry {
65 pub id: String,
67 pub path: PathBuf,
69 pub format: ModelFormatType,
71 pub architecture: Option<Architecture>,
73 pub is_valid: bool,
75}
76
77#[derive(Debug)]
79pub struct DefaultModelRegistry {
80 aliases: HashMap<String, String>,
82 discovered_models: Vec<ModelDiscoveryEntry>,
84}
85
86impl DefaultModelRegistry {
87 pub fn new() -> Self {
89 Self {
90 aliases: HashMap::new(),
91 discovered_models: Vec::new(),
92 }
93 }
94
95 pub fn with_defaults() -> Self {
97 let mut registry = Self::new();
98
99 registry.register_alias("tinyllama", "TinyLlama/TinyLlama-1.1B-Chat-v1.0");
101 registry.register_alias("llama2-7b", "meta-llama/Llama-2-7b-hf");
102 registry.register_alias("llama2-7b-chat", "meta-llama/Llama-2-7b-chat-hf");
103 registry.register_alias("llama3-8b", "meta-llama/Meta-Llama-3-8B");
104 registry.register_alias("llama3-8b-instruct", "meta-llama/Meta-Llama-3-8B-Instruct");
105 registry.register_alias("qwen2-7b", "Qwen/Qwen2-7B");
106 registry.register_alias("qwen2-7b-instruct", "Qwen/Qwen2-7B-Instruct");
107 registry.register_alias("qwen3-0.6b", "Qwen/Qwen3-0.6B");
108 registry.register_alias("qwen3-1.7b", "Qwen/Qwen3-1.7B");
109 registry.register_alias("qwen3-4b", "Qwen/Qwen3-4B");
110 registry.register_alias("mistral-7b", "mistralai/Mistral-7B-v0.1");
111 registry.register_alias("mistral-7b-instruct", "mistralai/Mistral-7B-Instruct-v0.2");
112 registry.register_alias("phi3-mini", "microsoft/Phi-3-mini-4k-instruct");
113
114 registry
115 }
116
117 pub fn register_alias(&mut self, alias: impl Into<String>, target: impl Into<String>) {
119 let alias_str = alias.into();
120 let target_str = target.into();
121 debug!("Registering alias: {} -> {}", alias_str, target_str);
122 self.aliases.insert(alias_str, target_str);
123 }
124
125 pub fn add_alias(&mut self, alias: ModelAlias) -> Result<()> {
127 self.register_alias(alias.name, alias.target);
128 Ok(())
129 }
130
131 pub fn resolve_model_id(&self, name: &str) -> String {
133 self.aliases
134 .get(name)
135 .cloned()
136 .unwrap_or_else(|| name.to_string())
137 }
138
139 pub fn list_aliases(&self) -> Vec<ModelAlias> {
141 self.aliases
142 .iter()
143 .map(|(name, target)| ModelAlias {
144 name: name.clone(),
145 target: target.clone(),
146 description: None,
147 })
148 .collect()
149 }
150
151 pub async fn discover_models(&mut self, root: &Path) -> Result<Vec<ModelDiscoveryEntry>> {
153 if !root.exists() || !root.is_dir() {
154 return Ok(Vec::new());
155 }
156
157 info!("Discovering models in: {:?}", root);
158
159 let mut discovered = Vec::new();
160
161 if let Some(model_entry) = self.inspect_model_dir(root).await {
163 discovered.push(model_entry);
164 } else {
165 if let Ok(entries) = std::fs::read_dir(root) {
167 for entry in entries.filter_map(|e| e.ok()) {
168 let path = entry.path();
169 if path.is_dir() {
170 if let Some(model_entry) = self.inspect_model_dir(&path).await {
171 discovered.push(model_entry);
172 }
173 }
174 }
175 }
176 }
177
178 self.discovered_models = discovered.clone();
179 Ok(discovered)
180 }
181
182 async fn inspect_model_dir(&self, path: &Path) -> Option<ModelDiscoveryEntry> {
184 let config_path = path.join("config.json");
186 if !config_path.exists() {
187 debug!("No config.json in: {:?}", path);
188 return None;
189 }
190
191 let format = self.detect_model_format(path);
193 if format == ModelFormatType::Unknown {
194 debug!("Unknown format in: {:?}", path);
195 return None;
196 }
197
198 debug!("Found valid model at: {:?}, format: {:?}", path, format);
199
200 let architecture = self.read_architecture(&config_path);
202
203 let id = if let Some(parent) = path.parent() {
205 if let Some(grandparent) = parent.parent() {
206 if let Some(name) = grandparent.file_name().and_then(|n| n.to_str()) {
208 if name.starts_with("models--") {
209 name[8..].replace("--", "/")
210 } else {
211 path.file_name()
212 .and_then(|n| n.to_str())
213 .unwrap_or("unknown")
214 .to_string()
215 }
216 } else {
217 path.file_name()
218 .and_then(|n| n.to_str())
219 .unwrap_or("unknown")
220 .to_string()
221 }
222 } else {
223 path.file_name()
224 .and_then(|n| n.to_str())
225 .unwrap_or("unknown")
226 .to_string()
227 }
228 } else {
229 path.file_name()
230 .and_then(|n| n.to_str())
231 .unwrap_or("unknown")
232 .to_string()
233 };
234
235 Some(ModelDiscoveryEntry {
236 id,
237 path: path.to_path_buf(),
238 format,
239 architecture,
240 is_valid: true,
241 })
242 }
243
244 fn detect_model_format(&self, path: &Path) -> ModelFormatType {
246 if path.join("model.safetensors").exists()
247 || path.join("model.safetensors.index.json").exists()
248 {
249 ModelFormatType::SafeTensors
250 } else if path.join("pytorch_model.bin").exists()
251 || path.join("pytorch_model.bin.index.json").exists()
252 {
253 ModelFormatType::PyTorch
254 } else if std::fs::read_dir(path)
255 .ok()
256 .and_then(|entries| {
257 entries
258 .filter_map(|e| e.ok())
259 .find(|e| e.path().extension().and_then(|s| s.to_str()) == Some("gguf"))
260 })
261 .is_some()
262 {
263 ModelFormatType::GGUF
264 } else {
265 ModelFormatType::Unknown
266 }
267 }
268
269 fn read_architecture(&self, config_path: &Path) -> Option<Architecture> {
271 let content = std::fs::read_to_string(config_path).ok()?;
272 let config: serde_json::Value = serde_json::from_str(&content).ok()?;
273
274 if let Some(model_type) = config.get("model_type").and_then(|v| v.as_str()) {
276 return Some(Architecture::from_str(model_type));
277 }
278
279 if let Some(architectures) = config.get("architectures").and_then(|v| v.as_array()) {
281 if let Some(arch) = architectures.first().and_then(|v| v.as_str()) {
282 return Some(Architecture::from_str(arch));
283 }
284 }
285
286 None
287 }
288}
289
290impl Default for DefaultModelRegistry {
291 fn default() -> Self {
292 Self::new()
293 }
294}
295
296#[cfg(test)]
301mod tests {
302 use super::*;
303
304 #[test]
305 fn test_architecture_from_str() {
306 assert_eq!(Architecture::from_str("llama"), Architecture::Llama);
307 assert_eq!(
308 Architecture::from_str("LlamaForCausalLM"),
309 Architecture::Llama
310 );
311 assert_eq!(Architecture::from_str("qwen2"), Architecture::Qwen2);
312 assert_eq!(Architecture::from_str("mistral"), Architecture::Mistral);
313 assert_eq!(Architecture::from_str("phi"), Architecture::Phi);
314 assert_eq!(Architecture::from_str("gpt2"), Architecture::GPT2);
315 assert_eq!(
316 Architecture::from_str("unknown_arch"),
317 Architecture::Unknown
318 );
319 }
320
321 #[test]
322 fn test_architecture_copy() {
323 let arch = Architecture::Llama;
324 let arch2 = arch;
325 assert_eq!(arch, arch2);
326 }
327
328 #[test]
329 fn test_model_format_type_eq() {
330 assert_eq!(ModelFormatType::SafeTensors, ModelFormatType::SafeTensors);
331 assert_ne!(ModelFormatType::SafeTensors, ModelFormatType::PyTorch);
332 }
333
334 #[test]
335 fn test_model_alias_creation() {
336 let alias = ModelAlias {
337 name: "test".to_string(),
338 target: "test/model".to_string(),
339 description: Some("Test model".to_string()),
340 };
341
342 assert_eq!(alias.name, "test");
343 assert_eq!(alias.target, "test/model");
344 assert!(alias.description.is_some());
345 }
346
347 #[test]
348 fn test_model_alias_clone() {
349 let alias = ModelAlias {
350 name: "test".to_string(),
351 target: "test/model".to_string(),
352 description: None,
353 };
354
355 let cloned = alias.clone();
356 assert_eq!(alias.name, cloned.name);
357 assert_eq!(alias.target, cloned.target);
358 }
359
360 #[test]
361 fn test_model_discovery_entry() {
362 let entry = ModelDiscoveryEntry {
363 id: "test-model".to_string(),
364 path: PathBuf::from("/path/to/model"),
365 format: ModelFormatType::SafeTensors,
366 architecture: Some(Architecture::Llama),
367 is_valid: true,
368 };
369
370 assert_eq!(entry.id, "test-model");
371 assert_eq!(entry.format, ModelFormatType::SafeTensors);
372 assert!(entry.is_valid);
373 }
374
375 #[test]
376 fn test_registry_creation() {
377 let registry = DefaultModelRegistry::new();
378 assert_eq!(registry.aliases.len(), 0);
379 assert_eq!(registry.discovered_models.len(), 0);
380 }
381
382 #[test]
383 fn test_registry_default() {
384 let registry = DefaultModelRegistry::default();
385 assert_eq!(registry.aliases.len(), 0);
386 }
387
388 #[test]
389 fn test_registry_with_defaults() {
390 let registry = DefaultModelRegistry::with_defaults();
391
392 assert!(registry.aliases.len() > 0);
394
395 assert!(registry.aliases.contains_key("tinyllama"));
397 assert!(registry.aliases.contains_key("llama2-7b"));
398 }
399
400 #[test]
401 fn test_registry_register_alias() {
402 let mut registry = DefaultModelRegistry::new();
403
404 registry.register_alias("test", "test/model");
405
406 assert_eq!(
407 registry.aliases.get("test"),
408 Some(&"test/model".to_string())
409 );
410 }
411
412 #[test]
413 fn test_registry_resolve_model_id() {
414 let mut registry = DefaultModelRegistry::new();
415
416 registry.register_alias("mymodel", "org/actual-model");
417
418 let resolved = registry.resolve_model_id("mymodel");
419 assert_eq!(resolved, "org/actual-model");
420
421 let unresolved = registry.resolve_model_id("unknown");
423 assert_eq!(unresolved, "unknown");
424 }
425
426 #[test]
427 fn test_registry_list_aliases() {
428 let mut registry = DefaultModelRegistry::new();
429
430 registry.register_alias("model1", "org/model1");
431 registry.register_alias("model2", "org/model2");
432
433 let aliases = registry.list_aliases();
434 assert_eq!(aliases.len(), 2);
435 }
436
437 #[test]
438 fn test_architecture_debug() {
439 let arch = Architecture::Llama;
440 let debug_str = format!("{:?}", arch);
441 assert!(debug_str.contains("Llama"));
442 }
443
444 #[test]
445 fn test_model_format_debug() {
446 let format = ModelFormatType::SafeTensors;
447 let debug_str = format!("{:?}", format);
448 assert!(debug_str.contains("SafeTensors"));
449 }
450
451 #[test]
452 fn test_model_discovery_entry_clone() {
453 let entry = ModelDiscoveryEntry {
454 id: "test".to_string(),
455 path: PathBuf::from("/path"),
456 format: ModelFormatType::GGUF,
457 architecture: Some(Architecture::Mistral),
458 is_valid: false,
459 };
460
461 let cloned = entry.clone();
462 assert_eq!(entry.id, cloned.id);
463 assert_eq!(entry.format, cloned.format);
464 assert_eq!(entry.is_valid, cloned.is_valid);
465 }
466
467 #[test]
468 fn test_registry_multiple_aliases_same_target() {
469 let mut registry = DefaultModelRegistry::new();
470
471 registry.register_alias("alias1", "org/model");
472 registry.register_alias("alias2", "org/model");
473
474 assert_eq!(registry.resolve_model_id("alias1"), "org/model");
475 assert_eq!(registry.resolve_model_id("alias2"), "org/model");
476 }
477
478 #[test]
479 fn test_architecture_serialization() {
480 let arch = Architecture::Qwen2;
481 let json = serde_json::to_string(&arch).unwrap();
482 assert!(json.contains("Qwen2"));
483
484 let deserialized: Architecture = serde_json::from_str(&json).unwrap();
485 assert_eq!(deserialized, arch);
486 }
487}