Skip to main content

ferrum_models/
source.rs

1//! Model source resolution and downloading with progress tracking
2
3use ferrum_types::{FerrumError, ModelSource, Result};
4use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
5use std::path::{Path, PathBuf};
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use tracing::{debug, info, warn};
10
11/// Configuration for model source resolution
12#[derive(Debug, Clone)]
13pub struct ModelSourceConfig {
14    pub cache_dir: Option<PathBuf>,
15    pub hf_token: Option<String>,
16    pub offline_mode: bool,
17    pub max_retries: usize,
18    pub download_timeout: u64,
19    pub use_file_lock: bool,
20}
21
22impl Default for ModelSourceConfig {
23    fn default() -> Self {
24        // Use HuggingFace standard cache directory
25        let default_cache = std::env::var("HF_HOME")
26            .ok()
27            .or_else(|| {
28                dirs::home_dir()
29                    .map(|h| h.join(".cache/huggingface"))
30                    .and_then(|p| p.to_str().map(String::from))
31            })
32            .map(PathBuf::from);
33
34        Self {
35            cache_dir: default_cache,
36            hf_token: Self::get_hf_token(),
37            offline_mode: false,
38            max_retries: 3,
39            download_timeout: 300,
40            use_file_lock: true,
41        }
42    }
43}
44
45impl ModelSourceConfig {
46    pub fn get_hf_token() -> Option<String> {
47        std::env::var("HF_TOKEN")
48            .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
49            .ok()
50    }
51}
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub enum ModelFormat {
55    SafeTensors,
56    PyTorchBin,
57    GGUF,
58    Unknown,
59}
60
61#[derive(Debug, Clone)]
62pub struct ResolvedModelSource {
63    pub original: String,
64    pub local_path: PathBuf,
65    pub format: ModelFormat,
66    pub from_cache: bool,
67}
68
69impl From<ResolvedModelSource> for ModelSource {
70    fn from(value: ResolvedModelSource) -> Self {
71        ModelSource::Local(value.local_path.display().to_string())
72    }
73}
74
75#[async_trait::async_trait]
76pub trait ModelSourceResolver: Send + Sync {
77    async fn resolve(&self, id: &str, revision: Option<&str>) -> Result<ResolvedModelSource>;
78}
79
80pub struct DefaultModelSourceResolver {
81    _config: ModelSourceConfig,
82    api: Api,
83}
84
85impl DefaultModelSourceResolver {
86    pub fn new(config: ModelSourceConfig) -> Self {
87        let mut builder = ApiBuilder::new();
88
89        if let Some(cache_dir) = &config.cache_dir {
90            builder = builder.with_cache_dir(cache_dir.clone());
91        }
92
93        if let Some(token) = &config.hf_token {
94            builder = builder.with_token(Some(token.clone()));
95        }
96
97        let api = builder.build().unwrap_or_else(|e| {
98            warn!("Failed to build HF API: {}, using default", e);
99            Api::new().expect("Failed to create default HF API")
100        });
101
102        Self {
103            _config: config,
104            api,
105        }
106    }
107
108    fn is_local_path(id: &str) -> bool {
109        Path::new(id).exists()
110    }
111
112    fn detect_format(path: &Path) -> ModelFormat {
113        if path.join("model.safetensors").exists()
114            || path.join("model.safetensors.index.json").exists()
115        {
116            ModelFormat::SafeTensors
117        } else if path.join("pytorch_model.bin").exists() {
118            ModelFormat::PyTorchBin
119        } else {
120            ModelFormat::Unknown
121        }
122    }
123
124    async fn resolve_local(&self, path: &str) -> Result<ResolvedModelSource> {
125        let path_buf = PathBuf::from(path);
126
127        if !path_buf.exists() {
128            return Err(FerrumError::model(format!("Path does not exist: {}", path)));
129        }
130
131        let format = Self::detect_format(&path_buf);
132
133        Ok(ResolvedModelSource {
134            original: path.to_string(),
135            local_path: path_buf,
136            format,
137            from_cache: true,
138        })
139    }
140
141    /// Download file with progress monitoring
142    async fn download_with_monitor(
143        &self,
144        repo: &ApiRepo,
145        filename: &str,
146        expected_cache_dir: &Path,
147    ) -> Result<PathBuf> {
148        info!("📥 下载中: {}...", filename);
149
150        let done = Arc::new(AtomicBool::new(false));
151        let done_clone = done.clone();
152        let filename_str = filename.to_string();
153
154        // Start monitor task
155        let monitor_task = tokio::spawn({
156            let done = done.clone();
157            let filename = filename_str.clone();
158            let cache_dir = expected_cache_dir.to_path_buf();
159
160            async move {
161                tokio::time::sleep(Duration::from_millis(1000)).await;
162
163                let start_time = Instant::now();
164                let mut last_size = 0u64;
165                let mut last_time = Instant::now();
166                let mut last_print = Instant::now();
167
168                while !done.load(Ordering::SeqCst) {
169                    // Try to find downloading file
170                    if let Some(current_size) = find_downloading_file(&cache_dir, &filename) {
171                        let elapsed_since_last = last_time.elapsed().as_secs_f64();
172
173                        if elapsed_since_last > 0.5 && current_size > last_size {
174                            let delta = current_size - last_size;
175                            let speed_mbps = delta as f64 / elapsed_since_last / 1024.0 / 1024.0;
176                            let current_mb = current_size as f64 / 1024.0 / 1024.0;
177
178                            // Only print every 2 seconds to avoid spam
179                            if last_print.elapsed().as_secs() >= 2 {
180                                info!(
181                                    "  📊 已下载: {:.2} MB (速度: {:.1} MB/s)",
182                                    current_mb, speed_mbps
183                                );
184                                last_print = Instant::now();
185                            }
186
187                            last_size = current_size;
188                            last_time = Instant::now();
189                        }
190                    }
191
192                    tokio::time::sleep(Duration::from_millis(500)).await;
193                }
194
195                // Final statistics
196                let total_time = start_time.elapsed().as_secs_f64();
197                if last_size > 0 && total_time > 0.0 {
198                    let avg_speed = last_size as f64 / total_time / 1024.0 / 1024.0;
199                    info!(
200                        "  ✅ 下载完成: {:.2} MB (平均速度: {:.1} MB/s, 耗时: {:.1}s)",
201                        last_size as f64 / 1024.0 / 1024.0,
202                        avg_speed,
203                        total_time
204                    );
205                }
206            }
207        });
208
209        // Do the actual download (blocking, but monitored)
210        let path = repo
211            .get(&filename_str)
212            .await
213            .map_err(|e| FerrumError::model(format!("Download failed: {}", e)))?;
214
215        // Signal completion
216        done_clone.store(true, Ordering::SeqCst);
217
218        // Wait for monitor to finish
219        let _ = monitor_task.await;
220
221        Ok(path)
222    }
223
224    async fn resolve_huggingface(
225        &self,
226        repo_id: &str,
227        revision: Option<&str>,
228    ) -> Result<ResolvedModelSource> {
229        info!("🔍 正在解析模型: {}", repo_id);
230
231        let repo = if let Some(rev) = revision {
232            self.api.repo(hf_hub::Repo::with_revision(
233                repo_id.to_string(),
234                hf_hub::RepoType::Model,
235                rev.to_string(),
236            ))
237        } else {
238            self.api.repo(hf_hub::Repo::new(
239                repo_id.to_string(),
240                hf_hub::RepoType::Model,
241            ))
242        };
243
244        // Download config first (small file, no need for progress)
245        info!("📥 下载中: config.json...");
246        let config_path = repo
247            .get("config.json")
248            .await
249            .map_err(|e| FerrumError::model(format!("Failed to download config: {}", e)))?;
250
251        info!("✅ config.json 下载完成");
252
253        let model_dir = config_path
254            .parent()
255            .ok_or_else(|| FerrumError::model("Invalid cache path"))?
256            .to_path_buf();
257
258        info!("📁 缓存目录: {:?}", model_dir);
259
260        // Download tokenizer files (critical for inference)
261        self.download_tokenizer_files(&repo).await?;
262
263        // Download weights
264        let format = self.download_weights(&repo, &model_dir).await?;
265
266        Ok(ResolvedModelSource {
267            original: repo_id.to_string(),
268            local_path: model_dir,
269            format,
270            from_cache: false,
271        })
272    }
273
274    async fn download_tokenizer_files(&self, repo: &ApiRepo) -> Result<()> {
275        info!("📥 下载 tokenizer 文件...");
276
277        // List of common tokenizer files
278        let tokenizer_files = vec![
279            "tokenizer.json",
280            "tokenizer_config.json",
281            "vocab.json",
282            "merges.txt",
283            "special_tokens_map.json",
284        ];
285
286        let mut downloaded_count = 0;
287        for filename in &tokenizer_files {
288            match repo.get(filename).await {
289                Ok(_path) => {
290                    info!("  ✅ {}", filename);
291                    downloaded_count += 1;
292                }
293                Err(e) => {
294                    debug!("  ⏭️  {} (optional): {}", filename, e);
295                }
296            }
297        }
298
299        if downloaded_count > 0 {
300            info!("✅ Tokenizer 文件下载完成 ({} 个文件)", downloaded_count);
301        } else {
302            warn!("⚠️  未找到 tokenizer 文件,可能影响推理");
303        }
304
305        Ok(())
306    }
307
308    async fn download_weights(&self, repo: &ApiRepo, model_dir: &Path) -> Result<ModelFormat> {
309        // Try SafeTensors single file
310        info!("🔍 检查 model.safetensors...");
311        match self
312            .download_with_monitor(repo, "model.safetensors", model_dir)
313            .await
314        {
315            Ok(path) => {
316                if let Ok(metadata) = std::fs::metadata(&path) {
317                    info!(
318                        "✅ model.safetensors 完成 ({:.2} GB)",
319                        metadata.len() as f64 / 1e9
320                    );
321                }
322                return Ok(ModelFormat::SafeTensors);
323            }
324            Err(e) => debug!("model.safetensors not found: {}", e),
325        }
326
327        // Try sharded SafeTensors
328        info!("🔍 检查分片模型...");
329        match repo.get("model.safetensors.index.json").await {
330            Ok(index_path) => {
331                info!("✅ 发现分片 SafeTensors 模型");
332
333                let content = std::fs::read_to_string(&index_path)
334                    .map_err(|e| FerrumError::io(format!("Failed to read index: {}", e)))?;
335
336                let index: serde_json::Value = serde_json::from_str(&content)
337                    .map_err(|e| FerrumError::model(format!("Failed to parse index: {}", e)))?;
338
339                if let Some(weight_map) = index.get("weight_map").and_then(|w| w.as_object()) {
340                    let shards: std::collections::HashSet<_> =
341                        weight_map.values().filter_map(|v| v.as_str()).collect();
342
343                    let total = shards.len();
344                    info!("📦 需要下载 {} 个分片", total);
345
346                    let mut total_bytes = 0u64;
347                    for (i, shard) in shards.iter().enumerate() {
348                        info!("📥 [{}/{}] {}", i + 1, total, shard);
349
350                        let shard_path = self.download_with_monitor(repo, shard, model_dir).await?;
351
352                        if let Ok(meta) = std::fs::metadata(&shard_path) {
353                            let size = meta.len();
354                            total_bytes += size;
355                            info!(
356                                "📊 进度: [{}/{}] 分片, 累计 {:.2} GB",
357                                i + 1,
358                                total,
359                                total_bytes as f64 / 1e9
360                            );
361                        }
362                    }
363
364                    info!(
365                        "🎉 全部下载完成! 总大小: {:.2} GB",
366                        total_bytes as f64 / 1e9
367                    );
368                }
369
370                return Ok(ModelFormat::SafeTensors);
371            }
372            Err(e) => debug!("Sharded model not found: {}", e),
373        }
374
375        // Try PyTorch
376        info!("🔍 检查 pytorch_model.bin...");
377        match self
378            .download_with_monitor(repo, "pytorch_model.bin", model_dir)
379            .await
380        {
381            Ok(path) => {
382                warn!("⚠️  使用 PyTorch 格式 (推荐使用 SafeTensors)");
383                if let Ok(meta) = std::fs::metadata(&path) {
384                    info!(
385                        "✅ pytorch_model.bin 完成 ({:.2} GB)",
386                        meta.len() as f64 / 1e9
387                    );
388                }
389                return Ok(ModelFormat::PyTorchBin);
390            }
391            Err(e) => debug!("pytorch_model.bin not found: {}", e),
392        }
393
394        if Self::detect_format(model_dir) == ModelFormat::GGUF {
395            return Ok(ModelFormat::GGUF);
396        }
397
398        Err(FerrumError::model("未找到支持的模型格式"))
399    }
400}
401
402/// Find downloading file in cache directory
403fn find_downloading_file(cache_dir: &Path, _filename: &str) -> Option<u64> {
404    // Just search for ANY .part file in the cache directory tree
405    // This is more reliable than trying to match filenames
406
407    // Check blobs directory
408    if let Ok(entries) = std::fs::read_dir(cache_dir.join("blobs")) {
409        for entry in entries.filter_map(|e| e.ok()) {
410            let path = entry.path();
411            let path_str = path.to_string_lossy();
412
413            if path_str.ends_with(".part") || path_str.contains(".sync.part") {
414                if let Ok(metadata) = std::fs::metadata(&path) {
415                    return Some(metadata.len());
416                }
417            }
418        }
419    }
420
421    // Also try to find in parent directories
422    let mut current = cache_dir.to_path_buf();
423    for _ in 0..3 {
424        if let Ok(entries) = std::fs::read_dir(&current) {
425            for entry in entries.filter_map(|e| e.ok()) {
426                if entry.path().is_dir() {
427                    if let Some(size) = scan_dir_for_part_files(&entry.path()) {
428                        return Some(size);
429                    }
430                }
431            }
432        }
433
434        if let Some(parent) = current.parent() {
435            current = parent.to_path_buf();
436        } else {
437            break;
438        }
439    }
440
441    None
442}
443
444/// Recursively scan directory for .part files
445fn scan_dir_for_part_files(dir: &Path) -> Option<u64> {
446    if let Ok(entries) = std::fs::read_dir(dir) {
447        for entry in entries.filter_map(|e| e.ok()) {
448            let path = entry.path();
449            let path_str = path.to_string_lossy();
450
451            if path_str.ends_with(".part") || path_str.contains(".sync.part") {
452                if let Ok(metadata) = std::fs::metadata(&path) {
453                    return Some(metadata.len());
454                }
455            }
456
457            if path.is_dir() {
458                if let Some(size) = scan_dir_for_part_files(&path) {
459                    return Some(size);
460                }
461            }
462        }
463    }
464    None
465}
466
467#[async_trait::async_trait]
468impl ModelSourceResolver for DefaultModelSourceResolver {
469    async fn resolve(&self, id: &str, revision: Option<&str>) -> Result<ResolvedModelSource> {
470        if Self::is_local_path(id) {
471            return self.resolve_local(id).await;
472        }
473
474        self.resolve_huggingface(id, revision).await
475    }
476}