use super::Result;
use crate::NovelTTSError;
use std::{
io::SeekFrom,
path::{Path, PathBuf},
};
use tokio::fs;
use tokio::{
io::{AsyncSeekExt, AsyncWriteExt},
select,
};
use tokio_util::sync::CancellationToken;
pub static CACHE_DIR: &str = ".novel-tts";
pub fn get_cache_dir() -> Result<PathBuf> {
Ok(dirs::home_dir()
.map(|home| home.join(CACHE_DIR))
.ok_or_else(|| anyhow::anyhow!("No home directory found"))?)
}
pub async fn download_from_url<F>(url: &str, dest: &PathBuf, mut on_progress: F) -> Result<()>
where
F: FnMut(u64, u64),
{
if let Some(parent) = dest.parent()
&& !parent.exists()
{
fs::create_dir_all(parent).await?;
}
let path = format!("{}.download", dest.display());
let (mut downloaded, mut file) = if let Ok(metadata) = std::fs::metadata(&path) {
let mut file = fs::File::options().append(true).open(&path).await?;
file.seek(SeekFrom::Start(metadata.len())).await?;
(metadata.len(), file)
} else {
(0, fs::File::create(&path).await?)
};
let client = reqwest::Client::new();
let mut client = client.get(url);
if downloaded > 0 {
client = client.header(reqwest::header::RANGE, format!("bytes={}-", downloaded));
}
let mut res = client.send().await?.error_for_status()?;
let content_length = res.content_length().unwrap_or(0) + downloaded;
on_progress(downloaded, content_length);
while let Some(data) = res.chunk().await? {
file.write_all(&data).await?;
downloaded += data.len() as u64;
on_progress(downloaded, content_length);
}
if downloaded != content_length {
return Err(anyhow::anyhow!("Download failed").into());
}
fs::rename(path, dest).await?;
Ok(())
}
#[derive(Debug, Clone)]
pub struct Download {
pub path: PathBuf,
pub url: String,
pub token: CancellationToken,
}
impl Download {
pub fn new<P: AsRef<Path>>(path: P, url: &str) -> Self {
Self {
path: path.as_ref().to_path_buf(),
token: CancellationToken::new(),
url: url.to_string(),
}
}
pub fn is_downloaded(&self) -> bool {
self.path.exists()
}
pub fn cancel_download(&self) {
self.token.cancel();
}
pub fn download<F, E>(&mut self, on_progress: F, mut on_error: E)
where
F: FnMut(u64, u64) + Send + 'static,
E: FnMut(NovelTTSError) + Send + 'static,
{
let path = self.path.clone();
let cancel_token = CancellationToken::new();
self.token = cancel_token.clone();
let url = self.url.clone();
tokio::spawn(async move {
select! {
_ = cancel_token.cancelled() => {
on_error(NovelTTSError::Cancel("download".into()));
}
res = download_from_url(&url, &path, on_progress) =>{
if let Err(e) = res {
on_error(e);
}
}
}
});
}
pub async fn async_download<F>(&mut self, on_progress: F) -> Result<()>
where
F: FnMut(u64, u64) + Send + 'static,
{
let cancel_token = CancellationToken::new();
self.token = cancel_token.clone();
select! {
_ = self.token.cancelled() => {
Ok(())
}
res = download_from_url(&self.url, &self.path, on_progress) =>{
res
}
}
}
}