1use std::error::Error;
2use std::fs::File;
3use std::io::{BufReader, Read, Seek, SeekFrom, Write, BufWriter};
4use std::path::Path;
5use std::sync::{Arc, Mutex};
6use std::time::Duration;
7use rayon::prelude::*;
8use reqwest::blocking::Client;
9use indicatif::{ProgressBar, ProgressStyle};
10use crate::utils::print;
11use crate::config::ProxyConfig;
12use sha2::{Sha256, Digest};
13use hex;
14use crate::optimization::Optimizer;
15
16#[cfg(target_family = "unix")]
17use std::os::unix::fs::FileExt;
18#[cfg(target_family = "windows")]
19use std::os::windows::fs::FileExt;
20
21const MIN_CHUNK_SIZE: u64 = 4 * 1024 * 1024;
22const MAX_RETRIES: usize = 3;
23
24pub struct AdvancedDownloader {
25 client: Client,
26 url: String,
27 output_path: String,
28 quiet_mode: bool,
29 #[allow(dead_code)]
30 proxy: ProxyConfig,
31 optimizer: Optimizer,
32 progress_callback: Option<Arc<dyn Fn(f32) + Send + Sync>>,
33 status_callback: Option<Arc<dyn Fn(String) + Send + Sync>>,
34}
35
36impl AdvancedDownloader {
37 pub fn new(url: String, output_path: String, quiet_mode: bool, proxy_config: ProxyConfig, optimizer: Optimizer) -> Self {
38 let is_iso = url.to_lowercase().ends_with(".iso");
39
40 let mut client_builder = Client::builder()
41 .timeout(std::time::Duration::from_secs(300))
42 .connect_timeout(std::time::Duration::from_secs(20))
43 .user_agent("KGet/1.0")
44 .no_gzip()
45 .no_deflate();
46
47 if proxy_config.enabled {
48 if let Some(proxy_url) = &proxy_config.url {
49 let proxy = match proxy_config.proxy_type {
50 crate::config::ProxyType::Http => reqwest::Proxy::http(proxy_url),
51 crate::config::ProxyType::Https => reqwest::Proxy::https(proxy_url),
52 crate::config::ProxyType::Socks5 => reqwest::Proxy::all(proxy_url),
53 };
54
55 if let Ok(mut proxy) = proxy {
56 if let (Some(username), Some(password)) = (&proxy_config.username, &proxy_config.password) {
57 proxy = proxy.basic_auth(username, password);
58 }
59 client_builder = client_builder.proxy(proxy);
60 }
61 }
62 }
63
64 let client = client_builder.build()
65 .expect("Failed to create HTTP client");
66
67 Self {
68 client,
69 url,
70 output_path,
71 quiet_mode,
72 proxy: proxy_config,
73 optimizer,
74 progress_callback: None,
75 status_callback: None,
76 }
77 }
78
79 pub fn set_progress_callback(&mut self, callback: impl Fn(f32) + Send + Sync + 'static) {
80 self.progress_callback = Some(Arc::new(callback));
81 }
82
83 pub fn set_status_callback(&mut self, callback: impl Fn(String) + Send + Sync + 'static) {
84 self.status_callback = Some(Arc::new(callback));
85 }
86
87 fn send_status(&self, msg: &str) {
88 if let Some(cb) = &self.status_callback {
89 cb(msg.to_string());
90 }
91 if !self.quiet_mode {
92 println!("{}", msg);
93 }
94 }
95
96
97 pub fn download(&self) -> Result<(), Box<dyn Error + Send + Sync>> {
98 let is_iso = self.url.to_lowercase().ends_with(".iso");
99 if !self.quiet_mode {
100 println!("Starting advanced download for: {}", self.url);
101 if is_iso {
102 println!("Warning: ISO mode active. Disabling optimizations that could corrupt binary data.");
103 }
104 }
105
106 let existing_size = if Path::new(&self.output_path).exists() {
108 let size = std::fs::metadata(&self.output_path)?.len();
109 if !self.quiet_mode {
110 println!("Existing file found with size: {} bytes", size);
111 }
112 Some(size)
113 } else {
114 if !self.quiet_mode {
115 println!("Output file does not exist, starting fresh download");
116 }
117 None
118 };
119
120 if !self.quiet_mode {
122 println!("Querying server for file size and range support...");
123 }
124 let (total_size, supports_range) = self.get_file_size_and_range()?;
125 if !self.quiet_mode {
126 println!("Total file size: {} bytes", total_size);
127 println!("Server supports range requests: {}", supports_range);
128 }
129
130 if let Some(size) = existing_size {
131 if size > total_size {
132 return Err("Existing file is larger than remote; aborting".into());
133 }
134 if !self.quiet_mode {
135 println!("Resuming download from byte: {}", size);
136 }
137 }
138
139 let progress = if !self.quiet_mode || self.progress_callback.is_some() {
141 let bar = ProgressBar::new(total_size);
142 if self.quiet_mode {
143 bar.set_draw_target(indicatif::ProgressDrawTarget::hidden());
144 } else {
145 bar.set_style(ProgressStyle::with_template(
146 "{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({bytes_per_sec}, {eta})"
147 ).unwrap().progress_chars("#>-"));
148 }
149 Some(Arc::new(Mutex::new(bar)))
150 } else {
151 None
152 };
153
154 if !self.quiet_mode {
156 println!("Preparing output file: {}", self.output_path);
157 }
158 let file = if existing_size.is_some() {
159 File::options().read(true).write(true).open(&self.output_path)?
160 } else {
161 File::create(&self.output_path)?
162 };
163 file.set_len(total_size)?;
164 if !self.quiet_mode {
165 println!("File preallocated to {} bytes", total_size);
166 }
167
168 if !supports_range {
170 if !self.quiet_mode {
171 println!("Range requests not supported, falling back to single-threaded download");
172 }
173 self.download_whole(&file, existing_size.unwrap_or(0), progress.clone())?;
174 if let Some(ref bar) = progress {
175 bar.lock().unwrap().finish_with_message("Download completed");
176 }
177 if !self.quiet_mode {
178 println!("Single-threaded download completed");
179 }
180 return Ok(());
181 }
182
183 if !self.quiet_mode {
185 println!("Calculating download chunks...");
186 }
187 let chunks = self.calculate_chunks(total_size, existing_size)?;
188 if !self.quiet_mode {
189 println!("Download will be split into {} chunks", chunks.len());
190 }
191
192 if !self.quiet_mode {
194 println!("Starting parallel chunk downloads...");
195 }
196 self.download_chunks_parallel(chunks, &file, progress.clone())?;
197
198 if let Some(ref bar) = progress {
199 bar.lock().unwrap().finish_with_message("Download completed");
200 }
201
202 if !self.quiet_mode || self.status_callback.is_some() {
204 if is_iso {
205
206 let should_verify = if self.status_callback.is_some() {
207 true
208 } else {
209 println!("\nThis is an ISO file. Would you like to verify its integrity? (y/N)");
210 let mut input = String::new();
211 std::io::stdin().read_line(&mut input).is_ok() && input.trim().to_lowercase() == "y"
212 };
213
214 if should_verify {
215 self.verify_integrity(total_size)?;
216 }
217 } else {
218 let metadata = std::fs::metadata(&self.output_path)?;
219 if metadata.len() != total_size {
220 return Err(format!("File size mismatch: expected {} bytes, got {} bytes", total_size, metadata.len()).into());
221 }
222 }
223 self.send_status("Advanced download completed successfully!");
224 }
225
226 Ok(())
227 }
228
229 fn get_file_size_and_range(&self) -> Result<(u64, bool), Box<dyn Error + Send + Sync>> {
230 let response = self.client.head(&self.url).send()?;
231 let content_length = response.headers()
232 .get(reqwest::header::CONTENT_LENGTH)
233 .and_then(|v| v.to_str().ok())
234 .and_then(|s| s.parse::<u64>().ok())
235 .ok_or("Could not determine file size")?;
236
237 let accepts_range = response.headers()
238 .get(reqwest::header::ACCEPT_RANGES)
239 .and_then(|v| v.to_str().ok())
240 .map(|s| s.eq_ignore_ascii_case("bytes"))
241 .unwrap_or(false);
242
243 Ok((content_length, accepts_range))
244 }
245
246 fn calculate_chunks(&self, total_size: u64, existing_size: Option<u64>) -> Result<Vec<(u64, u64)>, Box<dyn Error + Send + Sync>> {
247 let mut chunks = Vec::new();
248 let start_from = existing_size.unwrap_or(0);
249
250
251 let parallelism = rayon::current_num_threads() as u64;
252 let target_chunks = parallelism.saturating_mul(2).max(2); let chunk_size = ((total_size / target_chunks).max(MIN_CHUNK_SIZE)).min(64 * 1024 * 1024);
254
255 let mut start = start_from;
256 while start < total_size {
257 let end = (start + chunk_size).min(total_size);
258 chunks.push((start, end));
259 start = end;
260 }
261
262 Ok(chunks)
263 }
264
265 fn download_whole(&self, file: &File, offset: u64, progress: Option<Arc<Mutex<ProgressBar>>>) -> Result<(), Box<dyn Error + Send + Sync>> {
266 let mut response = self.client.get(&self.url).send()?;
267 if offset > 0 {
268 return Err("Server does not support range; cannot resume partial file".into());
270 }
271
272 let mut reader = BufReader::new(response);
273 let mut f = file.try_clone()?;
274 f.seek(SeekFrom::Start(0))?;
275
276 struct ProgressWriter<'a, W> {
277 inner: W,
278 progress: Option<Arc<Mutex<ProgressBar>>>,
279 callback: Option<&'a Arc<dyn Fn(f32) + Send + Sync>>,
280 }
281
282 impl<'a, W: Write> Write for ProgressWriter<'a, W> {
283 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
284 let n = self.inner.write(buf)?;
285 if let Some(ref bar) = self.progress {
286 let mut guard = bar.lock().unwrap();
287 guard.inc(n as u64);
288 if let Some(cb) = self.callback {
289 let pos = guard.position();
290 let len = guard.length().unwrap_or(1);
291 drop(guard);
292 (cb)(pos as f32 / len as f32);
293 }
294 }
295 Ok(n)
296 }
297
298 fn flush(&mut self) -> std::io::Result<()> {
299 self.inner.flush()
300 }
301 }
302
303 let mut writer = ProgressWriter {
304 inner: f,
305 progress,
306 callback: self.progress_callback.as_ref(),
307 };
308 std::io::copy(&mut reader, &mut writer)?;
309
310 Ok(())
311 }
312
313 fn download_chunks_parallel(&self, chunks: Vec<(u64, u64)>, file: &File, progress: Option<Arc<Mutex<ProgressBar>>>) -> Result<(), Box<dyn Error + Send + Sync>> {
314 let file = Arc::new(file);
315 let client = Arc::new(self.client.clone());
316 let url = Arc::new(self.url.clone());
317 let _optimizer = Arc::new(self.optimizer.clone());
318 let progress_callback = self.progress_callback.clone();
319
320 chunks.par_iter().try_for_each(|&(start, end)| {
321 let range = format!("bytes={}-{}", start, end - 1);
322 let range_header = reqwest::header::HeaderValue::from_str(&range)
323 .map_err(|e| format!("Invalid range header {}: {}", range, e))?;
324
325 for retry in 0..=MAX_RETRIES {
326 let request = client.get(url.as_str());
327 let request = request.header(reqwest::header::RANGE, range_header.clone());
328
329 match request.send() {
330 Ok(mut response) => {
331 let status = response.status();
332 if status.is_success() {
333 let mut current_pos = start;
337 let mut buffer = [0u8; 16384];
338
339 while current_pos < end {
340 let limit = (end - current_pos).min(buffer.len() as u64);
341 let n = response.read(&mut buffer[..limit as usize])?;
342 if n == 0 { break; }
343
344 #[cfg(target_family = "unix")]
345 file.write_at(&buffer[..n], current_pos)?;
346
347 #[cfg(target_family = "windows")]
348 file.seek_write(&buffer[..n], current_pos)?;
349
350 current_pos += n as u64;
351
352 if let Some(ref bar) = progress {
353 let mut guard = bar.lock().unwrap();
354 guard.inc(n as u64);
355 if let Some(ref cb) = progress_callback {
356 let pos = guard.position();
357 let len = guard.length().unwrap_or(1);
358 drop(guard);
359 (cb)(pos as f32 / len as f32);
360 }
361 }
362 }
363
364 return Ok::<(), Box<dyn Error + Send + Sync>>(());
365 } else if status.as_u16() == 416 {
366 if retry == MAX_RETRIES {
367 return Err(format!("Failed to download chunk {}-{}: HTTP {}", start, end, status).into());
368 }
369 std::thread::sleep(Duration::from_millis(250 * (retry as u64 + 1)));
370 }
371 }
372 Err(e) => {
373 if retry == MAX_RETRIES {
374 return Err(format!("Failed to download chunk {}-{}: {}", start, end, e).into());
375 }
376 std::thread::sleep(Duration::from_millis(250 * (retry as u64 + 1)));
377 }
378 }
379 }
380 Err(format!("Failed to download chunk {}-{} after retries", start, end).into())
381 })?;
382
383 Ok(())
384 }
385
386 fn verify_integrity(&self, expected_size: u64) -> Result<(), Box<dyn Error + Send + Sync>> {
387 let metadata = std::fs::metadata(&self.output_path)?;
388 let actual_size = metadata.len();
389
390 if actual_size != expected_size {
391 return Err(format!("File size mismatch: expected {} bytes, got {} bytes", expected_size, actual_size).into());
392 }
393
394 self.send_status(&format!("File size verified: {} bytes", actual_size));
395
396 self.send_status("Calculating SHA256 hash...");
398
399 let mut file = File::open(&self.output_path)?;
400 let mut hasher = Sha256::new();
401 let mut buffer = [0; 8192];
402 loop {
403 let n = file.read(&mut buffer)?;
404 if n == 0 {
405 break;
406 }
407 hasher.update(&buffer[..n]);
408 }
409 let hash = hasher.finalize();
410 let hash_hex = hex::encode(hash);
411
412 self.send_status(&format!("SHA256 hash: {}", hash_hex));
413 self.send_status("Integrity check passed - file is not corrupted");
414
415 Ok(())
416 }
417}