Skip to main content

kget/
advanced_download.rs

1use std::error::Error;
2use std::fs::File;
3use std::io::{BufReader, Read, Seek, SeekFrom, Write};
4use std::path::Path;
5use std::sync::{Arc, Mutex};
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::time::Duration;
8use rayon::prelude::*;
9use reqwest::blocking::Client;
10use indicatif::{ProgressBar, ProgressStyle};
11use crate::config::ProxyConfig;
12use sha2::{Sha256, Digest};
13use hex;
14use crate::optimization::Optimizer;
15
16#[cfg(target_family = "unix")]
17use std::os::unix::fs::FileExt;
18#[cfg(target_family = "windows")]
19use std::os::windows::fs::FileExt;
20
21const MIN_CHUNK_SIZE: u64 = 4 * 1024 * 1024; 
22const MAX_RETRIES: usize = 3;
23
24pub struct AdvancedDownloader {
25    client: Client,
26    url: String,
27    output_path: String,
28    quiet_mode: bool,
29    #[allow(dead_code)]
30    proxy: ProxyConfig,
31    optimizer: Optimizer,
32    progress_callback: Option<Arc<dyn Fn(f32) + Send + Sync>>,
33    status_callback: Option<Arc<dyn Fn(String) + Send + Sync>>,
34    cancel_token: Arc<AtomicBool>,
35}
36
37impl AdvancedDownloader {
38    pub fn new(url: String, output_path: String, quiet_mode: bool, proxy_config: ProxyConfig, optimizer: Optimizer) -> Self {
39        let _is_iso = url.to_lowercase().ends_with(".iso");
40        
41        let mut client_builder = Client::builder()
42            .timeout(std::time::Duration::from_secs(300))
43            .connect_timeout(std::time::Duration::from_secs(20))
44            .user_agent("KGet/1.0")
45            .no_gzip() 
46            .no_deflate();
47
48        if proxy_config.enabled {
49            if let Some(proxy_url) = &proxy_config.url {
50                let proxy = match proxy_config.proxy_type {
51                    crate::config::ProxyType::Http => reqwest::Proxy::http(proxy_url),
52                    crate::config::ProxyType::Https => reqwest::Proxy::https(proxy_url),
53                    crate::config::ProxyType::Socks5 => reqwest::Proxy::all(proxy_url),
54                };
55                
56                if let Ok(mut proxy) = proxy {
57                    if let (Some(username), Some(password)) = (&proxy_config.username, &proxy_config.password) {
58                        proxy = proxy.basic_auth(username, password);
59                    }
60                    client_builder = client_builder.proxy(proxy);
61                }
62            }
63        }
64        
65        let client = client_builder.build()
66            .expect("Failed to create HTTP client");
67        
68        Self {
69            client,
70            url,
71            output_path,
72            quiet_mode,
73            proxy: proxy_config,
74            optimizer,
75            progress_callback: None,
76            status_callback: None,
77            cancel_token: Arc::new(AtomicBool::new(false)),
78        }
79    }
80
81    pub fn set_cancel_token(&mut self, token: Arc<AtomicBool>) {
82        self.cancel_token = token;
83    }
84
85    pub fn is_cancelled(&self) -> bool {
86        self.cancel_token.load(Ordering::Relaxed)
87    }
88
89    pub fn set_progress_callback(&mut self, callback: impl Fn(f32) + Send + Sync + 'static) {
90        self.progress_callback = Some(Arc::new(callback));
91    }
92
93    pub fn set_status_callback(&mut self, callback: impl Fn(String) + Send + Sync + 'static) {
94        self.status_callback = Some(Arc::new(callback));
95    }
96
97    fn send_status(&self, msg: &str) {
98        if let Some(cb) = &self.status_callback {
99            cb(msg.to_string());
100        }
101        if !self.quiet_mode {
102            println!("{}", msg);
103        }
104    }
105
106
107    pub fn download(&self) -> Result<(), Box<dyn Error + Send + Sync>> {
108        let is_iso = self.url.to_lowercase().ends_with(".iso");
109        if !self.quiet_mode {
110            println!("Starting advanced download for: {}", self.url);
111            if is_iso {
112                println!("Warning: ISO mode active. Disabling optimizations that could corrupt binary data.");
113            }
114        }
115
116        // Verify if the output path is valid
117        let existing_size = if Path::new(&self.output_path).exists() {
118            let size = std::fs::metadata(&self.output_path)?.len();
119            if !self.quiet_mode {
120                println!("Existing file found with size: {} bytes", size);
121            }
122            Some(size)
123        } else {
124            if !self.quiet_mode {
125                println!("Output file does not exist, starting fresh download");
126            }
127            None
128        };
129
130        // Get the total file size and range support
131        if !self.quiet_mode {
132            println!("Querying server for file size and range support...");
133        }
134        let (total_size, supports_range) = self.get_file_size_and_range()?;
135        if !self.quiet_mode {
136            println!("Total file size: {} bytes", total_size);
137            println!("Server supports range requests: {}", supports_range);
138        }
139
140        if let Some(size) = existing_size {
141            if size > total_size {
142                return Err("Existing file is larger than remote; aborting".into());
143            }
144            if !self.quiet_mode {
145                println!("Resuming download from byte: {}", size);
146            }
147        }
148
149        // Create a progress bar if not quiet or if we have a callback
150        let progress = if !self.quiet_mode || self.progress_callback.is_some() {
151            let bar = ProgressBar::new(total_size);
152            if self.quiet_mode {
153                bar.set_draw_target(indicatif::ProgressDrawTarget::hidden());
154            } else {
155                bar.set_style(ProgressStyle::with_template(
156                    "{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})"
157                ).unwrap().progress_chars("#>-"));
158            }
159            Some(Arc::new(Mutex::new(bar)))
160        } else {
161            None
162        };
163
164        // Create or open the output file and preallocate
165        if !self.quiet_mode {
166            println!("Preparing output file: {}", self.output_path);
167        }
168        let file = if existing_size.is_some() {
169            File::options().read(true).write(true).open(&self.output_path)?
170        } else {
171            File::create(&self.output_path)?
172        };
173        file.set_len(total_size)?;
174        if !self.quiet_mode {
175            println!("File preallocated to {} bytes", total_size);
176        }
177
178        // If range not supported, do a single download
179        if !supports_range {
180            if !self.quiet_mode {
181                println!("Range requests not supported, falling back to single-threaded download");
182            }
183            self.download_whole(&file, existing_size.unwrap_or(0), progress.clone())?;
184            if let Some(ref bar) = progress {
185                bar.lock().unwrap().finish_with_message("Download completed");
186            }
187            if !self.quiet_mode {
188                println!("Single-threaded download completed");
189            }
190            return Ok(());
191        }
192
193        // Calculate chunks for parallel download
194        if !self.quiet_mode {
195            println!("Calculating download chunks...");
196        }
197        let chunks = self.calculate_chunks(total_size, existing_size)?;
198        if !self.quiet_mode {
199            println!("Download will be split into {} chunks", chunks.len());
200        }
201
202        // Download parallel chunks
203        if !self.quiet_mode {
204            println!("Starting parallel chunk downloads...");
205        }
206        self.download_chunks_parallel(chunks, &file, progress.clone())?;
207
208        if let Some(ref bar) = progress {
209            bar.lock().unwrap().finish_with_message("Download completed");
210        }
211
212        // Verify download integrity
213        if !self.quiet_mode || self.status_callback.is_some() {
214            if is_iso {
215                
216                let should_verify = if self.status_callback.is_some() {
217                    true 
218                } else {
219                    println!("\nThis is an ISO file. Would you like to verify its integrity? (y/N)");
220                    let mut input = String::new();
221                    std::io::stdin().read_line(&mut input).is_ok() && input.trim().to_lowercase() == "y"
222                };
223
224                if should_verify {
225                    self.verify_integrity(total_size)?;
226                }
227            } else {
228                let metadata = std::fs::metadata(&self.output_path)?;
229                if metadata.len() != total_size {
230                    return Err(format!("File size mismatch: expected {} bytes, got {} bytes", total_size, metadata.len()).into());
231                }
232            }
233            self.send_status("Advanced download completed successfully!");
234        }
235
236        Ok(())
237    }
238
239    fn get_file_size_and_range(&self) -> Result<(u64, bool), Box<dyn Error + Send + Sync>> {
240        let response = self.client.head(&self.url).send()?;
241        let content_length = response.headers()
242            .get(reqwest::header::CONTENT_LENGTH)
243            .and_then(|v| v.to_str().ok())
244            .and_then(|s| s.parse::<u64>().ok())
245            .ok_or("Could not determine file size")?;
246
247        let accepts_range = response.headers()
248            .get(reqwest::header::ACCEPT_RANGES)
249            .and_then(|v| v.to_str().ok())
250            .map(|s| s.eq_ignore_ascii_case("bytes"))
251            .unwrap_or(false);
252
253        Ok((content_length, accepts_range))
254    }
255
256    fn calculate_chunks(&self, total_size: u64, existing_size: Option<u64>) -> Result<Vec<(u64, u64)>, Box<dyn Error + Send + Sync>> {
257        let mut chunks = Vec::new();
258        let start_from = existing_size.unwrap_or(0);
259
260        
261        let parallelism = rayon::current_num_threads() as u64;
262        let target_chunks = parallelism.saturating_mul(2).max(2); // Reduced to avoid overwhelming servers
263        let chunk_size = ((total_size / target_chunks).max(MIN_CHUNK_SIZE)).min(64 * 1024 * 1024);
264
265        let mut start = start_from;
266        while start < total_size {
267            let end = (start + chunk_size).min(total_size);
268            chunks.push((start, end));
269            start = end;
270        }
271
272        Ok(chunks)
273    }
274
275    fn download_whole(&self, file: &File, offset: u64, progress: Option<Arc<Mutex<ProgressBar>>>) -> Result<(), Box<dyn Error + Send + Sync>> {
276        let response = self.client.get(&self.url).send()?;
277        if offset > 0 {
278            // Resume not possible without range; warn
279            return Err("Server does not support range; cannot resume partial file".into());
280        }
281
282        let mut reader = BufReader::new(response);
283        let mut f = file.try_clone()?;
284        f.seek(SeekFrom::Start(0))?;
285
286        struct ProgressWriter<'a, W> {
287            inner: W,
288            progress: Option<Arc<Mutex<ProgressBar>>>,
289            callback: Option<&'a Arc<dyn Fn(f32) + Send + Sync>>,
290        }
291
292        impl<'a, W: Write> Write for ProgressWriter<'a, W> {
293            fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
294                let n = self.inner.write(buf)?;
295                if let Some(ref bar) = self.progress {
296                    let guard = bar.lock().unwrap();
297                    guard.inc(n as u64);
298                    if let Some(cb) = self.callback {
299                        let pos = guard.position();
300                        let len = guard.length().unwrap_or(1);
301                        drop(guard);
302                        (cb)(pos as f32 / len as f32);
303                    }
304                }
305                Ok(n)
306            }
307
308            fn flush(&mut self) -> std::io::Result<()> {
309                self.inner.flush()
310            }
311        }
312
313        let mut writer = ProgressWriter { 
314            inner: f, 
315            progress,
316            callback: self.progress_callback.as_ref(),
317        };
318        std::io::copy(&mut reader, &mut writer)?;
319
320        Ok(())
321    }
322
323    fn download_chunks_parallel(&self, chunks: Vec<(u64, u64)>, file: &File, progress: Option<Arc<Mutex<ProgressBar>>>) -> Result<(), Box<dyn Error + Send + Sync>> {
324        let file = Arc::new(file);
325        let client = Arc::new(self.client.clone());
326        let url = Arc::new(self.url.clone());
327        let _optimizer = Arc::new(self.optimizer.clone());
328        let progress_callback = self.progress_callback.clone();
329        let cancel_token = self.cancel_token.clone();
330
331        chunks.par_iter().try_for_each(|&(start, end)| {
332            // Check for cancellation before starting chunk
333            if cancel_token.load(Ordering::Relaxed) {
334                return Err("Download cancelled".into());
335            }
336
337            let range = format!("bytes={}-{}", start, end - 1);
338            let range_header = reqwest::header::HeaderValue::from_str(&range)
339                .map_err(|e| format!("Invalid range header {}: {}", range, e))?;
340
341            for retry in 0..=MAX_RETRIES {
342                // Check for cancellation on each retry
343                if cancel_token.load(Ordering::Relaxed) {
344                    return Err("Download cancelled".into());
345                }
346
347                let request = client.get(url.as_str());
348                let request = request.header(reqwest::header::RANGE, range_header.clone());
349
350                match request.send() {
351                    Ok(mut response) => {
352                        let status = response.status();
353                        if status.is_success() {
354                            // Use FileExt to write at specific offset without seeking shared cursor
355                            // This prevents race conditions when multiple threads write to the same file
356                            
357                            let mut current_pos = start;
358                            let mut buffer = [0u8; 16384]; 
359                            
360                            while current_pos < end {
361                                // Check for cancellation periodically during download
362                                if cancel_token.load(Ordering::Relaxed) {
363                                    return Err("Download cancelled".into());
364                                }
365
366                                let limit = (end - current_pos).min(buffer.len() as u64);
367                                let n = response.read(&mut buffer[..limit as usize])?;
368                                if n == 0 { break; }
369                                
370                                #[cfg(target_family = "unix")]
371                                file.write_at(&buffer[..n], current_pos)?;
372                                
373                                #[cfg(target_family = "windows")]
374                                file.seek_write(&buffer[..n], current_pos)?;
375                                
376                                current_pos += n as u64;
377                            
378                                if let Some(ref bar) = progress {
379                                    let guard = bar.lock().unwrap();
380                                    guard.inc(n as u64);
381                                    if let Some(ref cb) = progress_callback {
382                                        let pos = guard.position();
383                                        let len = guard.length().unwrap_or(1);
384                                        drop(guard);
385                                        (cb)(pos as f32 / len as f32);
386                                    }
387                                }
388                            }
389
390                            return Ok::<(), Box<dyn Error + Send + Sync>>(());
391                        } else if status.as_u16() == 416 {
392                            if retry == MAX_RETRIES {
393                                return Err(format!("Failed to download chunk {}-{}: HTTP {}", start, end, status).into());
394                            }
395                            std::thread::sleep(Duration::from_millis(250 * (retry as u64 + 1)));
396                        }
397                    }
398                    Err(e) => {
399                        if retry == MAX_RETRIES {
400                            return Err(format!("Failed to download chunk {}-{}: {}", start, end, e).into());
401                        }
402                        std::thread::sleep(Duration::from_millis(250 * (retry as u64 + 1)));
403                    }
404                }
405            }
406            Err(format!("Failed to download chunk {}-{} after retries", start, end).into())
407        })?;
408
409        Ok(())
410    }
411
412        fn verify_integrity(&self, expected_size: u64) -> Result<(), Box<dyn Error + Send + Sync>> {
413            let metadata = std::fs::metadata(&self.output_path)?;
414        let actual_size = metadata.len();
415        
416        if actual_size != expected_size {
417            return Err(format!("File size mismatch: expected {} bytes, got {} bytes", expected_size, actual_size).into());
418        }
419        
420        self.send_status(&format!("File size verified: {} bytes", actual_size));
421
422        // Calculate SHA256 hash for corruption check
423        self.send_status("Calculating SHA256 hash...");
424        
425        let mut file = File::open(&self.output_path)?;
426        let mut hasher = Sha256::new();
427        let mut buffer = [0; 8192];
428        loop {
429            let n = file.read(&mut buffer)?;
430            if n == 0 {
431                break;
432            }
433            hasher.update(&buffer[..n]);
434        }
435        let hash = hasher.finalize();
436        let hash_hex = hex::encode(hash);
437        
438        self.send_status(&format!("SHA256 hash: {}", hash_hex));
439        self.send_status("Integrity check passed - file is not corrupted");
440
441        Ok(())
442    }
443}