1use anyhow::{Context, bail};
2use async_trait::async_trait;
3use futures_util::StreamExt;
4use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::env::home_dir;
8use std::fs;
9use std::io::{BufWriter, Seek, Write};
10use std::path::{Path, PathBuf};
11use std::sync::{Arc, Mutex};
12
13#[async_trait]
15pub trait ProgressCallback: Send + Sync {
16 async fn on_file_start(&self, file_name: &str, file_size: u64);
18
19 async fn on_file_progress(&self, file_name: &str, downloaded: u64, total: u64);
21
22 async fn on_file_complete(&self, file_name: &str);
24
25 async fn on_file_error(&self, file_name: &str, error: &str);
27}
28
29pub struct ProgressBarCallback {
31 bars: Arc<MultiProgress>,
32 progress_bars: Arc<Mutex<HashMap<String, ProgressBar>>>,
33}
34
35impl ProgressBarCallback {
36 pub fn new() -> Self {
37 Self {
38 bars: Arc::new(MultiProgress::new()),
39 progress_bars: Arc::new(Mutex::new(HashMap::new())),
40 }
41 }
42}
43
44impl Default for ProgressBarCallback {
45 fn default() -> Self {
46 Self::new()
47 }
48}
49
50impl Clone for ProgressBarCallback {
51 fn clone(&self) -> Self {
52 Self {
53 bars: self.bars.clone(),
54 progress_bars: self.progress_bars.clone(),
55 }
56 }
57}
58
59#[async_trait]
60impl ProgressCallback for ProgressBarCallback {
61 async fn on_file_start(&self, file_name: &str, file_size: u64) {
62 {
64 let bars = self.progress_bars.lock().unwrap();
65 if bars.contains_key(file_name) {
66 return; }
68 }
69
70 let bar = ProgressBar::new(file_size);
71 let style = ProgressStyle::default_bar().template(BAR_STYLE).unwrap();
72 bar.set_style(style);
73 bar.set_message(file_name.to_string());
74 self.bars.add(bar.clone());
75
76 let mut bars = self.progress_bars.lock().unwrap();
77 bars.insert(file_name.to_string(), bar);
78 }
79
80 async fn on_file_progress(&self, file_name: &str, downloaded: u64, _total: u64) {
81 let bars = self.progress_bars.lock().unwrap();
82 if let Some(bar) = bars.get(file_name) {
83 bar.set_position(downloaded);
84 }
85 }
86
87 async fn on_file_complete(&self, file_name: &str) {
88 let mut bars = self.progress_bars.lock().unwrap();
89 if let Some(bar) = bars.remove(file_name) {
90 bar.finish();
91 }
92 }
93
94 async fn on_file_error(&self, file_name: &str, _error: &str) {
95 let mut bars = self.progress_bars.lock().unwrap();
96 if let Some(bar) = bars.remove(file_name) {
97 bar.abandon();
98 }
99 }
100}
101
102#[derive(Clone)]
104pub struct SimpleCallback;
105
106#[async_trait]
107impl ProgressCallback for SimpleCallback {
108 async fn on_file_start(&self, file_name: &str, file_size: u64) {
109 println!("开始下载: {} (大小: {} bytes)", file_name, file_size);
110 }
111
112 async fn on_file_progress(&self, file_name: &str, downloaded: u64, total: u64) {
113 let percent = if total > 0 {
114 (downloaded as f64 / total as f64 * 100.0) as u32
115 } else {
116 0
117 };
118 println!("下载中: {} - {}% ({} / {} bytes)", file_name, percent, downloaded, total);
119 }
120
121 async fn on_file_complete(&self, file_name: &str) {
122 println!("下载完成: {}", file_name);
123 }
124
125 async fn on_file_error(&self, file_name: &str, error: &str) {
126 eprintln!("下载失败: {} - 错误: {}", file_name, error);
127 }
128}
129
130const FILES_URL: &str = "https://modelscope.cn/api/v1/models/<model_id>/repo/files?Recursive=true";
131const DOWNLOAD_URL: &str = "https://modelscope.cn/models/<model_id>/resolve/master/<path>";
132const LOGIN_URL: &str = "https://modelscope.cn/api/v1/login";
133const DIR: &str = ".modelscope";
134const COOKIES_FILE: &str = "cookies";
135
136const UA: (&str, &str) = (
137 "User-Agent",
138 "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/89.0.4389.90 Safari/537.36",
139);
140pub struct ModelScope;
141
142#[derive(Debug, Deserialize)]
143struct ModelScopeResponse {
144 #[serde(rename = "Code")]
145 #[allow(unused)]
146 code: i64,
147 #[serde(rename = "Success")]
148 success: bool,
149 #[serde(rename = "Message")]
150 message: String,
151 #[serde(rename = "Data")]
152 data: Option<ModelScopeResponseData>,
153}
154
155#[derive(Debug, Deserialize)]
156struct ModelScopeResponseData {
157 #[serde(rename = "Files")]
158 files: Vec<RepoFile>,
159}
160#[derive(Debug, Deserialize)]
161struct RepoFile {
162 #[serde(rename = "Name")]
163 name: String,
164 #[serde(rename = "Path")]
165 path: String,
166 #[serde(rename = "Size")]
167 size: u64,
168 #[serde(rename = "Sha256")]
169 #[allow(unused)]
170 sha256: String,
171 #[serde(rename = "Type")]
172 r#type: String,
173}
174
175const BAR_STYLE: &str = "{msg:<30} {bar} {decimal_bytes:<10} / {decimal_total_bytes:<10} {decimal_bytes_per_sec:<12} {percent:<3}% {eta_precise}";
176
177impl ModelScope {
178 async fn get_client() -> anyhow::Result<reqwest::Client> {
179 let client = reqwest::Client::builder().connect_timeout(std::time::Duration::from_secs(10));
180 let mut default_headers = reqwest::header::HeaderMap::new();
181 if let Some(cookies) = Self::get_cookies()? {
182 default_headers.insert("Cookie", cookies.parse()?);
183 }
184 let client = client.default_headers(default_headers);
185 Ok(client.build()?)
186 }
187
188 pub async fn download(model_id: &str, save_dir: impl Into<PathBuf>) -> anyhow::Result<()> {
189 Self::download_with_callback(model_id, save_dir, ProgressBarCallback::default()).await
190 }
191
192 pub async fn download_with_callback<C: ProgressCallback + Clone + 'static>(
193 model_id: &str,
194 save_dir: impl Into<PathBuf>,
195 callback: C,
196 ) -> anyhow::Result<()> {
197 let save_dir = save_dir.into();
199 fs::create_dir_all(&save_dir)?;
200
201 let model_dir = save_dir.join(model_id);
203
204 println!();
205 println!("Downloading model {} to: {}", model_id, model_dir.display());
206 println!();
207
208 fs::create_dir_all(&model_dir)?;
209
210 let files_url = FILES_URL.replace("<model_id>", model_id);
211
212 let client = Arc::new(Self::get_client().await?);
213
214 let resp = client.get(files_url).send().await?;
215
216 if !resp.status().is_success() {
217 bail!(
218 "Failed to get model files: {}\nTip: Maybe the model ID is incorrect or login is required",
219 resp.text().await?
220 );
221 }
222
223 let response = resp.json::<ModelScopeResponse>().await?;
224 if !response.success {
225 bail!("Failed to get model files: {}", response.message);
226 }
227
228 let data = response.data.unwrap();
229 let repo_files = data.files;
230
231 Config::append_save_dir(&save_dir)?;
234
235 let mut tasks = Vec::new();
236
237 for repo_file in repo_files.into_iter().filter(|f| f.r#type == "blob") {
238 let model_id = model_id.to_string();
239 let client = client.clone();
240 let save_dir = model_dir.clone();
241 let callback = callback.clone();
242
243 let task = tokio::spawn(async move {
244 let res = Self::download_file_with_callback(client, model_id, repo_file, save_dir, callback).await;
245 if let Err(e) = res {
246 bail!("Error downloading file: {}", e);
247 }
248 Ok::<(), anyhow::Error>(())
249 });
250
251 tasks.push(task);
252 }
253 for task in tasks {
254 task.await??;
255 }
256
257 Ok(())
258 }
259
260 async fn download_file(
261 client: Arc<reqwest::Client>,
262 model_id: String,
263 repo_file: RepoFile,
264 save_dir: PathBuf,
265 bar: ProgressBar,
266 ) -> anyhow::Result<()> {
267 let path = &repo_file.path;
268 let name = &repo_file.name;
269
270 bar.set_message(name.clone());
271
272 let file_path = save_dir.join(path);
273 if let Some(parent) = file_path.parent() {
274 fs::create_dir_all(parent)?;
275 }
276
277 let mut existing_size = 0;
278 let mut file_options = fs::OpenOptions::new();
279 file_options.write(true).create(true);
280
281 if file_path.exists() {
282 if let Ok(metadata) = fs::metadata(&file_path) {
283 existing_size = metadata.len();
284 file_options.append(true);
285 }
286 } else {
287 file_options.truncate(true);
288 }
289
290 let mut file = BufWriter::new(file_options.open(&file_path)?);
291
292 bar.set_position(existing_size);
294 bar.set_length(repo_file.size);
295
296 let url = DOWNLOAD_URL
297 .replace("<model_id>", &model_id)
298 .replace("<path>", path);
299
300 let mut rb = client.get(&url).header(UA.0, UA.1);
301
302 if existing_size == repo_file.size {
306 bar.finish();
307 return Ok(());
308 }
309
310 if existing_size < repo_file.size {
312 rb = rb.header("Range", format!("bytes={}-", existing_size));
313 }
314
315 let response = rb.send().await?;
316
317 let status = response.status();
318
319 if status == reqwest::StatusCode::OK && existing_size > 0 || existing_size > repo_file.size
322 {
323 file.rewind()?;
324 file.get_ref().set_len(0)?;
325 bar.set_position(0);
326 }
327
328 if !response.status().is_success()
330 && response.status() != reqwest::StatusCode::PARTIAL_CONTENT
331 {
332 bail!(
333 "Failed to download file {}: HTTP {}",
334 name,
335 response.status()
336 );
337 }
338
339 let mut stream = response.bytes_stream();
340
341 while let Some(item) = stream.next().await {
342 let chunk = item?;
343 file.write_all(&chunk)?;
344 bar.inc(chunk.len() as u64);
345 }
346
347 file.flush()?;
348
349 bar.finish();
350
351 Ok(())
352 }
353
354 async fn download_file_with_callback<C: ProgressCallback + Clone + 'static>(
355 client: Arc<reqwest::Client>,
356 model_id: String,
357 repo_file: RepoFile,
358 save_dir: PathBuf,
359 callback: C,
360 ) -> anyhow::Result<()> {
361 let path = &repo_file.path;
362 let name = &repo_file.name;
363
364 let file_path = save_dir.join(path);
365 if let Some(parent) = file_path.parent() {
366 fs::create_dir_all(parent)?;
367 }
368
369 let mut existing_size = 0;
370 let mut file_options = fs::OpenOptions::new();
371 file_options.write(true).create(true);
372
373 if file_path.exists() {
374 if let Ok(metadata) = fs::metadata(&file_path) {
375 existing_size = metadata.len();
376 file_options.append(true);
377 }
378 } else {
379 file_options.truncate(true);
380 }
381
382 let mut file = BufWriter::new(file_options.open(&file_path)?);
383
384 let url = DOWNLOAD_URL
385 .replace("<model_id>", &model_id)
386 .replace("<path>", path);
387
388 callback.on_file_start(name, repo_file.size).await;
390
391 let mut rb = client.get(&url).header(UA.0, UA.1);
392
393 if existing_size == repo_file.size {
395 callback.on_file_progress(name, repo_file.size, repo_file.size).await;
396 callback.on_file_complete(name).await;
397 return Ok(());
398 }
399
400 if existing_size < repo_file.size {
402 rb = rb.header("Range", format!("bytes={}-", existing_size));
403 }
404
405 let response = rb.send().await?;
406
407 let status = response.status();
408
409 if status == reqwest::StatusCode::OK && existing_size > 0 || existing_size > repo_file.size
412 {
413 file.rewind()?;
414 file.get_ref().set_len(0)?;
415 existing_size = 0;
416 callback.on_file_progress(name, 0, repo_file.size).await;
417 }
418
419 if !response.status().is_success()
421 && response.status() != reqwest::StatusCode::PARTIAL_CONTENT
422 {
423 let error_msg = format!("HTTP {}", response.status());
424 callback.on_file_error(name, &error_msg).await;
425 bail!(
426 "Failed to download file {}: HTTP {}",
427 name,
428 response.status()
429 );
430 }
431
432 let mut stream = response.bytes_stream();
433
434 while let Some(item) = stream.next().await {
435 let chunk = item?;
436 file.write_all(&chunk)?;
437 existing_size += chunk.len() as u64;
438 callback.on_file_progress(name, existing_size, repo_file.size).await;
439 }
440
441 file.flush()?;
442
443 callback.on_file_complete(name).await;
444
445 Ok(())
446 }
447
448 pub async fn login(token: &str) -> anyhow::Result<()> {
449 println!("Logging in...");
450 let client = Self::get_client().await?;
451 let resp = client
452 .post(LOGIN_URL)
453 .json(&serde_json::json!({
454 "AccessToken": token
455 }))
456 .send()
457 .await?;
458
459 let status = resp.status();
460
461 if !status.is_success() {
462 bail!("Failed to login: {}", resp.text().await?);
463 }
464
465 let cookies: serde_json::Value = resp
466 .cookies()
467 .map(|cookie| (cookie.name().to_string(), cookie.value().to_string()))
468 .collect();
469
470 let dir = Dirs::config_dir()?;
471
472 let cookies_file = dir.join(COOKIES_FILE);
473 fs::write(cookies_file, cookies.to_string())?;
474
475 println!("Login successful.");
476
477 Ok(())
478 }
479
480 pub async fn download_single_file(
481 model_id: &str,
482 file_path: &str,
483 save_dir: impl Into<PathBuf>,
484 ) -> anyhow::Result<()> {
485 Self::download_single_file_with_callback(model_id, file_path, save_dir, ProgressBarCallback::default()).await
486 }
487
488 pub async fn download_single_file_with_callback<C: ProgressCallback + Clone + 'static>(
489 model_id: &str,
490 file_path: &str,
491 save_dir: impl Into<PathBuf>,
492 callback: C,
493 ) -> anyhow::Result<()> {
494 let save_dir = save_dir.into();
495 fs::create_dir_all(&save_dir)?;
496
497 let model_dir = save_dir.join(model_id);
498 fs::create_dir_all(&model_dir)?;
499
500 println!();
501 println!(
502 "Downloading file {} from model {} to: {}",
503 file_path,
504 model_id,
505 model_dir.display()
506 );
507 println!();
508
509 let files_url = FILES_URL.replace("<model_id>", model_id);
510
511 let client = Arc::new(Self::get_client().await?);
512
513 let resp = client.get(files_url).send().await?;
515
516 if !resp.status().is_success() {
517 bail!(
518 "Failed to get model files: {}\nTip: Maybe the model ID is incorrect or login is required",
519 resp.text().await?
520 );
521 }
522
523 let response = resp.json::<ModelScopeResponse>().await?;
524 if !response.success {
525 bail!("Failed to get model files: {}", response.message);
526 }
527
528 let data = response.data.unwrap();
529 let repo_files = data.files;
530
531 let repo_file = repo_files
533 .into_iter()
534 .find(|f| f.path == file_path && f.r#type == "blob")
535 .ok_or_else(|| anyhow::anyhow!("File not found in model: {}", file_path))?;
536
537 Self::download_file_with_callback(client, model_id.to_string(), repo_file, model_dir, callback).await?;
538
539 Ok(())
540 }
541
542 fn get_cookies() -> anyhow::Result<Option<String>> {
543 let cookies_file = Dirs::config_dir()?.join(COOKIES_FILE);
544
545 if cookies_file.exists() {
546 let cookies = fs::read_to_string(cookies_file)?;
547 let cookies: serde_json::Value = serde_json::from_str(&cookies)?;
548
549 let cookies = cookies
550 .as_object()
551 .context("Failed to parse cookies")?
552 .iter()
553 .map(|(k, v)| format!("{}={}", k, v.as_str().unwrap_or_default()))
554 .collect::<Vec<_>>()
555 .join("; ");
556 return Ok(Some(cookies));
557 }
558
559 Ok(None)
560 }
561
562 pub async fn logout() -> anyhow::Result<()> {
563 let cookies_file = Dirs::config_dir()?.join(COOKIES_FILE);
565 if cookies_file.exists() {
566 fs::remove_file(cookies_file)?;
567 }
568 println!("Logged out.");
569 Ok(())
570 }
571
572 pub async fn list() -> anyhow::Result<Vec<(String, String)>> {
573 let model_paths = Config::get_known_save_dirs()?;
575
576 let mut models = vec![];
577 for model_path in model_paths {
578 for dir in fs::read_dir(model_path)? {
579 let dir = dir?;
580 if dir.file_type()?.is_dir() {
582 for entry in fs::read_dir(dir.path())? {
583 let entry = entry?;
584 if entry.file_type()?.is_dir() {
585 models.push((
586 format!(
588 "{}/{}",
589 dir.file_name().display(),
590 entry.file_name().display()
591 ),
592 dir.path().display().to_string(),
594 ));
595 }
596 }
597 }
598 }
599 }
600 Ok(models)
601 }
602}
603
604struct Dirs {}
605impl Dirs {
606 fn base_dir() -> anyhow::Result<PathBuf> {
607 let base_dir = home_dir()
608 .context("Failed to get home directory")?
609 .join(DIR);
610 if !base_dir.exists() {
611 fs::create_dir_all(&base_dir)?;
612 }
613 Ok(base_dir)
614 }
615
616 fn config_dir() -> anyhow::Result<PathBuf> {
617 let config_dir = Self::base_dir()?.join("config");
618 if !config_dir.exists() {
619 fs::create_dir_all(&config_dir)?;
620 }
621 Ok(config_dir)
622 }
623
624 #[allow(unused)]
625 pub fn model_dir() -> anyhow::Result<PathBuf> {
626 let model_dir = Self::base_dir()?.join("models");
627 if !model_dir.exists() {
628 fs::create_dir_all(&model_dir)?;
629 }
630 Ok(model_dir)
631 }
632}
633
634#[derive(Debug, Serialize, Deserialize)]
635struct Config {
636 known_save_dirs: Vec<PathBuf>,
637}
638
639impl Config {
640 const KNOWN_SAVE_DIRS: &'static str = "known_save_dirs";
641 fn append_save_dir(dir: &Path) -> anyhow::Result<()> {
642 let f = Dirs::config_dir()?.join(Self::KNOWN_SAVE_DIRS);
643
644 let mut known_save_dirs = Self::get_known_save_dirs()?;
646
647 let dir = dir.canonicalize()?;
649
650 if known_save_dirs.contains(&dir) {
651 return Ok(());
652 }
653
654 known_save_dirs.push(dir);
655 fs::write(
656 f,
657 known_save_dirs
658 .iter()
659 .filter(|p| p.exists())
660 .map(|p| p.display().to_string())
661 .filter(|s| !s.trim().is_empty())
662 .collect::<Vec<String>>()
663 .join("\n"),
664 )?;
665
666 Ok(())
667 }
668
669 fn get_known_save_dirs() -> anyhow::Result<Vec<PathBuf>> {
670 let config_dir = Dirs::config_dir()?;
671 if !config_dir.exists() {
672 fs::create_dir_all(&config_dir)?;
673 return Ok(vec![]);
674 }
675
676 let f = config_dir.join(Self::KNOWN_SAVE_DIRS);
677 if !f.exists() {
678 return Ok(vec![]);
679 }
680
681 let paths = fs::read_to_string(f)?
682 .lines()
683 .map(PathBuf::from)
684 .filter(|p| p.exists())
687 .collect::<Vec<_>>();
688
689 Ok(paths)
690 }
691}