kget/
advanced_download.rs

1use std::error::Error;
2use std::fs::File;
3use std::io::{BufReader, Read, Seek, SeekFrom, Write, BufWriter};
4use std::path::Path;
5use std::sync::{Arc, Mutex};
6use std::time::Duration;
7use rayon::prelude::*;
8use reqwest::blocking::Client;
9use indicatif::{ProgressBar, ProgressStyle};
10use crate::utils::print;
11use crate::config::ProxyConfig;
12use sha2::{Sha256, Digest};
13use hex;
14use crate::optimization::Optimizer;
15
16#[cfg(target_family = "unix")]
17use std::os::unix::fs::FileExt;
18#[cfg(target_family = "windows")]
19use std::os::windows::fs::FileExt;
20
21const MIN_CHUNK_SIZE: u64 = 4 * 1024 * 1024; 
22const MAX_RETRIES: usize = 3;
23
24pub struct AdvancedDownloader {
25    client: Client,
26    url: String,
27    output_path: String,
28    quiet_mode: bool,
29    #[allow(dead_code)]
30    proxy: ProxyConfig,
31    optimizer: Optimizer,
32    progress_callback: Option<Arc<dyn Fn(f32) + Send + Sync>>,
33    status_callback: Option<Arc<dyn Fn(String) + Send + Sync>>,
34}
35
36impl AdvancedDownloader {
37    pub fn new(url: String, output_path: String, quiet_mode: bool, proxy_config: ProxyConfig, optimizer: Optimizer) -> Self {
38        let is_iso = url.to_lowercase().ends_with(".iso");
39        
40        let mut client_builder = Client::builder()
41            .timeout(std::time::Duration::from_secs(300))
42            .connect_timeout(std::time::Duration::from_secs(20))
43            .user_agent("KGet/1.0")
44            .no_gzip() 
45            .no_deflate();
46
47        if proxy_config.enabled {
48            if let Some(proxy_url) = &proxy_config.url {
49                let proxy = match proxy_config.proxy_type {
50                    crate::config::ProxyType::Http => reqwest::Proxy::http(proxy_url),
51                    crate::config::ProxyType::Https => reqwest::Proxy::https(proxy_url),
52                    crate::config::ProxyType::Socks5 => reqwest::Proxy::all(proxy_url),
53                };
54                
55                if let Ok(mut proxy) = proxy {
56                    if let (Some(username), Some(password)) = (&proxy_config.username, &proxy_config.password) {
57                        proxy = proxy.basic_auth(username, password);
58                    }
59                    client_builder = client_builder.proxy(proxy);
60                }
61            }
62        }
63        
64        let client = client_builder.build()
65            .expect("Failed to create HTTP client");
66        
67        Self {
68            client,
69            url,
70            output_path,
71            quiet_mode,
72            proxy: proxy_config,
73            optimizer,
74            progress_callback: None,
75            status_callback: None,
76        }
77    }
78
79    pub fn set_progress_callback(&mut self, callback: impl Fn(f32) + Send + Sync + 'static) {
80        self.progress_callback = Some(Arc::new(callback));
81    }
82
83    pub fn set_status_callback(&mut self, callback: impl Fn(String) + Send + Sync + 'static) {
84        self.status_callback = Some(Arc::new(callback));
85    }
86
87    fn send_status(&self, msg: &str) {
88        if let Some(cb) = &self.status_callback {
89            cb(msg.to_string());
90        }
91        if !self.quiet_mode {
92            println!("{}", msg);
93        }
94    }
95
96
97    pub fn download(&self) -> Result<(), Box<dyn Error + Send + Sync>> {
98        let is_iso = self.url.to_lowercase().ends_with(".iso");
99        if !self.quiet_mode {
100            println!("Starting advanced download for: {}", self.url);
101            if is_iso {
102                println!("Warning: ISO mode active. Disabling optimizations that could corrupt binary data.");
103            }
104        }
105
106        // Verify if the output path is valid
107        let existing_size = if Path::new(&self.output_path).exists() {
108            let size = std::fs::metadata(&self.output_path)?.len();
109            if !self.quiet_mode {
110                println!("Existing file found with size: {} bytes", size);
111            }
112            Some(size)
113        } else {
114            if !self.quiet_mode {
115                println!("Output file does not exist, starting fresh download");
116            }
117            None
118        };
119
120        // Get the total file size and range support
121        if !self.quiet_mode {
122            println!("Querying server for file size and range support...");
123        }
124        let (total_size, supports_range) = self.get_file_size_and_range()?;
125        if !self.quiet_mode {
126            println!("Total file size: {} bytes", total_size);
127            println!("Server supports range requests: {}", supports_range);
128        }
129
130        if let Some(size) = existing_size {
131            if size > total_size {
132                return Err("Existing file is larger than remote; aborting".into());
133            }
134            if !self.quiet_mode {
135                println!("Resuming download from byte: {}", size);
136            }
137        }
138
139        // Create a progress bar if not quiet or if we have a callback
140        let progress = if !self.quiet_mode || self.progress_callback.is_some() {
141            let bar = ProgressBar::new(total_size);
142            if self.quiet_mode {
143                bar.set_draw_target(indicatif::ProgressDrawTarget::hidden());
144            } else {
145                bar.set_style(ProgressStyle::with_template(
146                    "{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})"
147                ).unwrap().progress_chars("#>-"));
148            }
149            Some(Arc::new(Mutex::new(bar)))
150        } else {
151            None
152        };
153
154        // Create or open the output file and preallocate
155        if !self.quiet_mode {
156            println!("Preparing output file: {}", self.output_path);
157        }
158        let file = if existing_size.is_some() {
159            File::options().read(true).write(true).open(&self.output_path)?
160        } else {
161            File::create(&self.output_path)?
162        };
163        file.set_len(total_size)?;
164        if !self.quiet_mode {
165            println!("File preallocated to {} bytes", total_size);
166        }
167
168        // If range not supported, do a single download
169        if !supports_range {
170            if !self.quiet_mode {
171                println!("Range requests not supported, falling back to single-threaded download");
172            }
173            self.download_whole(&file, existing_size.unwrap_or(0), progress.clone())?;
174            if let Some(ref bar) = progress {
175                bar.lock().unwrap().finish_with_message("Download completed");
176            }
177            if !self.quiet_mode {
178                println!("Single-threaded download completed");
179            }
180            return Ok(());
181        }
182
183        // Calculate chunks for parallel download
184        if !self.quiet_mode {
185            println!("Calculating download chunks...");
186        }
187        let chunks = self.calculate_chunks(total_size, existing_size)?;
188        if !self.quiet_mode {
189            println!("Download will be split into {} chunks", chunks.len());
190        }
191
192        // Download parallel chunks
193        if !self.quiet_mode {
194            println!("Starting parallel chunk downloads...");
195        }
196        self.download_chunks_parallel(chunks, &file, progress.clone())?;
197
198        if let Some(ref bar) = progress {
199            bar.lock().unwrap().finish_with_message("Download completed");
200        }
201
202        // Verify download integrity
203        if !self.quiet_mode || self.status_callback.is_some() {
204            if is_iso {
205                
206                let should_verify = if self.status_callback.is_some() {
207                    true 
208                } else {
209                    println!("\nThis is an ISO file. Would you like to verify its integrity? (y/N)");
210                    let mut input = String::new();
211                    std::io::stdin().read_line(&mut input).is_ok() && input.trim().to_lowercase() == "y"
212                };
213
214                if should_verify {
215                    self.verify_integrity(total_size)?;
216                }
217            } else {
218                let metadata = std::fs::metadata(&self.output_path)?;
219                if metadata.len() != total_size {
220                    return Err(format!("File size mismatch: expected {} bytes, got {} bytes", total_size, metadata.len()).into());
221                }
222            }
223            self.send_status("Advanced download completed successfully!");
224        }
225
226        Ok(())
227    }
228
229    fn get_file_size_and_range(&self) -> Result<(u64, bool), Box<dyn Error + Send + Sync>> {
230        let response = self.client.head(&self.url).send()?;
231        let content_length = response.headers()
232            .get(reqwest::header::CONTENT_LENGTH)
233            .and_then(|v| v.to_str().ok())
234            .and_then(|s| s.parse::<u64>().ok())
235            .ok_or("Could not determine file size")?;
236
237        let accepts_range = response.headers()
238            .get(reqwest::header::ACCEPT_RANGES)
239            .and_then(|v| v.to_str().ok())
240            .map(|s| s.eq_ignore_ascii_case("bytes"))
241            .unwrap_or(false);
242
243        Ok((content_length, accepts_range))
244    }
245
246    fn calculate_chunks(&self, total_size: u64, existing_size: Option<u64>) -> Result<Vec<(u64, u64)>, Box<dyn Error + Send + Sync>> {
247        let mut chunks = Vec::new();
248        let start_from = existing_size.unwrap_or(0);
249
250        
251        let parallelism = rayon::current_num_threads() as u64;
252        let target_chunks = parallelism.saturating_mul(2).max(2); // Reduced to avoid overwhelming servers
253        let chunk_size = ((total_size / target_chunks).max(MIN_CHUNK_SIZE)).min(64 * 1024 * 1024);
254
255        let mut start = start_from;
256        while start < total_size {
257            let end = (start + chunk_size).min(total_size);
258            chunks.push((start, end));
259            start = end;
260        }
261
262        Ok(chunks)
263    }
264
265    fn download_whole(&self, file: &File, offset: u64, progress: Option<Arc<Mutex<ProgressBar>>>) -> Result<(), Box<dyn Error + Send + Sync>> {
266        let mut response = self.client.get(&self.url).send()?;
267        if offset > 0 {
268            // Resume not possible without range; warn
269            return Err("Server does not support range; cannot resume partial file".into());
270        }
271
272        let mut reader = BufReader::new(response);
273        let mut f = file.try_clone()?;
274        f.seek(SeekFrom::Start(0))?;
275
276        struct ProgressWriter<'a, W> {
277            inner: W,
278            progress: Option<Arc<Mutex<ProgressBar>>>,
279            callback: Option<&'a Arc<dyn Fn(f32) + Send + Sync>>,
280        }
281
282        impl<'a, W: Write> Write for ProgressWriter<'a, W> {
283            fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
284                let n = self.inner.write(buf)?;
285                if let Some(ref bar) = self.progress {
286                    let mut guard = bar.lock().unwrap();
287                    guard.inc(n as u64);
288                    if let Some(cb) = self.callback {
289                        let pos = guard.position();
290                        let len = guard.length().unwrap_or(1);
291                        drop(guard);
292                        (cb)(pos as f32 / len as f32);
293                    }
294                }
295                Ok(n)
296            }
297
298            fn flush(&mut self) -> std::io::Result<()> {
299                self.inner.flush()
300            }
301        }
302
303        let mut writer = ProgressWriter { 
304            inner: f, 
305            progress,
306            callback: self.progress_callback.as_ref(),
307        };
308        std::io::copy(&mut reader, &mut writer)?;
309
310        Ok(())
311    }
312
313    fn download_chunks_parallel(&self, chunks: Vec<(u64, u64)>, file: &File, progress: Option<Arc<Mutex<ProgressBar>>>) -> Result<(), Box<dyn Error + Send + Sync>> {
314        let file = Arc::new(file);
315        let client = Arc::new(self.client.clone());
316        let url = Arc::new(self.url.clone());
317        let _optimizer = Arc::new(self.optimizer.clone());
318        let progress_callback = self.progress_callback.clone();
319
320        chunks.par_iter().try_for_each(|&(start, end)| {
321            let range = format!("bytes={}-{}", start, end - 1);
322            let range_header = reqwest::header::HeaderValue::from_str(&range)
323                .map_err(|e| format!("Invalid range header {}: {}", range, e))?;
324
325            for retry in 0..=MAX_RETRIES {
326                let request = client.get(url.as_str());
327                let request = request.header(reqwest::header::RANGE, range_header.clone());
328
329                match request.send() {
330                    Ok(mut response) => {
331                        let status = response.status();
332                        if status.is_success() {
333                            // Use FileExt to write at specific offset without seeking shared cursor
334                            // This prevents race conditions when multiple threads write to the same file
335                            
336                            let mut current_pos = start;
337                            let mut buffer = [0u8; 16384]; 
338                            
339                            while current_pos < end {
340                                let limit = (end - current_pos).min(buffer.len() as u64);
341                                let n = response.read(&mut buffer[..limit as usize])?;
342                                if n == 0 { break; }
343                                
344                                #[cfg(target_family = "unix")]
345                                file.write_at(&buffer[..n], current_pos)?;
346                                
347                                #[cfg(target_family = "windows")]
348                                file.seek_write(&buffer[..n], current_pos)?;
349                                
350                                current_pos += n as u64;
351                            
352                                if let Some(ref bar) = progress {
353                                    let mut guard = bar.lock().unwrap();
354                                    guard.inc(n as u64);
355                                    if let Some(ref cb) = progress_callback {
356                                        let pos = guard.position();
357                                        let len = guard.length().unwrap_or(1);
358                                        drop(guard);
359                                        (cb)(pos as f32 / len as f32);
360                                    }
361                                }
362                            }
363
364                            return Ok::<(), Box<dyn Error + Send + Sync>>(());
365                        } else if status.as_u16() == 416 {
366                            if retry == MAX_RETRIES {
367                                return Err(format!("Failed to download chunk {}-{}: HTTP {}", start, end, status).into());
368                            }
369                            std::thread::sleep(Duration::from_millis(250 * (retry as u64 + 1)));
370                        }
371                    }
372                    Err(e) => {
373                        if retry == MAX_RETRIES {
374                            return Err(format!("Failed to download chunk {}-{}: {}", start, end, e).into());
375                        }
376                        std::thread::sleep(Duration::from_millis(250 * (retry as u64 + 1)));
377                    }
378                }
379            }
380            Err(format!("Failed to download chunk {}-{} after retries", start, end).into())
381        })?;
382
383        Ok(())
384    }
385
386        fn verify_integrity(&self, expected_size: u64) -> Result<(), Box<dyn Error + Send + Sync>> {
387            let metadata = std::fs::metadata(&self.output_path)?;
388        let actual_size = metadata.len();
389        
390        if actual_size != expected_size {
391            return Err(format!("File size mismatch: expected {} bytes, got {} bytes", expected_size, actual_size).into());
392        }
393        
394        self.send_status(&format!("File size verified: {} bytes", actual_size));
395
396        // Calculate SHA256 hash for corruption check
397        self.send_status("Calculating SHA256 hash...");
398        
399        let mut file = File::open(&self.output_path)?;
400        let mut hasher = Sha256::new();
401        let mut buffer = [0; 8192];
402        loop {
403            let n = file.read(&mut buffer)?;
404            if n == 0 {
405                break;
406            }
407            hasher.update(&buffer[..n]);
408        }
409        let hash = hasher.finalize();
410        let hash_hex = hex::encode(hash);
411        
412        self.send_status(&format!("SHA256 hash: {}", hash_hex));
413        self.send_status("Integrity check passed - file is not corrupted");
414
415        Ok(())
416    }
417}