1use ferrum_types::{FerrumError, ModelSource, Result};
4use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
5use std::path::{Path, PathBuf};
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::sync::{Arc, OnceLock};
8use std::time::{Duration, Instant};
9use tracing::{debug, info, warn};
10
11#[derive(Debug, Clone, PartialEq, Eq)]
12struct ModelSourceRuntimeEnv {
13 hf_home: Option<String>,
14 hf_token: Option<String>,
15}
16
17impl ModelSourceRuntimeEnv {
18 fn from_env() -> Self {
19 Self::from_env_vars(std::env::vars())
20 }
21
22 fn from_env_vars<I, K, V>(vars: I) -> Self
23 where
24 I: IntoIterator<Item = (K, V)>,
25 K: AsRef<str>,
26 V: Into<String>,
27 {
28 let mut hf_home = None;
29 let mut hf_token = None;
30 let mut hf_hub_token = None;
31
32 for (key, value) in vars {
33 let value = value.into();
34 match key.as_ref() {
35 "HF_HOME" => hf_home = Some(value),
36 "HF_TOKEN" => hf_token = Some(value),
37 "HUGGING_FACE_HUB_TOKEN" => hf_hub_token = Some(value),
38 _ => {}
39 }
40 }
41
42 Self {
43 hf_home,
44 hf_token: hf_token.or(hf_hub_token),
45 }
46 }
47}
48
49fn model_source_runtime_env() -> &'static ModelSourceRuntimeEnv {
50 static CONFIG: OnceLock<ModelSourceRuntimeEnv> = OnceLock::new();
51 CONFIG.get_or_init(ModelSourceRuntimeEnv::from_env)
52}
53
54#[derive(Debug, Clone)]
56pub struct ModelSourceConfig {
57 pub cache_dir: Option<PathBuf>,
58 pub hf_token: Option<String>,
59 pub offline_mode: bool,
60 pub max_retries: usize,
61 pub download_timeout: u64,
62 pub use_file_lock: bool,
63}
64
65impl Default for ModelSourceConfig {
66 fn default() -> Self {
67 let default_cache = model_source_runtime_env()
69 .hf_home
70 .clone()
71 .or_else(|| {
72 dirs::home_dir()
73 .map(|h| h.join(".cache/huggingface"))
74 .and_then(|p| p.to_str().map(String::from))
75 })
76 .map(PathBuf::from);
77
78 Self {
79 cache_dir: default_cache,
80 hf_token: Self::get_hf_token(),
81 offline_mode: false,
82 max_retries: 3,
83 download_timeout: 300,
84 use_file_lock: true,
85 }
86 }
87}
88
89impl ModelSourceConfig {
90 pub fn get_hf_token() -> Option<String> {
91 model_source_runtime_env().hf_token.clone()
92 }
93}
94
95#[derive(Debug, Clone, Copy, PartialEq, Eq)]
96pub enum ModelFormat {
97 SafeTensors,
98 PyTorchBin,
99 GGUF,
100 Unknown,
101}
102
103#[derive(Debug, Clone)]
104pub struct ResolvedModelSource {
105 pub original: String,
106 pub local_path: PathBuf,
107 pub format: ModelFormat,
108 pub from_cache: bool,
109}
110
111impl From<ResolvedModelSource> for ModelSource {
112 fn from(value: ResolvedModelSource) -> Self {
113 ModelSource::Local(value.local_path.display().to_string())
114 }
115}
116
117#[async_trait::async_trait]
118pub trait ModelSourceResolver: Send + Sync {
119 async fn resolve(&self, id: &str, revision: Option<&str>) -> Result<ResolvedModelSource>;
120}
121
122pub struct DefaultModelSourceResolver {
123 _config: ModelSourceConfig,
124 api: Api,
125}
126
127impl DefaultModelSourceResolver {
128 pub fn new(config: ModelSourceConfig) -> Self {
129 let mut builder = ApiBuilder::new();
130
131 if let Some(cache_dir) = &config.cache_dir {
132 builder = builder.with_cache_dir(cache_dir.clone());
133 }
134
135 if let Some(token) = &config.hf_token {
136 builder = builder.with_token(Some(token.clone()));
137 }
138
139 let api = builder.build().unwrap_or_else(|e| {
140 warn!("Failed to build HF API: {}, using default", e);
141 Api::new().expect("Failed to create default HF API")
142 });
143
144 Self {
145 _config: config,
146 api,
147 }
148 }
149
150 fn is_local_path(id: &str) -> bool {
151 Path::new(id).exists()
152 }
153
154 fn detect_format(path: &Path) -> ModelFormat {
155 if path.join("model.safetensors").exists()
156 || path.join("model.safetensors.index.json").exists()
157 {
158 ModelFormat::SafeTensors
159 } else if path.join("pytorch_model.bin").exists() {
160 ModelFormat::PyTorchBin
161 } else {
162 ModelFormat::Unknown
163 }
164 }
165
166 async fn resolve_local(&self, path: &str) -> Result<ResolvedModelSource> {
167 let path_buf = PathBuf::from(path);
168
169 if !path_buf.exists() {
170 return Err(FerrumError::model(format!("Path does not exist: {}", path)));
171 }
172
173 let format = Self::detect_format(&path_buf);
174
175 Ok(ResolvedModelSource {
176 original: path.to_string(),
177 local_path: path_buf,
178 format,
179 from_cache: true,
180 })
181 }
182
183 async fn download_with_monitor(
185 &self,
186 repo: &ApiRepo,
187 filename: &str,
188 expected_cache_dir: &Path,
189 ) -> Result<PathBuf> {
190 info!("📥 下载中: {}...", filename);
191
192 let done = Arc::new(AtomicBool::new(false));
193 let done_clone = done.clone();
194 let filename_str = filename.to_string();
195
196 let monitor_task = tokio::spawn({
198 let done = done.clone();
199 let filename = filename_str.clone();
200 let cache_dir = expected_cache_dir.to_path_buf();
201
202 async move {
203 tokio::time::sleep(Duration::from_millis(1000)).await;
204
205 let start_time = Instant::now();
206 let mut last_size = 0u64;
207 let mut last_time = Instant::now();
208 let mut last_print = Instant::now();
209
210 while !done.load(Ordering::SeqCst) {
211 if let Some(current_size) = find_downloading_file(&cache_dir, &filename) {
213 let elapsed_since_last = last_time.elapsed().as_secs_f64();
214
215 if elapsed_since_last > 0.5 && current_size > last_size {
216 let delta = current_size - last_size;
217 let speed_mbps = delta as f64 / elapsed_since_last / 1024.0 / 1024.0;
218 let current_mb = current_size as f64 / 1024.0 / 1024.0;
219
220 if last_print.elapsed().as_secs() >= 2 {
222 info!(
223 " 📊 已下载: {:.2} MB (速度: {:.1} MB/s)",
224 current_mb, speed_mbps
225 );
226 last_print = Instant::now();
227 }
228
229 last_size = current_size;
230 last_time = Instant::now();
231 }
232 }
233
234 tokio::time::sleep(Duration::from_millis(500)).await;
235 }
236
237 let total_time = start_time.elapsed().as_secs_f64();
239 if last_size > 0 && total_time > 0.0 {
240 let avg_speed = last_size as f64 / total_time / 1024.0 / 1024.0;
241 info!(
242 " ✅ 下载完成: {:.2} MB (平均速度: {:.1} MB/s, 耗时: {:.1}s)",
243 last_size as f64 / 1024.0 / 1024.0,
244 avg_speed,
245 total_time
246 );
247 }
248 }
249 });
250
251 let path = repo
253 .get(&filename_str)
254 .await
255 .map_err(|e| FerrumError::model(format!("Download failed: {}", e)))?;
256
257 done_clone.store(true, Ordering::SeqCst);
259
260 let _ = monitor_task.await;
262
263 Ok(path)
264 }
265
266 async fn resolve_huggingface(
267 &self,
268 repo_id: &str,
269 revision: Option<&str>,
270 ) -> Result<ResolvedModelSource> {
271 info!("🔍 正在解析模型: {}", repo_id);
272
273 let repo = if let Some(rev) = revision {
274 self.api.repo(hf_hub::Repo::with_revision(
275 repo_id.to_string(),
276 hf_hub::RepoType::Model,
277 rev.to_string(),
278 ))
279 } else {
280 self.api.repo(hf_hub::Repo::new(
281 repo_id.to_string(),
282 hf_hub::RepoType::Model,
283 ))
284 };
285
286 info!("📥 下载中: config.json...");
288 let config_path = repo
289 .get("config.json")
290 .await
291 .map_err(|e| FerrumError::model(format!("Failed to download config: {}", e)))?;
292
293 info!("✅ config.json 下载完成");
294
295 let model_dir = config_path
296 .parent()
297 .ok_or_else(|| FerrumError::model("Invalid cache path"))?
298 .to_path_buf();
299
300 info!("📁 缓存目录: {:?}", model_dir);
301
302 self.download_tokenizer_files(&repo).await?;
304
305 let format = self.download_weights(&repo, &model_dir).await?;
307
308 Ok(ResolvedModelSource {
309 original: repo_id.to_string(),
310 local_path: model_dir,
311 format,
312 from_cache: false,
313 })
314 }
315
316 async fn download_tokenizer_files(&self, repo: &ApiRepo) -> Result<()> {
317 info!("📥 下载 tokenizer 文件...");
318
319 let tokenizer_files = vec![
321 "tokenizer.json",
322 "tokenizer_config.json",
323 "vocab.json",
324 "merges.txt",
325 "special_tokens_map.json",
326 ];
327
328 let mut downloaded_count = 0;
329 for filename in &tokenizer_files {
330 match repo.get(filename).await {
331 Ok(_path) => {
332 info!(" ✅ {}", filename);
333 downloaded_count += 1;
334 }
335 Err(e) => {
336 debug!(" ⏭️ {} (optional): {}", filename, e);
337 }
338 }
339 }
340
341 if downloaded_count > 0 {
342 info!("✅ Tokenizer 文件下载完成 ({} 个文件)", downloaded_count);
343 } else {
344 warn!("⚠️ 未找到 tokenizer 文件,可能影响推理");
345 }
346
347 Ok(())
348 }
349
350 async fn download_weights(&self, repo: &ApiRepo, model_dir: &Path) -> Result<ModelFormat> {
351 info!("🔍 检查 model.safetensors...");
353 match self
354 .download_with_monitor(repo, "model.safetensors", model_dir)
355 .await
356 {
357 Ok(path) => {
358 if let Ok(metadata) = std::fs::metadata(&path) {
359 info!(
360 "✅ model.safetensors 完成 ({:.2} GB)",
361 metadata.len() as f64 / 1e9
362 );
363 }
364 return Ok(ModelFormat::SafeTensors);
365 }
366 Err(e) => debug!("model.safetensors not found: {}", e),
367 }
368
369 info!("🔍 检查分片模型...");
371 match repo.get("model.safetensors.index.json").await {
372 Ok(index_path) => {
373 info!("✅ 发现分片 SafeTensors 模型");
374
375 let content = std::fs::read_to_string(&index_path)
376 .map_err(|e| FerrumError::io(format!("Failed to read index: {}", e)))?;
377
378 let index: serde_json::Value = serde_json::from_str(&content)
379 .map_err(|e| FerrumError::model(format!("Failed to parse index: {}", e)))?;
380
381 if let Some(weight_map) = index.get("weight_map").and_then(|w| w.as_object()) {
382 let shards: std::collections::HashSet<_> =
383 weight_map.values().filter_map(|v| v.as_str()).collect();
384
385 let total = shards.len();
386 info!("📦 需要下载 {} 个分片", total);
387
388 let mut total_bytes = 0u64;
389 for (i, shard) in shards.iter().enumerate() {
390 info!("📥 [{}/{}] {}", i + 1, total, shard);
391
392 let shard_path = self.download_with_monitor(repo, shard, model_dir).await?;
393
394 if let Ok(meta) = std::fs::metadata(&shard_path) {
395 let size = meta.len();
396 total_bytes += size;
397 info!(
398 "📊 进度: [{}/{}] 分片, 累计 {:.2} GB",
399 i + 1,
400 total,
401 total_bytes as f64 / 1e9
402 );
403 }
404 }
405
406 info!(
407 "🎉 全部下载完成! 总大小: {:.2} GB",
408 total_bytes as f64 / 1e9
409 );
410 }
411
412 return Ok(ModelFormat::SafeTensors);
413 }
414 Err(e) => debug!("Sharded model not found: {}", e),
415 }
416
417 info!("🔍 检查 pytorch_model.bin...");
419 match self
420 .download_with_monitor(repo, "pytorch_model.bin", model_dir)
421 .await
422 {
423 Ok(path) => {
424 warn!("⚠️ 使用 PyTorch 格式 (推荐使用 SafeTensors)");
425 if let Ok(meta) = std::fs::metadata(&path) {
426 info!(
427 "✅ pytorch_model.bin 完成 ({:.2} GB)",
428 meta.len() as f64 / 1e9
429 );
430 }
431 return Ok(ModelFormat::PyTorchBin);
432 }
433 Err(e) => debug!("pytorch_model.bin not found: {}", e),
434 }
435
436 if Self::detect_format(model_dir) == ModelFormat::GGUF {
437 return Ok(ModelFormat::GGUF);
438 }
439
440 Err(FerrumError::model("未找到支持的模型格式"))
441 }
442}
443
444fn find_downloading_file(cache_dir: &Path, _filename: &str) -> Option<u64> {
446 if let Ok(entries) = std::fs::read_dir(cache_dir.join("blobs")) {
451 for entry in entries.filter_map(|e| e.ok()) {
452 let path = entry.path();
453 let path_str = path.to_string_lossy();
454
455 if path_str.ends_with(".part") || path_str.contains(".sync.part") {
456 if let Ok(metadata) = std::fs::metadata(&path) {
457 return Some(metadata.len());
458 }
459 }
460 }
461 }
462
463 let mut current = cache_dir.to_path_buf();
465 for _ in 0..3 {
466 if let Ok(entries) = std::fs::read_dir(¤t) {
467 for entry in entries.filter_map(|e| e.ok()) {
468 if entry.path().is_dir() {
469 if let Some(size) = scan_dir_for_part_files(&entry.path()) {
470 return Some(size);
471 }
472 }
473 }
474 }
475
476 if let Some(parent) = current.parent() {
477 current = parent.to_path_buf();
478 } else {
479 break;
480 }
481 }
482
483 None
484}
485
486fn scan_dir_for_part_files(dir: &Path) -> Option<u64> {
488 if let Ok(entries) = std::fs::read_dir(dir) {
489 for entry in entries.filter_map(|e| e.ok()) {
490 let path = entry.path();
491 let path_str = path.to_string_lossy();
492
493 if path_str.ends_with(".part") || path_str.contains(".sync.part") {
494 if let Ok(metadata) = std::fs::metadata(&path) {
495 return Some(metadata.len());
496 }
497 }
498
499 if path.is_dir() {
500 if let Some(size) = scan_dir_for_part_files(&path) {
501 return Some(size);
502 }
503 }
504 }
505 }
506 None
507}
508
509#[async_trait::async_trait]
510impl ModelSourceResolver for DefaultModelSourceResolver {
511 async fn resolve(&self, id: &str, revision: Option<&str>) -> Result<ResolvedModelSource> {
512 if Self::is_local_path(id) {
513 return self.resolve_local(id).await;
514 }
515
516 self.resolve_huggingface(id, revision).await
517 }
518}
519
520#[cfg(test)]
521mod tests {
522 use super::*;
523
524 #[test]
525 fn model_source_runtime_env_parses_hf_cache_and_token() {
526 let env = ModelSourceRuntimeEnv::from_env_vars([
527 ("HF_HOME", "/tmp/hf"),
528 ("HF_TOKEN", "primary"),
529 ("HUGGING_FACE_HUB_TOKEN", "fallback"),
530 ]);
531
532 assert_eq!(env.hf_home.as_deref(), Some("/tmp/hf"));
533 assert_eq!(env.hf_token.as_deref(), Some("primary"));
534 }
535
536 #[test]
537 fn model_source_runtime_env_uses_hub_token_fallback() {
538 let env = ModelSourceRuntimeEnv::from_env_vars([("HUGGING_FACE_HUB_TOKEN", "fallback")]);
539
540 assert_eq!(env.hf_home, None);
541 assert_eq!(env.hf_token.as_deref(), Some("fallback"));
542 }
543}