candle_coreml/
unified_model_loader.rs

1//! Unified model loader that combines downloading, config generation, and model loading
2//!
3//! This module provides a simplified API that replaces hardcoded paths with
4//! automatic HuggingFace downloading and config generation.
5
6use crate::config::model::ModelConfig;
7use crate::download::unified::ensure_model_downloaded;
8use crate::{CacheManager, ConfigGenerator, QwenConfig, QwenModel};
9use anyhow::Result;
10use serde_json::Value;
11use std::path::Path;
12use tracing::{debug, info};
13
14/// Unified model loader that handles downloading, config generation, and model loading
15pub struct UnifiedModelLoader {
16    cache_manager: CacheManager,
17    pub config_generator: ConfigGenerator,
18}
19
20impl UnifiedModelLoader {
21    /// Create a new unified model loader
22    pub fn new() -> Result<Self> {
23        let cache_manager = CacheManager::new()?;
24        let config_generator = ConfigGenerator::new()?;
25
26        Ok(Self {
27            cache_manager,
28            config_generator,
29        })
30    }
31
32    /// Load a model by HuggingFace model ID with automatic downloading and config generation
33    ///
34    /// This replaces the pattern of hardcoded paths in config files.
35    ///
36    /// # Example
37    /// ```rust,no_run
38    /// use candle_coreml::UnifiedModelLoader;
39    ///
40    /// let loader = UnifiedModelLoader::new()?;
41    /// let _model = loader.load_model("mazhewitt/qwen-typo-fixer-coreml")?;
42    /// # Ok::<(), Box<dyn std::error::Error>>(())
43    /// ```
44    pub fn load_model(&self, model_id: &str) -> Result<QwenModel> {
45        info!("🚀 Loading model: {}", model_id);
46
47        // Step 1: Check if we have a cached config
48        if let Some(cached_config) = self.config_generator.load_cached_config(model_id)? {
49            info!("📖 Found cached config for {}", model_id);
50
51            // Verify the model files still exist
52            if self.verify_model_files_exist(&cached_config) {
53                // Validate config and internal wiring; regenerate if invalid or inconsistent
54                let valid_basic = cached_config.validate();
55                let valid_wiring = cached_config.validate_internal_wiring();
56                if valid_basic.is_ok() && valid_wiring.is_ok() {
57                    // If the cached config points to a Hugging Face snapshot path, upgrade it to our clean cache
58                    if let Some(model_path_str) = &cached_config.model_info.path {
59                        let looks_like_hf_snapshot = model_path_str.contains("/huggingface/hub/")
60                            || model_path_str.contains("/snapshots/");
61                        if looks_like_hf_snapshot {
62                            info!(
63                                "♻️  Cached config points to HF snapshot; regenerating config from clean download"
64                            );
65                            let clean_path = self.ensure_model_available(model_id)?;
66                            let config = self
67                                .config_generator
68                                .generate_config_from_directory_enhanced(
69                                    &clean_path,
70                                    model_id,
71                                    "qwen",
72                                )?;
73                            return self.load_model_from_config(&config);
74                        }
75                    }
76
77                    // Extra: if FFN package exposes both prefill & infer functions but config lacks ffn_infer, regenerate
78                    if self.config_requires_ffn_split_upgrade(&cached_config) {
79                        info!(
80                            "♻️  Cached config lacks 'ffn_infer' but FFN manifest has both functions; regenerating config"
81                        );
82                        if let Some(model_path_str) = &cached_config.model_info.path {
83                            let model_path = std::path::PathBuf::from(model_path_str);
84                            if model_path.exists() {
85                                let config = self
86                                    .config_generator
87                                    .generate_config_from_directory_enhanced(
88                                        &model_path,
89                                        model_id,
90                                        "qwen",
91                                    )?;
92                                return self.load_model_from_config(&config);
93                            }
94                        }
95                    }
96
97                    info!("✅ Cached config validated, using it");
98                    return self.load_model_from_config(&cached_config);
99                } else {
100                    // Log why we are regenerating
101                    if let Err(e) = valid_basic {
102                        info!("♻️  Cached config failed validation, regenerating: {e}");
103                    }
104                    if let Err(e) = valid_wiring {
105                        info!("♻️  Cached config failed internal wiring, regenerating: {e}");
106                    }
107
108                    // Regenerate from existing model directory if available
109                    if let Some(model_path_str) = &cached_config.model_info.path {
110                        let model_path = std::path::PathBuf::from(model_path_str);
111                        if model_path.exists() {
112                            info!(
113                                "🔍 Regenerating config from existing model at {}",
114                                model_path.display()
115                            );
116                            let config = self
117                                .config_generator
118                                .generate_config_from_directory_enhanced(
119                                    &model_path,
120                                    model_id,
121                                    "qwen",
122                                )?;
123                            return self.load_model_from_config(&config);
124                        } else {
125                            info!("⚠️  Cached model path missing, will re-download");
126                        }
127                    } else {
128                        info!("⚠️  Cached config missing model path, will re-download");
129                    }
130                }
131            } else {
132                info!("⚠️  Model files missing, will re-download");
133            }
134        }
135
136        // Step 2: Ensure the model is available in our clean cache (handles download if missing)
137        info!(
138            "⬇️  Ensuring model is available in clean cache: {}",
139            model_id
140        );
141        let model_path = self.ensure_model_available(model_id)?;
142
143        // Step 3: Generate config from downloaded files
144        info!("🔍 Generating config from downloaded model");
145        let config = self
146            .config_generator
147            .generate_config_from_directory_enhanced(
148                &model_path,
149                model_id,
150                "qwen", // Auto-detect this in the future
151            )?;
152
153        // Step 4: Load the model using the generated config
154        self.load_model_from_config(&config)
155    }
156
157    /// Determine if a cached config should be upgraded to include a separate ffn_infer
158    /// component by inspecting the FFN package manifest for both 'prefill' and 'infer' functions.
159    fn config_requires_ffn_split_upgrade(&self, config: &ModelConfig) -> bool {
160        // If ffn_infer already exists, nothing to do
161        if config.components.contains_key("ffn_infer") {
162            return false;
163        }
164
165        // Look for any FFN component file path to inspect its manifest
166        let ffn_component = config
167            .components
168            .iter()
169            .find(|(name, _)| name.to_lowercase().contains("ffn"))
170            .and_then(|(_, comp)| comp.file_path.as_ref());
171
172        let Some(ffn_path_str) = ffn_component else {
173            return false;
174        };
175        let ffn_path = std::path::Path::new(ffn_path_str);
176
177        // Determine manifest path (.mlpackage -> Manifest.json, .mlmodelc -> metadata.json)
178        let manifest_path = if ffn_path.join("Manifest.json").exists() {
179            ffn_path.join("Manifest.json")
180        } else if ffn_path.join("metadata.json").exists() {
181            ffn_path.join("metadata.json")
182        } else {
183            return false;
184        };
185
186        // Read and parse manifest
187        let Ok(content) = std::fs::read_to_string(&manifest_path) else {
188            return false;
189        };
190        let Ok(json): Result<Value, _> = serde_json::from_str(&content) else {
191            return false;
192        };
193
194        // Extract functions array
195        let funcs = json
196            .get(0)
197            .and_then(|m| m.get("functions"))
198            .and_then(|f| f.as_array());
199
200        if let Some(functions) = funcs {
201            let mut has_prefill = false;
202            let mut has_infer = false;
203            for f in functions {
204                if let Some(name) = f.get("name").and_then(|n| n.as_str()) {
205                    if name == "prefill" {
206                        has_prefill = true;
207                    } else if name == "infer" {
208                        has_infer = true;
209                    }
210                }
211            }
212            // If both are present but config lacks ffn_infer, we should regenerate
213            return has_prefill && has_infer;
214        }
215
216        false
217    }
218
219    /// Load a model from a pre-existing config (useful for advanced use cases)
220    pub fn load_model_from_config(&self, config: &ModelConfig) -> Result<QwenModel> {
221        info!("🔧 Loading model from config");
222
223        // Convert ModelConfig to QwenConfig
224        let qwen_config = QwenConfig::from_model_config(config.clone());
225
226        // Extract the model directory from the config
227        let model_dir = config
228            .model_info
229            .path
230            .as_ref()
231            .ok_or_else(|| anyhow::Error::msg("Model config missing path"))?;
232
233        // Load the QwenModel
234        let mut model = QwenModel::load_from_directory(model_dir, Some(qwen_config))?;
235        model.initialize_states()?;
236
237        info!("✅ Model loaded successfully");
238        Ok(model)
239    }
240
241    /// Ensure model is downloaded and return the path (useful for external tools)
242    pub fn ensure_model_available(&self, model_id: &str) -> Result<std::path::PathBuf> {
243        ensure_model_downloaded(model_id, false)
244    }
245
246    /// Generate or update config for a model without loading it
247    pub fn generate_config(&self, model_id: &str) -> Result<ModelConfig> {
248        let model_path = self.ensure_model_available(model_id)?;
249
250        self.config_generator
251            .generate_config_from_directory_enhanced(&model_path, model_id, "qwen")
252    }
253
254    /// List all cached models and their status
255    pub fn list_cached_models(&self) -> Result<Vec<CachedModelInfo>> {
256        let models_dir = self.cache_manager.models_dir();
257        let configs_dir = self.cache_manager.configs_dir();
258
259        let mut cached_models = Vec::new();
260
261        // Scan models directory
262        if models_dir.exists() {
263            for entry in std::fs::read_dir(&models_dir)? {
264                let entry = entry?;
265                if entry.file_type()?.is_dir() {
266                    let model_name = entry.file_name().to_string_lossy().to_string();
267                    let model_id = model_name.replace("--", "/"); // Convert back from filename
268
269                    let config_path = configs_dir.join(format!("{model_name}.json"));
270                    let has_config = config_path.exists();
271
272                    // Check if .mlpackage files exist
273                    let model_files = self.count_mlpackage_files(&entry.path())?;
274
275                    cached_models.push(CachedModelInfo {
276                        model_id,
277                        model_path: entry.path(),
278                        has_config,
279                        config_path: if has_config { Some(config_path) } else { None },
280                        mlpackage_count: model_files,
281                        size_bytes: self.get_directory_size(&entry.path())?,
282                    });
283                }
284            }
285        }
286
287        // Sort by model ID for consistent output
288        cached_models.sort_by(|a, b| a.model_id.cmp(&b.model_id));
289        Ok(cached_models)
290    }
291
292    /// Verify that all model files referenced in config still exist
293    fn verify_model_files_exist(&self, config: &ModelConfig) -> bool {
294        for (component_name, component) in &config.components {
295            match &component.file_path {
296                Some(file_path) => {
297                    let path = Path::new(file_path);
298                    if !path.exists() {
299                        debug!("Component '{}' file missing: {}", component_name, file_path);
300                        return false;
301                    }
302                }
303                None => {
304                    // Missing file_path makes the config unusable for model loading
305                    debug!(
306                        "Component '{}' missing file_path in cached config; regeneration required",
307                        component_name
308                    );
309                    return false;
310                }
311            }
312        }
313        true
314    }
315
316    /// Count .mlpackage files in a directory
317    fn count_mlpackage_files(&self, dir: &Path) -> Result<usize> {
318        let mut count = 0;
319
320        for entry in std::fs::read_dir(dir)? {
321            let entry = entry?;
322            if entry.file_type()?.is_dir() {
323                if let Some(extension) = entry.path().extension() {
324                    if extension == "mlpackage" {
325                        count += 1;
326                    }
327                }
328            }
329        }
330
331        Ok(count)
332    }
333
334    /// Get directory size in bytes
335    fn get_directory_size(&self, dir: &Path) -> Result<u64> {
336        let mut total_size = 0;
337        Self::visit_dir_size(dir, &mut total_size)?;
338        Ok(total_size)
339    }
340
341    fn visit_dir_size(dir: &Path, total: &mut u64) -> Result<()> {
342        for entry in std::fs::read_dir(dir)? {
343            let entry = entry?;
344            let path = entry.path();
345
346            if path.is_dir() {
347                Self::visit_dir_size(&path, total)?;
348            } else {
349                *total += entry.metadata()?.len();
350            }
351        }
352        Ok(())
353    }
354}
355
356/// Information about a cached model
357#[derive(Debug, Clone)]
358pub struct CachedModelInfo {
359    pub model_id: String,
360    pub model_path: std::path::PathBuf,
361    pub has_config: bool,
362    pub config_path: Option<std::path::PathBuf>,
363    pub mlpackage_count: usize,
364    pub size_bytes: u64,
365}
366
367impl CachedModelInfo {
368    /// Get human-readable size
369    pub fn size_human(&self) -> String {
370        let size = self.size_bytes as f64;
371
372        if size >= 1_000_000_000.0 {
373            format!("{:.1} GB", size / 1_000_000_000.0)
374        } else if size >= 1_000_000.0 {
375            format!("{:.1} MB", size / 1_000_000.0)
376        } else if size >= 1_000.0 {
377            format!("{:.1} KB", size / 1_000.0)
378        } else {
379            format!("{} B", size as u64)
380        }
381    }
382
383    /// Check if the model appears to be complete
384    pub fn is_complete(&self) -> bool {
385        self.has_config && self.mlpackage_count > 0
386    }
387}
388
389#[cfg(test)]
390mod tests {
391    use super::*;
392
393    #[test]
394    fn test_unified_loader_creation() {
395        let loader = UnifiedModelLoader::new().expect("Failed to create unified loader");
396
397        // Should have valid cache manager and config generator
398        let models = loader
399            .list_cached_models()
400            .expect("Failed to list cached models");
401        println!("Found {} cached models", models.len());
402
403        for model in &models {
404            println!(
405                "  • {} ({}, {} packages, {})",
406                model.model_id,
407                model.size_human(),
408                model.mlpackage_count,
409                if model.is_complete() {
410                    "complete"
411                } else {
412                    "incomplete"
413                }
414            );
415        }
416    }
417
418    #[test]
419    fn test_cached_model_info() {
420        let info = CachedModelInfo {
421            model_id: "test/model".to_string(),
422            model_path: std::path::PathBuf::from("/tmp/test"),
423            has_config: true,
424            config_path: Some(std::path::PathBuf::from("/tmp/test.json")),
425            mlpackage_count: 4,
426            size_bytes: 1_500_000_000,
427        };
428
429        assert_eq!(info.size_human(), "1.5 GB");
430        assert!(info.is_complete());
431    }
432}