1pub mod checksum;
13pub mod connection;
14pub mod mirror;
15pub mod resume;
16pub mod segment;
17
18pub use checksum::{compute_checksum, verify_checksum, ChecksumAlgorithm, ExpectedChecksum};
19pub use connection::{ConnectionPool, RetryPolicy, SpeedCalculator};
20pub use mirror::MirrorManager;
21pub use resume::{check_resume, ResumeInfo};
22pub use segment::{calculate_segment_count, probe_server, SegmentedDownload, ServerCapabilities};
23
24use crate::config::EngineConfig;
25use crate::error::{EngineError, NetworkErrorKind, Result, StorageErrorKind};
26use crate::storage::Segment;
27use crate::types::DownloadProgress;
28
29use futures::StreamExt;
30use parking_lot::RwLock;
31use reqwest::{Client, Response};
32use std::path::{Path, PathBuf};
33use std::sync::atomic::{AtomicU64, Ordering};
34use std::sync::Arc;
35use std::time::{Duration, Instant};
36use tokio::fs::{File, OpenOptions};
37use tokio::io::AsyncWriteExt;
38use tokio_util::sync::CancellationToken;
39
40pub struct HttpDownloader {
42 pool: Arc<ConnectionPool>,
43 config: HttpDownloaderConfig,
44 retry_policy: RetryPolicy,
45}
46
47#[derive(Debug, Clone)]
49pub struct HttpDownloaderConfig {
50 pub connect_timeout: Duration,
51 pub read_timeout: Duration,
52 pub max_redirects: usize,
53 pub default_user_agent: String,
54}
55
56impl HttpDownloader {
57 pub fn new(config: &EngineConfig) -> Result<Self> {
59 let pool = ConnectionPool::with_limits(
61 &config.http,
62 config.global_download_limit,
63 config.global_upload_limit,
64 )?;
65
66 let retry_policy = RetryPolicy::new(
68 config.http.max_retries as u32,
69 config.http.retry_delay_ms,
70 config.http.max_retry_delay_ms,
71 );
72
73 Ok(Self {
74 pool: Arc::new(pool),
75 config: HttpDownloaderConfig {
76 connect_timeout: Duration::from_secs(config.http.connect_timeout),
77 read_timeout: Duration::from_secs(config.http.read_timeout),
78 max_redirects: config.http.max_redirects,
79 default_user_agent: config.user_agent.clone(),
80 },
81 retry_policy,
82 })
83 }
84
85 fn client(&self) -> &Client {
87 self.pool.client()
88 }
89
90 pub fn retry_policy(&self) -> &RetryPolicy {
92 &self.retry_policy
93 }
94
95 #[allow(clippy::too_many_arguments)]
99 pub async fn download<F>(
100 &self,
101 url: &str,
102 save_dir: &Path,
103 filename: Option<&str>,
104 user_agent: Option<&str>,
105 referer: Option<&str>,
106 headers: &[(String, String)],
107 cookies: Option<&[String]>,
108 checksum: Option<&ExpectedChecksum>,
109 cancel_token: CancellationToken,
110 progress_callback: F,
111 ) -> Result<PathBuf>
112 where
113 F: Fn(DownloadProgress) + Send + 'static,
114 {
115 let mut request = self.client().get(url);
117
118 let ua = user_agent.unwrap_or(&self.config.default_user_agent);
120 request = request.header("User-Agent", ua);
121
122 if let Some(ref_url) = referer {
124 request = request.header("Referer", ref_url);
125 }
126
127 for (name, value) in headers {
129 request = request.header(name.as_str(), value.as_str());
130 }
131
132 if let Some(cookie_list) = cookies {
134 if !cookie_list.is_empty() {
135 let cookie_value = cookie_list.join("; ");
136 request = request.header("Cookie", cookie_value);
137 }
138 }
139
140 let mut head_request = self.client().head(url).header("User-Agent", ua);
142 if let Some(cookie_list) = cookies {
143 if !cookie_list.is_empty() {
144 head_request = head_request.header("Cookie", cookie_list.join("; "));
145 }
146 }
147 let head_response = head_request.send().await;
148
149 let (content_length, supports_range, suggested_filename) = match head_response {
150 Ok(resp) => {
151 let length = resp
152 .headers()
153 .get("content-length")
154 .and_then(|v| v.to_str().ok())
155 .and_then(|s| s.parse::<u64>().ok());
156
157 let supports_range = resp
158 .headers()
159 .get("accept-ranges")
160 .and_then(|v| v.to_str().ok())
161 .map(|v| v.contains("bytes"))
162 .unwrap_or(false);
163
164 let suggested = resp
166 .headers()
167 .get("content-disposition")
168 .and_then(|v| v.to_str().ok())
169 .and_then(parse_content_disposition);
170
171 (length, supports_range, suggested)
172 }
173 Err(_) => {
174 (None, false, None)
176 }
177 };
178
179 if cancel_token.is_cancelled() {
181 return Err(EngineError::Shutdown);
182 }
183
184 let final_filename = filename
186 .map(|s| s.to_string())
187 .or(suggested_filename)
188 .or_else(|| extract_filename_from_url(url))
189 .unwrap_or_else(|| "download".to_string());
190
191 if !save_dir.exists() {
193 tokio::fs::create_dir_all(save_dir).await.map_err(|e| {
194 EngineError::storage(
195 StorageErrorKind::Io,
196 save_dir,
197 format!("Failed to create directory: {}", e),
198 )
199 })?;
200 }
201
202 use std::path::Component;
205 for component in Path::new(&final_filename).components() {
206 match component {
207 Component::ParentDir => {
208 return Err(EngineError::storage(
209 StorageErrorKind::PathTraversal,
210 Path::new(&final_filename),
211 "Invalid filename: contains parent directory reference (..)",
212 ));
213 }
214 Component::RootDir | Component::Prefix(_) => {
215 return Err(EngineError::storage(
216 StorageErrorKind::PathTraversal,
217 Path::new(&final_filename),
218 "Invalid filename: contains absolute path",
219 ));
220 }
221 _ => {}
222 }
223 }
224
225 let save_path = save_dir.join(&final_filename);
226
227 let part_path = save_path.with_extension(
229 save_path
230 .extension()
231 .map(|e| format!("{}.part", e.to_string_lossy()))
232 .unwrap_or_else(|| "part".to_string()),
233 );
234
235 let existing_size = if supports_range && part_path.exists() {
237 tokio::fs::metadata(&part_path)
238 .await
239 .map(|m| m.len())
240 .unwrap_or(0)
241 } else {
242 0
243 };
244
245 if existing_size > 0 {
247 request = request.header("Range", format!("bytes={}-", existing_size));
248 }
249
250 let response = request.send().await?;
252
253 let status = response.status();
255 if !status.is_success() && status != reqwest::StatusCode::PARTIAL_CONTENT {
256 return Err(EngineError::network(
257 NetworkErrorKind::HttpStatus(status.as_u16()),
258 format!("HTTP error: {}", status),
259 ));
260 }
261
262 let total_size = content_length.or_else(|| {
264 response
265 .headers()
266 .get("content-length")
267 .and_then(|v| v.to_str().ok())
268 .and_then(|s| s.parse::<u64>().ok())
269 .map(|len| len + existing_size)
270 });
271
272 let file = if existing_size > 0 && status == reqwest::StatusCode::PARTIAL_CONTENT {
274 OpenOptions::new()
276 .write(true)
277 .append(true)
278 .open(&part_path)
279 .await
280 .map_err(|e| {
281 EngineError::storage(
282 StorageErrorKind::Io,
283 &part_path,
284 format!("Failed to open file for append: {}", e),
285 )
286 })?
287 } else {
288 File::create(&part_path).await.map_err(|e| {
290 EngineError::storage(
291 StorageErrorKind::Io,
292 &part_path,
293 format!("Failed to create file: {}", e),
294 )
295 })?
296 };
297
298 let downloaded = Arc::new(AtomicU64::new(existing_size));
300
301 let result = self
303 .stream_to_file(
304 response,
305 file,
306 downloaded.clone(),
307 total_size,
308 cancel_token.clone(),
309 move |completed, speed| {
310 progress_callback(DownloadProgress {
311 total_size,
312 completed_size: completed,
313 download_speed: speed,
314 upload_speed: 0,
315 connections: 1,
316 seeders: 0,
317 peers: 0,
318 eta_seconds: total_size.and_then(|total| {
319 if speed > 0 {
320 Some((total.saturating_sub(completed)) / speed)
321 } else {
322 None
323 }
324 }),
325 });
326 },
327 )
328 .await;
329
330 match result {
331 Ok(_) => {
332 if let Some(expected) = checksum {
334 let verified = verify_checksum(&part_path, expected).await?;
335 if !verified {
336 let actual = compute_checksum(&part_path, expected.algorithm).await?;
337 return Err(checksum::checksum_mismatch_error(&expected.value, &actual));
338 }
339 tracing::debug!("Checksum verified: {} matches expected", expected.algorithm);
340 }
341
342 tokio::fs::rename(&part_path, &save_path)
344 .await
345 .map_err(|e| {
346 EngineError::storage(
347 StorageErrorKind::Io,
348 &save_path,
349 format!("Failed to rename file: {}", e),
350 )
351 })?;
352
353 Ok(save_path)
354 }
355 Err(e) => {
356 Err(e)
358 }
359 }
360 }
361
362 async fn stream_to_file<F>(
364 &self,
365 response: Response,
366 mut file: File,
367 downloaded: Arc<AtomicU64>,
368 total_size: Option<u64>,
369 cancel_token: CancellationToken,
370 progress_callback: F,
371 ) -> Result<()>
372 where
373 F: Fn(u64, u64) + Send,
374 {
375 let mut stream = response.bytes_stream();
376 let mut last_update = Instant::now();
377 let mut bytes_since_update: u64 = 0;
378 let update_interval = Duration::from_millis(250); while let Some(chunk_result) = tokio::select! {
381 chunk = stream.next() => chunk,
382 _ = cancel_token.cancelled() => {
383 file.flush().await.ok();
384 return Err(EngineError::Shutdown);
385 }
386 } {
387 let chunk: bytes::Bytes = chunk_result.map_err(|e: reqwest::Error| {
388 EngineError::network(NetworkErrorKind::Other, format!("Stream error: {}", e))
389 })?;
390
391 let chunk_len = chunk.len() as u64;
392
393 self.pool.acquire_download(chunk_len).await;
395
396 file.write_all(&chunk).await.map_err(|e| {
398 EngineError::storage(
399 StorageErrorKind::Io,
400 PathBuf::new(),
401 format!("Failed to write: {}", e),
402 )
403 })?;
404
405 self.pool.record_download(chunk_len);
407
408 let new_total = downloaded.fetch_add(chunk_len, Ordering::Relaxed) + chunk_len;
410 bytes_since_update += chunk_len;
411
412 let now = Instant::now();
414 if now.duration_since(last_update) >= update_interval {
415 let elapsed_secs = now.duration_since(last_update).as_secs_f64();
416 let speed = if elapsed_secs > 0.0 {
417 (bytes_since_update as f64 / elapsed_secs) as u64
418 } else {
419 0
420 };
421
422 progress_callback(new_total, speed);
423
424 last_update = now;
425 bytes_since_update = 0;
426 }
427 }
428
429 file.flush().await.map_err(|e| {
431 EngineError::storage(
432 StorageErrorKind::Io,
433 PathBuf::new(),
434 format!("Failed to flush: {}", e),
435 )
436 })?;
437
438 file.sync_all().await.map_err(|e| {
439 EngineError::storage(
440 StorageErrorKind::Io,
441 PathBuf::new(),
442 format!("Failed to sync: {}", e),
443 )
444 })?;
445
446 let final_size = downloaded.load(Ordering::Relaxed);
448 progress_callback(final_size, 0);
449
450 if let Some(expected) = total_size {
452 if final_size < expected {
453 return Err(EngineError::network(
454 NetworkErrorKind::Other,
455 format!(
456 "Incomplete download: received {} bytes, expected {} bytes",
457 final_size, expected
458 ),
459 ));
460 }
461 }
462
463 Ok(())
464 }
465
466 #[allow(clippy::too_many_arguments)]
471 pub async fn download_segmented<F>(
476 &self,
477 url: &str,
478 save_dir: &Path,
479 filename: Option<&str>,
480 user_agent: Option<&str>,
481 referer: Option<&str>,
482 headers: &[(String, String)],
483 cookies: Option<&[String]>,
484 checksum: Option<&ExpectedChecksum>,
485 max_connections: usize,
486 min_segment_size: u64,
487 cancel_token: CancellationToken,
488 saved_segments: Option<Vec<Segment>>,
489 progress_callback: F,
490 segmented_ref: Option<Arc<RwLock<Option<Arc<SegmentedDownload>>>>>,
491 ) -> Result<(PathBuf, Option<Arc<SegmentedDownload>>)>
492 where
493 F: Fn(DownloadProgress) + Send + Sync + 'static,
494 {
495 let ua = user_agent.unwrap_or(&self.config.default_user_agent);
496
497 let capabilities = probe_server(self.client(), url, ua).await?;
499
500 let final_filename = filename
502 .map(|s| s.to_string())
503 .or(capabilities.suggested_filename.clone())
504 .or_else(|| extract_filename_from_url(url))
505 .unwrap_or_else(|| "download".to_string());
506
507 if !save_dir.exists() {
509 tokio::fs::create_dir_all(save_dir).await.map_err(|e| {
510 EngineError::storage(
511 StorageErrorKind::Io,
512 save_dir,
513 format!("Failed to create directory: {}", e),
514 )
515 })?;
516 }
517
518 let save_path = save_dir.join(&final_filename);
519
520 let use_segmented = capabilities.supports_range
522 && capabilities
523 .content_length
524 .map(|l| l > min_segment_size)
525 .unwrap_or(false);
526
527 if use_segmented {
528 let total_size = capabilities.content_length.unwrap();
529
530 let mut download = SegmentedDownload::new(
532 url.to_string(),
533 total_size,
534 save_path.clone(),
535 true,
536 capabilities.etag,
537 capabilities.last_modified,
538 );
539
540 if let Some(segments) = saved_segments {
542 tracing::debug!("Restoring {} saved segments", segments.len());
543 download.restore_segments(segments);
544 } else {
545 download.init_segments(max_connections, min_segment_size);
546 }
547
548 let download = Arc::new(download);
550 let download_ref = Arc::clone(&download);
551
552 if let Some(ref slot) = segmented_ref {
554 *slot.write() = Some(Arc::clone(&download));
555 }
556
557 let mut all_headers = headers.to_vec();
559 if let Some(r) = referer {
560 all_headers.push(("Referer".to_string(), r.to_string()));
561 }
562 if let Some(cookie_list) = cookies {
564 if !cookie_list.is_empty() {
565 all_headers.push(("Cookie".to_string(), cookie_list.join("; ")));
566 }
567 }
568
569 download
571 .start(
572 self.client(),
573 ua,
574 &all_headers,
575 max_connections,
576 cancel_token,
577 progress_callback,
578 )
579 .await?;
580
581 if let Some(expected) = checksum {
583 let verified = verify_checksum(&save_path, expected).await?;
584 if !verified {
585 let actual = compute_checksum(&save_path, expected.algorithm).await?;
586 return Err(checksum::checksum_mismatch_error(&expected.value, &actual));
587 }
588 tracing::debug!("Checksum verified: {} matches expected", expected.algorithm);
589 }
590
591 Ok((save_path, Some(download_ref)))
592 } else {
593 let path = self
595 .download(
596 url,
597 save_dir,
598 Some(&final_filename),
599 user_agent,
600 referer,
601 headers,
602 cookies,
603 checksum,
604 cancel_token,
605 progress_callback,
606 )
607 .await?;
608 Ok((path, None))
609 }
610 }
611}
612
613fn parse_content_disposition(header: &str) -> Option<String> {
615 if let Some(start) = header.find("filename=") {
617 let rest = &header[start + 9..];
618 if let Some(stripped) = rest.strip_prefix('"') {
619 let end = stripped.find('"')?;
621 return Some(stripped[..end].to_string());
622 } else {
623 let end = rest.find(';').unwrap_or(rest.len());
625 return Some(rest[..end].trim().to_string());
626 }
627 }
628
629 if let Some(start) = header.find("filename*=") {
630 let rest = &header[start + 10..];
631 if let Some(quote_start) = rest.find("''") {
633 let encoded = &rest[quote_start + 2..];
634 let end = encoded.find(';').unwrap_or(encoded.len());
635 if let Ok(decoded) = urlencoding::decode(&encoded[..end]) {
637 return Some(decoded.to_string());
638 }
639 }
640 }
641
642 None
643}
644
645fn extract_filename_from_url(url: &str) -> Option<String> {
647 url::Url::parse(url)
648 .ok()?
649 .path_segments()?
650 .next_back()
651 .filter(|s| !s.is_empty())
652 .map(|s| {
653 urlencoding::decode(s)
655 .map(|d| d.to_string())
656 .unwrap_or_else(|_| s.to_string())
657 })
658}
659
660#[cfg(test)]
661mod tests {
662 use super::*;
663
664 #[test]
665 fn test_parse_content_disposition() {
666 assert_eq!(
667 parse_content_disposition("attachment; filename=\"test.zip\""),
668 Some("test.zip".to_string())
669 );
670
671 assert_eq!(
672 parse_content_disposition("attachment; filename=test.zip"),
673 Some("test.zip".to_string())
674 );
675 }
676
677 #[test]
678 fn test_extract_filename_from_url() {
679 assert_eq!(
680 extract_filename_from_url("https://example.com/path/to/file.zip"),
681 Some("file.zip".to_string())
682 );
683
684 assert_eq!(
685 extract_filename_from_url("https://example.com/path/to/file%20name.zip"),
686 Some("file name.zip".to_string())
687 );
688 }
689}