1use std::{cmp::min, error::Error, sync::Arc, time::Duration};
2
3use reqwest::{
4 self, Url,
5 header::{CONTENT_DISPOSITION, CONTENT_RANGE, HeaderMap, RANGE},
6};
7use futures_util::{StreamExt, future};
9use indicatif::{HumanBytes, MultiProgress, ProgressBar, ProgressStyle};
10use tokio::fs::File;
11use tokio::io::{AsyncSeekExt, AsyncWriteExt, SeekFrom};
12use tokio::sync::Mutex;
13
14#[derive(Debug)]
15struct Chunk {
16 start_byte: u64,
17 end_byte: u64,
18 downloaded: u64,
19}
20
21impl Chunk {
22 fn new(start_byte: u64, end_byte: u64) -> Self {
23 Self {
24 start_byte,
25 end_byte,
26 downloaded: 0,
27 }
28 }
29}
30
31#[derive(Debug)]
32pub struct Downloader {
33 url: String,
34 headers: HeaderMap,
35 file_size: Option<u64>,
36 filename: Option<String>,
37 chunks: Arc<Mutex<Vec<Chunk>>>, }
39
40pub trait HeaderUtils {
41 fn extract_filename(&self) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
48
49 fn extract_file_size(&self) -> Result<u64, Box<dyn std::error::Error + Send + Sync>>;
61}
62
63impl HeaderUtils for HeaderMap {
66 fn extract_filename(&self) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
67 if let Some(disposition) = &self.get(CONTENT_DISPOSITION) {
68 let value = disposition.to_str()?;
69 if let Some(filename) = value.split("filename=").nth(1) {
70 return Ok(filename.trim_matches('"').to_string());
71 }
72 }
73 return Err(Box::from("Unable to extract filename".to_owned()));
74 }
76
77 fn extract_file_size(&self) -> Result<u64, Box<dyn std::error::Error + Send + Sync>> {
78 let &cr = &self
79 .get(CONTENT_RANGE)
80 .ok_or_else(|| Box::<dyn Error + Send + Sync>::from("Content_range not found"))?;
81 let content_range =
82 cr.to_str()?.split("/").into_iter().last().ok_or_else(|| {
83 Box::<dyn Error + Send + Sync>::from("Invalid Content_range_format")
84 })?;
85 Ok(content_range.parse()?)
86 }
87}
88
89pub fn extract_filename_from_url(url: &str) -> Option<String> {
94 if let Ok(parsed_url) = Url::parse(&url) {
95 if let Some(segment) = parsed_url.path_segments().and_then(|s| s.last()) {
96 if !segment.is_empty() {
97 return Some(segment.to_string());
98 }
99 }
100 }
101 return None;
102}
103
104impl Downloader {
105 pub fn new(url: &str) -> Self {
106 Self {
107 url: url.to_owned(),
108 headers: HeaderMap::new(),
109 file_size: None,
110 filename: None,
111 chunks: Arc::new(Mutex::new(Vec::new())),
112 }
113 }
114
115 async fn get_chunk(
116 &self,
117 range: Option<(u64, u64)>,
118 progress_bar: Option<ProgressBar>,
119 file: Option<Arc<Mutex<File>>>,
120 chunk_index: Option<usize>,
121 ) -> Result<u64, Box<dyn std::error::Error + Send + Sync>> {
122 let client = reqwest::Client::new();
123 let mut builder = client.get(&self.url);
124 if let Some((start, end)) = range {
125 builder = builder.header(RANGE, &format!("bytes={start}-{end}"));
126 }
127 let response = builder.send().await?;
128 let mut stream = response.bytes_stream();
129 let mut downloaded = 0u64;
130 let mut chunk_data = Vec::new();
131
132 while let Some(chunk) = stream.next().await {
134 let chunk = chunk?;
135 chunk_data.extend_from_slice(&chunk);
136 downloaded += chunk.len() as u64;
137
138 if let Some(bar) = &progress_bar {
139 bar.inc(chunk.len() as u64);
140 }
141
142 if let Some(idx) = chunk_index {
144 let mut chunks = self.chunks.lock().await;
145 if idx < chunks.len() {
146 chunks[idx].downloaded = downloaded;
147 }
148 }
149 }
150
151 if let (Some(file), Some((start, _))) = (file, range) {
153 let mut f = file.lock().await;
154 f.seek(SeekFrom::Start(start)).await?;
155 f.write_all(&chunk_data).await?;
156 }
157
158 if let Some(bar) = progress_bar {
159 bar.finish();
160 }
161
162 Ok(downloaded)
163 }
164
165 pub async fn download(
166 &mut self,
167 path: &str,
168 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
169 let client = reqwest::Client::new();
170 let response = client
172 .get(&self.url)
173 .header(RANGE, "bytes=0-0")
174 .send()
175 .await?;
176 self.headers = response.headers().clone().to_owned();
177
178 let filename = match &self.headers.extract_filename() {
179 Ok(filename) => filename.to_owned(),
180 Err(_) => match extract_filename_from_url(&self.url) {
181 Some(filename) => filename,
182 None => "download.bin".to_owned(),
183 },
184 };
185 println!("⛔filename: {filename}");
186 self.filename = Some(format!("{path}/{filename}").replace("//", "/"));
188
189 if let Ok(file_size) = self.headers.extract_file_size() {
190 self.file_size = Some(file_size);
191 println!("⛔file size: {}", HumanBytes(file_size));
192 } else {
193 println!("⛔ Unable to determine the file size. skipping threads")
194 }
195
196 let file = Arc::new(Mutex::new(
197 File::create(self.filename.as_ref().unwrap()).await?,
198 ));
199
200 if let Some(file_size) = self.file_size {
202 file.lock().await.set_len(file_size).await?;
205
206 let mut start = 0;
207 let byte_size = file_size / 8;
208
209 while start < file_size {
211 let end = min(start + byte_size, file_size);
212 self.chunks.lock().await.push(Chunk::new(start, end));
213 start = end + 1;
214 }
215
216 let num_chunks = self.chunks.lock().await.len();
217 println!("Created {} chunks for download", num_chunks);
218
219 let multi_progress = Arc::new(MultiProgress::new());
220
221 let mut tasks = Vec::new();
223 let chunks_clone = Arc::clone(&self.chunks);
224
225 for i in 0..num_chunks {
226 let chunks = Arc::clone(&chunks_clone);
227 let file_clone = Arc::clone(&file);
228 let url = self.url.clone();
229 let multi_progress_clone = Arc::clone(&multi_progress);
230
231 let task = tokio::spawn(async move {
232 let (start, end) = {
234 let chunks_guard = chunks.lock().await;
235 if i >= chunks_guard.len() {
236 return Err("Chunk index out of bounds".into());
237 }
238 (chunks_guard[i].start_byte, chunks_guard[i].end_byte)
239 };
240
241 let chunk_size = end - start + 1;
243 let progress_bar = multi_progress_clone.add(ProgressBar::new(chunk_size));
244 progress_bar.set_style(ProgressStyle::with_template(
245 &format!("[Chunk {}] {{wide_bar:40.cyan/blue}} {{binary_bytes}}/{{binary_total_bytes}} ({{percent}}%)", i)
246 ).unwrap());
247
248 let downloader = Downloader {
250 url,
251 headers: HeaderMap::new(),
252 file_size: None,
253 filename: None,
254 chunks: chunks,
255 };
256
257 downloader
259 .get_chunk(
260 Some((start, end)),
261 Some(progress_bar),
262 Some(file_clone),
263 Some(i),
264 )
265 .await
266 });
267
268 tasks.push(task);
269 }
270
271 println!("Starting concurrent downloads...");
273 let results = future::try_join_all(tasks)
274 .await
275 .map_err(|e| format!("Task join error: {}", e))?;
276
277 let total_downloaded: u64 = results
278 .into_iter()
279 .collect::<Result<Vec<_>, _>>()?
280 .into_iter()
281 .sum();
282
283 println!(
284 "Download completed! Total bytes: {}",
285 HumanBytes(total_downloaded)
286 );
287 } else {
288 let file_clone = Arc::clone(&file);
289 let bar = ProgressBar::new_spinner();
290 bar.enable_steady_tick(Duration::from_millis(100));
291 println!("");
292 bar.set_style(
293 ProgressStyle::with_template(&format!(
294 "{{spinner:.cyan}} {:?} ({{binary_bytes}} downloaded)",
295 self.filename.as_ref().unwrap()
296 ))
297 .unwrap()
298 .tick_chars("🌑🌒🌓🌔🌕🌖🌗🌘"),
301 );
302 let _ = self
303 .get_chunk(None, Some(bar), Some(file_clone), None)
304 .await;
305 }
306
307 Ok(())
308 }
309}