use crate::error::{Result, LateJavaCoreError};
use std::path::Path;
use std::collections::HashMap;
use futures::StreamExt;
#[derive(Debug, Clone)]
pub struct DownloadOptions {
pub url: String,
pub path: String,
pub length: Option<u64>,
pub folder: String,
pub r#type: Option<String>,
}
pub struct Downloader {
custom_headers: Option<HashMap<String, String>>,
}
impl Downloader {
pub fn new() -> Self {
Self {
custom_headers: None,
}
}
pub fn with_headers(headers: HashMap<String, String>) -> Self {
Self {
custom_headers: Some(headers),
}
}
pub async fn download_file(&self, url: &str, dir_path: &str, file_name: &str) -> Result<()> {
let full_path = Path::new(dir_path).join(file_name);
if !Path::new(dir_path).exists() {
std::fs::create_dir_all(dir_path)?;
}
let mut request = reqwest::Client::new().get(url);
if let Some(headers) = &self.custom_headers {
for (key, value) in headers {
request = request.header(key, value);
}
}
let response = request.send().await?;
let content_length = response.headers()
.get("content-length")
.and_then(|h| h.to_str().ok())
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(0);
let mut downloaded = 0u64;
let mut stream = response.bytes_stream();
let mut file = std::fs::File::create(&full_path)?;
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
downloaded += chunk.len() as u64;
use std::io::Write;
file.write_all(&chunk)?;
}
Ok(())
}
pub async fn download_file_multiple(
&self,
files: &[DownloadOptions],
total_size: u64,
limit: u32,
timeout: u64,
) -> Result<()> {
let limit = limit.min(files.len() as u32);
let mut completed = 0u32;
let mut downloaded = 0u64;
let mut queued = 0usize;
let start = std::time::Instant::now();
let mut before = 0u64;
let mut speeds = Vec::new();
let (tx, mut rx) = tokio::sync::mpsc::channel(100);
let download_next = |file: &DownloadOptions, tx: tokio::sync::mpsc::Sender<DownloadProgress>| {
let custom_headers = self.custom_headers.clone();
let file = file.clone();
async move {
if !Path::new(&file.folder).exists() {
std::fs::create_dir_all(&file.folder)?;
}
let mut request = reqwest::Client::new().get(&file.url);
if let Some(headers) = custom_headers {
for (key, value) in headers {
request = request.header(&key, &value);
}
}
let response = request.timeout(std::time::Duration::from_millis(timeout)).send().await?;
let mut stream = response.bytes_stream();
let mut file_handle = std::fs::File::create(&file.path)?;
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
use std::io::Write;
file_handle.write_all(&chunk)?;
tx.send(DownloadProgress {
downloaded: chunk.len() as u64,
total: total_size,
file_type: file.r#type.clone(),
}).await.map_err(|_| LateJavaCoreError::Download("Channel closed".to_string()))?;
}
Ok::<(), LateJavaCoreError>(())
}
};
let mut handles = Vec::new();
for i in 0..limit {
if queued >= files.len() {
break;
}
let file = files[queued].clone();
let tx_clone = tx.clone();
let handle = tokio::spawn(download_next(&file, tx_clone));
handles.push(handle);
queued += 1;
}
while let Some(progress) = rx.recv().await {
downloaded += progress.downloaded;
let duration = start.elapsed().as_secs_f64();
if duration > 0.0 {
let speed = downloaded as f64 / duration;
speeds.push(speed);
if speeds.len() > 5 {
speeds.remove(0);
}
let avg_speed = speeds.iter().sum::<f64>() / speeds.len() as f64;
let time_remaining = if avg_speed > 0.0 {
(total_size - downloaded) as f64 / avg_speed
} else {
0.0
};
}
}
for handle in handles {
handle.await??;
}
Ok(())
}
pub async fn check_url(&self, url: &str, timeout: u64) -> Result<Option<UrlCheckResult>> {
let mut request = reqwest::Client::new().head(url);
if let Some(headers) = &self.custom_headers {
for (key, value) in headers {
request = request.header(key, value);
}
}
let response = request.timeout(std::time::Duration::from_millis(timeout)).send().await?;
if response.status().is_success() {
let size = response.headers()
.get("content-length")
.and_then(|h| h.to_str().ok())
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(0);
Ok(Some(UrlCheckResult {
size,
status: response.status().as_u16(),
}))
} else {
Ok(None)
}
}
pub async fn check_mirror(&self, base_url: &str, mirrors: &[&str]) -> Result<Option<MirrorCheckResult>> {
for mirror in mirrors {
let test_url = format!("{}/{}", mirror, base_url);
if let Some(result) = self.check_url(&test_url, 10000).await? {
if result.status == 200 {
return Ok(Some(MirrorCheckResult {
url: test_url,
size: result.size,
status: result.status,
}));
}
}
}
Ok(None)
}
}
#[derive(Debug, Clone)]
pub struct DownloadProgress {
pub downloaded: u64,
pub total: u64,
pub file_type: Option<String>,
}
#[derive(Debug, Clone)]
pub struct UrlCheckResult {
pub size: u64,
pub status: u16,
}
#[derive(Debug, Clone)]
pub struct MirrorCheckResult {
pub url: String,
pub size: u64,
pub status: u16,
}