furl_core/
engine.rs

1use std::{cmp::min, error::Error, sync::Arc, time::Duration};
2
3use reqwest::{
4    self, Url,
5    header::{CONTENT_DISPOSITION, CONTENT_RANGE, HeaderMap, RANGE},
6};
7// use tokio::stream;
8use futures_util::{StreamExt, future};
9use indicatif::{HumanBytes, MultiProgress, ProgressBar, ProgressStyle};
10use tokio::fs::File;
11use tokio::io::{AsyncSeekExt, AsyncWriteExt, SeekFrom};
12use tokio::sync::Mutex;
13
14#[derive(Debug)]
15struct Chunk {
16    start_byte: u64,
17    end_byte: u64,
18    downloaded: u64,
19}
20
21impl Chunk {
22    fn new(start_byte: u64, end_byte: u64) -> Self {
23        Self {
24            start_byte,
25            end_byte,
26            downloaded: 0,
27        }
28    }
29}
30
31#[derive(Debug)]
32pub struct Downloader {
33    url: String,
34    headers: HeaderMap,
35    file_size: Option<u64>,
36    filename: Option<String>,
37    chunks: Arc<Mutex<Vec<Chunk>>>, // this stores downloaded chunk size
38}
39
40pub trait HeaderUtils {
41    /// # `extract_filename`
42    ///
43    /// When response header provides content disposition or any other keys to
44    /// provide file name or file type, we can extract it from here.
45    ///
46    /// We can also guess file name from the url and content-type too.
47    fn extract_filename(&self) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
48
49    /// # `extract_file_size`
50    ///
51    /// When response header provides content-range, it is easy to extract the
52    /// actual file size in bytes.
53    ///
54    /// example response: `Content-Range` `bytes 0-0/360996864`
55    ///
56    /// From the above response header, we can extract value in bytes
57    ///
58    /// It will help us downloading partial data with multiple threads.
59    ///
60    fn extract_file_size(&self) -> Result<u64, Box<dyn std::error::Error + Send + Sync>>;
61}
62
63/// Since HeaderMap is imported from Reqwest, we need to define trait and then
64/// implement it to the HeaderMap struct.
65impl HeaderUtils for HeaderMap {
66    fn extract_filename(&self) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
67        if let Some(disposition) = &self.get(CONTENT_DISPOSITION) {
68            let value = disposition.to_str()?;
69            if let Some(filename) = value.split("filename=").nth(1) {
70                return Ok(filename.trim_matches('"').to_string());
71            }
72        }
73        return Err(Box::from("Unable to extract filename".to_owned()));
74        // TODO: guess filename from content type
75    }
76
77    /// Returns the file size in bytes.
78    ///
79    /// If the content-length is is not found or not extracted, it just returns
80    /// Error
81    fn extract_file_size(&self) -> Result<u64, Box<dyn std::error::Error + Send + Sync>> {
82        let &cr = &self
83            .get(CONTENT_RANGE)
84            .ok_or_else(|| Box::<dyn Error + Send + Sync>::from("Content_range not found"))?;
85        let content_range =
86            cr.to_str()?.split("/").into_iter().last().ok_or_else(|| {
87                Box::<dyn Error + Send + Sync>::from("Invalid Content_range_format")
88            })?;
89        Ok(content_range.parse()?)
90    }
91}
92
93/// # Extract file name from Urls
94/// This method is used when we do not have any headers passed for file name
95/// For example: if content disposition is not provided, but there is a valid
96/// filename in the request url
97pub fn extract_filename_from_url(url: &str) -> Option<String> {
98    if let Ok(parsed_url) = Url::parse(&url) {
99        if let Some(segment) = parsed_url.path_segments().and_then(|s| s.last()) {
100            if !segment.is_empty() {
101                return Some(segment.to_string());
102            }
103        }
104    }
105    return None;
106}
107
108impl Downloader {
109    pub fn new<S: Into<String>>(url: S) -> Self {
110        Self {
111            url: url.into(),
112            headers: HeaderMap::new(),
113            file_size: None,
114            filename: None,
115            chunks: Arc::new(Mutex::new(Vec::new())),
116        }
117    }
118
119    async fn get_chunk(
120        &self,
121        range: Option<(u64, u64)>,
122        progress_bar: Option<ProgressBar>,
123        file: Option<Arc<Mutex<File>>>,
124        chunk_index: Option<usize>,
125    ) -> Result<u64, Box<dyn std::error::Error + Send + Sync>> {
126        let client = reqwest::Client::new();
127        let mut builder = client.get(&self.url);
128        if let Some((start, end)) = range {
129            builder = builder.header(RANGE, &format!("bytes={start}-{end}"));
130        }
131        let response = builder.send().await?;
132        let mut stream = response.bytes_stream();
133        let mut downloaded = 0u64;
134        let mut chunk_data = Vec::new();
135
136        // progress_bar
137        while let Some(chunk) = stream.next().await {
138            let chunk = chunk?;
139            chunk_data.extend_from_slice(&chunk);
140            downloaded += chunk.len() as u64;
141
142            if let Some(bar) = &progress_bar {
143                bar.inc(chunk.len() as u64);
144            }
145
146            // Update chunk progress in shared state
147            if let Some(idx) = chunk_index {
148                let mut chunks = self.chunks.lock().await;
149                if idx < chunks.len() {
150                    chunks[idx].downloaded = downloaded;
151                }
152            }
153        }
154
155        // Write to file at correct position
156        if let (Some(file), Some((start, _))) = (file, range) {
157            let mut f = file.lock().await;
158            f.seek(SeekFrom::Start(start)).await?;
159            f.write_all(&chunk_data).await?;
160        }
161
162        if let Some(bar) = progress_bar {
163            bar.finish();
164        }
165
166        Ok(downloaded)
167    }
168
169    /// downloads the file into the provided path
170    /// # Arguments
171    ///
172    /// * `path`: download path
173    /// * `threads`: number of threads to use for downloading.
174    ///   if you pass None, or Some(0), it will defaults to 8
175    ///
176    /// Note: If the download size is less than 1 MB, then it will completely
177    /// ignore threads, and download it as a single thread.
178    ///
179    /// If the file size is unknown at the moment it gets the header, it will
180    /// also ignore threads and skips the progress bar and just shows a simple
181    /// ticker as a feedback to let user know that the process is not is in a
182    /// deadlock state.
183    ///
184    pub async fn download(
185        &mut self,
186        path: &str,
187        threads: Option<u8>,
188    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
189        let client = reqwest::Client::new();
190        let threads: u64 = match threads {
191            Some(0) => 8,
192            Some(count) => count as u64,
193            None => 8,
194        };
195        // get response headers to get file name, length, etc.
196        let response = client
197            .get(&self.url)
198            .header(RANGE, "bytes=0-0")
199            .send()
200            .await?;
201        self.headers = response.headers().clone().to_owned();
202
203        let filename = match &self.headers.extract_filename() {
204            Ok(filename) => filename.to_owned(),
205            Err(_) => match extract_filename_from_url(&self.url) {
206                Some(filename) => filename,
207                None => "download.bin".to_owned(),
208            },
209        };
210        println!("Downloading \"{filename}\"");
211        self.filename = Some(format!("{path}/{filename}").replace("//", "/"));
212
213        if let Ok(file_size) = self.headers.extract_file_size() {
214            self.file_size = Some(file_size);
215            println!("file size: {}", HumanBytes(file_size));
216        } else {
217            println!("Unable to determine the file size. skipping threads");
218        }
219
220        let file = Arc::new(Mutex::new(
221            File::create(self.filename.as_ref().unwrap()).await?,
222        ));
223
224        // handle chunks with threads
225        if let Some(file_size) = self.file_size {
226            // allocate file's size if size is known
227            // this helps seeking to the position and writing the chunk at that position
228            file.lock().await.set_len(file_size).await?;
229
230            let mut start = 0;
231            let mut byte_size = file_size / threads;
232
233            //ignore threads if the file is less than a MB.
234            if file_size < 1024 * 1024 {
235                println!("ℹ️ The file is smaller than 1 MB, so skipping threads.");
236                byte_size = file_size;
237            }
238
239            // split chunks to download
240            while start < file_size {
241                let end = min(start + byte_size, file_size);
242                self.chunks.lock().await.push(Chunk::new(start, end));
243                start = end + 1;
244            }
245
246            let num_chunks = self.chunks.lock().await.len();
247            println!("Created {} chunks for download", num_chunks);
248
249            let multi_progress = Arc::new(MultiProgress::new());
250
251            // Create tasks for concurrent downloading
252            let mut tasks = Vec::new();
253            let chunks_clone = Arc::clone(&self.chunks);
254
255            for i in 0..num_chunks {
256                let chunks = Arc::clone(&chunks_clone);
257                let file_clone = Arc::clone(&file);
258                let url = self.url.clone();
259                let multi_progress_clone = Arc::clone(&multi_progress);
260
261                let task = tokio::spawn(async move {
262                    // Get chunk info
263                    let (start, end) = {
264                        let chunks_guard = chunks.lock().await;
265                        if i >= chunks_guard.len() {
266                            return Err("Chunk index out of bounds".into());
267                        }
268                        (chunks_guard[i].start_byte, chunks_guard[i].end_byte)
269                    };
270
271                    // Create progress bar for this chunk
272                    let chunk_size = end - start + 1;
273                    let progress_bar = multi_progress_clone.add(ProgressBar::new(chunk_size));
274                    progress_bar.set_style(ProgressStyle::with_template(
275                        &format!(
276                            "[Chunk {:03}] {{wide_bar:40.cyan/blue}} {{binary_bytes}}/{{binary_total_bytes}} ({{percent}}%)",
277                            // chunk index starts from 0, but 1 seems natural for human
278                            i+1
279                        )
280                    ).unwrap());
281
282                    // Create a downloader instance for this chunk
283                    let downloader = Downloader {
284                        url,
285                        headers: HeaderMap::new(),
286                        file_size: None,
287                        filename: None,
288                        chunks: chunks,
289                    };
290
291                    // Download the chunk
292                    downloader
293                        .get_chunk(
294                            Some((start, end)),
295                            Some(progress_bar),
296                            Some(file_clone),
297                            Some(i),
298                        )
299                        .await
300                });
301
302                tasks.push(task);
303            }
304
305            // Wait for all downloads to complete
306            println!("Starting concurrent downloads...");
307            let results = future::try_join_all(tasks)
308                .await
309                .map_err(|e| format!("Task join error: {}", e))?;
310
311            let total_downloaded: u64 = results
312                .into_iter()
313                .collect::<Result<Vec<_>, _>>()?
314                .into_iter()
315                .sum();
316
317            println!(
318                "Download completed! Total bytes: {}",
319                HumanBytes(total_downloaded)
320            );
321        } else {
322            // continue without threads when the file size is unknown
323            // and just display ticks instead of progressbar since the size is unknown.
324            let file_clone = Arc::clone(&file);
325            let bar = ProgressBar::new_spinner();
326            bar.enable_steady_tick(Duration::from_millis(100));
327            println!("");
328            bar.set_style(
329                ProgressStyle::with_template(&format!(
330                    "{{spinner:.cyan}} {:?} ({{binary_bytes}} downloaded)",
331                    self.filename.as_ref().unwrap()
332                ))
333                .unwrap()
334                .tick_chars("⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏"),
335            );
336            let _ = self
337                .get_chunk(None, Some(bar), Some(file_clone), None)
338                .await;
339        }
340
341        Ok(())
342    }
343}