Skip to main content

furl_core/
engine.rs

1use std::{
2    cmp::min,
3    error::Error,
4    sync::{
5        Arc,
6        atomic::{AtomicUsize, Ordering},
7    },
8    time::Duration,
9};
10
11use reqwest::{
12    self, Url,
13    header::{CONTENT_DISPOSITION, CONTENT_RANGE, HeaderMap, RANGE},
14};
15// use tokio::stream;
16use futures_util::{StreamExt, future};
17use indicatif::{HumanBytes, MultiProgress, ProgressBar, ProgressStyle};
18use tokio::fs::File;
19use tokio::io::{AsyncSeekExt, AsyncWriteExt, SeekFrom};
20use tokio::sync::Mutex;
21
22const _1MB: u64 = 1024 * 1024;
23const _10MB: u64 = 10 * 1024 * 1024;
24
25#[derive(Debug)]
26struct Chunk {
27    start_byte: u64,
28    end_byte: u64,
29    downloaded: u64,
30}
31
32impl Chunk {
33    fn new(start_byte: u64, end_byte: u64) -> Self {
34        Self {
35            start_byte,
36            end_byte,
37            downloaded: 0,
38        }
39    }
40}
41
42#[derive(Debug)]
43pub struct Downloader {
44    url: String,
45    headers: HeaderMap,
46    file_size: Option<u64>,
47    filename: Option<String>,
48    chunks: Arc<Mutex<Vec<Chunk>>>, // this stores downloaded chunk size
49}
50
51pub trait HeaderUtils {
52    /// # `extract_filename`
53    ///
54    /// When response header provides content disposition or any other keys to
55    /// provide file name or file type, we can extract it from here.
56    ///
57    /// We can also guess file name from the url and content-type too.
58    fn extract_filename(&self) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
59
60    /// # `extract_file_size`
61    ///
62    /// When response header provides content-range, it is easy to extract the
63    /// actual file size in bytes.
64    ///
65    /// example response: `Content-Range` `bytes 0-0/360996864`
66    ///
67    /// From the above response header, we can extract value in bytes
68    ///
69    /// It will help us downloading partial data with multiple threads.
70    ///
71    fn extract_file_size(&self) -> Result<u64, Box<dyn std::error::Error + Send + Sync>>;
72}
73
74/// Since HeaderMap is imported from Reqwest, we need to define trait and then
75/// implement it to the HeaderMap struct.
76impl HeaderUtils for HeaderMap {
77    fn extract_filename(&self) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
78        if let Some(disposition) = &self.get(CONTENT_DISPOSITION) {
79            let value = disposition.to_str()?;
80            if let Some(filename) = value.split("filename=").nth(1) {
81                return Ok(filename.trim_matches('"').to_string());
82            }
83        }
84        Err(Box::from("Unable to extract filename".to_owned()))
85        // TODO: guess filename from content type
86    }
87
88    /// Returns the file size in bytes.
89    ///
90    /// If the content-length is is not found or not extracted, it just returns
91    /// Error
92    fn extract_file_size(&self) -> Result<u64, Box<dyn std::error::Error + Send + Sync>> {
93        let &cr = &self
94            .get(CONTENT_RANGE)
95            .ok_or_else(|| Box::<dyn Error + Send + Sync>::from("Content_range not found"))?;
96        let content_range =
97            cr.to_str()?.split("/").last().ok_or_else(|| {
98                Box::<dyn Error + Send + Sync>::from("Invalid Content_range_format")
99            })?;
100        Ok(content_range.parse()?)
101    }
102}
103
104/// # Extract file name from Urls
105/// This method is used when we do not have any headers passed for file name
106/// For example: if content disposition is not provided, but there is a valid
107/// filename in the request url
108pub fn extract_filename_from_url(url: &str) -> Option<String> {
109    if let Ok(parsed_url) = Url::parse(url)
110        && let Some(segment) = parsed_url.path_segments().and_then(|mut s| s.next_back())
111        && !segment.is_empty()
112    {
113        return Some(segment.to_string());
114    }
115
116    None
117}
118
119impl Downloader {
120    pub fn new<S: Into<String>>(url: S) -> Self {
121        Self {
122            url: url.into(),
123            headers: HeaderMap::new(),
124            file_size: None,
125            filename: None,
126            chunks: Arc::new(Mutex::new(Vec::new())),
127        }
128    }
129
130    async fn get_chunk(
131        &self,
132        range: Option<(u64, u64)>,
133        progress_bar: Option<ProgressBar>,
134        file: Option<Arc<Mutex<File>>>,
135        chunk_index: Option<usize>,
136    ) -> Result<u64, Box<dyn std::error::Error + Send + Sync>> {
137        let client = reqwest::Client::new();
138        let mut builder = client.get(&self.url);
139        if let Some((start, end)) = range {
140            builder = builder.header(RANGE, &format!("bytes={start}-{end}"));
141        }
142        let response = builder.send().await?;
143        let mut stream = response.bytes_stream();
144        let mut downloaded = 0u64;
145        let mut chunk_data = Vec::new();
146
147        // progress_bar
148        while let Some(chunk) = stream.next().await {
149            let chunk = chunk?;
150            chunk_data.extend_from_slice(&chunk);
151            downloaded += chunk.len() as u64;
152
153            if let Some(bar) = &progress_bar {
154                bar.inc(chunk.len() as u64);
155            }
156
157            // Update chunk progress in shared state
158            if let Some(idx) = chunk_index {
159                let mut chunks = self.chunks.lock().await;
160                if idx < chunks.len() {
161                    chunks[idx].downloaded = downloaded;
162                }
163            }
164        }
165
166        // Write to file at correct position
167        if let (Some(file), Some((start, _))) = (file, range) {
168            let mut f = file.lock().await;
169            f.seek(SeekFrom::Start(start)).await?;
170            f.write_all(&chunk_data).await?;
171        }
172
173        if let Some(bar) = progress_bar {
174            bar.finish();
175        }
176
177        Ok(downloaded)
178    }
179
180    /// downloads the file into the provided path
181    /// # Arguments
182    ///
183    /// * `path`: download path
184    /// * `threads`: number of threads to use for downloading.
185    ///   if you pass None, or Some(0), it will defaults to 8
186    ///
187    /// Note: If the download size is less than 1 MB, then it will completely
188    /// ignore threads, and download it as a single thread.
189    ///
190    /// If the file size is unknown at the moment it gets the header, it will
191    /// also ignore threads and skips the progress bar and just shows a simple
192    /// ticker as a feedback to let user know that the process is not is in a
193    /// deadlock state.
194    ///
195    pub async fn download(
196        &mut self,
197        path: &str,
198        filename: Option<String>,
199        threads: Option<u8>,
200    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
201        let client = reqwest::Client::new();
202        let threads: u64 = match threads {
203            Some(0) => 8,
204            Some(count) => count as u64,
205            None => 8,
206        };
207        // get response headers to get file name, length, etc.
208        let response = client
209            .get(&self.url)
210            .header(RANGE, "bytes=0-0")
211            .send()
212            .await?;
213        self.headers = response.headers().clone().to_owned();
214
215        let filename = if let Some(filename) = filename {
216            filename
217        } else {
218            match &self.headers.extract_filename() {
219                Ok(filename) => filename.to_owned(),
220                Err(_) => match extract_filename_from_url(&self.url) {
221                    Some(filename) => filename,
222                    None => "download.dat".to_owned(),
223                },
224            }
225        };
226        println!("Downloading \"{filename}\"");
227        self.filename = Some(format!("{path}/{filename}").replace("//", "/"));
228
229        if let Ok(file_size) = self.headers.extract_file_size() {
230            self.file_size = Some(file_size);
231            println!("file size: {}", HumanBytes(file_size));
232        } else {
233            println!("Unable to determine the file size. skipping threads");
234        }
235
236        let file = Arc::new(Mutex::new(
237            File::create(self.filename.as_ref().unwrap()).await?,
238        ));
239
240        // handle chunks with threads
241        if let Some(file_size) = self.file_size {
242            // allocate file's size if size is known
243            // this helps seeking to the position and writing the chunk at that position
244            file.lock().await.set_len(file_size).await?;
245
246            let mut start = 0;
247            let thread_size = file_size / threads;
248            let mut byte_size = thread_size;
249
250            //ignore threads if the file is less than a MB.
251            if file_size < _1MB {
252                println!("ℹ️ The file is smaller than 1 MB, so skipping threads.");
253                byte_size = file_size;
254            }
255
256            // if the byte size is larger than 10 MB, split into 10 MB chunks
257            // so that memory consumption is less.
258            if thread_size > _10MB {
259                byte_size = _10MB
260            }
261
262            // split chunks to download
263            while start < file_size {
264                let end = min(start + byte_size, file_size);
265                self.chunks.lock().await.push(Chunk::new(start, end));
266                start = end + 1;
267            }
268
269            let num_chunks = self.chunks.lock().await.len();
270            println!("Created {} chunks for download", num_chunks);
271
272            let multi_progress = Arc::new(MultiProgress::new());
273
274            // Create tasks for concurrent downloading
275            let mut tasks = Vec::new();
276            let chunks_clone = Arc::clone(&self.chunks);
277
278            // Use a fixed-size worker pool and an atomic index so we don't spawn one task per chunk.
279            // Each worker atomically pulls the next chunk index and processes it, which creates
280            // queue-like behavior while keeping the number of concurrent tasks limited to `threads`.
281            let index = Arc::new(AtomicUsize::new(0));
282
283            // limit number of workers to at most num_chunks
284            let threads_to_spawn = std::cmp::min(threads as usize, num_chunks);
285
286            for _ in 0..threads_to_spawn {
287                let chunks = Arc::clone(&chunks_clone);
288                let file_clone = Arc::clone(&file);
289                let url = self.url.clone();
290                let multi_progress_clone = Arc::clone(&multi_progress);
291                let index_clone = Arc::clone(&index);
292
293                let task = tokio::spawn(async move {
294                    let mut worker_total: u64 = 0;
295                    loop {
296                        // fetch next chunk index
297                        let i = index_clone.fetch_add(1, Ordering::SeqCst);
298                        if i >= num_chunks {
299                            break;
300                        }
301
302                        // Get chunk info
303                        let (start, end) = {
304                            let chunks_guard = chunks.lock().await;
305                            (chunks_guard[i].start_byte, chunks_guard[i].end_byte)
306                        };
307
308                        // Create progress bar for this chunk
309                        let chunk_size = end - start + 1;
310                        let progress_bar = multi_progress_clone.add(ProgressBar::new(chunk_size));
311                        progress_bar.set_style(ProgressStyle::with_template(
312                            &format!(
313                                "[Chunk {:03}] {{wide_bar:40.cyan/blue}} {{binary_bytes}}/{{binary_total_bytes}} ({{percent}}%)",
314                                // chunk index starts from 0, but 1 seems natural for human
315                                i + 1
316                            )
317                        ).unwrap());
318
319                        // Create a downloader instance for this chunk
320                        let downloader = Downloader {
321                            url: url.clone(),
322                            headers: HeaderMap::new(),
323                            file_size: None,
324                            filename: None,
325                            chunks: Arc::clone(&chunks),
326                        };
327
328                        // Download the chunk and accumulate the bytes downloaded by this worker
329                        let downloaded = downloader
330                            .get_chunk(
331                                Some((start, end)),
332                                Some(progress_bar),
333                                Some(Arc::clone(&file_clone)),
334                                Some(i),
335                            )
336                            .await?;
337                        worker_total += downloaded;
338                    }
339
340                    Ok::<u64, Box<dyn std::error::Error + Send + Sync>>(worker_total)
341                });
342
343                tasks.push(task);
344            }
345
346            // Wait for all downloads to complete
347            println!("Starting concurrent downloads...");
348            let results = future::try_join_all(tasks)
349                .await
350                .map_err(|e| format!("Task join error: {}", e))?;
351
352            let total_downloaded: u64 = results
353                .into_iter()
354                .collect::<Result<Vec<_>, _>>()?
355                .into_iter()
356                .sum();
357
358            println!(
359                "Download completed! Total bytes: {}",
360                HumanBytes(total_downloaded)
361            );
362        } else {
363            // continue without threads when the file size is unknown
364            // and just display ticks instead of progressbar since the size is unknown.
365            let file_clone = Arc::clone(&file);
366            let bar = ProgressBar::new_spinner();
367            bar.enable_steady_tick(Duration::from_millis(100));
368            println!("\n");
369            bar.set_style(
370                ProgressStyle::with_template(&format!(
371                    "{{spinner:.cyan}} {:?} ({{binary_bytes}} downloaded)",
372                    self.filename.as_ref().unwrap()
373                ))
374                .unwrap()
375                .tick_chars("⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏"),
376            );
377            let _ = self
378                .get_chunk(None, Some(bar), Some(file_clone), None)
379                .await;
380        }
381
382        Ok(())
383    }
384}
385
386#[cfg(test)]
387mod tests {
388    use super::*;
389    use reqwest::header::HeaderMap;
390    use tokio::runtime::Runtime;
391
392    #[test]
393    fn test_extract_filename_from_url() {
394        let url = "https://example.com/path/to/file.txt";
395        assert_eq!(extract_filename_from_url(url), Some("file.txt".to_string()));
396        let url2 = "https://example.com/path/to/";
397        assert_eq!(extract_filename_from_url(url2), None);
398    }
399
400    #[test]
401    fn test_header_extract_filename() {
402        let mut headers = HeaderMap::new();
403        headers.insert(
404            reqwest::header::CONTENT_DISPOSITION,
405            "attachment; filename=\"myfile.bin\"".parse().unwrap(),
406        );
407        let name = headers.extract_filename().unwrap();
408        assert_eq!(name, "myfile.bin");
409    }
410
411    #[test]
412    fn test_header_extract_file_size() {
413        let mut headers = HeaderMap::new();
414        headers.insert(
415            reqwest::header::CONTENT_RANGE,
416            "bytes 0-0/12345".parse().unwrap(),
417        );
418        let size = headers.extract_file_size().unwrap();
419        assert_eq!(size, 12345u64);
420    }
421
422    #[test]
423    fn test_downloader_new_and_defaults() {
424        let d = Downloader::new("https://example.com/file");
425        assert_eq!(d.url, "https://example.com/file");
426        assert!(d.filename.is_none());
427        assert!(d.file_size.is_none());
428    }
429
430    // Placeholder async test for download-related behavior; does not perform network IO.
431    #[test]
432    fn test_download_placeholder() {
433        // Create a runtime to run async parts if needed.
434        let rt = Runtime::new().unwrap();
435        rt.block_on(async {
436            let mut downloader = Downloader::new("https://example.com/file");
437            // Set internal fields to avoid real network operations in this placeholder.
438            downloader.headers = HeaderMap::new();
439            downloader.filename = Some("tmp_download.bin".to_string());
440            downloader.file_size = Some(0);
441
442            // Ensure setters/readers behave as expected in a minimal scenario.
443            assert_eq!(downloader.filename.as_deref(), Some("tmp_download.bin"));
444            assert_eq!(downloader.file_size, Some(0));
445        });
446    }
447}