1use std::{cmp::min, error::Error, sync::Arc, time::Duration};
2
3use reqwest::{
4 self, Url,
5 header::{CONTENT_DISPOSITION, CONTENT_RANGE, HeaderMap, RANGE},
6};
7use futures_util::{StreamExt, future};
9use indicatif::{HumanBytes, MultiProgress, ProgressBar, ProgressStyle};
10use tokio::fs::File;
11use tokio::io::{AsyncSeekExt, AsyncWriteExt, SeekFrom};
12use tokio::sync::Mutex;
13
14#[derive(Debug)]
15struct Chunk {
16 start_byte: u64,
17 end_byte: u64,
18 downloaded: u64,
19}
20
21impl Chunk {
22 fn new(start_byte: u64, end_byte: u64) -> Self {
23 Self {
24 start_byte,
25 end_byte,
26 downloaded: 0,
27 }
28 }
29}
30
31#[derive(Debug)]
32pub struct Downloader {
33 url: String,
34 headers: HeaderMap,
35 file_size: Option<u64>,
36 filename: Option<String>,
37 chunks: Arc<Mutex<Vec<Chunk>>>, }
39
40pub trait HeaderUtils {
41 fn extract_filename(&self) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
48
49 fn extract_file_size(&self) -> Result<u64, Box<dyn std::error::Error + Send + Sync>>;
61}
62
63impl HeaderUtils for HeaderMap {
66 fn extract_filename(&self) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
67 if let Some(disposition) = &self.get(CONTENT_DISPOSITION) {
68 let value = disposition.to_str()?;
69 if let Some(filename) = value.split("filename=").nth(1) {
70 return Ok(filename.trim_matches('"').to_string());
71 }
72 }
73 return Err(Box::from("Unable to extract filename".to_owned()));
74 }
76
77 fn extract_file_size(&self) -> Result<u64, Box<dyn std::error::Error + Send + Sync>> {
78 let &cr = &self
79 .get(CONTENT_RANGE)
80 .ok_or_else(|| Box::<dyn Error + Send + Sync>::from("Content_range not found"))?;
81 let content_range =
82 cr.to_str()?.split("/").into_iter().last().ok_or_else(|| {
83 Box::<dyn Error + Send + Sync>::from("Invalid Content_range_format")
84 })?;
85 Ok(content_range.parse()?)
86 }
87}
88
89pub fn extract_filename_from_url(url: &str) -> Option<String> {
94 if let Ok(parsed_url) = Url::parse(&url) {
95 if let Some(segment) = parsed_url.path_segments().and_then(|s| s.last()) {
96 if !segment.is_empty() {
97 return Some(segment.to_string());
98 }
99 }
100 }
101 return None;
102}
103
104impl Downloader {
105 pub fn new<S: Into<String>>(url: S) -> Self {
106 Self {
107 url: url.into(),
108 headers: HeaderMap::new(),
109 file_size: None,
110 filename: None,
111 chunks: Arc::new(Mutex::new(Vec::new())),
112 }
113 }
114
115 async fn get_chunk(
116 &self,
117 range: Option<(u64, u64)>,
118 progress_bar: Option<ProgressBar>,
119 file: Option<Arc<Mutex<File>>>,
120 chunk_index: Option<usize>,
121 ) -> Result<u64, Box<dyn std::error::Error + Send + Sync>> {
122 let client = reqwest::Client::new();
123 let mut builder = client.get(&self.url);
124 if let Some((start, end)) = range {
125 builder = builder.header(RANGE, &format!("bytes={start}-{end}"));
126 }
127 let response = builder.send().await?;
128 let mut stream = response.bytes_stream();
129 let mut downloaded = 0u64;
130 let mut chunk_data = Vec::new();
131
132 while let Some(chunk) = stream.next().await {
134 let chunk = chunk?;
135 chunk_data.extend_from_slice(&chunk);
136 downloaded += chunk.len() as u64;
137
138 if let Some(bar) = &progress_bar {
139 bar.inc(chunk.len() as u64);
140 }
141
142 if let Some(idx) = chunk_index {
144 let mut chunks = self.chunks.lock().await;
145 if idx < chunks.len() {
146 chunks[idx].downloaded = downloaded;
147 }
148 }
149 }
150
151 if let (Some(file), Some((start, _))) = (file, range) {
153 let mut f = file.lock().await;
154 f.seek(SeekFrom::Start(start)).await?;
155 f.write_all(&chunk_data).await?;
156 }
157
158 if let Some(bar) = progress_bar {
159 bar.finish();
160 }
161
162 Ok(downloaded)
163 }
164
165 pub async fn download(
173 &mut self,
174 path: &str,
175 threads: Option<u8>,
176 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
177 let client = reqwest::Client::new();
178 let threads: u64 = match threads {
179 Some(0) => 8,
180 Some(count) => count as u64,
181 None => 8,
182 };
183 let response = client
185 .get(&self.url)
186 .header(RANGE, "bytes=0-0")
187 .send()
188 .await?;
189 self.headers = response.headers().clone().to_owned();
190
191 let filename = match &self.headers.extract_filename() {
192 Ok(filename) => filename.to_owned(),
193 Err(_) => match extract_filename_from_url(&self.url) {
194 Some(filename) => filename,
195 None => "download.bin".to_owned(),
196 },
197 };
198 println!("⛔filename: {filename}");
199 self.filename = Some(format!("{path}/{filename}").replace("//", "/"));
201
202 if let Ok(file_size) = self.headers.extract_file_size() {
203 self.file_size = Some(file_size);
204 println!("⛔file size: {}", HumanBytes(file_size));
205 } else {
206 println!("⛔ Unable to determine the file size. skipping threads")
207 }
208
209 let file = Arc::new(Mutex::new(
210 File::create(self.filename.as_ref().unwrap()).await?,
211 ));
212
213 if let Some(file_size) = self.file_size {
215 file.lock().await.set_len(file_size).await?;
218
219 let mut start = 0;
220 let byte_size = file_size / threads;
221
222 while start < file_size {
224 let end = min(start + byte_size, file_size);
225 self.chunks.lock().await.push(Chunk::new(start, end));
226 start = end + 1;
227 }
228
229 let num_chunks = self.chunks.lock().await.len();
230 println!("Created {} chunks for download", num_chunks);
231
232 let multi_progress = Arc::new(MultiProgress::new());
233
234 let mut tasks = Vec::new();
236 let chunks_clone = Arc::clone(&self.chunks);
237
238 for i in 0..num_chunks {
239 let chunks = Arc::clone(&chunks_clone);
240 let file_clone = Arc::clone(&file);
241 let url = self.url.clone();
242 let multi_progress_clone = Arc::clone(&multi_progress);
243
244 let task = tokio::spawn(async move {
245 let (start, end) = {
247 let chunks_guard = chunks.lock().await;
248 if i >= chunks_guard.len() {
249 return Err("Chunk index out of bounds".into());
250 }
251 (chunks_guard[i].start_byte, chunks_guard[i].end_byte)
252 };
253
254 let chunk_size = end - start + 1;
256 let progress_bar = multi_progress_clone.add(ProgressBar::new(chunk_size));
257 progress_bar.set_style(ProgressStyle::with_template(
258 &format!(
259 "[Chunk {}] {{wide_bar:40.cyan/blue}} {{binary_bytes}}/{{binary_total_bytes}} ({{percent}}%)",
260 i+1
262 )
263 ).unwrap());
264
265 let downloader = Downloader {
267 url,
268 headers: HeaderMap::new(),
269 file_size: None,
270 filename: None,
271 chunks: chunks,
272 };
273
274 downloader
276 .get_chunk(
277 Some((start, end)),
278 Some(progress_bar),
279 Some(file_clone),
280 Some(i),
281 )
282 .await
283 });
284
285 tasks.push(task);
286 }
287
288 println!("Starting concurrent downloads...");
290 let results = future::try_join_all(tasks)
291 .await
292 .map_err(|e| format!("Task join error: {}", e))?;
293
294 let total_downloaded: u64 = results
295 .into_iter()
296 .collect::<Result<Vec<_>, _>>()?
297 .into_iter()
298 .sum();
299
300 println!(
301 "Download completed! Total bytes: {}",
302 HumanBytes(total_downloaded)
303 );
304 } else {
305 let file_clone = Arc::clone(&file);
306 let bar = ProgressBar::new_spinner();
307 bar.enable_steady_tick(Duration::from_millis(100));
308 println!("");
309 bar.set_style(
310 ProgressStyle::with_template(&format!(
311 "{{spinner:.cyan}} {:?} ({{binary_bytes}} downloaded)",
312 self.filename.as_ref().unwrap()
313 ))
314 .unwrap()
315 .tick_chars("⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏"),
316 );
317 let _ = self
318 .get_chunk(None, Some(bar), Some(file_clone), None)
319 .await;
320 }
321
322 Ok(())
323 }
324}