Skip to main content

furl_core/
engine.rs

1use std::{
2    cmp::min,
3    error::Error,
4    sync::{
5        Arc,
6        atomic::{AtomicUsize, Ordering},
7    },
8    time::Duration,
9};
10
11use reqwest::{
12    self, Url,
13    header::{CONTENT_DISPOSITION, CONTENT_RANGE, HeaderMap, RANGE},
14};
15// use tokio::stream;
16use futures_util::{StreamExt, future};
17use indicatif::{HumanBytes, MultiProgress, ProgressBar, ProgressStyle};
18use tokio::fs::File;
19use tokio::io::{AsyncSeekExt, AsyncWriteExt, SeekFrom};
20use tokio::sync::Mutex;
21
22const _1MB: u64 = 1024 * 1024;
23const _10MB: u64 = 10 * 1024 * 1024;
24
25#[derive(Debug)]
26struct Chunk {
27    start_byte: u64,
28    end_byte: u64,
29    downloaded: u64,
30}
31
32impl Chunk {
33    fn new(start_byte: u64, end_byte: u64) -> Self {
34        Self {
35            start_byte,
36            end_byte,
37            downloaded: 0,
38        }
39    }
40}
41
42#[derive(Debug)]
43pub struct Downloader {
44    url: String,
45    headers: HeaderMap,
46    file_size: Option<u64>,
47    filename: Option<String>,
48    chunks: Arc<Mutex<Vec<Chunk>>>, // this stores downloaded chunk size
49}
50
51pub trait HeaderUtils {
52    /// # `extract_filename`
53    ///
54    /// When response header provides content disposition or any other keys to
55    /// provide file name or file type, we can extract it from here.
56    ///
57    /// We can also guess file name from the url and content-type too.
58    fn extract_filename(&self) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
59
60    /// # `extract_file_size`
61    ///
62    /// When response header provides content-range, it is easy to extract the
63    /// actual file size in bytes.
64    ///
65    /// example response: `Content-Range` `bytes 0-0/360996864`
66    ///
67    /// From the above response header, we can extract value in bytes
68    ///
69    /// It will help us downloading partial data with multiple threads.
70    ///
71    fn extract_file_size(&self) -> Result<u64, Box<dyn std::error::Error + Send + Sync>>;
72}
73
74/// Since HeaderMap is imported from Reqwest, we need to define trait and then
75/// implement it to the HeaderMap struct.
76impl HeaderUtils for HeaderMap {
77    fn extract_filename(&self) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
78        if let Some(disposition) = &self.get(CONTENT_DISPOSITION) {
79            let value = disposition.to_str()?;
80            if let Some(filename) = value.split("filename=").nth(1) {
81                return Ok(filename.trim_matches('"').to_string());
82            }
83        }
84        Err(Box::from("Unable to extract filename".to_owned()))
85        // TODO: guess filename from content type
86    }
87
88    /// Returns the file size in bytes.
89    ///
90    /// If the content-length is is not found or not extracted, it just returns
91    /// Error
92    fn extract_file_size(&self) -> Result<u64, Box<dyn std::error::Error + Send + Sync>> {
93        let &cr = &self
94            .get(CONTENT_RANGE)
95            .ok_or_else(|| Box::<dyn Error + Send + Sync>::from("Content_range not found"))?;
96        let content_range =
97            cr.to_str()?.split("/").last().ok_or_else(|| {
98                Box::<dyn Error + Send + Sync>::from("Invalid Content_range_format")
99            })?;
100        Ok(content_range.parse()?)
101    }
102}
103
104/// # Extract file name from Urls
105/// This method is used when we do not have any headers passed for file name
106/// For example: if content disposition is not provided, but there is a valid
107/// filename in the request url
108pub fn extract_filename_from_url(url: &str) -> Option<String> {
109    if let Ok(parsed_url) = Url::parse(url)
110        && let Some(segment) = parsed_url.path_segments().and_then(|mut s| s.next_back())
111        && !segment.is_empty()
112    {
113        return Some(segment.to_string());
114    }
115
116    None
117}
118
119impl Downloader {
120    pub fn new<S: Into<String>>(url: S) -> Self {
121        Self {
122            url: url.into(),
123            headers: HeaderMap::new(),
124            file_size: None,
125            filename: None,
126            chunks: Arc::new(Mutex::new(Vec::new())),
127        }
128    }
129
130    async fn get_chunk(
131        &self,
132        range: Option<(u64, u64)>,
133        progress_bar: Option<ProgressBar>,
134        file: Option<Arc<Mutex<File>>>,
135        chunk_index: Option<usize>,
136    ) -> Result<u64, Box<dyn std::error::Error + Send + Sync>> {
137        let client = reqwest::Client::new();
138        let mut builder = client.get(&self.url);
139        if let Some((start, end)) = range {
140            builder = builder.header(RANGE, &format!("bytes={start}-{end}"));
141        }
142        let response = builder.send().await?;
143        let mut stream = response.bytes_stream();
144        let mut downloaded = 0u64;
145        let mut chunk_data = Vec::new();
146
147        // progress_bar
148        while let Some(chunk) = stream.next().await {
149            let chunk = chunk?;
150            chunk_data.extend_from_slice(&chunk);
151            downloaded += chunk.len() as u64;
152
153            if let Some(bar) = &progress_bar {
154                bar.inc(chunk.len() as u64);
155            }
156
157            // Update chunk progress in shared state
158            if let Some(idx) = chunk_index {
159                let mut chunks = self.chunks.lock().await;
160                if idx < chunks.len() {
161                    chunks[idx].downloaded = downloaded;
162                }
163            }
164        }
165
166        // Write to file at correct position
167        if let (Some(file), Some((start, _))) = (file, range) {
168            let mut f = file.lock().await;
169            f.seek(SeekFrom::Start(start)).await?;
170            f.write_all(&chunk_data).await?;
171        }
172
173        if let Some(bar) = progress_bar {
174            bar.finish();
175        }
176
177        Ok(downloaded)
178    }
179
180    /// downloads the file into the provided path
181    /// # Arguments
182    ///
183    /// * `path`: download path
184    /// * `threads`: number of threads to use for downloading.
185    ///   if you pass None, or Some(0), it will defaults to 8
186    ///
187    /// Note: If the download size is less than 1 MB, then it will completely
188    /// ignore threads, and download it as a single thread.
189    ///
190    /// If the file size is unknown at the moment it gets the header, it will
191    /// also ignore threads and skips the progress bar and just shows a simple
192    /// ticker as a feedback to let user know that the process is not is in a
193    /// deadlock state.
194    ///
195    pub async fn download(
196        &mut self,
197        path: &str,
198        threads: Option<u8>,
199    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
200        let client = reqwest::Client::new();
201        let threads: u64 = match threads {
202            Some(0) => 8,
203            Some(count) => count as u64,
204            None => 8,
205        };
206        // get response headers to get file name, length, etc.
207        let response = client
208            .get(&self.url)
209            .header(RANGE, "bytes=0-0")
210            .send()
211            .await?;
212        self.headers = response.headers().clone().to_owned();
213
214        let filename = match &self.headers.extract_filename() {
215            Ok(filename) => filename.to_owned(),
216            Err(_) => match extract_filename_from_url(&self.url) {
217                Some(filename) => filename,
218                None => "download.bin".to_owned(),
219            },
220        };
221        println!("Downloading \"{filename}\"");
222        self.filename = Some(format!("{path}/{filename}").replace("//", "/"));
223
224        if let Ok(file_size) = self.headers.extract_file_size() {
225            self.file_size = Some(file_size);
226            println!("file size: {}", HumanBytes(file_size));
227        } else {
228            println!("Unable to determine the file size. skipping threads");
229        }
230
231        let file = Arc::new(Mutex::new(
232            File::create(self.filename.as_ref().unwrap()).await?,
233        ));
234
235        // handle chunks with threads
236        if let Some(file_size) = self.file_size {
237            // allocate file's size if size is known
238            // this helps seeking to the position and writing the chunk at that position
239            file.lock().await.set_len(file_size).await?;
240
241            let mut start = 0;
242            let thread_size = file_size / threads;
243            let mut byte_size = thread_size;
244
245            //ignore threads if the file is less than a MB.
246            if file_size < _1MB {
247                println!("ℹ️ The file is smaller than 1 MB, so skipping threads.");
248                byte_size = file_size;
249            }
250
251            // if the byte size is larger than 10 MB, split into 10 MB chunks
252            // so that memory consumption is less.
253            if thread_size > _10MB {
254                byte_size = _10MB
255            }
256
257            // split chunks to download
258            while start < file_size {
259                let end = min(start + byte_size, file_size);
260                self.chunks.lock().await.push(Chunk::new(start, end));
261                start = end + 1;
262            }
263
264            let num_chunks = self.chunks.lock().await.len();
265            println!("Created {} chunks for download", num_chunks);
266
267            let multi_progress = Arc::new(MultiProgress::new());
268
269            // Create tasks for concurrent downloading
270            let mut tasks = Vec::new();
271            let chunks_clone = Arc::clone(&self.chunks);
272
273            // Use a fixed-size worker pool and an atomic index so we don't spawn one task per chunk.
274            // Each worker atomically pulls the next chunk index and processes it, which creates
275            // queue-like behavior while keeping the number of concurrent tasks limited to `threads`.
276            let index = Arc::new(AtomicUsize::new(0));
277
278            // limit number of workers to at most num_chunks
279            let threads_to_spawn = std::cmp::min(threads as usize, num_chunks);
280
281            for _ in 0..threads_to_spawn {
282                let chunks = Arc::clone(&chunks_clone);
283                let file_clone = Arc::clone(&file);
284                let url = self.url.clone();
285                let multi_progress_clone = Arc::clone(&multi_progress);
286                let index_clone = Arc::clone(&index);
287
288                let task = tokio::spawn(async move {
289                    let mut worker_total: u64 = 0;
290                    loop {
291                        // fetch next chunk index
292                        let i = index_clone.fetch_add(1, Ordering::SeqCst);
293                        if i >= num_chunks {
294                            break;
295                        }
296
297                        // Get chunk info
298                        let (start, end) = {
299                            let chunks_guard = chunks.lock().await;
300                            (chunks_guard[i].start_byte, chunks_guard[i].end_byte)
301                        };
302
303                        // Create progress bar for this chunk
304                        let chunk_size = end - start + 1;
305                        let progress_bar = multi_progress_clone.add(ProgressBar::new(chunk_size));
306                        progress_bar.set_style(ProgressStyle::with_template(
307                            &format!(
308                                "[Chunk {:03}] {{wide_bar:40.cyan/blue}} {{binary_bytes}}/{{binary_total_bytes}} ({{percent}}%)",
309                                // chunk index starts from 0, but 1 seems natural for human
310                                i + 1
311                            )
312                        ).unwrap());
313
314                        // Create a downloader instance for this chunk
315                        let downloader = Downloader {
316                            url: url.clone(),
317                            headers: HeaderMap::new(),
318                            file_size: None,
319                            filename: None,
320                            chunks: Arc::clone(&chunks),
321                        };
322
323                        // Download the chunk and accumulate the bytes downloaded by this worker
324                        let downloaded = downloader
325                            .get_chunk(
326                                Some((start, end)),
327                                Some(progress_bar),
328                                Some(Arc::clone(&file_clone)),
329                                Some(i),
330                            )
331                            .await?;
332                        worker_total += downloaded;
333                    }
334
335                    Ok::<u64, Box<dyn std::error::Error + Send + Sync>>(worker_total)
336                });
337
338                tasks.push(task);
339            }
340
341            // Wait for all downloads to complete
342            println!("Starting concurrent downloads...");
343            let results = future::try_join_all(tasks)
344                .await
345                .map_err(|e| format!("Task join error: {}", e))?;
346
347            let total_downloaded: u64 = results
348                .into_iter()
349                .collect::<Result<Vec<_>, _>>()?
350                .into_iter()
351                .sum();
352
353            println!(
354                "Download completed! Total bytes: {}",
355                HumanBytes(total_downloaded)
356            );
357        } else {
358            // continue without threads when the file size is unknown
359            // and just display ticks instead of progressbar since the size is unknown.
360            let file_clone = Arc::clone(&file);
361            let bar = ProgressBar::new_spinner();
362            bar.enable_steady_tick(Duration::from_millis(100));
363            println!("\n");
364            bar.set_style(
365                ProgressStyle::with_template(&format!(
366                    "{{spinner:.cyan}} {:?} ({{binary_bytes}} downloaded)",
367                    self.filename.as_ref().unwrap()
368                ))
369                .unwrap()
370                .tick_chars("⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏"),
371            );
372            let _ = self
373                .get_chunk(None, Some(bar), Some(file_clone), None)
374                .await;
375        }
376
377        Ok(())
378    }
379}
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384    use reqwest::header::HeaderMap;
385    use tokio::runtime::Runtime;
386
387    #[test]
388    fn test_extract_filename_from_url() {
389        let url = "https://example.com/path/to/file.txt";
390        assert_eq!(extract_filename_from_url(url), Some("file.txt".to_string()));
391        let url2 = "https://example.com/path/to/";
392        assert_eq!(extract_filename_from_url(url2), None);
393    }
394
395    #[test]
396    fn test_header_extract_filename() {
397        let mut headers = HeaderMap::new();
398        headers.insert(
399            reqwest::header::CONTENT_DISPOSITION,
400            "attachment; filename=\"myfile.bin\"".parse().unwrap(),
401        );
402        let name = headers.extract_filename().unwrap();
403        assert_eq!(name, "myfile.bin");
404    }
405
406    #[test]
407    fn test_header_extract_file_size() {
408        let mut headers = HeaderMap::new();
409        headers.insert(
410            reqwest::header::CONTENT_RANGE,
411            "bytes 0-0/12345".parse().unwrap(),
412        );
413        let size = headers.extract_file_size().unwrap();
414        assert_eq!(size, 12345u64);
415    }
416
417    #[test]
418    fn test_downloader_new_and_defaults() {
419        let d = Downloader::new("https://example.com/file");
420        assert_eq!(d.url, "https://example.com/file");
421        assert!(d.filename.is_none());
422        assert!(d.file_size.is_none());
423    }
424
425    // Placeholder async test for download-related behavior; does not perform network IO.
426    #[test]
427    fn test_download_placeholder() {
428        // Create a runtime to run async parts if needed.
429        let rt = Runtime::new().unwrap();
430        rt.block_on(async {
431            let mut downloader = Downloader::new("https://example.com/file");
432            // Set internal fields to avoid real network operations in this placeholder.
433            downloader.headers = HeaderMap::new();
434            downloader.filename = Some("tmp_download.bin".to_string());
435            downloader.file_size = Some(0);
436
437            // Ensure setters/readers behave as expected in a minimal scenario.
438            assert_eq!(downloader.filename.as_deref(), Some("tmp_download.bin"));
439            assert_eq!(downloader.file_size, Some(0));
440        });
441    }
442}