modelscope_ng/
lib.rs

1use anyhow::{Context, bail};
2use async_trait::async_trait;
3use futures_util::StreamExt;
4use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::env::home_dir;
8use std::fs;
9use std::io::{BufWriter, Seek, Write};
10use std::path::{Path, PathBuf};
11use std::sync::{Arc, Mutex};
12
13/// 进度回调 trait
14#[async_trait]
15pub trait ProgressCallback: Send + Sync {
16    /// 当文件下载开始时调用
17    async fn on_file_start(&self, file_name: &str, file_size: u64);
18    
19    /// 当文件下载进度更新时调用
20    async fn on_file_progress(&self, file_name: &str, downloaded: u64, total: u64);
21    
22    /// 当文件下载完成时调用
23    async fn on_file_complete(&self, file_name: &str);
24    
25    /// 当文件下载失败时调用
26    async fn on_file_error(&self, file_name: &str, error: &str);
27}
28
29/// 默认的进度回调实现(使用进度条)
30pub struct ProgressBarCallback {
31    bars: Arc<MultiProgress>,
32    progress_bars: Arc<Mutex<HashMap<String, ProgressBar>>>,
33}
34
35impl ProgressBarCallback {
36    pub fn new() -> Self {
37        Self {
38            bars: Arc::new(MultiProgress::new()),
39            progress_bars: Arc::new(Mutex::new(HashMap::new())),
40        }
41    }
42}
43
44impl Default for ProgressBarCallback {
45    fn default() -> Self {
46        Self::new()
47    }
48}
49
50impl Clone for ProgressBarCallback {
51    fn clone(&self) -> Self {
52        Self {
53            bars: self.bars.clone(),
54            progress_bars: self.progress_bars.clone(),
55        }
56    }
57}
58
59#[async_trait]
60impl ProgressCallback for ProgressBarCallback {
61    async fn on_file_start(&self, file_name: &str, file_size: u64) {
62        // 检查是否已经存在相同名称的进度条
63        {
64            let bars = self.progress_bars.lock().unwrap();
65            if bars.contains_key(file_name) {
66                return; // 如果已存在,不再创建新的进度条
67            }
68        }
69        
70        let bar = ProgressBar::new(file_size);
71        let style = ProgressStyle::default_bar().template(BAR_STYLE).unwrap();
72        bar.set_style(style);
73        bar.set_message(file_name.to_string());
74        self.bars.add(bar.clone());
75        
76        let mut bars = self.progress_bars.lock().unwrap();
77        bars.insert(file_name.to_string(), bar);
78    }
79    
80    async fn on_file_progress(&self, file_name: &str, downloaded: u64, _total: u64) {
81        let bars = self.progress_bars.lock().unwrap();
82        if let Some(bar) = bars.get(file_name) {
83            bar.set_position(downloaded);
84        }
85    }
86    
87    async fn on_file_complete(&self, file_name: &str) {
88        let mut bars = self.progress_bars.lock().unwrap();
89        if let Some(bar) = bars.remove(file_name) {
90            bar.finish();
91        }
92    }
93    
94    async fn on_file_error(&self, file_name: &str, _error: &str) {
95        let mut bars = self.progress_bars.lock().unwrap();
96        if let Some(bar) = bars.remove(file_name) {
97            bar.abandon();
98        }
99    }
100}
101
102/// 简单的回调实现,只打印进度信息
103#[derive(Clone)]
104pub struct SimpleCallback;
105
106#[async_trait]
107impl ProgressCallback for SimpleCallback {
108    async fn on_file_start(&self, file_name: &str, file_size: u64) {
109        println!("开始下载: {} (大小: {} bytes)", file_name, file_size);
110    }
111    
112    async fn on_file_progress(&self, file_name: &str, downloaded: u64, total: u64) {
113        let percent = if total > 0 {
114            (downloaded as f64 / total as f64 * 100.0) as u32
115        } else {
116            0
117        };
118        println!("下载中: {} - {}% ({} / {} bytes)", file_name, percent, downloaded, total);
119    }
120    
121    async fn on_file_complete(&self, file_name: &str) {
122        println!("下载完成: {}", file_name);
123    }
124    
125    async fn on_file_error(&self, file_name: &str, error: &str) {
126        eprintln!("下载失败: {} - 错误: {}", file_name, error);
127    }
128}
129
130const FILES_URL: &str = "https://modelscope.cn/api/v1/models/<model_id>/repo/files?Recursive=true";
131const DOWNLOAD_URL: &str = "https://modelscope.cn/models/<model_id>/resolve/master/<path>";
132const LOGIN_URL: &str = "https://modelscope.cn/api/v1/login";
133const DIR: &str = ".modelscope";
134const COOKIES_FILE: &str = "cookies";
135
136const UA: (&str, &str) = (
137    "User-Agent",
138    "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/89.0.4389.90 Safari/537.36",
139);
140pub struct ModelScope;
141
142#[derive(Debug, Deserialize)]
143struct ModelScopeResponse {
144    #[serde(rename = "Code")]
145    #[allow(unused)]
146    code: i64,
147    #[serde(rename = "Success")]
148    success: bool,
149    #[serde(rename = "Message")]
150    message: String,
151    #[serde(rename = "Data")]
152    data: Option<ModelScopeResponseData>,
153}
154
155#[derive(Debug, Deserialize)]
156struct ModelScopeResponseData {
157    #[serde(rename = "Files")]
158    files: Vec<RepoFile>,
159}
160#[derive(Debug, Deserialize)]
161struct RepoFile {
162    #[serde(rename = "Name")]
163    name: String,
164    #[serde(rename = "Path")]
165    path: String,
166    #[serde(rename = "Size")]
167    size: u64,
168    #[serde(rename = "Sha256")]
169    #[allow(unused)]
170    sha256: String,
171    #[serde(rename = "Type")]
172    r#type: String,
173}
174
175const BAR_STYLE: &str = "{msg:<30} {bar} {decimal_bytes:<10} / {decimal_total_bytes:<10} {decimal_bytes_per_sec:<12} {percent:<3}%  {eta_precise}";
176
177impl ModelScope {
178    async fn get_client() -> anyhow::Result<reqwest::Client> {
179        let client = reqwest::Client::builder().connect_timeout(std::time::Duration::from_secs(10));
180        let mut default_headers = reqwest::header::HeaderMap::new();
181        if let Some(cookies) = Self::get_cookies()? {
182            default_headers.insert("Cookie", cookies.parse()?);
183        }
184        let client = client.default_headers(default_headers);
185        Ok(client.build()?)
186    }
187
188    pub async fn download(model_id: &str, save_dir: impl Into<PathBuf>) -> anyhow::Result<()> {
189        Self::download_with_callback(model_id, save_dir, ProgressBarCallback::default()).await
190    }
191
192    pub async fn download_with_callback<C: ProgressCallback + Clone + 'static>(
193        model_id: &str,
194        save_dir: impl Into<PathBuf>,
195        callback: C,
196    ) -> anyhow::Result<()> {
197        // Model root dir
198        let save_dir = save_dir.into();
199        fs::create_dir_all(&save_dir)?;
200
201        // Model save dir, like <save_dir>/<model_id>
202        let model_dir = save_dir.join(model_id);
203
204        println!();
205        println!("Downloading model {} to: {}", model_id, model_dir.display());
206        println!();
207
208        fs::create_dir_all(&model_dir)?;
209
210        let files_url = FILES_URL.replace("<model_id>", model_id);
211
212        let client = Arc::new(Self::get_client().await?);
213
214        let resp = client.get(files_url).send().await?;
215
216        if !resp.status().is_success() {
217            bail!(
218                "Failed to get model files: {}\nTip: Maybe the model ID is incorrect or login is required",
219                resp.text().await?
220            );
221        }
222
223        let response = resp.json::<ModelScopeResponse>().await?;
224        if !response.success {
225            bail!("Failed to get model files: {}", response.message);
226        }
227
228        let data = response.data.unwrap();
229        let repo_files = data.files;
230
231        // Add the incoming model save path to the known model paths
232        // This is used when using the list command
233        Config::append_save_dir(&save_dir)?;
234
235        let mut tasks = Vec::new();
236
237        for repo_file in repo_files.into_iter().filter(|f| f.r#type == "blob") {
238            let model_id = model_id.to_string();
239            let client = client.clone();
240            let save_dir = model_dir.clone();
241            let callback = callback.clone();
242
243            let task = tokio::spawn(async move {
244                let res = Self::download_file_with_callback(client, model_id, repo_file, save_dir, callback).await;
245                if let Err(e) = res {
246                    bail!("Error downloading file: {}", e);
247                }
248                Ok::<(), anyhow::Error>(())
249            });
250
251            tasks.push(task);
252        }
253        for task in tasks {
254            task.await??;
255        }
256
257        Ok(())
258    }
259
260    async fn download_file(
261        client: Arc<reqwest::Client>,
262        model_id: String,
263        repo_file: RepoFile,
264        save_dir: PathBuf,
265        bar: ProgressBar,
266    ) -> anyhow::Result<()> {
267        let path = &repo_file.path;
268        let name = &repo_file.name;
269
270        bar.set_message(name.clone());
271
272        let file_path = save_dir.join(path);
273        if let Some(parent) = file_path.parent() {
274            fs::create_dir_all(parent)?;
275        }
276
277        let mut existing_size = 0;
278        let mut file_options = fs::OpenOptions::new();
279        file_options.write(true).create(true);
280
281        if file_path.exists() {
282            if let Ok(metadata) = fs::metadata(&file_path) {
283                existing_size = metadata.len();
284                file_options.append(true);
285            }
286        } else {
287            file_options.truncate(true);
288        }
289
290        let mut file = BufWriter::new(file_options.open(&file_path)?);
291
292        // Set progress bar initial position
293        bar.set_position(existing_size);
294        bar.set_length(repo_file.size);
295
296        let url = DOWNLOAD_URL
297            .replace("<model_id>", &model_id)
298            .replace("<path>", path);
299
300        let mut rb = client.get(&url).header(UA.0, UA.1);
301
302        // Already downloaded, just return ok.
303        // If file size equal repo file size, maybe check sha256
304        // But I think the probability of files having the same number of bytes is relatively low, so I won't check here. 🙊
305        if existing_size == repo_file.size {
306            bar.finish();
307            return Ok(());
308        }
309
310        // Resume download
311        if existing_size < repo_file.size {
312            rb = rb.header("Range", format!("bytes={}-", existing_size));
313        }
314
315        let response = rb.send().await?;
316
317        let status = response.status();
318
319        // Server doesn't support resume download, re-downloading from beginning
320        // Or existing file size is larger than repo size, re-downloading from beginning
321        if status == reqwest::StatusCode::OK && existing_size > 0 || existing_size > repo_file.size
322        {
323            file.rewind()?;
324            file.get_ref().set_len(0)?;
325            bar.set_position(0);
326        }
327
328        // If status is not success or partial content, bail
329        if !response.status().is_success()
330            && response.status() != reqwest::StatusCode::PARTIAL_CONTENT
331        {
332            bail!(
333                "Failed to download file {}: HTTP {}",
334                name,
335                response.status()
336            );
337        }
338
339        let mut stream = response.bytes_stream();
340
341        while let Some(item) = stream.next().await {
342            let chunk = item?;
343            file.write_all(&chunk)?;
344            bar.inc(chunk.len() as u64);
345        }
346
347        file.flush()?;
348
349        bar.finish();
350
351        Ok(())
352    }
353
354    async fn download_file_with_callback<C: ProgressCallback + Clone + 'static>(
355        client: Arc<reqwest::Client>,
356        model_id: String,
357        repo_file: RepoFile,
358        save_dir: PathBuf,
359        callback: C,
360    ) -> anyhow::Result<()> {
361        let path = &repo_file.path;
362        let name = &repo_file.name;
363
364        let file_path = save_dir.join(path);
365        if let Some(parent) = file_path.parent() {
366            fs::create_dir_all(parent)?;
367        }
368
369        let mut existing_size = 0;
370        let mut file_options = fs::OpenOptions::new();
371        file_options.write(true).create(true);
372
373        if file_path.exists() {
374            if let Ok(metadata) = fs::metadata(&file_path) {
375                existing_size = metadata.len();
376                file_options.append(true);
377            }
378        } else {
379            file_options.truncate(true);
380        }
381
382        let mut file = BufWriter::new(file_options.open(&file_path)?);
383
384        let url = DOWNLOAD_URL
385            .replace("<model_id>", &model_id)
386            .replace("<path>", path);
387
388        // Now we call on_file_start after checking if file exists
389        callback.on_file_start(name, repo_file.size).await;
390
391        let mut rb = client.get(&url).header(UA.0, UA.1);
392
393        // Already downloaded, just return ok.
394        if existing_size == repo_file.size {
395            callback.on_file_progress(name, repo_file.size, repo_file.size).await;
396            callback.on_file_complete(name).await;
397            return Ok(());
398        }
399
400        // Resume download
401        if existing_size < repo_file.size {
402            rb = rb.header("Range", format!("bytes={}-", existing_size));
403        }
404
405        let response = rb.send().await?;
406
407        let status = response.status();
408
409        // Server doesn't support resume download, re-downloading from beginning
410        // Or existing file size is larger than repo size, re-downloading from beginning
411        if status == reqwest::StatusCode::OK && existing_size > 0 || existing_size > repo_file.size
412        {
413            file.rewind()?;
414            file.get_ref().set_len(0)?;
415            existing_size = 0;
416            callback.on_file_progress(name, 0, repo_file.size).await;
417        }
418
419        // If status is not success or partial content, bail
420        if !response.status().is_success()
421            && response.status() != reqwest::StatusCode::PARTIAL_CONTENT
422        {
423            let error_msg = format!("HTTP {}", response.status());
424            callback.on_file_error(name, &error_msg).await;
425            bail!(
426                "Failed to download file {}: HTTP {}",
427                name,
428                response.status()
429            );
430        }
431
432        let mut stream = response.bytes_stream();
433
434        while let Some(item) = stream.next().await {
435            let chunk = item?;
436            file.write_all(&chunk)?;
437            existing_size += chunk.len() as u64;
438            callback.on_file_progress(name, existing_size, repo_file.size).await;
439        }
440
441        file.flush()?;
442
443        callback.on_file_complete(name).await;
444
445        Ok(())
446    }
447
448    pub async fn login(token: &str) -> anyhow::Result<()> {
449        println!("Logging in...");
450        let client = Self::get_client().await?;
451        let resp = client
452            .post(LOGIN_URL)
453            .json(&serde_json::json!({
454                "AccessToken": token
455            }))
456            .send()
457            .await?;
458
459        let status = resp.status();
460
461        if !status.is_success() {
462            bail!("Failed to login: {}", resp.text().await?);
463        }
464
465        let cookies: serde_json::Value = resp
466            .cookies()
467            .map(|cookie| (cookie.name().to_string(), cookie.value().to_string()))
468            .collect();
469
470        let dir = Dirs::config_dir()?;
471
472        let cookies_file = dir.join(COOKIES_FILE);
473        fs::write(cookies_file, cookies.to_string())?;
474
475        println!("Login successful.");
476
477        Ok(())
478    }
479
480    pub async fn download_single_file(
481        model_id: &str,
482        file_path: &str,
483        save_dir: impl Into<PathBuf>,
484    ) -> anyhow::Result<()> {
485        Self::download_single_file_with_callback(model_id, file_path, save_dir, ProgressBarCallback::default()).await
486    }
487
488    pub async fn download_single_file_with_callback<C: ProgressCallback + Clone + 'static>(
489        model_id: &str,
490        file_path: &str,
491        save_dir: impl Into<PathBuf>,
492        callback: C,
493    ) -> anyhow::Result<()> {
494        let save_dir = save_dir.into();
495        fs::create_dir_all(&save_dir)?;
496
497        let model_dir = save_dir.join(model_id);
498        fs::create_dir_all(&model_dir)?;
499
500        println!();
501        println!(
502            "Downloading file {} from model {} to: {}",
503            file_path,
504            model_id,
505            model_dir.display()
506        );
507        println!();
508
509        let files_url = FILES_URL.replace("<model_id>", model_id);
510
511        let client = Arc::new(Self::get_client().await?);
512
513        // Get file list from API
514        let resp = client.get(files_url).send().await?;
515
516        if !resp.status().is_success() {
517            bail!(
518                "Failed to get model files: {}\nTip: Maybe the model ID is incorrect or login is required",
519                resp.text().await?
520            );
521        }
522
523        let response = resp.json::<ModelScopeResponse>().await?;
524        if !response.success {
525            bail!("Failed to get model files: {}", response.message);
526        }
527
528        let data = response.data.unwrap();
529        let repo_files = data.files;
530
531        // Find the target file
532        let repo_file = repo_files
533            .into_iter()
534            .find(|f| f.path == file_path && f.r#type == "blob")
535            .ok_or_else(|| anyhow::anyhow!("File not found in model: {}", file_path))?;
536
537        Self::download_file_with_callback(client, model_id.to_string(), repo_file, model_dir, callback).await?;
538
539        Ok(())
540    }
541
542    fn get_cookies() -> anyhow::Result<Option<String>> {
543        let cookies_file = Dirs::config_dir()?.join(COOKIES_FILE);
544
545        if cookies_file.exists() {
546            let cookies = fs::read_to_string(cookies_file)?;
547            let cookies: serde_json::Value = serde_json::from_str(&cookies)?;
548
549            let cookies = cookies
550                .as_object()
551                .context("Failed to parse cookies")?
552                .iter()
553                .map(|(k, v)| format!("{}={}", k, v.as_str().unwrap_or_default()))
554                .collect::<Vec<_>>()
555                .join("; ");
556            return Ok(Some(cookies));
557        }
558
559        Ok(None)
560    }
561
562    pub async fn logout() -> anyhow::Result<()> {
563        // May just delete cookies file
564        let cookies_file = Dirs::config_dir()?.join(COOKIES_FILE);
565        if cookies_file.exists() {
566            fs::remove_file(cookies_file)?;
567        }
568        println!("Logged out.");
569        Ok(())
570    }
571
572    pub async fn list() -> anyhow::Result<Vec<(String, String)>> {
573        // Known model save paths
574        let model_paths = Config::get_known_save_dirs()?;
575
576        let mut models = vec![];
577        for model_path in model_paths {
578            for dir in fs::read_dir(model_path)? {
579                let dir = dir?;
580                // This level is the model vendor, and the next level is the model name
581                if dir.file_type()?.is_dir() {
582                    for entry in fs::read_dir(dir.path())? {
583                        let entry = entry?;
584                        if entry.file_type()?.is_dir() {
585                            models.push((
586                                // Model ID
587                                format!(
588                                    "{}/{}",
589                                    dir.file_name().display(),
590                                    entry.file_name().display()
591                                ),
592                                // Model path
593                                dir.path().display().to_string(),
594                            ));
595                        }
596                    }
597                }
598            }
599        }
600        Ok(models)
601    }
602}
603
604struct Dirs {}
605impl Dirs {
606    fn base_dir() -> anyhow::Result<PathBuf> {
607        let base_dir = home_dir()
608            .context("Failed to get home directory")?
609            .join(DIR);
610        if !base_dir.exists() {
611            fs::create_dir_all(&base_dir)?;
612        }
613        Ok(base_dir)
614    }
615
616    fn config_dir() -> anyhow::Result<PathBuf> {
617        let config_dir = Self::base_dir()?.join("config");
618        if !config_dir.exists() {
619            fs::create_dir_all(&config_dir)?;
620        }
621        Ok(config_dir)
622    }
623
624    #[allow(unused)]
625    pub fn model_dir() -> anyhow::Result<PathBuf> {
626        let model_dir = Self::base_dir()?.join("models");
627        if !model_dir.exists() {
628            fs::create_dir_all(&model_dir)?;
629        }
630        Ok(model_dir)
631    }
632}
633
634#[derive(Debug, Serialize, Deserialize)]
635struct Config {
636    known_save_dirs: Vec<PathBuf>,
637}
638
639impl Config {
640    const KNOWN_SAVE_DIRS: &'static str = "known_save_dirs";
641    fn append_save_dir(dir: &Path) -> anyhow::Result<()> {
642        let f = Dirs::config_dir()?.join(Self::KNOWN_SAVE_DIRS);
643
644        // Get existing known save dirs
645        let mut known_save_dirs = Self::get_known_save_dirs()?;
646
647        // Canonicalize the directory
648        let dir = dir.canonicalize()?;
649
650        if known_save_dirs.contains(&dir) {
651            return Ok(());
652        }
653
654        known_save_dirs.push(dir);
655        fs::write(
656            f,
657            known_save_dirs
658                .iter()
659                .filter(|p| p.exists())
660                .map(|p| p.display().to_string())
661                .filter(|s| !s.trim().is_empty())
662                .collect::<Vec<String>>()
663                .join("\n"),
664        )?;
665
666        Ok(())
667    }
668
669    fn get_known_save_dirs() -> anyhow::Result<Vec<PathBuf>> {
670        let config_dir = Dirs::config_dir()?;
671        if !config_dir.exists() {
672            fs::create_dir_all(&config_dir)?;
673            return Ok(vec![]);
674        }
675
676        let f = config_dir.join(Self::KNOWN_SAVE_DIRS);
677        if !f.exists() {
678            return Ok(vec![]);
679        }
680
681        let paths = fs::read_to_string(f)?
682            .lines()
683            .map(PathBuf::from)
684            // Filter out non-existent paths
685            // These paths will be cleaned up when append_save_dir is called
686            .filter(|p| p.exists())
687            .collect::<Vec<_>>();
688
689        Ok(paths)
690    }
691}