use std::sync::Condvar;
use std::sync::{
atomic::{AtomicBool, AtomicU64, Ordering},
Arc, Mutex,
};
use crate::loader::LoaderEvent;
use crate::reader::AppendableDataWrapper;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DownloadStatus {
NotStarted,
Downloading,
Completed,
Aborted,
}
pub struct Downloader {
data: Arc<Mutex<Box<dyn AppendableDataWrapper + Send + 'static>>>,
condvar: Arc<Condvar>,
status: Arc<Mutex<DownloadStatus>>,
total_bytes: Arc<AtomicU64>,
downloaded_bytes: Arc<AtomicU64>,
download_called: Arc<AtomicBool>,
should_abort: Arc<AtomicBool>,
download_completed: Arc<AtomicBool>,
thread_handle: Arc<Mutex<Option<tokio::task::JoinHandle<Result<(), ()>>>>>,
callback: Arc<Mutex<Option<Box<dyn Fn(LoaderEvent) + Send + 'static>>>>,
}
impl Downloader {
pub fn new<T: AppendableDataWrapper + Send + 'static>(data: T) -> Self {
Self {
data: Arc::new(Mutex::new(Box::new(data))),
condvar: Arc::new(Condvar::new()),
status: Arc::new(Mutex::new(DownloadStatus::NotStarted)),
total_bytes: Arc::new(AtomicU64::new(0)),
downloaded_bytes: Arc::new(AtomicU64::new(0)),
download_called: Arc::new(AtomicBool::new(false)),
should_abort: Arc::new(AtomicBool::new(false)),
download_completed: Arc::new(AtomicBool::new(false)),
thread_handle: Arc::new(Mutex::new(None)),
callback: Arc::new(Mutex::new(None)),
}
}
pub fn status(&self) -> DownloadStatus {
*self.status.lock().unwrap()
}
pub fn total_bytes(&self) -> u64 {
self.total_bytes.load(Ordering::Relaxed)
}
pub fn downloaded_bytes(&self) -> u64 {
self.downloaded_bytes.load(Ordering::Relaxed)
}
pub fn data(&self) -> Arc<Mutex<Box<dyn AppendableDataWrapper + Send + 'static>>> {
Arc::clone(&self.data)
}
pub fn condvar(&self) -> Arc<Condvar> {
Arc::clone(&self.condvar)
}
pub fn download_completed(&self) -> Arc<AtomicBool> {
Arc::clone(&self.download_completed)
}
pub fn set_callback<F>(&self, callback: F)
where
F: Fn(LoaderEvent) + Send + 'static,
{
let mut cb = self.callback.lock().unwrap();
*cb = Some(Box::new(callback));
}
pub async fn download(
&self,
url: &str,
headers: Option<Vec<(String, String)>>,
) -> Result<(), ()> {
if self.download_called.swap(true, Ordering::SeqCst) {
panic!("download() can only be called once");
}
{
let mut status = self.status.lock().unwrap();
*status = DownloadStatus::Downloading;
}
let data = Arc::clone(&self.data);
let condvar = Arc::clone(&self.condvar);
let status = Arc::clone(&self.status);
let total_bytes = Arc::clone(&self.total_bytes);
let downloaded_bytes = Arc::clone(&self.downloaded_bytes);
let should_abort = Arc::clone(&self.should_abort);
let download_completed = Arc::clone(&self.download_completed);
let callback = Arc::clone(&self.callback);
use futures_util::StreamExt;
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()
.unwrap();
let mut request_builder = client.get(url);
if let Some(hdrs) = headers {
for (key, value) in hdrs {
request_builder = request_builder.header(key, value);
}
}
let response = match request_builder.send().await {
Ok(resp) => resp,
Err(e) => {
eprintln!("Failed to send request: {}", e);
let mut s = status.lock().unwrap();
*s = DownloadStatus::Aborted;
if let Some(ref cb) = *callback.lock().unwrap() {
cb(LoaderEvent::Aborted);
}
return Err(());
}
};
let content_length = response
.headers()
.get(reqwest::header::CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(0);
total_bytes.store(content_length, Ordering::Relaxed);
data.lock().unwrap().set_capacity(content_length as usize);
let handle = tokio::task::spawn(async move {
let mut stream = response.bytes_stream();
while let Some(chunk_result) = stream.next().await {
if should_abort.load(Ordering::Relaxed) {
let mut s = status.lock().unwrap();
*s = DownloadStatus::Aborted;
if let Some(ref cb) = *callback.lock().unwrap() {
cb(LoaderEvent::Aborted);
}
return Err(());
}
match chunk_result {
Ok(chunk) => {
let mut data_lock = data.lock().unwrap();
data_lock.append_data(&chunk);
drop(data_lock);
condvar.notify_all();
downloaded_bytes.fetch_add(chunk.len() as u64, Ordering::Relaxed);
}
Err(e) => {
eprintln!("Error reading chunk: {}", e);
let mut s = status.lock().unwrap();
*s = DownloadStatus::Aborted;
if let Some(ref cb) = *callback.lock().unwrap() {
cb(LoaderEvent::Aborted);
}
return Err(());
}
}
}
data.lock().unwrap().complete();
let mut s = status.lock().unwrap();
*s = DownloadStatus::Completed;
download_completed.store(true, Ordering::Release);
condvar.notify_all();
if let Some(ref cb) = *callback.lock().unwrap() {
cb(LoaderEvent::Completed);
}
return Ok(());
});
let mut th = self.thread_handle.lock().unwrap();
*th = Some(handle);
Ok(())
}
pub fn abort(&self) -> Result<(), DownloadStatus> {
let mut status = self.status.lock().unwrap();
if *status != DownloadStatus::Downloading {
return Err(status.clone());
}
self.should_abort.store(true, Ordering::SeqCst);
let mut th = self.thread_handle.lock().unwrap();
if let Some(handle) = th.take() {
let _ = handle.abort();
}
*status = DownloadStatus::Aborted;
Ok(())
}
}
impl Drop for Downloader {
fn drop(&mut self) {
let mut status = self.status.lock().unwrap();
self.should_abort.store(true, Ordering::SeqCst);
let mut th = self.thread_handle.lock().unwrap();
if let Some(handle) = th.take() {
let _ = handle.abort();
}
*status = DownloadStatus::Aborted;
}
}