1use std::{cmp::min, error::Error, sync::Arc, time::Duration};
2
3use reqwest::{
4 self, Url,
5 header::{CONTENT_DISPOSITION, CONTENT_RANGE, HeaderMap, RANGE},
6};
7use futures_util::{StreamExt, future};
9use indicatif::{HumanBytes, MultiProgress, ProgressBar, ProgressStyle};
10use tokio::fs::File;
11use tokio::io::{AsyncSeekExt, AsyncWriteExt, SeekFrom};
12use tokio::sync::Mutex;
13
14#[derive(Debug)]
15struct Chunk {
16 start_byte: u64,
17 end_byte: u64,
18 downloaded: u64,
19}
20
21impl Chunk {
22 fn new(start_byte: u64, end_byte: u64) -> Self {
23 Self {
24 start_byte,
25 end_byte,
26 downloaded: 0,
27 }
28 }
29}
30
31#[derive(Debug)]
32pub struct Downloader {
33 url: String,
34 headers: HeaderMap,
35 file_size: Option<u64>,
36 filename: Option<String>,
37 chunks: Arc<Mutex<Vec<Chunk>>>, }
39
40pub trait HeaderUtils {
41 fn extract_filename(&self) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
48
49 fn extract_file_size(&self) -> Result<u64, Box<dyn std::error::Error + Send + Sync>>;
61}
62
63impl HeaderUtils for HeaderMap {
66 fn extract_filename(&self) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
67 if let Some(disposition) = &self.get(CONTENT_DISPOSITION) {
68 let value = disposition.to_str()?;
69 if let Some(filename) = value.split("filename=").nth(1) {
70 return Ok(filename.trim_matches('"').to_string());
71 }
72 }
73 return Err(Box::from("Unable to extract filename".to_owned()));
74 }
76
77 fn extract_file_size(&self) -> Result<u64, Box<dyn std::error::Error + Send + Sync>> {
82 let &cr = &self
83 .get(CONTENT_RANGE)
84 .ok_or_else(|| Box::<dyn Error + Send + Sync>::from("Content_range not found"))?;
85 let content_range =
86 cr.to_str()?.split("/").into_iter().last().ok_or_else(|| {
87 Box::<dyn Error + Send + Sync>::from("Invalid Content_range_format")
88 })?;
89 Ok(content_range.parse()?)
90 }
91}
92
93pub fn extract_filename_from_url(url: &str) -> Option<String> {
98 if let Ok(parsed_url) = Url::parse(&url) {
99 if let Some(segment) = parsed_url.path_segments().and_then(|s| s.last()) {
100 if !segment.is_empty() {
101 return Some(segment.to_string());
102 }
103 }
104 }
105 return None;
106}
107
108impl Downloader {
109 pub fn new<S: Into<String>>(url: S) -> Self {
110 Self {
111 url: url.into(),
112 headers: HeaderMap::new(),
113 file_size: None,
114 filename: None,
115 chunks: Arc::new(Mutex::new(Vec::new())),
116 }
117 }
118
119 async fn get_chunk(
120 &self,
121 range: Option<(u64, u64)>,
122 progress_bar: Option<ProgressBar>,
123 file: Option<Arc<Mutex<File>>>,
124 chunk_index: Option<usize>,
125 ) -> Result<u64, Box<dyn std::error::Error + Send + Sync>> {
126 let client = reqwest::Client::new();
127 let mut builder = client.get(&self.url);
128 if let Some((start, end)) = range {
129 builder = builder.header(RANGE, &format!("bytes={start}-{end}"));
130 }
131 let response = builder.send().await?;
132 let mut stream = response.bytes_stream();
133 let mut downloaded = 0u64;
134 let mut chunk_data = Vec::new();
135
136 while let Some(chunk) = stream.next().await {
138 let chunk = chunk?;
139 chunk_data.extend_from_slice(&chunk);
140 downloaded += chunk.len() as u64;
141
142 if let Some(bar) = &progress_bar {
143 bar.inc(chunk.len() as u64);
144 }
145
146 if let Some(idx) = chunk_index {
148 let mut chunks = self.chunks.lock().await;
149 if idx < chunks.len() {
150 chunks[idx].downloaded = downloaded;
151 }
152 }
153 }
154
155 if let (Some(file), Some((start, _))) = (file, range) {
157 let mut f = file.lock().await;
158 f.seek(SeekFrom::Start(start)).await?;
159 f.write_all(&chunk_data).await?;
160 }
161
162 if let Some(bar) = progress_bar {
163 bar.finish();
164 }
165
166 Ok(downloaded)
167 }
168
169 pub async fn download(
185 &mut self,
186 path: &str,
187 threads: Option<u8>,
188 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
189 let client = reqwest::Client::new();
190 let threads: u64 = match threads {
191 Some(0) => 8,
192 Some(count) => count as u64,
193 None => 8,
194 };
195 let response = client
197 .get(&self.url)
198 .header(RANGE, "bytes=0-0")
199 .send()
200 .await?;
201 self.headers = response.headers().clone().to_owned();
202
203 let filename = match &self.headers.extract_filename() {
204 Ok(filename) => filename.to_owned(),
205 Err(_) => match extract_filename_from_url(&self.url) {
206 Some(filename) => filename,
207 None => "download.bin".to_owned(),
208 },
209 };
210 println!("Downloading \"{filename}\"");
211 self.filename = Some(format!("{path}/{filename}").replace("//", "/"));
212
213 if let Ok(file_size) = self.headers.extract_file_size() {
214 self.file_size = Some(file_size);
215 println!("file size: {}", HumanBytes(file_size));
216 } else {
217 println!("Unable to determine the file size. skipping threads");
218 }
219
220 let file = Arc::new(Mutex::new(
221 File::create(self.filename.as_ref().unwrap()).await?,
222 ));
223
224 if let Some(file_size) = self.file_size {
226 file.lock().await.set_len(file_size).await?;
229
230 let mut start = 0;
231 let mut byte_size = file_size / threads;
232
233 if file_size < 1024 * 1024 {
235 println!("ℹ️ The file is smaller than 1 MB, so skipping threads.");
236 byte_size = file_size;
237 }
238
239 while start < file_size {
241 let end = min(start + byte_size, file_size);
242 self.chunks.lock().await.push(Chunk::new(start, end));
243 start = end + 1;
244 }
245
246 let num_chunks = self.chunks.lock().await.len();
247 println!("Created {} chunks for download", num_chunks);
248
249 let multi_progress = Arc::new(MultiProgress::new());
250
251 let mut tasks = Vec::new();
253 let chunks_clone = Arc::clone(&self.chunks);
254
255 for i in 0..num_chunks {
256 let chunks = Arc::clone(&chunks_clone);
257 let file_clone = Arc::clone(&file);
258 let url = self.url.clone();
259 let multi_progress_clone = Arc::clone(&multi_progress);
260
261 let task = tokio::spawn(async move {
262 let (start, end) = {
264 let chunks_guard = chunks.lock().await;
265 if i >= chunks_guard.len() {
266 return Err("Chunk index out of bounds".into());
267 }
268 (chunks_guard[i].start_byte, chunks_guard[i].end_byte)
269 };
270
271 let chunk_size = end - start + 1;
273 let progress_bar = multi_progress_clone.add(ProgressBar::new(chunk_size));
274 progress_bar.set_style(ProgressStyle::with_template(
275 &format!(
276 "[Chunk {:03}] {{wide_bar:40.cyan/blue}} {{binary_bytes}}/{{binary_total_bytes}} ({{percent}}%)",
277 i+1
279 )
280 ).unwrap());
281
282 let downloader = Downloader {
284 url,
285 headers: HeaderMap::new(),
286 file_size: None,
287 filename: None,
288 chunks: chunks,
289 };
290
291 downloader
293 .get_chunk(
294 Some((start, end)),
295 Some(progress_bar),
296 Some(file_clone),
297 Some(i),
298 )
299 .await
300 });
301
302 tasks.push(task);
303 }
304
305 println!("Starting concurrent downloads...");
307 let results = future::try_join_all(tasks)
308 .await
309 .map_err(|e| format!("Task join error: {}", e))?;
310
311 let total_downloaded: u64 = results
312 .into_iter()
313 .collect::<Result<Vec<_>, _>>()?
314 .into_iter()
315 .sum();
316
317 println!(
318 "Download completed! Total bytes: {}",
319 HumanBytes(total_downloaded)
320 );
321 } else {
322 let file_clone = Arc::clone(&file);
325 let bar = ProgressBar::new_spinner();
326 bar.enable_steady_tick(Duration::from_millis(100));
327 println!("");
328 bar.set_style(
329 ProgressStyle::with_template(&format!(
330 "{{spinner:.cyan}} {:?} ({{binary_bytes}} downloaded)",
331 self.filename.as_ref().unwrap()
332 ))
333 .unwrap()
334 .tick_chars("⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏"),
335 );
336 let _ = self
337 .get_chunk(None, Some(bar), Some(file_clone), None)
338 .await;
339 }
340
341 Ok(())
342 }
343}