Skip to main content

novel_tts/
download.rs

1//! 文件下载管理模块
2//!
3//! 该模块提供了文件下载功能,支持断点续传和进度回调。
4
5use super::Result;
6use crate::NovelTTSError;
7use std::{
8    io::SeekFrom,
9    path::{Path, PathBuf},
10};
11use tokio::fs;
12use tokio::{
13    io::{AsyncSeekExt, AsyncWriteExt},
14    select,
15};
16use tokio_util::sync::CancellationToken;
17
18/// 缓存目录名称
19pub static CACHE_DIR: &str = ".novel-tts";
20
21/// 获取缓存目录路径
22///
23/// # 返回值
24/// 返回Result包装的PathBuf,包含缓存目录的路径
25pub fn get_cache_dir() -> Result<PathBuf> {
26    Ok(dirs::home_dir()
27        .map(|home| home.join(CACHE_DIR))
28        .ok_or_else(|| anyhow::anyhow!("No home directory found"))?)
29}
30
31/// 从URL下载文件
32///
33/// # 参数
34/// * `url` - 下载地址
35/// * `dest` - 目标文件路径
36/// * `on_progress` - 进度回调函数
37///
38/// # 返回值
39/// 返回Result,下载成功返回Ok,失败返回Err
40pub async fn download_from_url<F>(url: &str, dest: &PathBuf, mut on_progress: F) -> Result<()>
41where
42    F: FnMut(u64, u64),
43{
44    if let Some(parent) = dest.parent()
45        && !parent.exists()
46    {
47        fs::create_dir_all(parent).await?;
48    }
49
50    let path = format!("{}.download", dest.display());
51
52    let (mut downloaded, mut file) = if let Ok(metadata) = std::fs::metadata(&path) {
53        let mut file = fs::File::options().append(true).open(&path).await?;
54        file.seek(SeekFrom::Start(metadata.len())).await?;
55        (metadata.len(), file)
56    } else {
57        (0, fs::File::create(&path).await?)
58    };
59
60    let client = reqwest::Client::new();
61
62    let mut client = client.get(url);
63    if downloaded > 0 {
64        client = client.header(reqwest::header::RANGE, format!("bytes={}-", downloaded));
65    }
66
67    let mut res = client.send().await?.error_for_status()?;
68
69    let content_length = res.content_length().unwrap_or(0) + downloaded;
70
71    on_progress(downloaded, content_length);
72
73    while let Some(data) = res.chunk().await? {
74        file.write_all(&data).await?;
75        downloaded += data.len() as u64;
76        on_progress(downloaded, content_length);
77    }
78
79    if downloaded != content_length {
80        return Err(anyhow::anyhow!("Download failed").into());
81    }
82
83    fs::rename(path, dest).await?;
84    Ok(())
85}
86
87/// 下载信息结构体
88///
89/// 包含下载任务的相关信息,如文件路径、URL和取消令牌
90#[derive(Debug, Clone)]
91pub struct Download {
92    /// 文件路径
93    pub path: PathBuf,
94    /// 下载URL
95    pub url: String,
96    /// 取消令牌
97    pub token: CancellationToken,
98}
99
100impl Download {
101    /// 创建新的下载任务
102    ///
103    /// # 参数
104    /// * `path` - 文件保存路径
105    /// * `url` - 下载地址
106    ///
107    /// # 返回值
108    /// 返回新的Download实例
109    pub fn new<P: AsRef<Path>>(path: P, url: &str) -> Self {
110        Self {
111            path: path.as_ref().to_path_buf(),
112            token: CancellationToken::new(),
113            url: url.to_string(),
114        }
115    }
116
117    /// 检查文件是否已下载
118    ///
119    /// # 返回值
120    /// 如果文件已存在返回true,否则返回false
121    pub fn is_downloaded(&self) -> bool {
122        self.path.exists()
123    }
124
125    /// 取消下载任务
126    pub fn cancel_download(&self) {
127        self.token.cancel();
128    }
129
130    /// 启动下载任务(同步方式)
131    ///
132    /// # 参数
133    /// * `on_progress` - 进度回调函数
134    /// * `on_error` - 错误回调函数
135    ///
136    /// # 返回值
137    /// 返回Result,启动成功返回Ok,失败返回Err
138    pub fn download<F, E>(&mut self, on_progress: F, mut on_error: E)
139    where
140        F: FnMut(u64, u64) + Send + 'static,
141        E: FnMut(NovelTTSError) + Send + 'static,
142    {
143        let path = self.path.clone();
144        let cancel_token = CancellationToken::new();
145        self.token = cancel_token.clone();
146        let url = self.url.clone();
147
148        tokio::spawn(async move {
149            select! {
150                _ = cancel_token.cancelled() => {
151                    on_error(NovelTTSError::Cancel("download".into()));
152                }
153                res = download_from_url(&url, &path, on_progress) =>{
154                    if let Err(e) = res {
155                        on_error(e);
156                    }
157                }
158            }
159        });
160    }
161
162    /// 启动下载任务(异步方式)
163    ///
164    /// # 参数
165    /// * `on_progress` - 进度回调函数
166    ///
167    /// # 返回值
168    /// 返回Result,下载成功返回Ok,失败返回Err
169    pub async fn async_download<F>(&mut self, on_progress: F) -> Result<()>
170    where
171        F: FnMut(u64, u64) + Send + 'static,
172    {
173        let cancel_token = CancellationToken::new();
174        self.token = cancel_token.clone();
175        select! {
176            _ = self.token.cancelled() => {
177                Ok(())
178            }
179            res = download_from_url(&self.url, &self.path, on_progress) =>{
180                res
181            }
182        }
183    }
184}