1use std::{
2 cmp::min,
3 error::Error,
4 sync::{
5 Arc,
6 atomic::{AtomicUsize, Ordering},
7 },
8 time::Duration,
9};
10
11use reqwest::{
12 self, Url,
13 header::{CONTENT_DISPOSITION, CONTENT_RANGE, HeaderMap, RANGE},
14};
15use futures_util::{StreamExt, future};
17use indicatif::{HumanBytes, MultiProgress, ProgressBar, ProgressStyle};
18use tokio::fs::File;
19use tokio::io::{AsyncSeekExt, AsyncWriteExt, SeekFrom};
20use tokio::sync::Mutex;
21
22const _1MB: u64 = 1024 * 1024;
23const _10MB: u64 = 10 * 1024 * 1024;
24
25#[derive(Debug)]
26struct Chunk {
27 start_byte: u64,
28 end_byte: u64,
29 downloaded: u64,
30}
31
32impl Chunk {
33 fn new(start_byte: u64, end_byte: u64) -> Self {
34 Self {
35 start_byte,
36 end_byte,
37 downloaded: 0,
38 }
39 }
40}
41
42#[derive(Debug)]
43pub struct Downloader {
44 url: String,
45 headers: HeaderMap,
46 file_size: Option<u64>,
47 filename: Option<String>,
48 chunks: Arc<Mutex<Vec<Chunk>>>, }
50
51pub trait HeaderUtils {
52 fn extract_filename(&self) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
59
60 fn extract_file_size(&self) -> Result<u64, Box<dyn std::error::Error + Send + Sync>>;
72}
73
74impl HeaderUtils for HeaderMap {
77 fn extract_filename(&self) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
78 if let Some(disposition) = &self.get(CONTENT_DISPOSITION) {
79 let value = disposition.to_str()?;
80 if let Some(filename) = value.split("filename=").nth(1) {
81 return Ok(filename.trim_matches('"').to_string());
82 }
83 }
84 Err(Box::from("Unable to extract filename".to_owned()))
85 }
87
88 fn extract_file_size(&self) -> Result<u64, Box<dyn std::error::Error + Send + Sync>> {
93 let &cr = &self
94 .get(CONTENT_RANGE)
95 .ok_or_else(|| Box::<dyn Error + Send + Sync>::from("Content_range not found"))?;
96 let content_range =
97 cr.to_str()?.split("/").last().ok_or_else(|| {
98 Box::<dyn Error + Send + Sync>::from("Invalid Content_range_format")
99 })?;
100 Ok(content_range.parse()?)
101 }
102}
103
104pub fn extract_filename_from_url(url: &str) -> Option<String> {
109 if let Ok(parsed_url) = Url::parse(url)
110 && let Some(segment) = parsed_url.path_segments().and_then(|mut s| s.next_back())
111 && !segment.is_empty()
112 {
113 return Some(segment.to_string());
114 }
115
116 None
117}
118
119impl Downloader {
120 pub fn new<S: Into<String>>(url: S) -> Self {
121 Self {
122 url: url.into(),
123 headers: HeaderMap::new(),
124 file_size: None,
125 filename: None,
126 chunks: Arc::new(Mutex::new(Vec::new())),
127 }
128 }
129
130 async fn get_chunk(
131 &self,
132 range: Option<(u64, u64)>,
133 progress_bar: Option<ProgressBar>,
134 file: Option<Arc<Mutex<File>>>,
135 chunk_index: Option<usize>,
136 ) -> Result<u64, Box<dyn std::error::Error + Send + Sync>> {
137 let client = reqwest::Client::new();
138 let mut builder = client.get(&self.url);
139 if let Some((start, end)) = range {
140 builder = builder.header(RANGE, &format!("bytes={start}-{end}"));
141 }
142 let response = builder.send().await?;
143 let mut stream = response.bytes_stream();
144 let mut downloaded = 0u64;
145 let mut chunk_data = Vec::new();
146
147 while let Some(chunk) = stream.next().await {
149 let chunk = chunk?;
150 chunk_data.extend_from_slice(&chunk);
151 downloaded += chunk.len() as u64;
152
153 if let Some(bar) = &progress_bar {
154 bar.inc(chunk.len() as u64);
155 }
156
157 if let Some(idx) = chunk_index {
159 let mut chunks = self.chunks.lock().await;
160 if idx < chunks.len() {
161 chunks[idx].downloaded = downloaded;
162 }
163 }
164 }
165
166 if let (Some(file), Some((start, _))) = (file, range) {
168 let mut f = file.lock().await;
169 f.seek(SeekFrom::Start(start)).await?;
170 f.write_all(&chunk_data).await?;
171 }
172
173 if let Some(bar) = progress_bar {
174 bar.finish();
175 }
176
177 Ok(downloaded)
178 }
179
180 pub async fn download(
196 &mut self,
197 path: &str,
198 filename: Option<String>,
199 threads: Option<u8>,
200 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
201 let client = reqwest::Client::new();
202 let threads: u64 = match threads {
203 Some(0) => 8,
204 Some(count) => count as u64,
205 None => 8,
206 };
207 let response = client
209 .get(&self.url)
210 .header(RANGE, "bytes=0-0")
211 .send()
212 .await?;
213 self.headers = response.headers().clone().to_owned();
214
215 let filename = if let Some(filename) = filename {
216 filename
217 } else {
218 match &self.headers.extract_filename() {
219 Ok(filename) => filename.to_owned(),
220 Err(_) => match extract_filename_from_url(&self.url) {
221 Some(filename) => filename,
222 None => "download.dat".to_owned(),
223 },
224 }
225 };
226 println!("Downloading \"{filename}\"");
227 self.filename = Some(format!("{path}/{filename}").replace("//", "/"));
228
229 if let Ok(file_size) = self.headers.extract_file_size() {
230 self.file_size = Some(file_size);
231 println!("file size: {}", HumanBytes(file_size));
232 } else {
233 println!("Unable to determine the file size. skipping threads");
234 }
235
236 let file = Arc::new(Mutex::new(
237 File::create(self.filename.as_ref().unwrap()).await?,
238 ));
239
240 if let Some(file_size) = self.file_size {
242 file.lock().await.set_len(file_size).await?;
245
246 let mut start = 0;
247 let thread_size = file_size / threads;
248 let mut byte_size = thread_size;
249
250 if file_size < _1MB {
252 println!("ℹ️ The file is smaller than 1 MB, so skipping threads.");
253 byte_size = file_size;
254 }
255
256 if thread_size > _10MB {
259 byte_size = _10MB
260 }
261
262 while start < file_size {
264 let end = min(start + byte_size, file_size);
265 self.chunks.lock().await.push(Chunk::new(start, end));
266 start = end + 1;
267 }
268
269 let num_chunks = self.chunks.lock().await.len();
270 println!("Created {} chunks for download", num_chunks);
271
272 let multi_progress = Arc::new(MultiProgress::new());
273
274 let mut tasks = Vec::new();
276 let chunks_clone = Arc::clone(&self.chunks);
277
278 let index = Arc::new(AtomicUsize::new(0));
282
283 let threads_to_spawn = std::cmp::min(threads as usize, num_chunks);
285
286 for _ in 0..threads_to_spawn {
287 let chunks = Arc::clone(&chunks_clone);
288 let file_clone = Arc::clone(&file);
289 let url = self.url.clone();
290 let multi_progress_clone = Arc::clone(&multi_progress);
291 let index_clone = Arc::clone(&index);
292
293 let task = tokio::spawn(async move {
294 let mut worker_total: u64 = 0;
295 loop {
296 let i = index_clone.fetch_add(1, Ordering::SeqCst);
298 if i >= num_chunks {
299 break;
300 }
301
302 let (start, end) = {
304 let chunks_guard = chunks.lock().await;
305 (chunks_guard[i].start_byte, chunks_guard[i].end_byte)
306 };
307
308 let chunk_size = end - start + 1;
310 let progress_bar = multi_progress_clone.add(ProgressBar::new(chunk_size));
311 progress_bar.set_style(ProgressStyle::with_template(
312 &format!(
313 "[Chunk {:03}] {{wide_bar:40.cyan/blue}} {{binary_bytes}}/{{binary_total_bytes}} ({{percent}}%)",
314 i + 1
316 )
317 ).unwrap());
318
319 let downloader = Downloader {
321 url: url.clone(),
322 headers: HeaderMap::new(),
323 file_size: None,
324 filename: None,
325 chunks: Arc::clone(&chunks),
326 };
327
328 let downloaded = downloader
330 .get_chunk(
331 Some((start, end)),
332 Some(progress_bar),
333 Some(Arc::clone(&file_clone)),
334 Some(i),
335 )
336 .await?;
337 worker_total += downloaded;
338 }
339
340 Ok::<u64, Box<dyn std::error::Error + Send + Sync>>(worker_total)
341 });
342
343 tasks.push(task);
344 }
345
346 println!("Starting concurrent downloads...");
348 let results = future::try_join_all(tasks)
349 .await
350 .map_err(|e| format!("Task join error: {}", e))?;
351
352 let total_downloaded: u64 = results
353 .into_iter()
354 .collect::<Result<Vec<_>, _>>()?
355 .into_iter()
356 .sum();
357
358 println!(
359 "Download completed! Total bytes: {}",
360 HumanBytes(total_downloaded)
361 );
362 } else {
363 let file_clone = Arc::clone(&file);
366 let bar = ProgressBar::new_spinner();
367 bar.enable_steady_tick(Duration::from_millis(100));
368 println!("\n");
369 bar.set_style(
370 ProgressStyle::with_template(&format!(
371 "{{spinner:.cyan}} {:?} ({{binary_bytes}} downloaded)",
372 self.filename.as_ref().unwrap()
373 ))
374 .unwrap()
375 .tick_chars("⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏"),
376 );
377 let _ = self
378 .get_chunk(None, Some(bar), Some(file_clone), None)
379 .await;
380 }
381
382 Ok(())
383 }
384}
385
386#[cfg(test)]
387mod tests {
388 use super::*;
389 use reqwest::header::HeaderMap;
390 use tokio::runtime::Runtime;
391
392 #[test]
393 fn test_extract_filename_from_url() {
394 let url = "https://example.com/path/to/file.txt";
395 assert_eq!(extract_filename_from_url(url), Some("file.txt".to_string()));
396 let url2 = "https://example.com/path/to/";
397 assert_eq!(extract_filename_from_url(url2), None);
398 }
399
400 #[test]
401 fn test_header_extract_filename() {
402 let mut headers = HeaderMap::new();
403 headers.insert(
404 reqwest::header::CONTENT_DISPOSITION,
405 "attachment; filename=\"myfile.bin\"".parse().unwrap(),
406 );
407 let name = headers.extract_filename().unwrap();
408 assert_eq!(name, "myfile.bin");
409 }
410
411 #[test]
412 fn test_header_extract_file_size() {
413 let mut headers = HeaderMap::new();
414 headers.insert(
415 reqwest::header::CONTENT_RANGE,
416 "bytes 0-0/12345".parse().unwrap(),
417 );
418 let size = headers.extract_file_size().unwrap();
419 assert_eq!(size, 12345u64);
420 }
421
422 #[test]
423 fn test_downloader_new_and_defaults() {
424 let d = Downloader::new("https://example.com/file");
425 assert_eq!(d.url, "https://example.com/file");
426 assert!(d.filename.is_none());
427 assert!(d.file_size.is_none());
428 }
429
430 #[test]
432 fn test_download_placeholder() {
433 let rt = Runtime::new().unwrap();
435 rt.block_on(async {
436 let mut downloader = Downloader::new("https://example.com/file");
437 downloader.headers = HeaderMap::new();
439 downloader.filename = Some("tmp_download.bin".to_string());
440 downloader.file_size = Some(0);
441
442 assert_eq!(downloader.filename.as_deref(), Some("tmp_download.bin"));
444 assert_eq!(downloader.file_size, Some(0));
445 });
446 }
447}