1use ferrum_types::{FerrumError, ModelSource, Result};
4use hf_hub::api::tokio::{Api, ApiBuilder, ApiRepo};
5use std::path::{Path, PathBuf};
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use tracing::{debug, info, warn};
10
11#[derive(Debug, Clone)]
13pub struct ModelSourceConfig {
14 pub cache_dir: Option<PathBuf>,
15 pub hf_token: Option<String>,
16 pub offline_mode: bool,
17 pub max_retries: usize,
18 pub download_timeout: u64,
19 pub use_file_lock: bool,
20}
21
22impl Default for ModelSourceConfig {
23 fn default() -> Self {
24 let default_cache = std::env::var("HF_HOME")
26 .ok()
27 .or_else(|| {
28 dirs::home_dir()
29 .map(|h| h.join(".cache/huggingface"))
30 .and_then(|p| p.to_str().map(String::from))
31 })
32 .map(PathBuf::from);
33
34 Self {
35 cache_dir: default_cache,
36 hf_token: Self::get_hf_token(),
37 offline_mode: false,
38 max_retries: 3,
39 download_timeout: 300,
40 use_file_lock: true,
41 }
42 }
43}
44
45impl ModelSourceConfig {
46 pub fn get_hf_token() -> Option<String> {
47 std::env::var("HF_TOKEN")
48 .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
49 .ok()
50 }
51}
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub enum ModelFormat {
55 SafeTensors,
56 PyTorchBin,
57 GGUF,
58 Unknown,
59}
60
61#[derive(Debug, Clone)]
62pub struct ResolvedModelSource {
63 pub original: String,
64 pub local_path: PathBuf,
65 pub format: ModelFormat,
66 pub from_cache: bool,
67}
68
69impl From<ResolvedModelSource> for ModelSource {
70 fn from(value: ResolvedModelSource) -> Self {
71 ModelSource::Local(value.local_path.display().to_string())
72 }
73}
74
75#[async_trait::async_trait]
76pub trait ModelSourceResolver: Send + Sync {
77 async fn resolve(&self, id: &str, revision: Option<&str>) -> Result<ResolvedModelSource>;
78}
79
80pub struct DefaultModelSourceResolver {
81 _config: ModelSourceConfig,
82 api: Api,
83}
84
85impl DefaultModelSourceResolver {
86 pub fn new(config: ModelSourceConfig) -> Self {
87 let mut builder = ApiBuilder::new();
88
89 if let Some(cache_dir) = &config.cache_dir {
90 builder = builder.with_cache_dir(cache_dir.clone());
91 }
92
93 if let Some(token) = &config.hf_token {
94 builder = builder.with_token(Some(token.clone()));
95 }
96
97 let api = builder.build().unwrap_or_else(|e| {
98 warn!("Failed to build HF API: {}, using default", e);
99 Api::new().expect("Failed to create default HF API")
100 });
101
102 Self {
103 _config: config,
104 api,
105 }
106 }
107
108 fn is_local_path(id: &str) -> bool {
109 Path::new(id).exists()
110 }
111
112 fn detect_format(path: &Path) -> ModelFormat {
113 if path.join("model.safetensors").exists()
114 || path.join("model.safetensors.index.json").exists()
115 {
116 ModelFormat::SafeTensors
117 } else if path.join("pytorch_model.bin").exists() {
118 ModelFormat::PyTorchBin
119 } else {
120 ModelFormat::Unknown
121 }
122 }
123
124 async fn resolve_local(&self, path: &str) -> Result<ResolvedModelSource> {
125 let path_buf = PathBuf::from(path);
126
127 if !path_buf.exists() {
128 return Err(FerrumError::model(format!("Path does not exist: {}", path)));
129 }
130
131 let format = Self::detect_format(&path_buf);
132
133 Ok(ResolvedModelSource {
134 original: path.to_string(),
135 local_path: path_buf,
136 format,
137 from_cache: true,
138 })
139 }
140
141 async fn download_with_monitor(
143 &self,
144 repo: &ApiRepo,
145 filename: &str,
146 expected_cache_dir: &Path,
147 ) -> Result<PathBuf> {
148 info!("📥 下载中: {}...", filename);
149
150 let done = Arc::new(AtomicBool::new(false));
151 let done_clone = done.clone();
152 let filename_str = filename.to_string();
153
154 let monitor_task = tokio::spawn({
156 let done = done.clone();
157 let filename = filename_str.clone();
158 let cache_dir = expected_cache_dir.to_path_buf();
159
160 async move {
161 tokio::time::sleep(Duration::from_millis(1000)).await;
162
163 let start_time = Instant::now();
164 let mut last_size = 0u64;
165 let mut last_time = Instant::now();
166 let mut last_print = Instant::now();
167
168 while !done.load(Ordering::SeqCst) {
169 if let Some(current_size) = find_downloading_file(&cache_dir, &filename) {
171 let elapsed_since_last = last_time.elapsed().as_secs_f64();
172
173 if elapsed_since_last > 0.5 && current_size > last_size {
174 let delta = current_size - last_size;
175 let speed_mbps = delta as f64 / elapsed_since_last / 1024.0 / 1024.0;
176 let current_mb = current_size as f64 / 1024.0 / 1024.0;
177
178 if last_print.elapsed().as_secs() >= 2 {
180 info!(
181 " 📊 已下载: {:.2} MB (速度: {:.1} MB/s)",
182 current_mb, speed_mbps
183 );
184 last_print = Instant::now();
185 }
186
187 last_size = current_size;
188 last_time = Instant::now();
189 }
190 }
191
192 tokio::time::sleep(Duration::from_millis(500)).await;
193 }
194
195 let total_time = start_time.elapsed().as_secs_f64();
197 if last_size > 0 && total_time > 0.0 {
198 let avg_speed = last_size as f64 / total_time / 1024.0 / 1024.0;
199 info!(
200 " ✅ 下载完成: {:.2} MB (平均速度: {:.1} MB/s, 耗时: {:.1}s)",
201 last_size as f64 / 1024.0 / 1024.0,
202 avg_speed,
203 total_time
204 );
205 }
206 }
207 });
208
209 let path = repo
211 .get(&filename_str)
212 .await
213 .map_err(|e| FerrumError::model(format!("Download failed: {}", e)))?;
214
215 done_clone.store(true, Ordering::SeqCst);
217
218 let _ = monitor_task.await;
220
221 Ok(path)
222 }
223
224 async fn resolve_huggingface(
225 &self,
226 repo_id: &str,
227 revision: Option<&str>,
228 ) -> Result<ResolvedModelSource> {
229 info!("🔍 正在解析模型: {}", repo_id);
230
231 let repo = if let Some(rev) = revision {
232 self.api.repo(hf_hub::Repo::with_revision(
233 repo_id.to_string(),
234 hf_hub::RepoType::Model,
235 rev.to_string(),
236 ))
237 } else {
238 self.api.repo(hf_hub::Repo::new(
239 repo_id.to_string(),
240 hf_hub::RepoType::Model,
241 ))
242 };
243
244 info!("📥 下载中: config.json...");
246 let config_path = repo
247 .get("config.json")
248 .await
249 .map_err(|e| FerrumError::model(format!("Failed to download config: {}", e)))?;
250
251 info!("✅ config.json 下载完成");
252
253 let model_dir = config_path
254 .parent()
255 .ok_or_else(|| FerrumError::model("Invalid cache path"))?
256 .to_path_buf();
257
258 info!("📁 缓存目录: {:?}", model_dir);
259
260 self.download_tokenizer_files(&repo).await?;
262
263 let format = self.download_weights(&repo, &model_dir).await?;
265
266 Ok(ResolvedModelSource {
267 original: repo_id.to_string(),
268 local_path: model_dir,
269 format,
270 from_cache: false,
271 })
272 }
273
274 async fn download_tokenizer_files(&self, repo: &ApiRepo) -> Result<()> {
275 info!("📥 下载 tokenizer 文件...");
276
277 let tokenizer_files = vec![
279 "tokenizer.json",
280 "tokenizer_config.json",
281 "vocab.json",
282 "merges.txt",
283 "special_tokens_map.json",
284 ];
285
286 let mut downloaded_count = 0;
287 for filename in &tokenizer_files {
288 match repo.get(filename).await {
289 Ok(_path) => {
290 info!(" ✅ {}", filename);
291 downloaded_count += 1;
292 }
293 Err(e) => {
294 debug!(" ⏭️ {} (optional): {}", filename, e);
295 }
296 }
297 }
298
299 if downloaded_count > 0 {
300 info!("✅ Tokenizer 文件下载完成 ({} 个文件)", downloaded_count);
301 } else {
302 warn!("⚠️ 未找到 tokenizer 文件,可能影响推理");
303 }
304
305 Ok(())
306 }
307
308 async fn download_weights(&self, repo: &ApiRepo, model_dir: &Path) -> Result<ModelFormat> {
309 info!("🔍 检查 model.safetensors...");
311 match self
312 .download_with_monitor(repo, "model.safetensors", model_dir)
313 .await
314 {
315 Ok(path) => {
316 if let Ok(metadata) = std::fs::metadata(&path) {
317 info!(
318 "✅ model.safetensors 完成 ({:.2} GB)",
319 metadata.len() as f64 / 1e9
320 );
321 }
322 return Ok(ModelFormat::SafeTensors);
323 }
324 Err(e) => debug!("model.safetensors not found: {}", e),
325 }
326
327 info!("🔍 检查分片模型...");
329 match repo.get("model.safetensors.index.json").await {
330 Ok(index_path) => {
331 info!("✅ 发现分片 SafeTensors 模型");
332
333 let content = std::fs::read_to_string(&index_path)
334 .map_err(|e| FerrumError::io(format!("Failed to read index: {}", e)))?;
335
336 let index: serde_json::Value = serde_json::from_str(&content)
337 .map_err(|e| FerrumError::model(format!("Failed to parse index: {}", e)))?;
338
339 if let Some(weight_map) = index.get("weight_map").and_then(|w| w.as_object()) {
340 let shards: std::collections::HashSet<_> =
341 weight_map.values().filter_map(|v| v.as_str()).collect();
342
343 let total = shards.len();
344 info!("📦 需要下载 {} 个分片", total);
345
346 let mut total_bytes = 0u64;
347 for (i, shard) in shards.iter().enumerate() {
348 info!("📥 [{}/{}] {}", i + 1, total, shard);
349
350 let shard_path = self.download_with_monitor(repo, shard, model_dir).await?;
351
352 if let Ok(meta) = std::fs::metadata(&shard_path) {
353 let size = meta.len();
354 total_bytes += size;
355 info!(
356 "📊 进度: [{}/{}] 分片, 累计 {:.2} GB",
357 i + 1,
358 total,
359 total_bytes as f64 / 1e9
360 );
361 }
362 }
363
364 info!(
365 "🎉 全部下载完成! 总大小: {:.2} GB",
366 total_bytes as f64 / 1e9
367 );
368 }
369
370 return Ok(ModelFormat::SafeTensors);
371 }
372 Err(e) => debug!("Sharded model not found: {}", e),
373 }
374
375 info!("🔍 检查 pytorch_model.bin...");
377 match self
378 .download_with_monitor(repo, "pytorch_model.bin", model_dir)
379 .await
380 {
381 Ok(path) => {
382 warn!("⚠️ 使用 PyTorch 格式 (推荐使用 SafeTensors)");
383 if let Ok(meta) = std::fs::metadata(&path) {
384 info!(
385 "✅ pytorch_model.bin 完成 ({:.2} GB)",
386 meta.len() as f64 / 1e9
387 );
388 }
389 return Ok(ModelFormat::PyTorchBin);
390 }
391 Err(e) => debug!("pytorch_model.bin not found: {}", e),
392 }
393
394 if Self::detect_format(model_dir) == ModelFormat::GGUF {
395 return Ok(ModelFormat::GGUF);
396 }
397
398 Err(FerrumError::model("未找到支持的模型格式"))
399 }
400}
401
402fn find_downloading_file(cache_dir: &Path, _filename: &str) -> Option<u64> {
404 if let Ok(entries) = std::fs::read_dir(cache_dir.join("blobs")) {
409 for entry in entries.filter_map(|e| e.ok()) {
410 let path = entry.path();
411 let path_str = path.to_string_lossy();
412
413 if path_str.ends_with(".part") || path_str.contains(".sync.part") {
414 if let Ok(metadata) = std::fs::metadata(&path) {
415 return Some(metadata.len());
416 }
417 }
418 }
419 }
420
421 let mut current = cache_dir.to_path_buf();
423 for _ in 0..3 {
424 if let Ok(entries) = std::fs::read_dir(¤t) {
425 for entry in entries.filter_map(|e| e.ok()) {
426 if entry.path().is_dir() {
427 if let Some(size) = scan_dir_for_part_files(&entry.path()) {
428 return Some(size);
429 }
430 }
431 }
432 }
433
434 if let Some(parent) = current.parent() {
435 current = parent.to_path_buf();
436 } else {
437 break;
438 }
439 }
440
441 None
442}
443
444fn scan_dir_for_part_files(dir: &Path) -> Option<u64> {
446 if let Ok(entries) = std::fs::read_dir(dir) {
447 for entry in entries.filter_map(|e| e.ok()) {
448 let path = entry.path();
449 let path_str = path.to_string_lossy();
450
451 if path_str.ends_with(".part") || path_str.contains(".sync.part") {
452 if let Ok(metadata) = std::fs::metadata(&path) {
453 return Some(metadata.len());
454 }
455 }
456
457 if path.is_dir() {
458 if let Some(size) = scan_dir_for_part_files(&path) {
459 return Some(size);
460 }
461 }
462 }
463 }
464 None
465}
466
467#[async_trait::async_trait]
468impl ModelSourceResolver for DefaultModelSourceResolver {
469 async fn resolve(&self, id: &str, revision: Option<&str>) -> Result<ResolvedModelSource> {
470 if Self::is_local_path(id) {
471 return self.resolve_local(id).await;
472 }
473
474 self.resolve_huggingface(id, revision).await
475 }
476}