Skip to main content

model_hub/
lib.rs

1use std::{
2    collections::HashSet,
3    path::{Path, PathBuf},
4    sync::Arc,
5    time::Duration,
6};
7
8use anyhow::{Context, Result, bail};
9use futures_util::StreamExt;
10use serde::Deserialize;
11use tokio::{io::AsyncWriteExt, sync::Semaphore, task::JoinSet};
12
13// ── Provider ──────────────────────────────────────────────────────────────────
14
15/// 下载平台及鉴权信息。
16#[derive(Debug, Clone)]
17pub enum HubProvider {
18    /// Hugging Face 平台,可选传入 HF Token。
19    HuggingFace { token: Option<String> },
20    /// ModelScope 平台,可选传入 AccessToken。
21    ModelScope { token: Option<String> },
22}
23
24impl HubProvider {
25    /// 统一获取 token,消除重复分支逻辑。
26    fn token(&self) -> Option<&str> {
27        match self {
28            Self::HuggingFace { token } | Self::ModelScope { token } => token.as_deref(),
29        }
30    }
31
32    /// 各平台默认分支名称。
33    fn default_revision(&self) -> &'static str {
34        match self {
35            Self::HuggingFace { .. } => "main",
36            Self::ModelScope { .. } => "master",
37        }
38    }
39}
40
41// ── 内部数据结构 ───────────────────────────────────────────────────────────────
42
43#[derive(Debug, Deserialize)]
44struct HfFile {
45    path:   String,
46    size:   u64,
47    r#type: String,
48}
49
50#[derive(Debug, Deserialize)]
51struct MsResponse {
52    #[serde(rename = "Success")]
53    success: bool,
54    #[serde(rename = "Data")]
55    data:    Option<MsData>,
56}
57
58#[derive(Debug, Deserialize)]
59struct MsData {
60    #[serde(rename = "Files")]
61    files: Vec<MsFile>,
62}
63
64#[derive(Debug, Deserialize)]
65struct MsFile {
66    #[serde(rename = "Path")]
67    path:   String,
68    #[serde(rename = "Size")]
69    size:   u64,
70    #[serde(rename = "Type")]
71    r#type: String,
72}
73
74/// 平台无关的统一文件描述,需要 Clone 以便在 retry 闭包中复用。
75#[derive(Clone)]
76struct UnifiedFile {
77    path:         String,
78    size:         u64,
79    download_url: String,
80}
81
82// ── 公开 API ───────────────────────────────────────────────────────────────────
83
84/// 单次下载请求的参数。
85pub struct DownloadOptions {
86    /// 仓库 ID,例如 `"meta-llama/Llama-2-7b-hf"`。
87    pub repo_id:  String,
88    /// 分支、tag 或 commit hash。`None` 时使用平台默认分支。
89    pub revision: Option<String>,
90    /// 本地根目录,库会在其下自动创建 `<owner>--<model>` 子目录。
91    pub save_dir: PathBuf,
92    /// 允许下载的相对路径白名单,`None` 表示下载全部文件。
93    pub files:    Option<Vec<String>>,
94}
95
96/// 模型下载器,持有 HTTP 客户端与配置,可复用于多次下载。
97pub struct ModelDownloader {
98    /// reqwest::Client 内部已是 Arc,无需再套一层 Arc。
99    client:      reqwest::Client,
100    provider:    HubProvider,
101    concurrency: usize,
102    max_retries: u32,
103}
104
105impl ModelDownloader {
106    const DEFAULT_CONCURRENCY: usize = 4;
107    const DEFAULT_MAX_RETRIES: u32 = 3;
108
109    /// 为指定平台创建下载器。
110    pub fn new(provider: HubProvider) -> Result<Self> {
111        let client = Self::build_client(&provider)?;
112        Ok(Self {
113            client,
114            provider,
115            concurrency: Self::DEFAULT_CONCURRENCY,
116            max_retries: Self::DEFAULT_MAX_RETRIES,
117        })
118    }
119
120    /// 设置同时进行的最大下载数(默认 4,最小 1)。
121    pub fn with_concurrency(mut self, n: usize) -> Self {
122        self.concurrency = n.max(1);
123        self
124    }
125
126    /// 设置每个文件的最大重试次数(默认 3)。
127    pub fn with_max_retries(mut self, n: u32) -> Self {
128        self.max_retries = n;
129        self
130    }
131
132    /// 执行下载。
133    pub async fn download(&self, options: DownloadOptions) -> Result<()> {
134        Self::validate_options(&options)?;
135
136        let revision = options
137            .revision
138            .as_deref()
139            .unwrap_or_else(|| self.provider.default_revision());
140
141        let model_dir = options
142            .repo_id
143            .split('/')
144            .fold(options.save_dir.clone(), |p, c| p.join(c));
145        tokio::fs::create_dir_all(&model_dir).await?;
146
147        let files = match &self.provider {
148            HubProvider::HuggingFace { .. } => self.get_hf_files(&options.repo_id, revision).await?,
149            HubProvider::ModelScope { .. } => self.get_ms_files(&options.repo_id, revision).await?,
150        };
151
152        // 使用 HashSet 将过滤从 O(n×m) 降为 O(1)
153        let filter: Option<HashSet<String>> = options.files.map(|v| v.into_iter().collect());
154
155        let sem = Arc::new(Semaphore::new(self.concurrency));
156        let mut join_set: JoinSet<Result<()>> = JoinSet::new();
157
158        for file in files {
159            if let Some(ref set) = filter {
160                if !set.contains(&file.path) {
161                    continue;
162                }
163            }
164
165            // 路径穿越校验:在派发任务前提前验证,快速失败
166            let dest = safe_join(&model_dir, &file.path)?;
167            let client = self.client.clone();
168            let sem = Arc::clone(&sem);
169            let max_retries = self.max_retries;
170
171            join_set.spawn(async move {
172                // 通过信号量限制并发数,_permit 在任务结束时自动释放
173                let _permit = sem.acquire().await.expect("信号量意外关闭");
174                with_retry(max_retries, || {
175                    let c = client.clone();
176                    let f = file.clone();
177                    let d = dest.clone();
178                    async move { download_single_file(c, f, d).await }
179                })
180                .await
181            });
182        }
183
184        // 使用 JoinSet 收集结果;任意任务失败时立即 abort 其余任务
185        while let Some(result) = join_set.join_next().await {
186            match result {
187                Ok(Ok(())) => {}
188                Ok(Err(e)) => {
189                    join_set.abort_all();
190                    return Err(e);
191                }
192                // abort_all 后剩余任务以 Cancelled 结束,忽略即可
193                Err(e) if e.is_cancelled() => {}
194                Err(e) => {
195                    join_set.abort_all();
196                    bail!("下载任务发生 panic: {}", e);
197                }
198            }
199        }
200
201        Ok(())
202    }
203
204    // ── 私有方法 ───────────────────────────────────────────────────────────────
205
206    fn build_client(provider: &HubProvider) -> Result<reqwest::Client> {
207        let mut headers = reqwest::header::HeaderMap::new();
208
209        // 使用语义化 UA,而非伪造浏览器标识
210        headers.insert(
211            reqwest::header::USER_AGENT,
212            concat!("model-hub/", env!("CARGO_PKG_VERSION")).parse()?,
213        );
214
215        // 两个平台鉴权逻辑相同,统一处理,消除重复
216        if let Some(token) = provider.token() {
217            headers.insert(reqwest::header::AUTHORIZATION, format!("Bearer {}", token).parse()?);
218        }
219
220        Ok(reqwest::Client::builder()
221            .default_headers(headers)
222            .build()?)
223    }
224
225    fn validate_options(options: &DownloadOptions) -> Result<()> {
226        if options.repo_id.is_empty() {
227            bail!("repo_id 不能为空");
228        }
229        if options.repo_id.contains("..") {
230            bail!("repo_id 含有非法字符 '..'");
231        }
232        if let Some(ref files) = options.files {
233            for path in files {
234                if path.contains("..") || path.starts_with('/') || path.starts_with('\\') {
235                    bail!("files 列表中含有非法路径: {:?}", path);
236                }
237            }
238        }
239        Ok(())
240    }
241
242    /// 获取 HuggingFace 文件列表,支持递归子目录与 Link 分页。
243    async fn get_hf_files(&self, repo_id: &str, revision: &str) -> Result<Vec<UnifiedFile>> {
244        let base_url = std::env::var("HF_ENDPOINT").unwrap_or_else(|_| "https://huggingface.co".to_string());
245
246        let mut all_files = Vec::new();
247        // ?recursive=1 获取所有子目录文件
248        let mut next_url: Option<String> = Some(format!(
249            "{}/api/models/{}/tree/{}?recursive=1",
250            base_url, repo_id, revision
251        ));
252
253        while let Some(url) = next_url.take() {
254            let resp = self.client.get(&url).send().await?;
255            if !resp.status().is_success() {
256                bail!("HuggingFace API 请求失败 (HTTP {}): {}", resp.status(), url);
257            }
258
259            // 从 Link 响应头提取下一页 URL(分页处理)
260            next_url = resp
261                .headers()
262                .get(reqwest::header::LINK)
263                .and_then(|v| v.to_str().ok())
264                .and_then(parse_link_next);
265
266            let page: Vec<HfFile> = resp.json().await?;
267            all_files.extend(
268                page.into_iter()
269                    .filter(|f| f.r#type == "file")
270                    .map(|f| UnifiedFile {
271                        download_url: format!("{}/{}/resolve/{}/{}", base_url, repo_id, revision, f.path),
272                        path:         f.path,
273                        size:         f.size,
274                    }),
275            );
276        }
277
278        Ok(all_files)
279    }
280
281    /// 获取 ModelScope 文件列表。
282    async fn get_ms_files(&self, repo_id: &str, revision: &str) -> Result<Vec<UnifiedFile>> {
283        let url = format!(
284            "https://modelscope.cn/api/v1/models/{}/repo/files?Recursive=true&Revision={}",
285            repo_id, revision
286        );
287
288        let resp = self.client.get(&url).send().await?;
289        if !resp.status().is_success() {
290            bail!("ModelScope API 请求失败 (HTTP {})", resp.status());
291        }
292
293        let parsed: MsResponse = resp.json().await?;
294        if !parsed.success {
295            bail!("ModelScope API 返回失败状态");
296        }
297
298        let files = parsed.data.context("未获取到 ModelScope 文件数据")?.files;
299
300        Ok(files
301            .into_iter()
302            .filter(|f| f.r#type == "blob")
303            .map(|f| UnifiedFile {
304                download_url: format!(
305                    "https://modelscope.cn/models/{}/resolve/{}/{}",
306                    repo_id, revision, f.path
307                ),
308                path:         f.path,
309                size:         f.size,
310            })
311            .collect())
312    }
313}
314
315// ── 辅助函数 ───────────────────────────────────────────────────────────────────
316
317/// 安全路径拼接:过滤 `..` / 绝对路径组件,并在最终验证 dest 在 base 内。
318///
319/// 防止服务端返回恶意路径(路径穿越攻击)时写入 model_dir 之外的位置。
320fn safe_join(base: &Path, file_path: &str) -> Result<PathBuf> {
321    let clean: PathBuf = file_path
322        .split('/')
323        .filter(|c| !c.is_empty() && *c != "." && *c != "..")
324        .collect();
325
326    let dest = base.join(&clean);
327
328    // 双重保险:即使 filter 有遗漏,canonicalize 前的前缀检查也能兜底
329    if !dest.starts_with(base) {
330        bail!("检测到非法路径(路径穿越): {:?}", file_path);
331    }
332    Ok(dest)
333}
334
335/// 解析 HTTP `Link` 响应头,返回 `rel="next"` 对应的 URL。
336///
337/// 格式示例:`<https://...?cursor=xxx>; rel="next", <https://...>; rel="last"`
338fn parse_link_next(header: &str) -> Option<String> {
339    header.split(',').find_map(|part| {
340        let mut seg = part.trim().splitn(2, ';');
341        let url_part = seg.next()?.trim();
342        let rel_part = seg.next()?.trim();
343        if rel_part == r#"rel="next""# {
344            Some(
345                url_part
346                    .trim_start_matches('<')
347                    .trim_end_matches('>')
348                    .to_string(),
349            )
350        } else {
351            None
352        }
353    })
354}
355
356/// 带指数退避的重试包装器。
357///
358/// - 首次立即执行,失败后等待 1 → 2 → 4 → … 秒(上限 60 秒)再重试。
359/// - `max_retries = 0` 表示不重试,失败直接返回错误。
360async fn with_retry<F, Fut>(max_retries: u32, mut f: F) -> Result<()>
361where
362    F: FnMut() -> Fut,
363    Fut: std::future::Future<Output = Result<()>>,
364{
365    // 前 max_retries 次:失败后等待再重试
366    for attempt in 0..max_retries {
367        if f().await.is_ok() {
368            return Ok(());
369        }
370        let secs = 2u64.saturating_pow(attempt).min(60);
371        tokio::time::sleep(Duration::from_secs(secs)).await;
372    }
373    // 末次尝试:直接返回结果,编译器可静态证明此处一定返回,无需 unreachable!()
374    f().await
375        .map_err(|e| e.context(format!("已重试 {} 次后仍失败", max_retries)))
376}
377
378/// 下载单个文件,支持断点续传。
379///
380/// **续传安全性**:仅当服务端真实返回 `206 Partial Content` 时才以追加模式
381/// 写入;若服务端忽略 `Range` 头返回 `200`,则截断重写,防止文件静默损坏。
382async fn download_single_file(client: reqwest::Client, file_info: UnifiedFile, dest: PathBuf) -> Result<()> {
383    if let Some(parent) = dest.parent() {
384        tokio::fs::create_dir_all(parent).await?;
385    }
386
387    // 检查本地已有文件大小(使用 tokio::fs 避免在 async 中阻塞线程)
388    let existing_size = match tokio::fs::metadata(&dest).await {
389        Ok(meta) => {
390            let size = meta.len();
391            if size == file_info.size {
392                return Ok(()); // 文件已完整,直接跳过
393            }
394            size
395        }
396        Err(_) => 0,
397    };
398
399    let should_resume = existing_size > 0 && existing_size < file_info.size;
400
401    let req = if should_resume {
402        client
403            .get(&file_info.download_url)
404            .header("Range", format!("bytes={}-", existing_size))
405    } else {
406        client.get(&file_info.download_url)
407    };
408
409    let resp = req.send().await?;
410    let status = resp.status();
411
412    if !status.is_success() {
413        bail!("下载失败: {} (HTTP {})", file_info.path, status);
414    }
415
416    // 关键:以实际响应状态码决定写入模式,而非以请求意图决定
417    let append = should_resume && status == reqwest::StatusCode::PARTIAL_CONTENT;
418
419    let file = if append {
420        tokio::fs::OpenOptions::new()
421            .write(true)
422            .create(true)
423            .append(true)
424            .open(&dest)
425            .await
426    } else {
427        tokio::fs::OpenOptions::new()
428            .write(true)
429            .create(true)
430            .truncate(true)
431            .open(&dest)
432            .await
433    }
434    .with_context(|| format!("无法打开文件: {:?}", dest))?;
435
436    let mut writer = tokio::io::BufWriter::new(file);
437    let mut stream = resp.bytes_stream();
438
439    while let Some(chunk) = stream.next().await {
440        writer.write_all(&chunk?).await?;
441    }
442    writer.flush().await?;
443
444    Ok(())
445}