1use std::{
2 collections::HashSet,
3 path::{Path, PathBuf},
4 sync::Arc,
5 time::Duration,
6};
7
8use anyhow::{Context, Result, bail};
9use futures_util::StreamExt;
10use serde::Deserialize;
11use tokio::{io::AsyncWriteExt, sync::Semaphore, task::JoinSet};
12
13#[derive(Debug, Clone)]
17pub enum HubProvider {
18 HuggingFace { token: Option<String> },
20 ModelScope { token: Option<String> },
22}
23
24impl HubProvider {
25 fn token(&self) -> Option<&str> {
27 match self {
28 Self::HuggingFace { token } | Self::ModelScope { token } => token.as_deref(),
29 }
30 }
31
32 fn default_revision(&self) -> &'static str {
34 match self {
35 Self::HuggingFace { .. } => "main",
36 Self::ModelScope { .. } => "master",
37 }
38 }
39}
40
41#[derive(Debug, Deserialize)]
44struct HfFile {
45 path: String,
46 size: u64,
47 r#type: String,
48}
49
50#[derive(Debug, Deserialize)]
51struct MsResponse {
52 #[serde(rename = "Success")]
53 success: bool,
54 #[serde(rename = "Data")]
55 data: Option<MsData>,
56}
57
58#[derive(Debug, Deserialize)]
59struct MsData {
60 #[serde(rename = "Files")]
61 files: Vec<MsFile>,
62}
63
64#[derive(Debug, Deserialize)]
65struct MsFile {
66 #[serde(rename = "Path")]
67 path: String,
68 #[serde(rename = "Size")]
69 size: u64,
70 #[serde(rename = "Type")]
71 r#type: String,
72}
73
74#[derive(Clone)]
76struct UnifiedFile {
77 path: String,
78 size: u64,
79 download_url: String,
80}
81
82pub struct DownloadOptions {
86 pub repo_id: String,
88 pub revision: Option<String>,
90 pub save_dir: PathBuf,
92 pub files: Option<Vec<String>>,
94}
95
96pub struct ModelDownloader {
98 client: reqwest::Client,
100 provider: HubProvider,
101 concurrency: usize,
102 max_retries: u32,
103}
104
105impl ModelDownloader {
106 const DEFAULT_CONCURRENCY: usize = 4;
107 const DEFAULT_MAX_RETRIES: u32 = 3;
108
109 pub fn new(provider: HubProvider) -> Result<Self> {
111 let client = Self::build_client(&provider)?;
112 Ok(Self {
113 client,
114 provider,
115 concurrency: Self::DEFAULT_CONCURRENCY,
116 max_retries: Self::DEFAULT_MAX_RETRIES,
117 })
118 }
119
120 pub fn with_concurrency(mut self, n: usize) -> Self {
122 self.concurrency = n.max(1);
123 self
124 }
125
126 pub fn with_max_retries(mut self, n: u32) -> Self {
128 self.max_retries = n;
129 self
130 }
131
132 pub async fn download(&self, options: DownloadOptions) -> Result<()> {
134 Self::validate_options(&options)?;
135
136 let revision = options
137 .revision
138 .as_deref()
139 .unwrap_or_else(|| self.provider.default_revision());
140
141 let model_dir = options
142 .repo_id
143 .split('/')
144 .fold(options.save_dir.clone(), |p, c| p.join(c));
145 tokio::fs::create_dir_all(&model_dir).await?;
146
147 let files = match &self.provider {
148 HubProvider::HuggingFace { .. } => self.get_hf_files(&options.repo_id, revision).await?,
149 HubProvider::ModelScope { .. } => self.get_ms_files(&options.repo_id, revision).await?,
150 };
151
152 let filter: Option<HashSet<String>> = options.files.map(|v| v.into_iter().collect());
154
155 let sem = Arc::new(Semaphore::new(self.concurrency));
156 let mut join_set: JoinSet<Result<()>> = JoinSet::new();
157
158 for file in files {
159 if let Some(ref set) = filter {
160 if !set.contains(&file.path) {
161 continue;
162 }
163 }
164
165 let dest = safe_join(&model_dir, &file.path)?;
167 let client = self.client.clone();
168 let sem = Arc::clone(&sem);
169 let max_retries = self.max_retries;
170
171 join_set.spawn(async move {
172 let _permit = sem.acquire().await.expect("信号量意外关闭");
174 with_retry(max_retries, || {
175 let c = client.clone();
176 let f = file.clone();
177 let d = dest.clone();
178 async move { download_single_file(c, f, d).await }
179 })
180 .await
181 });
182 }
183
184 while let Some(result) = join_set.join_next().await {
186 match result {
187 Ok(Ok(())) => {}
188 Ok(Err(e)) => {
189 join_set.abort_all();
190 return Err(e);
191 }
192 Err(e) if e.is_cancelled() => {}
194 Err(e) => {
195 join_set.abort_all();
196 bail!("下载任务发生 panic: {}", e);
197 }
198 }
199 }
200
201 Ok(())
202 }
203
204 fn build_client(provider: &HubProvider) -> Result<reqwest::Client> {
207 let mut headers = reqwest::header::HeaderMap::new();
208
209 headers.insert(
211 reqwest::header::USER_AGENT,
212 concat!("model-hub/", env!("CARGO_PKG_VERSION")).parse()?,
213 );
214
215 if let Some(token) = provider.token() {
217 headers.insert(reqwest::header::AUTHORIZATION, format!("Bearer {}", token).parse()?);
218 }
219
220 Ok(reqwest::Client::builder()
221 .default_headers(headers)
222 .build()?)
223 }
224
225 fn validate_options(options: &DownloadOptions) -> Result<()> {
226 if options.repo_id.is_empty() {
227 bail!("repo_id 不能为空");
228 }
229 if options.repo_id.contains("..") {
230 bail!("repo_id 含有非法字符 '..'");
231 }
232 if let Some(ref files) = options.files {
233 for path in files {
234 if path.contains("..") || path.starts_with('/') || path.starts_with('\\') {
235 bail!("files 列表中含有非法路径: {:?}", path);
236 }
237 }
238 }
239 Ok(())
240 }
241
242 async fn get_hf_files(&self, repo_id: &str, revision: &str) -> Result<Vec<UnifiedFile>> {
244 let base_url = std::env::var("HF_ENDPOINT").unwrap_or_else(|_| "https://huggingface.co".to_string());
245
246 let mut all_files = Vec::new();
247 let mut next_url: Option<String> = Some(format!(
249 "{}/api/models/{}/tree/{}?recursive=1",
250 base_url, repo_id, revision
251 ));
252
253 while let Some(url) = next_url.take() {
254 let resp = self.client.get(&url).send().await?;
255 if !resp.status().is_success() {
256 bail!("HuggingFace API 请求失败 (HTTP {}): {}", resp.status(), url);
257 }
258
259 next_url = resp
261 .headers()
262 .get(reqwest::header::LINK)
263 .and_then(|v| v.to_str().ok())
264 .and_then(parse_link_next);
265
266 let page: Vec<HfFile> = resp.json().await?;
267 all_files.extend(
268 page.into_iter()
269 .filter(|f| f.r#type == "file")
270 .map(|f| UnifiedFile {
271 download_url: format!("{}/{}/resolve/{}/{}", base_url, repo_id, revision, f.path),
272 path: f.path,
273 size: f.size,
274 }),
275 );
276 }
277
278 Ok(all_files)
279 }
280
281 async fn get_ms_files(&self, repo_id: &str, revision: &str) -> Result<Vec<UnifiedFile>> {
283 let url = format!(
284 "https://modelscope.cn/api/v1/models/{}/repo/files?Recursive=true&Revision={}",
285 repo_id, revision
286 );
287
288 let resp = self.client.get(&url).send().await?;
289 if !resp.status().is_success() {
290 bail!("ModelScope API 请求失败 (HTTP {})", resp.status());
291 }
292
293 let parsed: MsResponse = resp.json().await?;
294 if !parsed.success {
295 bail!("ModelScope API 返回失败状态");
296 }
297
298 let files = parsed.data.context("未获取到 ModelScope 文件数据")?.files;
299
300 Ok(files
301 .into_iter()
302 .filter(|f| f.r#type == "blob")
303 .map(|f| UnifiedFile {
304 download_url: format!(
305 "https://modelscope.cn/models/{}/resolve/{}/{}",
306 repo_id, revision, f.path
307 ),
308 path: f.path,
309 size: f.size,
310 })
311 .collect())
312 }
313}
314
315fn safe_join(base: &Path, file_path: &str) -> Result<PathBuf> {
321 let clean: PathBuf = file_path
322 .split('/')
323 .filter(|c| !c.is_empty() && *c != "." && *c != "..")
324 .collect();
325
326 let dest = base.join(&clean);
327
328 if !dest.starts_with(base) {
330 bail!("检测到非法路径(路径穿越): {:?}", file_path);
331 }
332 Ok(dest)
333}
334
335fn parse_link_next(header: &str) -> Option<String> {
339 header.split(',').find_map(|part| {
340 let mut seg = part.trim().splitn(2, ';');
341 let url_part = seg.next()?.trim();
342 let rel_part = seg.next()?.trim();
343 if rel_part == r#"rel="next""# {
344 Some(
345 url_part
346 .trim_start_matches('<')
347 .trim_end_matches('>')
348 .to_string(),
349 )
350 } else {
351 None
352 }
353 })
354}
355
356async fn with_retry<F, Fut>(max_retries: u32, mut f: F) -> Result<()>
361where
362 F: FnMut() -> Fut,
363 Fut: std::future::Future<Output = Result<()>>,
364{
365 for attempt in 0..max_retries {
367 if f().await.is_ok() {
368 return Ok(());
369 }
370 let secs = 2u64.saturating_pow(attempt).min(60);
371 tokio::time::sleep(Duration::from_secs(secs)).await;
372 }
373 f().await
375 .map_err(|e| e.context(format!("已重试 {} 次后仍失败", max_retries)))
376}
377
378async fn download_single_file(client: reqwest::Client, file_info: UnifiedFile, dest: PathBuf) -> Result<()> {
383 if let Some(parent) = dest.parent() {
384 tokio::fs::create_dir_all(parent).await?;
385 }
386
387 let existing_size = match tokio::fs::metadata(&dest).await {
389 Ok(meta) => {
390 let size = meta.len();
391 if size == file_info.size {
392 return Ok(()); }
394 size
395 }
396 Err(_) => 0,
397 };
398
399 let should_resume = existing_size > 0 && existing_size < file_info.size;
400
401 let req = if should_resume {
402 client
403 .get(&file_info.download_url)
404 .header("Range", format!("bytes={}-", existing_size))
405 } else {
406 client.get(&file_info.download_url)
407 };
408
409 let resp = req.send().await?;
410 let status = resp.status();
411
412 if !status.is_success() {
413 bail!("下载失败: {} (HTTP {})", file_info.path, status);
414 }
415
416 let append = should_resume && status == reqwest::StatusCode::PARTIAL_CONTENT;
418
419 let file = if append {
420 tokio::fs::OpenOptions::new()
421 .write(true)
422 .create(true)
423 .append(true)
424 .open(&dest)
425 .await
426 } else {
427 tokio::fs::OpenOptions::new()
428 .write(true)
429 .create(true)
430 .truncate(true)
431 .open(&dest)
432 .await
433 }
434 .with_context(|| format!("无法打开文件: {:?}", dest))?;
435
436 let mut writer = tokio::io::BufWriter::new(file);
437 let mut stream = resp.bytes_stream();
438
439 while let Some(chunk) = stream.next().await {
440 writer.write_all(&chunk?).await?;
441 }
442 writer.flush().await?;
443
444 Ok(())
445}