use std::collections::HashMap;
use std::fs::File;
use std::io::{Read, Seek, SeekFrom, Write};
use std::path::PathBuf;
use std::sync::Arc;
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
use minimo::showln;
use reqwest::header;
use reqwest::{Client, Response};
use tokio::sync::Mutex;
use tokio::task;
const ONE_KB: u64 = 1024;
pub struct Download {
pub url: String,
pub filename: String,
pub memory: u64,
pub threads: u64,
pub network: Network,
pub progress: Progress,
}
impl Default for Download {
fn default() -> Download {
Download {
url: "".to_string(),
filename: "".to_string(),
memory: 256,
threads: 1,
network: Network {
..Default::default()
},
progress: Progress {
..Default::default()
},
}
}
}
impl Download {
pub async fn get(self) {
let content_length_resp = self.network.get_content_length(&self.url).await;
match content_length_resp {
Some(content_length) => {
let children = self.spawn_threads(content_length).await;
futures::future::join_all(children).await;
}
None => println!("Content length is not present for this URL. Support for this type of hosted file will be added in the future."),
}
}
fn calculate_ranges(
threads: u64,
content_length: u64,
max_buffer_size: u64,
mut progress: Progress,
) -> (Progress, Vec<(String, u64, u64, u64, u64)>) {
let mut range_start = 0;
let mut ranges = vec![];
let chunk_size = content_length / threads - 1;
for thread in 0..threads {
let mut range_end = chunk_size + range_start;
if thread == (threads - 1) {
range_end = content_length;
}
let thread_number = thread + 1;
let range: String = format!("bytes={}-{}", range_start, range_end);
let range_to_process: u64 = range_end - range_start;
let buffer_chunks: u64 = range_to_process / max_buffer_size;
let chunk_remainder: u64 = range_to_process % max_buffer_size;
ranges.push((range, range_start, thread_number, buffer_chunks, chunk_remainder));
progress.add(range_to_process, &thread_number);
range_start = range_start + chunk_size + 1;
}
(progress, ranges)
}
async fn spawn_threads(self, content_length: u64) -> Vec<task::JoinHandle<()>> {
let max_buffer_size = ONE_KB * self.memory;
let (progress, ranges) = Self::calculate_ranges(self.threads, content_length, max_buffer_size, self.progress);
let progress_arc = Arc::new(Mutex::new(progress));
let network_arc = Arc::new(self.network);
let filename_arc = Arc::new(self.filename);
showln!(yellow_bold, "╭─ 📦 ",cyan_bold,&filename_arc);
let mut children = vec![];
for (range, range_start, thread_number, buffer_chunks, chunk_remainder) in ranges {
let network_ref = network_arc.clone();
let progress_ref = progress_arc.clone();
let filename_ref = filename_arc.clone();
let url_ref = self.url.clone();
children.push(task::spawn(async move {
let mut file_handle = File::create(PathBuf::from( filename_ref.as_str())).unwrap();
file_handle.seek(SeekFrom::Start(range_start)).unwrap();
let mut file_range_resp = network_ref.make_request(&url_ref, range).await;
while let Some(chunk) = file_range_resp.chunk().await.unwrap() {
file_handle.write_all(&chunk).unwrap();
file_handle.flush().unwrap();
progress_ref.lock().await.inc(chunk.len() as u64, &thread_number);
}
if chunk_remainder != 0 {
progress_ref.lock().await.set_position(chunk_remainder, &thread_number);
}
progress_ref.lock().await.finish(&thread_number);
}));
}
progress_arc.lock().await.join_and_clear();
children
}
}
pub struct Network {
pub client: Client,
}
impl Default for Network {
fn default() -> Network {
Network {
client: Client::new(),
}
}
}
impl Network {
pub async fn make_request(&self, url: &str, range: String) -> Response {
self.client
.get(url)
.header(header::RANGE, range)
.send()
.await
.expect("Could not send request.")
}
pub async fn get_content_length(&self, url: &str) -> Option<u64> {
self.make_request(url, "".to_string()).await.content_length()
}
}
pub struct Progress {
pub multi_progress: MultiProgress,
pub progress_bars: HashMap<u64, ProgressBar>,
}
impl Default for Progress {
fn default() -> Progress {
Progress {
multi_progress: MultiProgress::new(),
progress_bars: HashMap::new(),
}
}
}
impl Progress {
pub fn add(&mut self, range: u64, thread_number: &u64) {
let pb = self.multi_progress.add(ProgressBar::new(range));
let style = ProgressStyle::default_bar()
.template("|{percent:3}% {spinner:.white}|{bar:35.white/white}| {bytes:2} | {eta:2}")
.unwrap()
.progress_chars("█▓▒ ")
.tick_chars("⣿⠿⠟⢿⡿⣻⣽⣾⣷⣯⣟⡿⢿⠿⠻⠟⠋⠙⠹⡁⢁⠄⠂⠂")
;
pb.set_style(style);
self.progress_bars.insert(*thread_number, pb);
}
pub fn inc(&self, amount: u64, thread_number: &u64) {
if let Some(pb) = self.progress_bars.get(thread_number) {
pb.inc(amount);
}
}
pub fn set_position(&self, amount: u64, thread_number: &u64) {
if let Some(pb) = self.progress_bars.get(thread_number) {
pb.set_position(amount);
}
}
pub fn finish(&self, thread_number: &u64) {
if let Some(pb) = self.progress_bars.get(thread_number) {
pb.finish_with_message("🚀 Done!");
}
}
pub fn join_and_clear(&self) {
self.multi_progress.clear().unwrap();
}
}