Skip to main content

ferrum_models/
source.rs

1//! Model source resolution and downloading with progress tracking
2
3use ferrum_types::{FerrumError, ModelSource, Result};
4use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
5use std::path::{Path, PathBuf};
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::sync::{Arc, OnceLock};
8use std::time::{Duration, Instant};
9use tracing::{debug, info, warn};
10
11#[derive(Debug, Clone, PartialEq, Eq)]
12struct ModelSourceRuntimeEnv {
13    hf_home: Option<String>,
14    hf_token: Option<String>,
15}
16
17impl ModelSourceRuntimeEnv {
18    fn from_env() -> Self {
19        Self::from_env_vars(std::env::vars())
20    }
21
22    fn from_env_vars<I, K, V>(vars: I) -> Self
23    where
24        I: IntoIterator<Item = (K, V)>,
25        K: AsRef<str>,
26        V: Into<String>,
27    {
28        let mut hf_home = None;
29        let mut hf_token = None;
30        let mut hf_hub_token = None;
31
32        for (key, value) in vars {
33            let value = value.into();
34            match key.as_ref() {
35                "HF_HOME" => hf_home = Some(value),
36                "HF_TOKEN" => hf_token = Some(value),
37                "HUGGING_FACE_HUB_TOKEN" => hf_hub_token = Some(value),
38                _ => {}
39            }
40        }
41
42        Self {
43            hf_home,
44            hf_token: hf_token.or(hf_hub_token),
45        }
46    }
47}
48
49fn model_source_runtime_env() -> &'static ModelSourceRuntimeEnv {
50    static CONFIG: OnceLock<ModelSourceRuntimeEnv> = OnceLock::new();
51    CONFIG.get_or_init(ModelSourceRuntimeEnv::from_env)
52}
53
54/// Configuration for model source resolution
55#[derive(Debug, Clone)]
56pub struct ModelSourceConfig {
57    pub cache_dir: Option<PathBuf>,
58    pub hf_token: Option<String>,
59    pub offline_mode: bool,
60    pub max_retries: usize,
61    pub download_timeout: u64,
62    pub use_file_lock: bool,
63}
64
65impl Default for ModelSourceConfig {
66    fn default() -> Self {
67        // Use HuggingFace standard cache directory
68        let default_cache = model_source_runtime_env()
69            .hf_home
70            .clone()
71            .or_else(|| {
72                dirs::home_dir()
73                    .map(|h| h.join(".cache/huggingface"))
74                    .and_then(|p| p.to_str().map(String::from))
75            })
76            .map(PathBuf::from);
77
78        Self {
79            cache_dir: default_cache,
80            hf_token: Self::get_hf_token(),
81            offline_mode: false,
82            max_retries: 3,
83            download_timeout: 300,
84            use_file_lock: true,
85        }
86    }
87}
88
89impl ModelSourceConfig {
90    pub fn get_hf_token() -> Option<String> {
91        model_source_runtime_env().hf_token.clone()
92    }
93}
94
95#[derive(Debug, Clone, Copy, PartialEq, Eq)]
96pub enum ModelFormat {
97    SafeTensors,
98    PyTorchBin,
99    GGUF,
100    Unknown,
101}
102
103#[derive(Debug, Clone)]
104pub struct ResolvedModelSource {
105    pub original: String,
106    pub local_path: PathBuf,
107    pub format: ModelFormat,
108    pub from_cache: bool,
109}
110
111impl From<ResolvedModelSource> for ModelSource {
112    fn from(value: ResolvedModelSource) -> Self {
113        ModelSource::Local(value.local_path.display().to_string())
114    }
115}
116
117#[async_trait::async_trait]
118pub trait ModelSourceResolver: Send + Sync {
119    async fn resolve(&self, id: &str, revision: Option<&str>) -> Result<ResolvedModelSource>;
120}
121
122pub struct DefaultModelSourceResolver {
123    _config: ModelSourceConfig,
124    api: Api,
125}
126
127impl DefaultModelSourceResolver {
128    pub fn new(config: ModelSourceConfig) -> Self {
129        let mut builder = ApiBuilder::new();
130
131        if let Some(cache_dir) = &config.cache_dir {
132            builder = builder.with_cache_dir(cache_dir.clone());
133        }
134
135        if let Some(token) = &config.hf_token {
136            builder = builder.with_token(Some(token.clone()));
137        }
138
139        let api = builder.build().unwrap_or_else(|e| {
140            warn!("Failed to build HF API: {}, using default", e);
141            Api::new().expect("Failed to create default HF API")
142        });
143
144        Self {
145            _config: config,
146            api,
147        }
148    }
149
150    fn is_local_path(id: &str) -> bool {
151        Path::new(id).exists()
152    }
153
154    fn detect_format(path: &Path) -> ModelFormat {
155        if path.join("model.safetensors").exists()
156            || path.join("model.safetensors.index.json").exists()
157        {
158            ModelFormat::SafeTensors
159        } else if path.join("pytorch_model.bin").exists() {
160            ModelFormat::PyTorchBin
161        } else {
162            ModelFormat::Unknown
163        }
164    }
165
166    async fn resolve_local(&self, path: &str) -> Result<ResolvedModelSource> {
167        let path_buf = PathBuf::from(path);
168
169        if !path_buf.exists() {
170            return Err(FerrumError::model(format!("Path does not exist: {}", path)));
171        }
172
173        let format = Self::detect_format(&path_buf);
174
175        Ok(ResolvedModelSource {
176            original: path.to_string(),
177            local_path: path_buf,
178            format,
179            from_cache: true,
180        })
181    }
182
183    /// Download file with progress monitoring
184    async fn download_with_monitor(
185        &self,
186        repo: &ApiRepo,
187        filename: &str,
188        expected_cache_dir: &Path,
189    ) -> Result<PathBuf> {
190        info!("📥 下载中: {}...", filename);
191
192        let done = Arc::new(AtomicBool::new(false));
193        let done_clone = done.clone();
194        let filename_str = filename.to_string();
195
196        // Start monitor task
197        let monitor_task = tokio::spawn({
198            let done = done.clone();
199            let filename = filename_str.clone();
200            let cache_dir = expected_cache_dir.to_path_buf();
201
202            async move {
203                tokio::time::sleep(Duration::from_millis(1000)).await;
204
205                let start_time = Instant::now();
206                let mut last_size = 0u64;
207                let mut last_time = Instant::now();
208                let mut last_print = Instant::now();
209
210                while !done.load(Ordering::SeqCst) {
211                    // Try to find downloading file
212                    if let Some(current_size) = find_downloading_file(&cache_dir, &filename) {
213                        let elapsed_since_last = last_time.elapsed().as_secs_f64();
214
215                        if elapsed_since_last > 0.5 && current_size > last_size {
216                            let delta = current_size - last_size;
217                            let speed_mbps = delta as f64 / elapsed_since_last / 1024.0 / 1024.0;
218                            let current_mb = current_size as f64 / 1024.0 / 1024.0;
219
220                            // Only print every 2 seconds to avoid spam
221                            if last_print.elapsed().as_secs() >= 2 {
222                                info!(
223                                    "  📊 已下载: {:.2} MB (速度: {:.1} MB/s)",
224                                    current_mb, speed_mbps
225                                );
226                                last_print = Instant::now();
227                            }
228
229                            last_size = current_size;
230                            last_time = Instant::now();
231                        }
232                    }
233
234                    tokio::time::sleep(Duration::from_millis(500)).await;
235                }
236
237                // Final statistics
238                let total_time = start_time.elapsed().as_secs_f64();
239                if last_size > 0 && total_time > 0.0 {
240                    let avg_speed = last_size as f64 / total_time / 1024.0 / 1024.0;
241                    info!(
242                        "  ✅ 下载完成: {:.2} MB (平均速度: {:.1} MB/s, 耗时: {:.1}s)",
243                        last_size as f64 / 1024.0 / 1024.0,
244                        avg_speed,
245                        total_time
246                    );
247                }
248            }
249        });
250
251        // Do the actual download (blocking, but monitored)
252        let path = repo
253            .get(&filename_str)
254            .await
255            .map_err(|e| FerrumError::model(format!("Download failed: {}", e)))?;
256
257        // Signal completion
258        done_clone.store(true, Ordering::SeqCst);
259
260        // Wait for monitor to finish
261        let _ = monitor_task.await;
262
263        Ok(path)
264    }
265
266    async fn resolve_huggingface(
267        &self,
268        repo_id: &str,
269        revision: Option<&str>,
270    ) -> Result<ResolvedModelSource> {
271        info!("🔍 正在解析模型: {}", repo_id);
272
273        let repo = if let Some(rev) = revision {
274            self.api.repo(hf_hub::Repo::with_revision(
275                repo_id.to_string(),
276                hf_hub::RepoType::Model,
277                rev.to_string(),
278            ))
279        } else {
280            self.api.repo(hf_hub::Repo::new(
281                repo_id.to_string(),
282                hf_hub::RepoType::Model,
283            ))
284        };
285
286        // Download config first (small file, no need for progress)
287        info!("📥 下载中: config.json...");
288        let config_path = repo
289            .get("config.json")
290            .await
291            .map_err(|e| FerrumError::model(format!("Failed to download config: {}", e)))?;
292
293        info!("✅ config.json 下载完成");
294
295        let model_dir = config_path
296            .parent()
297            .ok_or_else(|| FerrumError::model("Invalid cache path"))?
298            .to_path_buf();
299
300        info!("📁 缓存目录: {:?}", model_dir);
301
302        // Download tokenizer files (critical for inference)
303        self.download_tokenizer_files(&repo).await?;
304
305        // Download weights
306        let format = self.download_weights(&repo, &model_dir).await?;
307
308        Ok(ResolvedModelSource {
309            original: repo_id.to_string(),
310            local_path: model_dir,
311            format,
312            from_cache: false,
313        })
314    }
315
316    async fn download_tokenizer_files(&self, repo: &ApiRepo) -> Result<()> {
317        info!("📥 下载 tokenizer 文件...");
318
319        // List of common tokenizer files
320        let tokenizer_files = vec![
321            "tokenizer.json",
322            "tokenizer_config.json",
323            "vocab.json",
324            "merges.txt",
325            "special_tokens_map.json",
326        ];
327
328        let mut downloaded_count = 0;
329        for filename in &tokenizer_files {
330            match repo.get(filename).await {
331                Ok(_path) => {
332                    info!("  ✅ {}", filename);
333                    downloaded_count += 1;
334                }
335                Err(e) => {
336                    debug!("  ⏭️  {} (optional): {}", filename, e);
337                }
338            }
339        }
340
341        if downloaded_count > 0 {
342            info!("✅ Tokenizer 文件下载完成 ({} 个文件)", downloaded_count);
343        } else {
344            warn!("⚠️  未找到 tokenizer 文件,可能影响推理");
345        }
346
347        Ok(())
348    }
349
350    async fn download_weights(&self, repo: &ApiRepo, model_dir: &Path) -> Result<ModelFormat> {
351        // Try SafeTensors single file
352        info!("🔍 检查 model.safetensors...");
353        match self
354            .download_with_monitor(repo, "model.safetensors", model_dir)
355            .await
356        {
357            Ok(path) => {
358                if let Ok(metadata) = std::fs::metadata(&path) {
359                    info!(
360                        "✅ model.safetensors 完成 ({:.2} GB)",
361                        metadata.len() as f64 / 1e9
362                    );
363                }
364                return Ok(ModelFormat::SafeTensors);
365            }
366            Err(e) => debug!("model.safetensors not found: {}", e),
367        }
368
369        // Try sharded SafeTensors
370        info!("🔍 检查分片模型...");
371        match repo.get("model.safetensors.index.json").await {
372            Ok(index_path) => {
373                info!("✅ 发现分片 SafeTensors 模型");
374
375                let content = std::fs::read_to_string(&index_path)
376                    .map_err(|e| FerrumError::io(format!("Failed to read index: {}", e)))?;
377
378                let index: serde_json::Value = serde_json::from_str(&content)
379                    .map_err(|e| FerrumError::model(format!("Failed to parse index: {}", e)))?;
380
381                if let Some(weight_map) = index.get("weight_map").and_then(|w| w.as_object()) {
382                    let shards: std::collections::HashSet<_> =
383                        weight_map.values().filter_map(|v| v.as_str()).collect();
384
385                    let total = shards.len();
386                    info!("📦 需要下载 {} 个分片", total);
387
388                    let mut total_bytes = 0u64;
389                    for (i, shard) in shards.iter().enumerate() {
390                        info!("📥 [{}/{}] {}", i + 1, total, shard);
391
392                        let shard_path = self.download_with_monitor(repo, shard, model_dir).await?;
393
394                        if let Ok(meta) = std::fs::metadata(&shard_path) {
395                            let size = meta.len();
396                            total_bytes += size;
397                            info!(
398                                "📊 进度: [{}/{}] 分片, 累计 {:.2} GB",
399                                i + 1,
400                                total,
401                                total_bytes as f64 / 1e9
402                            );
403                        }
404                    }
405
406                    info!(
407                        "🎉 全部下载完成! 总大小: {:.2} GB",
408                        total_bytes as f64 / 1e9
409                    );
410                }
411
412                return Ok(ModelFormat::SafeTensors);
413            }
414            Err(e) => debug!("Sharded model not found: {}", e),
415        }
416
417        // Try PyTorch
418        info!("🔍 检查 pytorch_model.bin...");
419        match self
420            .download_with_monitor(repo, "pytorch_model.bin", model_dir)
421            .await
422        {
423            Ok(path) => {
424                warn!("⚠️  使用 PyTorch 格式 (推荐使用 SafeTensors)");
425                if let Ok(meta) = std::fs::metadata(&path) {
426                    info!(
427                        "✅ pytorch_model.bin 完成 ({:.2} GB)",
428                        meta.len() as f64 / 1e9
429                    );
430                }
431                return Ok(ModelFormat::PyTorchBin);
432            }
433            Err(e) => debug!("pytorch_model.bin not found: {}", e),
434        }
435
436        if Self::detect_format(model_dir) == ModelFormat::GGUF {
437            return Ok(ModelFormat::GGUF);
438        }
439
440        Err(FerrumError::model("未找到支持的模型格式"))
441    }
442}
443
444/// Find downloading file in cache directory
445fn find_downloading_file(cache_dir: &Path, _filename: &str) -> Option<u64> {
446    // Just search for ANY .part file in the cache directory tree
447    // This is more reliable than trying to match filenames
448
449    // Check blobs directory
450    if let Ok(entries) = std::fs::read_dir(cache_dir.join("blobs")) {
451        for entry in entries.filter_map(|e| e.ok()) {
452            let path = entry.path();
453            let path_str = path.to_string_lossy();
454
455            if path_str.ends_with(".part") || path_str.contains(".sync.part") {
456                if let Ok(metadata) = std::fs::metadata(&path) {
457                    return Some(metadata.len());
458                }
459            }
460        }
461    }
462
463    // Also try to find in parent directories
464    let mut current = cache_dir.to_path_buf();
465    for _ in 0..3 {
466        if let Ok(entries) = std::fs::read_dir(&current) {
467            for entry in entries.filter_map(|e| e.ok()) {
468                if entry.path().is_dir() {
469                    if let Some(size) = scan_dir_for_part_files(&entry.path()) {
470                        return Some(size);
471                    }
472                }
473            }
474        }
475
476        if let Some(parent) = current.parent() {
477            current = parent.to_path_buf();
478        } else {
479            break;
480        }
481    }
482
483    None
484}
485
486/// Recursively scan directory for .part files
487fn scan_dir_for_part_files(dir: &Path) -> Option<u64> {
488    if let Ok(entries) = std::fs::read_dir(dir) {
489        for entry in entries.filter_map(|e| e.ok()) {
490            let path = entry.path();
491            let path_str = path.to_string_lossy();
492
493            if path_str.ends_with(".part") || path_str.contains(".sync.part") {
494                if let Ok(metadata) = std::fs::metadata(&path) {
495                    return Some(metadata.len());
496                }
497            }
498
499            if path.is_dir() {
500                if let Some(size) = scan_dir_for_part_files(&path) {
501                    return Some(size);
502                }
503            }
504        }
505    }
506    None
507}
508
509#[async_trait::async_trait]
510impl ModelSourceResolver for DefaultModelSourceResolver {
511    async fn resolve(&self, id: &str, revision: Option<&str>) -> Result<ResolvedModelSource> {
512        if Self::is_local_path(id) {
513            return self.resolve_local(id).await;
514        }
515
516        self.resolve_huggingface(id, revision).await
517    }
518}
519
520#[cfg(test)]
521mod tests {
522    use super::*;
523
524    #[test]
525    fn model_source_runtime_env_parses_hf_cache_and_token() {
526        let env = ModelSourceRuntimeEnv::from_env_vars([
527            ("HF_HOME", "/tmp/hf"),
528            ("HF_TOKEN", "primary"),
529            ("HUGGING_FACE_HUB_TOKEN", "fallback"),
530        ]);
531
532        assert_eq!(env.hf_home.as_deref(), Some("/tmp/hf"));
533        assert_eq!(env.hf_token.as_deref(), Some("primary"));
534    }
535
536    #[test]
537    fn model_source_runtime_env_uses_hub_token_fallback() {
538        let env = ModelSourceRuntimeEnv::from_env_vars([("HUGGING_FACE_HUB_TOKEN", "fallback")]);
539
540        assert_eq!(env.hf_home, None);
541        assert_eq!(env.hf_token.as_deref(), Some("fallback"));
542    }
543}