candle_coreml/
unified_model_loader.rs1use crate::config::model::ModelConfig;
7use crate::download::unified::ensure_model_downloaded;
8use crate::{CacheManager, ConfigGenerator, QwenConfig, QwenModel};
9use anyhow::Result;
10use serde_json::Value;
11use std::path::Path;
12use tracing::{debug, info};
13
14pub struct UnifiedModelLoader {
16 cache_manager: CacheManager,
17 pub config_generator: ConfigGenerator,
18}
19
20impl UnifiedModelLoader {
21 pub fn new() -> Result<Self> {
23 let cache_manager = CacheManager::new()?;
24 let config_generator = ConfigGenerator::new()?;
25
26 Ok(Self {
27 cache_manager,
28 config_generator,
29 })
30 }
31
32 pub fn load_model(&self, model_id: &str) -> Result<QwenModel> {
45 info!("🚀 Loading model: {}", model_id);
46
47 if let Some(cached_config) = self.config_generator.load_cached_config(model_id)? {
49 info!("📖 Found cached config for {}", model_id);
50
51 if self.verify_model_files_exist(&cached_config) {
53 let valid_basic = cached_config.validate();
55 let valid_wiring = cached_config.validate_internal_wiring();
56 if valid_basic.is_ok() && valid_wiring.is_ok() {
57 if let Some(model_path_str) = &cached_config.model_info.path {
59 let looks_like_hf_snapshot = model_path_str.contains("/huggingface/hub/")
60 || model_path_str.contains("/snapshots/");
61 if looks_like_hf_snapshot {
62 info!(
63 "♻️ Cached config points to HF snapshot; regenerating config from clean download"
64 );
65 let clean_path = self.ensure_model_available(model_id)?;
66 let config = self
67 .config_generator
68 .generate_config_from_directory_enhanced(
69 &clean_path,
70 model_id,
71 "qwen",
72 )?;
73 return self.load_model_from_config(&config);
74 }
75 }
76
77 if self.config_requires_ffn_split_upgrade(&cached_config) {
79 info!(
80 "♻️ Cached config lacks 'ffn_infer' but FFN manifest has both functions; regenerating config"
81 );
82 if let Some(model_path_str) = &cached_config.model_info.path {
83 let model_path = std::path::PathBuf::from(model_path_str);
84 if model_path.exists() {
85 let config = self
86 .config_generator
87 .generate_config_from_directory_enhanced(
88 &model_path,
89 model_id,
90 "qwen",
91 )?;
92 return self.load_model_from_config(&config);
93 }
94 }
95 }
96
97 info!("✅ Cached config validated, using it");
98 return self.load_model_from_config(&cached_config);
99 } else {
100 if let Err(e) = valid_basic {
102 info!("♻️ Cached config failed validation, regenerating: {e}");
103 }
104 if let Err(e) = valid_wiring {
105 info!("♻️ Cached config failed internal wiring, regenerating: {e}");
106 }
107
108 if let Some(model_path_str) = &cached_config.model_info.path {
110 let model_path = std::path::PathBuf::from(model_path_str);
111 if model_path.exists() {
112 info!(
113 "🔍 Regenerating config from existing model at {}",
114 model_path.display()
115 );
116 let config = self
117 .config_generator
118 .generate_config_from_directory_enhanced(
119 &model_path,
120 model_id,
121 "qwen",
122 )?;
123 return self.load_model_from_config(&config);
124 } else {
125 info!("⚠️ Cached model path missing, will re-download");
126 }
127 } else {
128 info!("⚠️ Cached config missing model path, will re-download");
129 }
130 }
131 } else {
132 info!("⚠️ Model files missing, will re-download");
133 }
134 }
135
136 info!(
138 "⬇️ Ensuring model is available in clean cache: {}",
139 model_id
140 );
141 let model_path = self.ensure_model_available(model_id)?;
142
143 info!("🔍 Generating config from downloaded model");
145 let config = self
146 .config_generator
147 .generate_config_from_directory_enhanced(
148 &model_path,
149 model_id,
150 "qwen", )?;
152
153 self.load_model_from_config(&config)
155 }
156
157 fn config_requires_ffn_split_upgrade(&self, config: &ModelConfig) -> bool {
160 if config.components.contains_key("ffn_infer") {
162 return false;
163 }
164
165 let ffn_component = config
167 .components
168 .iter()
169 .find(|(name, _)| name.to_lowercase().contains("ffn"))
170 .and_then(|(_, comp)| comp.file_path.as_ref());
171
172 let Some(ffn_path_str) = ffn_component else {
173 return false;
174 };
175 let ffn_path = std::path::Path::new(ffn_path_str);
176
177 let manifest_path = if ffn_path.join("Manifest.json").exists() {
179 ffn_path.join("Manifest.json")
180 } else if ffn_path.join("metadata.json").exists() {
181 ffn_path.join("metadata.json")
182 } else {
183 return false;
184 };
185
186 let Ok(content) = std::fs::read_to_string(&manifest_path) else {
188 return false;
189 };
190 let Ok(json): Result<Value, _> = serde_json::from_str(&content) else {
191 return false;
192 };
193
194 let funcs = json
196 .get(0)
197 .and_then(|m| m.get("functions"))
198 .and_then(|f| f.as_array());
199
200 if let Some(functions) = funcs {
201 let mut has_prefill = false;
202 let mut has_infer = false;
203 for f in functions {
204 if let Some(name) = f.get("name").and_then(|n| n.as_str()) {
205 if name == "prefill" {
206 has_prefill = true;
207 } else if name == "infer" {
208 has_infer = true;
209 }
210 }
211 }
212 return has_prefill && has_infer;
214 }
215
216 false
217 }
218
219 pub fn load_model_from_config(&self, config: &ModelConfig) -> Result<QwenModel> {
221 info!("🔧 Loading model from config");
222
223 let qwen_config = QwenConfig::from_model_config(config.clone());
225
226 let model_dir = config
228 .model_info
229 .path
230 .as_ref()
231 .ok_or_else(|| anyhow::Error::msg("Model config missing path"))?;
232
233 let mut model = QwenModel::load_from_directory(model_dir, Some(qwen_config))?;
235 model.initialize_states()?;
236
237 info!("✅ Model loaded successfully");
238 Ok(model)
239 }
240
241 pub fn ensure_model_available(&self, model_id: &str) -> Result<std::path::PathBuf> {
243 ensure_model_downloaded(model_id, false)
244 }
245
246 pub fn generate_config(&self, model_id: &str) -> Result<ModelConfig> {
248 let model_path = self.ensure_model_available(model_id)?;
249
250 self.config_generator
251 .generate_config_from_directory_enhanced(&model_path, model_id, "qwen")
252 }
253
254 pub fn list_cached_models(&self) -> Result<Vec<CachedModelInfo>> {
256 let models_dir = self.cache_manager.models_dir();
257 let configs_dir = self.cache_manager.configs_dir();
258
259 let mut cached_models = Vec::new();
260
261 if models_dir.exists() {
263 for entry in std::fs::read_dir(&models_dir)? {
264 let entry = entry?;
265 if entry.file_type()?.is_dir() {
266 let model_name = entry.file_name().to_string_lossy().to_string();
267 let model_id = model_name.replace("--", "/"); let config_path = configs_dir.join(format!("{model_name}.json"));
270 let has_config = config_path.exists();
271
272 let model_files = self.count_mlpackage_files(&entry.path())?;
274
275 cached_models.push(CachedModelInfo {
276 model_id,
277 model_path: entry.path(),
278 has_config,
279 config_path: if has_config { Some(config_path) } else { None },
280 mlpackage_count: model_files,
281 size_bytes: self.get_directory_size(&entry.path())?,
282 });
283 }
284 }
285 }
286
287 cached_models.sort_by(|a, b| a.model_id.cmp(&b.model_id));
289 Ok(cached_models)
290 }
291
292 fn verify_model_files_exist(&self, config: &ModelConfig) -> bool {
294 for (component_name, component) in &config.components {
295 match &component.file_path {
296 Some(file_path) => {
297 let path = Path::new(file_path);
298 if !path.exists() {
299 debug!("Component '{}' file missing: {}", component_name, file_path);
300 return false;
301 }
302 }
303 None => {
304 debug!(
306 "Component '{}' missing file_path in cached config; regeneration required",
307 component_name
308 );
309 return false;
310 }
311 }
312 }
313 true
314 }
315
316 fn count_mlpackage_files(&self, dir: &Path) -> Result<usize> {
318 let mut count = 0;
319
320 for entry in std::fs::read_dir(dir)? {
321 let entry = entry?;
322 if entry.file_type()?.is_dir() {
323 if let Some(extension) = entry.path().extension() {
324 if extension == "mlpackage" {
325 count += 1;
326 }
327 }
328 }
329 }
330
331 Ok(count)
332 }
333
334 fn get_directory_size(&self, dir: &Path) -> Result<u64> {
336 let mut total_size = 0;
337 Self::visit_dir_size(dir, &mut total_size)?;
338 Ok(total_size)
339 }
340
341 fn visit_dir_size(dir: &Path, total: &mut u64) -> Result<()> {
342 for entry in std::fs::read_dir(dir)? {
343 let entry = entry?;
344 let path = entry.path();
345
346 if path.is_dir() {
347 Self::visit_dir_size(&path, total)?;
348 } else {
349 *total += entry.metadata()?.len();
350 }
351 }
352 Ok(())
353 }
354}
355
356#[derive(Debug, Clone)]
358pub struct CachedModelInfo {
359 pub model_id: String,
360 pub model_path: std::path::PathBuf,
361 pub has_config: bool,
362 pub config_path: Option<std::path::PathBuf>,
363 pub mlpackage_count: usize,
364 pub size_bytes: u64,
365}
366
367impl CachedModelInfo {
368 pub fn size_human(&self) -> String {
370 let size = self.size_bytes as f64;
371
372 if size >= 1_000_000_000.0 {
373 format!("{:.1} GB", size / 1_000_000_000.0)
374 } else if size >= 1_000_000.0 {
375 format!("{:.1} MB", size / 1_000_000.0)
376 } else if size >= 1_000.0 {
377 format!("{:.1} KB", size / 1_000.0)
378 } else {
379 format!("{} B", size as u64)
380 }
381 }
382
383 pub fn is_complete(&self) -> bool {
385 self.has_config && self.mlpackage_count > 0
386 }
387}
388
389#[cfg(test)]
390mod tests {
391 use super::*;
392
393 #[test]
394 fn test_unified_loader_creation() {
395 let loader = UnifiedModelLoader::new().expect("Failed to create unified loader");
396
397 let models = loader
399 .list_cached_models()
400 .expect("Failed to list cached models");
401 println!("Found {} cached models", models.len());
402
403 for model in &models {
404 println!(
405 " • {} ({}, {} packages, {})",
406 model.model_id,
407 model.size_human(),
408 model.mlpackage_count,
409 if model.is_complete() {
410 "complete"
411 } else {
412 "incomplete"
413 }
414 );
415 }
416 }
417
418 #[test]
419 fn test_cached_model_info() {
420 let info = CachedModelInfo {
421 model_id: "test/model".to_string(),
422 model_path: std::path::PathBuf::from("/tmp/test"),
423 has_config: true,
424 config_path: Some(std::path::PathBuf::from("/tmp/test.json")),
425 mlpackage_count: 4,
426 size_bytes: 1_500_000_000,
427 };
428
429 assert_eq!(info.size_human(), "1.5 GB");
430 assert!(info.is_complete());
431 }
432}