Skip to main content

kget/
advanced_download.rs

1//! Advanced parallel download functionality with resume support.
2//!
3//! The [`AdvancedDownloader`] provides high-performance downloads using:
4//! - **Parallel connections**: Split files into chunks downloaded simultaneously
5//! - **Resume support**: Continue interrupted downloads from where they left off
6//! - **Progress callbacks**: Real-time progress and status updates
7//! - **Cancellation**: Graceful download cancellation via atomic tokens
8//!
9//! # Example
10//!
11//! ```rust,no_run
12//! use kget::{AdvancedDownloader, ProxyConfig, Optimizer};
13//! use std::sync::Arc;
14//!
15//! let mut downloader = AdvancedDownloader::new(
16//!     "https://releases.ubuntu.com/22.04/ubuntu-22.04-desktop-amd64.iso".to_string(),
17//!     "ubuntu.iso".to_string(),
18//!     false,
19//!     ProxyConfig::default(),
20//!     Optimizer::new(),
21//! );
22//!
23//! // Set progress callback (0.0 to 1.0)
24//! downloader.set_progress_callback(Arc::new(|progress| {
25//!     println!("Progress: {:.1}%", progress * 100.0);
26//! }));
27//!
28//! // Set status callback for messages
29//! downloader.set_status_callback(Arc::new(|msg| {
30//!     println!("Status: {}", msg);
31//! }));
32//!
33//! // Start download
34//! downloader.download().unwrap();
35//! ```
36//!
37//! # Parallel Downloads
38//!
39//! The downloader automatically determines the optimal number of connections
40//! based on the [`Optimizer`] configuration. For large files,
41//! this can provide significant speed improvements.
42
43use std::error::Error;
44use std::fs::File;
45use std::io::{BufReader, Read, Seek, SeekFrom, Write};
46use std::path::Path;
47use std::sync::{Arc, Mutex};
48use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
49use std::time::{Duration, Instant};
50use rayon::prelude::*;
51use reqwest::blocking::Client;
52use indicatif::{ProgressBar, ProgressStyle};
53use crate::config::ProxyConfig;
54use sha2::{Sha256, Digest};
55use hex;
56use crate::optimization::Optimizer;
57
58#[cfg(target_family = "unix")]
59use std::os::unix::fs::FileExt;
60#[cfg(target_family = "windows")]
61use std::os::windows::fs::FileExt;
62
63/// Minimum chunk size for parallel downloads (4 MB)
64const MIN_CHUNK_SIZE: u64 = 4 * 1024 * 1024; 
65/// Maximum retry attempts per chunk
66const MAX_RETRIES: usize = 3;
67
68/// High-performance downloader with parallel connections and resume support.
69///
70/// `AdvancedDownloader` is the recommended way to download large files. It provides:
71///
72/// - **Parallel chunk downloads**: Splits files into segments downloaded simultaneously
73/// - **Automatic resume**: Detects existing partial files and resumes from last position
74/// - **Server compatibility**: Falls back to single-stream if server doesn't support ranges
75/// - **ISO optimization**: Disables compression for binary files to prevent corruption
76/// - **Progress tracking**: Real-time callbacks for UI integration
77/// - **Cancellation support**: Stop downloads gracefully via atomic cancel token
78///
79/// # Example
80///
81/// ```rust,no_run
82/// use kget::{AdvancedDownloader, ProxyConfig, Optimizer};
83///
84/// let downloader = AdvancedDownloader::new(
85///     "https://example.com/large-file.zip".to_string(),
86///     "large-file.zip".to_string(),
87///     false,  // quiet_mode
88///     ProxyConfig::default(),
89///     Optimizer::new(),
90/// );
91///
92/// downloader.download().expect("Download failed");
93/// ```
94///
95/// # With Progress Tracking
96///
97/// ```rust,no_run
98/// use kget::{AdvancedDownloader, ProxyConfig, Optimizer};
99/// use std::sync::Arc;
100///
101/// let mut dl = AdvancedDownloader::new(
102///     "https://example.com/file.iso".to_string(),
103///     "file.iso".to_string(),
104///     true, // quiet mode (no stdout)
105///     ProxyConfig::default(),
106///     Optimizer::new(),
107/// );
108///
109/// dl.set_progress_callback(Arc::new(|p| {
110///     // p is 0.0 to 1.0
111///     update_ui_progress(p);
112/// }));
113///
114/// dl.set_status_callback(Arc::new(|msg| {
115///     log::info!("{}", msg);
116/// }));
117///
118/// dl.download().unwrap();
119///
120/// fn update_ui_progress(p: f32) {
121///     // Update your UI here
122/// }
123/// ```
124pub struct AdvancedDownloader {
125    client: Client,
126    url: String,
127    output_path: String,
128    quiet_mode: bool,
129    #[allow(dead_code)]
130    proxy: ProxyConfig,
131    optimizer: Optimizer,
132    progress_callback: Option<Arc<dyn Fn(f32) + Send + Sync>>,
133    status_callback: Option<Arc<dyn Fn(String) + Send + Sync>>,
134    cancel_token: Arc<AtomicBool>,
135}
136
137impl AdvancedDownloader {
138    /// Create a new `AdvancedDownloader` instance.
139    ///
140    /// # Arguments
141    ///
142    /// * `url` - URL to download from
143    /// * `output_path` - Local path for the downloaded file
144    /// * `quiet_mode` - If true, suppress console output
145    /// * `proxy_config` - Proxy settings (use `ProxyConfig::default()` for direct connection)
146    /// * `optimizer` - Optimizer for connection settings
147    ///
148    /// # Example
149    ///
150    /// ```rust,no_run
151    /// use kget::{AdvancedDownloader, ProxyConfig, Optimizer};
152    ///
153    /// let dl = AdvancedDownloader::new(
154    ///     "https://example.com/file.zip".to_string(),
155    ///     "./downloads/file.zip".to_string(),
156    ///     false,
157    ///     ProxyConfig::default(),
158    ///     Optimizer::new(),
159    /// );
160    /// ```
161    pub fn new(url: String, output_path: String, quiet_mode: bool, proxy_config: ProxyConfig, optimizer: Optimizer) -> Self {
162        let _is_iso = url.to_lowercase().ends_with(".iso");
163        
164        let mut client_builder = Client::builder()
165            .timeout(std::time::Duration::from_secs(300))
166            .connect_timeout(std::time::Duration::from_secs(20))
167            .user_agent("KGet/1.0")
168            .no_gzip() 
169            .no_deflate();
170
171        if proxy_config.enabled {
172            if let Some(proxy_url) = &proxy_config.url {
173                let proxy = match proxy_config.proxy_type {
174                    crate::config::ProxyType::Http => reqwest::Proxy::http(proxy_url),
175                    crate::config::ProxyType::Https => reqwest::Proxy::https(proxy_url),
176                    crate::config::ProxyType::Socks5 => reqwest::Proxy::all(proxy_url),
177                };
178                
179                if let Ok(mut proxy) = proxy {
180                    if let (Some(username), Some(password)) = (&proxy_config.username, &proxy_config.password) {
181                        proxy = proxy.basic_auth(username, password);
182                    }
183                    client_builder = client_builder.proxy(proxy);
184                }
185            }
186        }
187        
188        let client = client_builder.build()
189            .expect("Failed to create HTTP client");
190        
191        Self {
192            client,
193            url,
194            output_path,
195            quiet_mode,
196            proxy: proxy_config,
197            optimizer,
198            progress_callback: None,
199            status_callback: None,
200            cancel_token: Arc::new(AtomicBool::new(false)),
201        }
202    }
203
204    /// Set a custom cancellation token for graceful download interruption.
205    ///
206    /// When the token is set to `true`, the download will stop at the next
207    /// checkpoint and return an error.
208    ///
209    /// # Example
210    ///
211    /// ```rust,no_run
212    /// use kget::{AdvancedDownloader, ProxyConfig, Optimizer};
213    /// use std::sync::Arc;
214    /// use std::sync::atomic::AtomicBool;
215    ///
216    /// let cancel = Arc::new(AtomicBool::new(false));
217    /// let mut dl = AdvancedDownloader::new(/* ... */
218    /// #    "".to_string(), "".to_string(), false, ProxyConfig::default(), Optimizer::new()
219    /// );
220    /// dl.set_cancel_token(cancel.clone());
221    ///
222    /// // In another thread:
223    /// // cancel.store(true, std::sync::atomic::Ordering::Relaxed);
224    /// ```
225    pub fn set_cancel_token(&mut self, token: Arc<AtomicBool>) {
226        self.cancel_token = token;
227    }
228
229    /// Check if the download has been cancelled.
230    pub fn is_cancelled(&self) -> bool {
231        self.cancel_token.load(Ordering::Relaxed)
232    }
233
234    /// Set a callback for progress updates.
235    ///
236    /// The callback receives a value from 0.0 (start) to 1.0 (complete).
237    ///
238    /// # Example
239    ///
240    /// ```rust,no_run
241    /// # use kget::{AdvancedDownloader, ProxyConfig, Optimizer};
242    /// # let mut dl = AdvancedDownloader::new("".to_string(), "".to_string(), false, ProxyConfig::default(), Optimizer::new());
243    /// dl.set_progress_callback(|progress| {
244    ///     println!("Downloaded: {:.1}%", progress * 100.0);
245    /// });
246    /// ```
247    pub fn set_progress_callback(&mut self, callback: impl Fn(f32) + Send + Sync + 'static) {
248        self.progress_callback = Some(Arc::new(callback));
249    }
250
251    /// Set a callback for status messages.
252    ///
253    /// Receives human-readable status updates during the download.
254    ///
255    /// # Example
256    ///
257    /// ```rust,no_run
258    /// # use kget::{AdvancedDownloader, ProxyConfig, Optimizer};
259    /// # let mut dl = AdvancedDownloader::new("".to_string(), "".to_string(), false, ProxyConfig::default(), Optimizer::new());
260    /// dl.set_status_callback(|msg| {
261    ///     log::info!("Download status: {}", msg);
262    /// });
263    /// ```
264    pub fn set_status_callback(&mut self, callback: impl Fn(String) + Send + Sync + 'static) {
265        self.status_callback = Some(Arc::new(callback));
266    }
267
268    fn send_status(&self, msg: &str) {
269        if let Some(cb) = &self.status_callback {
270            cb(msg.to_string());
271        }
272        if !self.quiet_mode {
273            println!("{}", msg);
274        }
275    }
276
277    /// Start the download.
278    ///
279    /// This method:
280    /// 1. Checks for existing partial file (resume support)
281    /// 2. Queries server for file size and range support
282    /// 3. Downloads using parallel connections if supported
283    /// 4. Falls back to single-stream if server doesn't support ranges
284    ///
285    /// # Returns
286    ///
287    /// Returns `Ok(())` on successful download, or an error if the download fails.
288    ///
289    /// # Errors
290    ///
291    /// - Network connection failures
292    /// - Existing file larger than remote (corrupted state)
293    /// - Cancellation via cancel token
294    /// - Disk I/O errors
295    pub fn download(&self) -> Result<(), Box<dyn Error + Send + Sync>> {
296        let is_iso = self.url.to_lowercase().ends_with(".iso");
297        if !self.quiet_mode {
298            println!("Starting advanced download for: {}", self.url);
299            if is_iso {
300                println!("Warning: ISO mode active. Disabling optimizations that could corrupt binary data.");
301            }
302        }
303
304        // Verify if the output path is valid
305        let existing_size = if Path::new(&self.output_path).exists() {
306            let size = std::fs::metadata(&self.output_path)?.len();
307            if !self.quiet_mode {
308                println!("Existing file found with size: {} bytes", size);
309            }
310            Some(size)
311        } else {
312            if !self.quiet_mode {
313                println!("Output file does not exist, starting fresh download");
314            }
315            None
316        };
317
318        // Get the total file size and range support
319        if !self.quiet_mode {
320            println!("Querying server for file size and range support...");
321        }
322        let (total_size, supports_range) = self.get_file_size_and_range()?;
323        if !self.quiet_mode {
324            println!("Total file size: {} bytes", total_size);
325            println!("Server supports range requests: {}", supports_range);
326        }
327
328        if let Some(size) = existing_size {
329            if size > total_size {
330                return Err("Existing file is larger than remote; aborting".into());
331            }
332            if !self.quiet_mode {
333                println!("Resuming download from byte: {}", size);
334            }
335        }
336
337        // Create a progress bar if not quiet or if we have a callback
338        let progress = if !self.quiet_mode || self.progress_callback.is_some() {
339            let bar = ProgressBar::new(total_size);
340            if self.quiet_mode {
341                bar.set_draw_target(indicatif::ProgressDrawTarget::hidden());
342            } else {
343                bar.set_style(ProgressStyle::with_template(
344                    "{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})"
345                ).unwrap().progress_chars("#>-"));
346            }
347            Some(Arc::new(Mutex::new(bar)))
348        } else {
349            None
350        };
351
352        // Create or open the output file and preallocate
353        if !self.quiet_mode {
354            println!("Preparing output file: {}", self.output_path);
355        }
356        let file = if existing_size.is_some() {
357            File::options().read(true).write(true).open(&self.output_path)?
358        } else {
359            File::create(&self.output_path)?
360        };
361        file.set_len(total_size)?;
362        if !self.quiet_mode {
363            println!("File preallocated to {} bytes", total_size);
364        }
365
366        // If range not supported, do a single download
367        if !supports_range {
368            if !self.quiet_mode {
369                println!("Range requests not supported, falling back to single-threaded download");
370            }
371            self.download_whole(&file, existing_size.unwrap_or(0), progress.clone())?;
372            if let Some(ref bar) = progress {
373                bar.lock().unwrap().finish_with_message("Download completed");
374            }
375            if !self.quiet_mode {
376                println!("Single-threaded download completed");
377            }
378            return Ok(());
379        }
380
381        // Calculate chunks for parallel download
382        if !self.quiet_mode {
383            println!("Calculating download chunks...");
384        }
385        let chunks = self.calculate_chunks(total_size, existing_size)?;
386        if !self.quiet_mode {
387            println!("Download will be split into {} chunks", chunks.len());
388        }
389
390        // Download parallel chunks
391        if !self.quiet_mode {
392            println!("Starting parallel chunk downloads...");
393        }
394        self.download_chunks_parallel(chunks, &file, progress.clone())?;
395
396        if let Some(ref bar) = progress {
397            bar.lock().unwrap().finish_with_message("Download completed");
398        }
399
400        // Verify download integrity
401        if !self.quiet_mode || self.status_callback.is_some() {
402            if is_iso {
403                
404                let should_verify = if self.status_callback.is_some() {
405                    true 
406                } else {
407                    println!("\nThis is an ISO file. Would you like to verify its integrity? (y/N)");
408                    let mut input = String::new();
409                    std::io::stdin().read_line(&mut input).is_ok() && input.trim().to_lowercase() == "y"
410                };
411
412                if should_verify {
413                    self.verify_integrity(total_size)?;
414                }
415            } else {
416                let metadata = std::fs::metadata(&self.output_path)?;
417                if metadata.len() != total_size {
418                    return Err(format!("File size mismatch: expected {} bytes, got {} bytes", total_size, metadata.len()).into());
419                }
420            }
421            self.send_status("Advanced download completed successfully!");
422        }
423
424        Ok(())
425    }
426
427    fn get_file_size_and_range(&self) -> Result<(u64, bool), Box<dyn Error + Send + Sync>> {
428        let response = self.client.head(&self.url).send()?;
429        let content_length = response.headers()
430            .get(reqwest::header::CONTENT_LENGTH)
431            .and_then(|v| v.to_str().ok())
432            .and_then(|s| s.parse::<u64>().ok())
433            .ok_or("Could not determine file size")?;
434
435        let accepts_range = response.headers()
436            .get(reqwest::header::ACCEPT_RANGES)
437            .and_then(|v| v.to_str().ok())
438            .map(|s| s.eq_ignore_ascii_case("bytes"))
439            .unwrap_or(false);
440
441        Ok((content_length, accepts_range))
442    }
443
444    fn calculate_chunks(&self, total_size: u64, existing_size: Option<u64>) -> Result<Vec<(u64, u64)>, Box<dyn Error + Send + Sync>> {
445        let mut chunks = Vec::new();
446        let start_from = existing_size.unwrap_or(0);
447
448        
449        let parallelism = rayon::current_num_threads() as u64;
450        let target_chunks = parallelism.saturating_mul(2).max(2); // Reduced to avoid overwhelming servers
451        let chunk_size = ((total_size / target_chunks).max(MIN_CHUNK_SIZE)).min(64 * 1024 * 1024);
452
453        let mut start = start_from;
454        while start < total_size {
455            let end = (start + chunk_size).min(total_size);
456            chunks.push((start, end));
457            start = end;
458        }
459
460        Ok(chunks)
461    }
462
463    fn download_whole(&self, file: &File, offset: u64, progress: Option<Arc<Mutex<ProgressBar>>>) -> Result<(), Box<dyn Error + Send + Sync>> {
464        let response = self.client.get(&self.url).send()?;
465        if offset > 0 {
466            // Resume not possible without range; warn
467            return Err("Server does not support range; cannot resume partial file".into());
468        }
469
470        let mut reader = BufReader::new(response);
471        let mut f = file.try_clone()?;
472        f.seek(SeekFrom::Start(0))?;
473
474        struct ProgressWriter<'a, W> {
475            inner: W,
476            progress: Option<Arc<Mutex<ProgressBar>>>,
477            callback: Option<&'a Arc<dyn Fn(f32) + Send + Sync>>,
478        }
479
480        impl<'a, W: Write> Write for ProgressWriter<'a, W> {
481            fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
482                let n = self.inner.write(buf)?;
483                if let Some(ref bar) = self.progress {
484                    let guard = bar.lock().unwrap();
485                    guard.inc(n as u64);
486                    if let Some(cb) = self.callback {
487                        let pos = guard.position();
488                        let len = guard.length().unwrap_or(1);
489                        drop(guard);
490                        (cb)(pos as f32 / len as f32);
491                    }
492                }
493                Ok(n)
494            }
495
496            fn flush(&mut self) -> std::io::Result<()> {
497                self.inner.flush()
498            }
499        }
500
501        let mut writer = ProgressWriter { 
502            inner: f, 
503            progress,
504            callback: self.progress_callback.as_ref(),
505        };
506        std::io::copy(&mut reader, &mut writer)?;
507
508        Ok(())
509    }
510
511    fn download_chunks_parallel(&self, chunks: Vec<(u64, u64)>, file: &File, progress: Option<Arc<Mutex<ProgressBar>>>) -> Result<(), Box<dyn Error + Send + Sync>> {
512        let file = Arc::new(file);
513        let client = Arc::new(self.client.clone());
514        let url = Arc::new(self.url.clone());
515        let _optimizer = Arc::new(self.optimizer.clone());
516        let progress_callback = self.progress_callback.clone();
517        let cancel_token = self.cancel_token.clone();
518        
519        // Shared progress counter for pipe-friendly output
520        let total_bytes: u64 = chunks.iter().map(|(s, e)| e - s).sum();
521        let downloaded_bytes = Arc::new(AtomicU64::new(0));
522        let last_print_time = Arc::new(Mutex::new(Instant::now()));
523
524        chunks.par_iter().try_for_each(|&(start, end)| {
525            // Check for cancellation before starting chunk
526            if cancel_token.load(Ordering::Relaxed) {
527                return Err("Download cancelled".into());
528            }
529
530            let range = format!("bytes={}-{}", start, end - 1);
531            let range_header = reqwest::header::HeaderValue::from_str(&range)
532                .map_err(|e| format!("Invalid range header {}: {}", range, e))?;
533
534            for retry in 0..=MAX_RETRIES {
535                // Check for cancellation on each retry
536                if cancel_token.load(Ordering::Relaxed) {
537                    return Err("Download cancelled".into());
538                }
539
540                let request = client.get(url.as_str());
541                let request = request.header(reqwest::header::RANGE, range_header.clone());
542
543                match request.send() {
544                    Ok(mut response) => {
545                        let status = response.status();
546                        if status.is_success() {
547                            // Use FileExt to write at specific offset without seeking shared cursor
548                            // This prevents race conditions when multiple threads write to the same file
549                            
550                            let mut current_pos = start;
551                            let mut buffer = [0u8; 16384]; 
552                            
553                            while current_pos < end {
554                                // Check for cancellation periodically during download
555                                if cancel_token.load(Ordering::Relaxed) {
556                                    return Err("Download cancelled".into());
557                                }
558
559                                let limit = (end - current_pos).min(buffer.len() as u64);
560                                let n = response.read(&mut buffer[..limit as usize])?;
561                                if n == 0 { break; }
562                                
563                                #[cfg(target_family = "unix")]
564                                file.write_at(&buffer[..n], current_pos)?;
565                                
566                                #[cfg(target_family = "windows")]
567                                file.seek_write(&buffer[..n], current_pos)?;
568                                
569                                current_pos += n as u64;
570                                
571                                // Update shared progress counter
572                                let new_downloaded = downloaded_bytes.fetch_add(n as u64, Ordering::Relaxed) + n as u64;
573                                
574                                // Print progress periodically (every 200ms) for pipe-friendly output
575                                {
576                                    let mut last_time = last_print_time.lock().unwrap();
577                                    if last_time.elapsed() >= Duration::from_millis(200) {
578                                        let percent = (new_downloaded as f64 / total_bytes as f64 * 100.0).min(100.0);
579                                        // PROGRESS: format that Swift can parse
580                                        println!("PROGRESS: {:.1}% ({}/{})", percent, new_downloaded, total_bytes);
581                                        *last_time = Instant::now();
582                                    }
583                                }
584                            
585                                if let Some(ref bar) = progress {
586                                    let guard = bar.lock().unwrap();
587                                    guard.inc(n as u64);
588                                    if let Some(ref cb) = progress_callback {
589                                        let pos = guard.position();
590                                        let len = guard.length().unwrap_or(1);
591                                        drop(guard);
592                                        (cb)(pos as f32 / len as f32);
593                                    }
594                                }
595                            }
596
597                            return Ok::<(), Box<dyn Error + Send + Sync>>(());
598                        } else if status.as_u16() == 416 {
599                            if retry == MAX_RETRIES {
600                                return Err(format!("Failed to download chunk {}-{}: HTTP {}", start, end, status).into());
601                            }
602                            std::thread::sleep(Duration::from_millis(250 * (retry as u64 + 1)));
603                        }
604                    }
605                    Err(e) => {
606                        if retry == MAX_RETRIES {
607                            return Err(format!("Failed to download chunk {}-{}: {}", start, end, e).into());
608                        }
609                        std::thread::sleep(Duration::from_millis(250 * (retry as u64 + 1)));
610                    }
611                }
612            }
613            Err(format!("Failed to download chunk {}-{} after retries", start, end).into())
614        })?;
615
616        Ok(())
617    }
618
619        fn verify_integrity(&self, expected_size: u64) -> Result<(), Box<dyn Error + Send + Sync>> {
620            let metadata = std::fs::metadata(&self.output_path)?;
621        let actual_size = metadata.len();
622        
623        if actual_size != expected_size {
624            return Err(format!("File size mismatch: expected {} bytes, got {} bytes", expected_size, actual_size).into());
625        }
626        
627        self.send_status(&format!("File size verified: {} bytes", actual_size));
628
629        // Calculate SHA256 hash for corruption check
630        self.send_status("Calculating SHA256 hash...");
631        
632        let mut file = File::open(&self.output_path)?;
633        let mut hasher = Sha256::new();
634        let mut buffer = [0; 8192];
635        loop {
636            let n = file.read(&mut buffer)?;
637            if n == 0 {
638                break;
639            }
640            hasher.update(&buffer[..n]);
641        }
642        let hash = hasher.finalize();
643        let hash_hex = hex::encode(hash);
644        
645        self.send_status(&format!("SHA256 hash: {}", hash_hex));
646        self.send_status("Integrity check passed - file is not corrupted");
647
648        Ok(())
649    }
650}