furl_core/
engine.rs

1use std::{cmp::min, error::Error, sync::Arc, time::Duration};
2
3use reqwest::{
4    self, Url,
5    header::{CONTENT_DISPOSITION, CONTENT_RANGE, HeaderMap, RANGE},
6};
7// use tokio::stream;
8use futures_util::{StreamExt, future};
9use indicatif::{HumanBytes, MultiProgress, ProgressBar, ProgressStyle};
10use tokio::fs::File;
11use tokio::io::{AsyncSeekExt, AsyncWriteExt, SeekFrom};
12use tokio::sync::Mutex;
13
14#[derive(Debug)]
15struct Chunk {
16    start_byte: u64,
17    end_byte: u64,
18    downloaded: u64,
19}
20
21impl Chunk {
22    fn new(start_byte: u64, end_byte: u64) -> Self {
23        Self {
24            start_byte,
25            end_byte,
26            downloaded: 0,
27        }
28    }
29}
30
31#[derive(Debug)]
32pub struct Downloader {
33    url: String,
34    headers: HeaderMap,
35    file_size: Option<u64>,
36    filename: Option<String>,
37    chunks: Arc<Mutex<Vec<Chunk>>>, // this stores downloaded chunk size
38}
39
40pub trait HeaderUtils {
41    /// # `extract_filename`
42    ///
43    /// When response header provides content disposition or any other keys to
44    /// provide file name or file type, we can extract it from here.
45    ///
46    /// We can also guess file name from the url and content-type too.
47    fn extract_filename(&self) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
48
49    /// # `extract_file_size`
50    ///
51    /// When response header provides content-range, it is easy to extract the
52    /// actual file size in bytes.
53    ///
54    /// example response: `Content-Range` `bytes 0-0/360996864`
55    ///
56    /// From the above response header, we can extract value in bytes
57    ///
58    /// It will help us downloading partial data with multiple threads.
59    ///
60    fn extract_file_size(&self) -> Result<u64, Box<dyn std::error::Error + Send + Sync>>;
61}
62
63/// Since HeaderMap is imported from Reqwest, we need to define trait and then
64/// implement it to the HeaderMap struct.
65impl HeaderUtils for HeaderMap {
66    fn extract_filename(&self) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
67        if let Some(disposition) = &self.get(CONTENT_DISPOSITION) {
68            let value = disposition.to_str()?;
69            if let Some(filename) = value.split("filename=").nth(1) {
70                return Ok(filename.trim_matches('"').to_string());
71            }
72        }
73        return Err(Box::from("Unable to extract filename".to_owned()));
74        // TODO: guess filename from content type
75    }
76
77    fn extract_file_size(&self) -> Result<u64, Box<dyn std::error::Error + Send + Sync>> {
78        let &cr = &self
79            .get(CONTENT_RANGE)
80            .ok_or_else(|| Box::<dyn Error + Send + Sync>::from("Content_range not found"))?;
81        let content_range =
82            cr.to_str()?.split("/").into_iter().last().ok_or_else(|| {
83                Box::<dyn Error + Send + Sync>::from("Invalid Content_range_format")
84            })?;
85        Ok(content_range.parse()?)
86    }
87}
88
89/// # Extract file name from Urls
90/// This method is used when we do not have any headers passed for file name
91/// For example: if content disposition is not provided, but there is a valid
92/// filename in the request url
93pub fn extract_filename_from_url(url: &str) -> Option<String> {
94    if let Ok(parsed_url) = Url::parse(&url) {
95        if let Some(segment) = parsed_url.path_segments().and_then(|s| s.last()) {
96            if !segment.is_empty() {
97                return Some(segment.to_string());
98            }
99        }
100    }
101    return None;
102}
103
104impl Downloader {
105    pub fn new(url: &str) -> Self {
106        Self {
107            url: url.to_owned(),
108            headers: HeaderMap::new(),
109            file_size: None,
110            filename: None,
111            chunks: Arc::new(Mutex::new(Vec::new())),
112        }
113    }
114
115    async fn get_chunk(
116        &self,
117        range: Option<(u64, u64)>,
118        progress_bar: Option<ProgressBar>,
119        file: Option<Arc<Mutex<File>>>,
120        chunk_index: Option<usize>,
121    ) -> Result<u64, Box<dyn std::error::Error + Send + Sync>> {
122        let client = reqwest::Client::new();
123        let mut builder = client.get(&self.url);
124        if let Some((start, end)) = range {
125            builder = builder.header(RANGE, &format!("bytes={start}-{end}"));
126        }
127        let response = builder.send().await?;
128        let mut stream = response.bytes_stream();
129        let mut downloaded = 0u64;
130        let mut chunk_data = Vec::new();
131
132        // progress_bar
133        while let Some(chunk) = stream.next().await {
134            let chunk = chunk?;
135            chunk_data.extend_from_slice(&chunk);
136            downloaded += chunk.len() as u64;
137
138            if let Some(bar) = &progress_bar {
139                bar.inc(chunk.len() as u64);
140            }
141
142            // Update chunk progress in shared state
143            if let Some(idx) = chunk_index {
144                let mut chunks = self.chunks.lock().await;
145                if idx < chunks.len() {
146                    chunks[idx].downloaded = downloaded;
147                }
148            }
149        }
150
151        // Write to file at correct position
152        if let (Some(file), Some((start, _))) = (file, range) {
153            let mut f = file.lock().await;
154            f.seek(SeekFrom::Start(start)).await?;
155            f.write_all(&chunk_data).await?;
156        }
157
158        if let Some(bar) = progress_bar {
159            bar.finish();
160        }
161
162        Ok(downloaded)
163    }
164
165    pub async fn download(
166        &mut self,
167        path: &str,
168    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
169        let client = reqwest::Client::new();
170        // get response headers to get file name, length, etc.
171        let response = client
172            .get(&self.url)
173            .header(RANGE, "bytes=0-0")
174            .send()
175            .await?;
176        self.headers = response.headers().clone().to_owned();
177
178        let filename = match &self.headers.extract_filename() {
179            Ok(filename) => filename.to_owned(),
180            Err(_) => match extract_filename_from_url(&self.url) {
181                Some(filename) => filename,
182                None => "download.bin".to_owned(),
183            },
184        };
185        println!("⛔filename: {filename}");
186        // trim trailing / from original path
187        self.filename = Some(format!("{path}/{filename}").replace("//", "/"));
188
189        if let Ok(file_size) = self.headers.extract_file_size() {
190            self.file_size = Some(file_size);
191            println!("⛔file size: {}", HumanBytes(file_size));
192        } else {
193            println!("⛔ Unable to determine the file size. skipping threads")
194        }
195
196        let file = Arc::new(Mutex::new(
197            File::create(self.filename.as_ref().unwrap()).await?,
198        ));
199
200        // handle chunks with threads
201        if let Some(file_size) = self.file_size {
202            // allocate file's size if size is known
203            // this helps seeking to the position and writing the chunk at that position
204            file.lock().await.set_len(file_size).await?;
205
206            let mut start = 0;
207            let byte_size = file_size / 8;
208
209            // split chunks to download
210            while start < file_size {
211                let end = min(start + byte_size, file_size);
212                self.chunks.lock().await.push(Chunk::new(start, end));
213                start = end + 1;
214            }
215
216            let num_chunks = self.chunks.lock().await.len();
217            println!("Created {} chunks for download", num_chunks);
218
219            let multi_progress = Arc::new(MultiProgress::new());
220
221            // Create tasks for concurrent downloading
222            let mut tasks = Vec::new();
223            let chunks_clone = Arc::clone(&self.chunks);
224
225            for i in 0..num_chunks {
226                let chunks = Arc::clone(&chunks_clone);
227                let file_clone = Arc::clone(&file);
228                let url = self.url.clone();
229                let multi_progress_clone = Arc::clone(&multi_progress);
230
231                let task = tokio::spawn(async move {
232                    // Get chunk info
233                    let (start, end) = {
234                        let chunks_guard = chunks.lock().await;
235                        if i >= chunks_guard.len() {
236                            return Err("Chunk index out of bounds".into());
237                        }
238                        (chunks_guard[i].start_byte, chunks_guard[i].end_byte)
239                    };
240
241                    // Create progress bar for this chunk
242                    let chunk_size = end - start + 1;
243                    let progress_bar = multi_progress_clone.add(ProgressBar::new(chunk_size));
244                    progress_bar.set_style(ProgressStyle::with_template(
245                        &format!("[Chunk {}] {{wide_bar:40.cyan/blue}} {{binary_bytes}}/{{binary_total_bytes}} ({{percent}}%)", i)
246                    ).unwrap());
247
248                    // Create a downloader instance for this chunk
249                    let downloader = Downloader {
250                        url,
251                        headers: HeaderMap::new(),
252                        file_size: None,
253                        filename: None,
254                        chunks: chunks,
255                    };
256
257                    // Download the chunk
258                    downloader
259                        .get_chunk(
260                            Some((start, end)),
261                            Some(progress_bar),
262                            Some(file_clone),
263                            Some(i),
264                        )
265                        .await
266                });
267
268                tasks.push(task);
269            }
270
271            // Wait for all downloads to complete
272            println!("Starting concurrent downloads...");
273            let results = future::try_join_all(tasks)
274                .await
275                .map_err(|e| format!("Task join error: {}", e))?;
276
277            let total_downloaded: u64 = results
278                .into_iter()
279                .collect::<Result<Vec<_>, _>>()?
280                .into_iter()
281                .sum();
282
283            println!(
284                "Download completed! Total bytes: {}",
285                HumanBytes(total_downloaded)
286            );
287        } else {
288            let file_clone = Arc::clone(&file);
289            let bar = ProgressBar::new_spinner();
290            bar.enable_steady_tick(Duration::from_millis(100));
291            println!("");
292            bar.set_style(
293                ProgressStyle::with_template(&format!(
294                    "{{spinner:.cyan}} {:?} ({{binary_bytes}} downloaded)",
295                    self.filename.as_ref().unwrap()
296                ))
297                .unwrap()
298                // .tick_chars("⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏"),
299                // set tick character as a moon's phase as progress indicator
300                .tick_chars("🌑🌒🌓🌔🌕🌖🌗🌘"),
301            );
302            let _ = self
303                .get_chunk(None, Some(bar), Some(file_clone), None)
304                .await;
305        }
306
307        Ok(())
308    }
309}