1use std::{
2 cmp::min,
3 error::Error,
4 sync::{
5 Arc,
6 atomic::{AtomicUsize, Ordering},
7 },
8 time::Duration,
9};
10
11use reqwest::{
12 self, Url,
13 header::{CONTENT_DISPOSITION, CONTENT_RANGE, HeaderMap, RANGE},
14};
15use futures_util::{StreamExt, future};
17use indicatif::{HumanBytes, MultiProgress, ProgressBar, ProgressStyle};
18use tokio::fs::File;
19use tokio::io::{AsyncSeekExt, AsyncWriteExt, SeekFrom};
20use tokio::sync::Mutex;
21
22const _1MB: u64 = 1024 * 1024;
23const _10MB: u64 = 10 * 1024 * 1024;
24
25#[derive(Debug)]
26struct Chunk {
27 start_byte: u64,
28 end_byte: u64,
29 downloaded: u64,
30}
31
32impl Chunk {
33 fn new(start_byte: u64, end_byte: u64) -> Self {
34 Self {
35 start_byte,
36 end_byte,
37 downloaded: 0,
38 }
39 }
40}
41
42#[derive(Debug)]
43pub struct Downloader {
44 url: String,
45 headers: HeaderMap,
46 file_size: Option<u64>,
47 filename: Option<String>,
48 chunks: Arc<Mutex<Vec<Chunk>>>, }
50
51pub trait HeaderUtils {
52 fn extract_filename(&self) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
59
60 fn extract_file_size(&self) -> Result<u64, Box<dyn std::error::Error + Send + Sync>>;
72}
73
74impl HeaderUtils for HeaderMap {
77 fn extract_filename(&self) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
78 if let Some(disposition) = &self.get(CONTENT_DISPOSITION) {
79 let value = disposition.to_str()?;
80 if let Some(filename) = value.split("filename=").nth(1) {
81 return Ok(filename.trim_matches('"').to_string());
82 }
83 }
84 Err(Box::from("Unable to extract filename".to_owned()))
85 }
87
88 fn extract_file_size(&self) -> Result<u64, Box<dyn std::error::Error + Send + Sync>> {
93 let &cr = &self
94 .get(CONTENT_RANGE)
95 .ok_or_else(|| Box::<dyn Error + Send + Sync>::from("Content_range not found"))?;
96 let content_range =
97 cr.to_str()?.split("/").last().ok_or_else(|| {
98 Box::<dyn Error + Send + Sync>::from("Invalid Content_range_format")
99 })?;
100 Ok(content_range.parse()?)
101 }
102}
103
104pub fn extract_filename_from_url(url: &str) -> Option<String> {
109 if let Ok(parsed_url) = Url::parse(url)
110 && let Some(segment) = parsed_url.path_segments().and_then(|mut s| s.next_back())
111 && !segment.is_empty()
112 {
113 return Some(segment.to_string());
114 }
115
116 None
117}
118
119impl Downloader {
120 pub fn new<S: Into<String>>(url: S) -> Self {
121 Self {
122 url: url.into(),
123 headers: HeaderMap::new(),
124 file_size: None,
125 filename: None,
126 chunks: Arc::new(Mutex::new(Vec::new())),
127 }
128 }
129
130 async fn get_chunk(
131 &self,
132 range: Option<(u64, u64)>,
133 progress_bar: Option<ProgressBar>,
134 file: Option<Arc<Mutex<File>>>,
135 chunk_index: Option<usize>,
136 ) -> Result<u64, Box<dyn std::error::Error + Send + Sync>> {
137 let client = reqwest::Client::new();
138 let mut builder = client.get(&self.url);
139 if let Some((start, end)) = range {
140 builder = builder.header(RANGE, &format!("bytes={start}-{end}"));
141 }
142 let response = builder.send().await?;
143 let mut stream = response.bytes_stream();
144 let mut downloaded = 0u64;
145 let mut chunk_data = Vec::new();
146
147 while let Some(chunk) = stream.next().await {
149 let chunk = chunk?;
150 chunk_data.extend_from_slice(&chunk);
151 downloaded += chunk.len() as u64;
152
153 if let Some(bar) = &progress_bar {
154 bar.inc(chunk.len() as u64);
155 }
156
157 if let Some(idx) = chunk_index {
159 let mut chunks = self.chunks.lock().await;
160 if idx < chunks.len() {
161 chunks[idx].downloaded = downloaded;
162 }
163 }
164 }
165
166 if let (Some(file), Some((start, _))) = (file, range) {
168 let mut f = file.lock().await;
169 f.seek(SeekFrom::Start(start)).await?;
170 f.write_all(&chunk_data).await?;
171 }
172
173 if let Some(bar) = progress_bar {
174 bar.finish();
175 }
176
177 Ok(downloaded)
178 }
179
180 pub async fn download(
196 &mut self,
197 path: &str,
198 threads: Option<u8>,
199 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
200 let client = reqwest::Client::new();
201 let threads: u64 = match threads {
202 Some(0) => 8,
203 Some(count) => count as u64,
204 None => 8,
205 };
206 let response = client
208 .get(&self.url)
209 .header(RANGE, "bytes=0-0")
210 .send()
211 .await?;
212 self.headers = response.headers().clone().to_owned();
213
214 let filename = match &self.headers.extract_filename() {
215 Ok(filename) => filename.to_owned(),
216 Err(_) => match extract_filename_from_url(&self.url) {
217 Some(filename) => filename,
218 None => "download.bin".to_owned(),
219 },
220 };
221 println!("Downloading \"{filename}\"");
222 self.filename = Some(format!("{path}/{filename}").replace("//", "/"));
223
224 if let Ok(file_size) = self.headers.extract_file_size() {
225 self.file_size = Some(file_size);
226 println!("file size: {}", HumanBytes(file_size));
227 } else {
228 println!("Unable to determine the file size. skipping threads");
229 }
230
231 let file = Arc::new(Mutex::new(
232 File::create(self.filename.as_ref().unwrap()).await?,
233 ));
234
235 if let Some(file_size) = self.file_size {
237 file.lock().await.set_len(file_size).await?;
240
241 let mut start = 0;
242 let thread_size = file_size / threads;
243 let mut byte_size = thread_size;
244
245 if file_size < _1MB {
247 println!("ℹ️ The file is smaller than 1 MB, so skipping threads.");
248 byte_size = file_size;
249 }
250
251 if thread_size > _10MB {
254 byte_size = _10MB
255 }
256
257 while start < file_size {
259 let end = min(start + byte_size, file_size);
260 self.chunks.lock().await.push(Chunk::new(start, end));
261 start = end + 1;
262 }
263
264 let num_chunks = self.chunks.lock().await.len();
265 println!("Created {} chunks for download", num_chunks);
266
267 let multi_progress = Arc::new(MultiProgress::new());
268
269 let mut tasks = Vec::new();
271 let chunks_clone = Arc::clone(&self.chunks);
272
273 let index = Arc::new(AtomicUsize::new(0));
277
278 let threads_to_spawn = std::cmp::min(threads as usize, num_chunks);
280
281 for _ in 0..threads_to_spawn {
282 let chunks = Arc::clone(&chunks_clone);
283 let file_clone = Arc::clone(&file);
284 let url = self.url.clone();
285 let multi_progress_clone = Arc::clone(&multi_progress);
286 let index_clone = Arc::clone(&index);
287
288 let task = tokio::spawn(async move {
289 let mut worker_total: u64 = 0;
290 loop {
291 let i = index_clone.fetch_add(1, Ordering::SeqCst);
293 if i >= num_chunks {
294 break;
295 }
296
297 let (start, end) = {
299 let chunks_guard = chunks.lock().await;
300 (chunks_guard[i].start_byte, chunks_guard[i].end_byte)
301 };
302
303 let chunk_size = end - start + 1;
305 let progress_bar = multi_progress_clone.add(ProgressBar::new(chunk_size));
306 progress_bar.set_style(ProgressStyle::with_template(
307 &format!(
308 "[Chunk {:03}] {{wide_bar:40.cyan/blue}} {{binary_bytes}}/{{binary_total_bytes}} ({{percent}}%)",
309 i + 1
311 )
312 ).unwrap());
313
314 let downloader = Downloader {
316 url: url.clone(),
317 headers: HeaderMap::new(),
318 file_size: None,
319 filename: None,
320 chunks: Arc::clone(&chunks),
321 };
322
323 let downloaded = downloader
325 .get_chunk(
326 Some((start, end)),
327 Some(progress_bar),
328 Some(Arc::clone(&file_clone)),
329 Some(i),
330 )
331 .await?;
332 worker_total += downloaded;
333 }
334
335 Ok::<u64, Box<dyn std::error::Error + Send + Sync>>(worker_total)
336 });
337
338 tasks.push(task);
339 }
340
341 println!("Starting concurrent downloads...");
343 let results = future::try_join_all(tasks)
344 .await
345 .map_err(|e| format!("Task join error: {}", e))?;
346
347 let total_downloaded: u64 = results
348 .into_iter()
349 .collect::<Result<Vec<_>, _>>()?
350 .into_iter()
351 .sum();
352
353 println!(
354 "Download completed! Total bytes: {}",
355 HumanBytes(total_downloaded)
356 );
357 } else {
358 let file_clone = Arc::clone(&file);
361 let bar = ProgressBar::new_spinner();
362 bar.enable_steady_tick(Duration::from_millis(100));
363 println!("\n");
364 bar.set_style(
365 ProgressStyle::with_template(&format!(
366 "{{spinner:.cyan}} {:?} ({{binary_bytes}} downloaded)",
367 self.filename.as_ref().unwrap()
368 ))
369 .unwrap()
370 .tick_chars("⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏"),
371 );
372 let _ = self
373 .get_chunk(None, Some(bar), Some(file_clone), None)
374 .await;
375 }
376
377 Ok(())
378 }
379}
380
381#[cfg(test)]
382mod tests {
383 use super::*;
384 use reqwest::header::HeaderMap;
385 use tokio::runtime::Runtime;
386
387 #[test]
388 fn test_extract_filename_from_url() {
389 let url = "https://example.com/path/to/file.txt";
390 assert_eq!(extract_filename_from_url(url), Some("file.txt".to_string()));
391 let url2 = "https://example.com/path/to/";
392 assert_eq!(extract_filename_from_url(url2), None);
393 }
394
395 #[test]
396 fn test_header_extract_filename() {
397 let mut headers = HeaderMap::new();
398 headers.insert(
399 reqwest::header::CONTENT_DISPOSITION,
400 "attachment; filename=\"myfile.bin\"".parse().unwrap(),
401 );
402 let name = headers.extract_filename().unwrap();
403 assert_eq!(name, "myfile.bin");
404 }
405
406 #[test]
407 fn test_header_extract_file_size() {
408 let mut headers = HeaderMap::new();
409 headers.insert(
410 reqwest::header::CONTENT_RANGE,
411 "bytes 0-0/12345".parse().unwrap(),
412 );
413 let size = headers.extract_file_size().unwrap();
414 assert_eq!(size, 12345u64);
415 }
416
417 #[test]
418 fn test_downloader_new_and_defaults() {
419 let d = Downloader::new("https://example.com/file");
420 assert_eq!(d.url, "https://example.com/file");
421 assert!(d.filename.is_none());
422 assert!(d.file_size.is_none());
423 }
424
425 #[test]
427 fn test_download_placeholder() {
428 let rt = Runtime::new().unwrap();
430 rt.block_on(async {
431 let mut downloader = Downloader::new("https://example.com/file");
432 downloader.headers = HeaderMap::new();
434 downloader.filename = Some("tmp_download.bin".to_string());
435 downloader.file_size = Some(0);
436
437 assert_eq!(downloader.filename.as_deref(), Some("tmp_download.bin"));
439 assert_eq!(downloader.file_size, Some(0));
440 });
441 }
442}