gosh_dl/http/
mod.rs

1//! HTTP Download Engine
2//!
3//! This module handles HTTP/HTTPS downloads with support for:
4//! - Single and multi-connection (segmented) downloads
5//! - Resume capability via Range headers
6//! - Progress tracking
7//! - Custom headers, user-agent, referer
8//! - Connection pooling with rate limiting
9//! - Retry logic with exponential backoff
10//! - Checksum verification (MD5/SHA256)
11
12pub mod checksum;
13pub mod connection;
14pub mod mirror;
15pub mod resume;
16pub mod segment;
17
18pub use checksum::{compute_checksum, verify_checksum, ChecksumAlgorithm, ExpectedChecksum};
19pub use connection::{ConnectionPool, RetryPolicy, SpeedCalculator};
20pub use mirror::MirrorManager;
21pub use resume::{check_resume, ResumeInfo};
22pub use segment::{calculate_segment_count, probe_server, SegmentedDownload, ServerCapabilities};
23
24use crate::config::EngineConfig;
25use crate::error::{EngineError, NetworkErrorKind, Result, StorageErrorKind};
26use crate::storage::Segment;
27use crate::types::DownloadProgress;
28
29use futures::StreamExt;
30use parking_lot::RwLock;
31use reqwest::{Client, Response};
32use std::path::{Path, PathBuf};
33use std::sync::atomic::{AtomicU64, Ordering};
34use std::sync::Arc;
35use std::time::{Duration, Instant};
36use tokio::fs::{File, OpenOptions};
37use tokio::io::AsyncWriteExt;
38use tokio_util::sync::CancellationToken;
39
40/// HTTP Downloader
41pub struct HttpDownloader {
42    pool: Arc<ConnectionPool>,
43    config: HttpDownloaderConfig,
44    retry_policy: RetryPolicy,
45}
46
47/// Configuration for HTTP downloader
48#[derive(Debug, Clone)]
49pub struct HttpDownloaderConfig {
50    pub connect_timeout: Duration,
51    pub read_timeout: Duration,
52    pub max_redirects: usize,
53    pub default_user_agent: String,
54}
55
56impl HttpDownloader {
57    /// Create a new HTTP downloader
58    pub fn new(config: &EngineConfig) -> Result<Self> {
59        // Create connection pool with rate limiting if configured
60        let pool = ConnectionPool::with_limits(
61            &config.http,
62            config.global_download_limit,
63            config.global_upload_limit,
64        )?;
65
66        // Create retry policy from config
67        let retry_policy = RetryPolicy::new(
68            config.http.max_retries as u32,
69            config.http.retry_delay_ms,
70            config.http.max_retry_delay_ms,
71        );
72
73        Ok(Self {
74            pool: Arc::new(pool),
75            config: HttpDownloaderConfig {
76                connect_timeout: Duration::from_secs(config.http.connect_timeout),
77                read_timeout: Duration::from_secs(config.http.read_timeout),
78                max_redirects: config.http.max_redirects,
79                default_user_agent: config.user_agent.clone(),
80            },
81            retry_policy,
82        })
83    }
84
85    /// Get the underlying client
86    fn client(&self) -> &Client {
87        self.pool.client()
88    }
89
90    /// Get the retry policy
91    pub fn retry_policy(&self) -> &RetryPolicy {
92        &self.retry_policy
93    }
94
95    /// Download a file from a URL
96    ///
97    /// Returns the final path of the downloaded file
98    #[allow(clippy::too_many_arguments)]
99    pub async fn download<F>(
100        &self,
101        url: &str,
102        save_dir: &Path,
103        filename: Option<&str>,
104        user_agent: Option<&str>,
105        referer: Option<&str>,
106        headers: &[(String, String)],
107        cookies: Option<&[String]>,
108        checksum: Option<&ExpectedChecksum>,
109        cancel_token: CancellationToken,
110        progress_callback: F,
111    ) -> Result<PathBuf>
112    where
113        F: Fn(DownloadProgress) + Send + 'static,
114    {
115        // Build the request
116        let mut request = self.client().get(url);
117
118        // Set user agent
119        let ua = user_agent.unwrap_or(&self.config.default_user_agent);
120        request = request.header("User-Agent", ua);
121
122        // Set referer if provided
123        if let Some(ref_url) = referer {
124            request = request.header("Referer", ref_url);
125        }
126
127        // Add custom headers
128        for (name, value) in headers {
129            request = request.header(name.as_str(), value.as_str());
130        }
131
132        // Add cookies if provided
133        if let Some(cookie_list) = cookies {
134            if !cookie_list.is_empty() {
135                let cookie_value = cookie_list.join("; ");
136                request = request.header("Cookie", cookie_value);
137            }
138        }
139
140        // Send HEAD request first to get metadata
141        let mut head_request = self.client().head(url).header("User-Agent", ua);
142        if let Some(cookie_list) = cookies {
143            if !cookie_list.is_empty() {
144                head_request = head_request.header("Cookie", cookie_list.join("; "));
145            }
146        }
147        let head_response = head_request.send().await;
148
149        let (content_length, supports_range, suggested_filename) = match head_response {
150            Ok(resp) => {
151                let length = resp
152                    .headers()
153                    .get("content-length")
154                    .and_then(|v| v.to_str().ok())
155                    .and_then(|s| s.parse::<u64>().ok());
156
157                let supports_range = resp
158                    .headers()
159                    .get("accept-ranges")
160                    .and_then(|v| v.to_str().ok())
161                    .map(|v| v.contains("bytes"))
162                    .unwrap_or(false);
163
164                // Try to get filename from Content-Disposition
165                let suggested = resp
166                    .headers()
167                    .get("content-disposition")
168                    .and_then(|v| v.to_str().ok())
169                    .and_then(parse_content_disposition);
170
171                (length, supports_range, suggested)
172            }
173            Err(_) => {
174                // HEAD failed, we'll get metadata from GET response
175                (None, false, None)
176            }
177        };
178
179        // Check for cancellation
180        if cancel_token.is_cancelled() {
181            return Err(EngineError::Shutdown);
182        }
183
184        // Determine filename
185        let final_filename = filename
186            .map(|s| s.to_string())
187            .or(suggested_filename)
188            .or_else(|| extract_filename_from_url(url))
189            .unwrap_or_else(|| "download".to_string());
190
191        // Ensure save directory exists
192        if !save_dir.exists() {
193            tokio::fs::create_dir_all(save_dir).await.map_err(|e| {
194                EngineError::storage(
195                    StorageErrorKind::Io,
196                    save_dir,
197                    format!("Failed to create directory: {}", e),
198                )
199            })?;
200        }
201
202        // Validate filename for path traversal attacks (security)
203        // Check each path component to prevent directory traversal
204        use std::path::Component;
205        for component in Path::new(&final_filename).components() {
206            match component {
207                Component::ParentDir => {
208                    return Err(EngineError::storage(
209                        StorageErrorKind::PathTraversal,
210                        Path::new(&final_filename),
211                        "Invalid filename: contains parent directory reference (..)",
212                    ));
213                }
214                Component::RootDir | Component::Prefix(_) => {
215                    return Err(EngineError::storage(
216                        StorageErrorKind::PathTraversal,
217                        Path::new(&final_filename),
218                        "Invalid filename: contains absolute path",
219                    ));
220                }
221                _ => {}
222            }
223        }
224
225        let save_path = save_dir.join(&final_filename);
226
227        // Use .part extension during download
228        let part_path = save_path.with_extension(
229            save_path
230                .extension()
231                .map(|e| format!("{}.part", e.to_string_lossy()))
232                .unwrap_or_else(|| "part".to_string()),
233        );
234
235        // Check if we can resume
236        let existing_size = if supports_range && part_path.exists() {
237            tokio::fs::metadata(&part_path)
238                .await
239                .map(|m| m.len())
240                .unwrap_or(0)
241        } else {
242            0
243        };
244
245        // Add Range header if resuming
246        if existing_size > 0 {
247            request = request.header("Range", format!("bytes={}-", existing_size));
248        }
249
250        // Send the request
251        let response = request.send().await?;
252
253        // Check response status
254        let status = response.status();
255        if !status.is_success() && status != reqwest::StatusCode::PARTIAL_CONTENT {
256            return Err(EngineError::network(
257                NetworkErrorKind::HttpStatus(status.as_u16()),
258                format!("HTTP error: {}", status),
259            ));
260        }
261
262        // Get actual content length from response if not from HEAD
263        let total_size = content_length.or_else(|| {
264            response
265                .headers()
266                .get("content-length")
267                .and_then(|v| v.to_str().ok())
268                .and_then(|s| s.parse::<u64>().ok())
269                .map(|len| len + existing_size)
270        });
271
272        // Open file for writing
273        let file = if existing_size > 0 && status == reqwest::StatusCode::PARTIAL_CONTENT {
274            // Append mode for resume
275            OpenOptions::new()
276                .write(true)
277                .append(true)
278                .open(&part_path)
279                .await
280                .map_err(|e| {
281                    EngineError::storage(
282                        StorageErrorKind::Io,
283                        &part_path,
284                        format!("Failed to open file for append: {}", e),
285                    )
286                })?
287        } else {
288            // Create new file
289            File::create(&part_path).await.map_err(|e| {
290                EngineError::storage(
291                    StorageErrorKind::Io,
292                    &part_path,
293                    format!("Failed to create file: {}", e),
294                )
295            })?
296        };
297
298        // Download with progress tracking
299        let downloaded = Arc::new(AtomicU64::new(existing_size));
300
301        // Stream the response body
302        let result = self
303            .stream_to_file(
304                response,
305                file,
306                downloaded.clone(),
307                total_size,
308                cancel_token.clone(),
309                move |completed, speed| {
310                    progress_callback(DownloadProgress {
311                        total_size,
312                        completed_size: completed,
313                        download_speed: speed,
314                        upload_speed: 0,
315                        connections: 1,
316                        seeders: 0,
317                        peers: 0,
318                        eta_seconds: total_size.and_then(|total| {
319                            if speed > 0 {
320                                Some((total.saturating_sub(completed)) / speed)
321                            } else {
322                                None
323                            }
324                        }),
325                    });
326                },
327            )
328            .await;
329
330        match result {
331            Ok(_) => {
332                // Verify checksum before renaming (if checksum was provided)
333                if let Some(expected) = checksum {
334                    let verified = verify_checksum(&part_path, expected).await?;
335                    if !verified {
336                        let actual = compute_checksum(&part_path, expected.algorithm).await?;
337                        return Err(checksum::checksum_mismatch_error(&expected.value, &actual));
338                    }
339                    tracing::debug!("Checksum verified: {} matches expected", expected.algorithm);
340                }
341
342                // Rename .part file to final name
343                tokio::fs::rename(&part_path, &save_path)
344                    .await
345                    .map_err(|e| {
346                        EngineError::storage(
347                            StorageErrorKind::Io,
348                            &save_path,
349                            format!("Failed to rename file: {}", e),
350                        )
351                    })?;
352
353                Ok(save_path)
354            }
355            Err(e) => {
356                // Keep .part file for potential resume
357                Err(e)
358            }
359        }
360    }
361
362    /// Stream response body to file with progress tracking
363    async fn stream_to_file<F>(
364        &self,
365        response: Response,
366        mut file: File,
367        downloaded: Arc<AtomicU64>,
368        total_size: Option<u64>,
369        cancel_token: CancellationToken,
370        progress_callback: F,
371    ) -> Result<()>
372    where
373        F: Fn(u64, u64) + Send,
374    {
375        let mut stream = response.bytes_stream();
376        let mut last_update = Instant::now();
377        let mut bytes_since_update: u64 = 0;
378        let update_interval = Duration::from_millis(250); // Update progress 4 times per second
379
380        while let Some(chunk_result) = tokio::select! {
381            chunk = stream.next() => chunk,
382            _ = cancel_token.cancelled() => {
383                file.flush().await.ok();
384                return Err(EngineError::Shutdown);
385            }
386        } {
387            let chunk: bytes::Bytes = chunk_result.map_err(|e: reqwest::Error| {
388                EngineError::network(NetworkErrorKind::Other, format!("Stream error: {}", e))
389            })?;
390
391            let chunk_len = chunk.len() as u64;
392
393            // Apply rate limiting if configured
394            self.pool.acquire_download(chunk_len).await;
395
396            // Write chunk to file
397            file.write_all(&chunk).await.map_err(|e| {
398                EngineError::storage(
399                    StorageErrorKind::Io,
400                    PathBuf::new(),
401                    format!("Failed to write: {}", e),
402                )
403            })?;
404
405            // Record downloaded bytes for stats
406            self.pool.record_download(chunk_len);
407
408            // Update counters
409            let new_total = downloaded.fetch_add(chunk_len, Ordering::Relaxed) + chunk_len;
410            bytes_since_update += chunk_len;
411
412            // Emit progress at intervals
413            let now = Instant::now();
414            if now.duration_since(last_update) >= update_interval {
415                let elapsed_secs = now.duration_since(last_update).as_secs_f64();
416                let speed = if elapsed_secs > 0.0 {
417                    (bytes_since_update as f64 / elapsed_secs) as u64
418                } else {
419                    0
420                };
421
422                progress_callback(new_total, speed);
423
424                last_update = now;
425                bytes_since_update = 0;
426            }
427        }
428
429        // Flush and sync
430        file.flush().await.map_err(|e| {
431            EngineError::storage(
432                StorageErrorKind::Io,
433                PathBuf::new(),
434                format!("Failed to flush: {}", e),
435            )
436        })?;
437
438        file.sync_all().await.map_err(|e| {
439            EngineError::storage(
440                StorageErrorKind::Io,
441                PathBuf::new(),
442                format!("Failed to sync: {}", e),
443            )
444        })?;
445
446        // Final progress update
447        let final_size = downloaded.load(Ordering::Relaxed);
448        progress_callback(final_size, 0);
449
450        // Validate received size matches expected (if known)
451        if let Some(expected) = total_size {
452            if final_size < expected {
453                return Err(EngineError::network(
454                    NetworkErrorKind::Other,
455                    format!(
456                        "Incomplete download: received {} bytes, expected {} bytes",
457                        final_size, expected
458                    ),
459                ));
460            }
461        }
462
463        Ok(())
464    }
465
466    /// Download a file using multiple connections (segmented download)
467    ///
468    /// This method probes the server first and uses segmented downloads
469    /// if the server supports Range requests and the file is large enough.
470    #[allow(clippy::too_many_arguments)]
471    /// Download with segmented multi-connection support.
472    ///
473    /// Returns the final path and optionally an Arc reference to the SegmentedDownload
474    /// (only when using segmented download mode).
475    pub async fn download_segmented<F>(
476        &self,
477        url: &str,
478        save_dir: &Path,
479        filename: Option<&str>,
480        user_agent: Option<&str>,
481        referer: Option<&str>,
482        headers: &[(String, String)],
483        cookies: Option<&[String]>,
484        checksum: Option<&ExpectedChecksum>,
485        max_connections: usize,
486        min_segment_size: u64,
487        cancel_token: CancellationToken,
488        saved_segments: Option<Vec<Segment>>,
489        progress_callback: F,
490        segmented_ref: Option<Arc<RwLock<Option<Arc<SegmentedDownload>>>>>,
491    ) -> Result<(PathBuf, Option<Arc<SegmentedDownload>>)>
492    where
493        F: Fn(DownloadProgress) + Send + Sync + 'static,
494    {
495        let ua = user_agent.unwrap_or(&self.config.default_user_agent);
496
497        // Probe server capabilities
498        let capabilities = probe_server(self.client(), url, ua).await?;
499
500        // Determine filename
501        let final_filename = filename
502            .map(|s| s.to_string())
503            .or(capabilities.suggested_filename.clone())
504            .or_else(|| extract_filename_from_url(url))
505            .unwrap_or_else(|| "download".to_string());
506
507        // Ensure save directory exists
508        if !save_dir.exists() {
509            tokio::fs::create_dir_all(save_dir).await.map_err(|e| {
510                EngineError::storage(
511                    StorageErrorKind::Io,
512                    save_dir,
513                    format!("Failed to create directory: {}", e),
514                )
515            })?;
516        }
517
518        let save_path = save_dir.join(&final_filename);
519
520        // Decide whether to use segmented download
521        let use_segmented = capabilities.supports_range
522            && capabilities
523                .content_length
524                .map(|l| l > min_segment_size)
525                .unwrap_or(false);
526
527        if use_segmented {
528            let total_size = capabilities.content_length.unwrap();
529
530            // Create segmented download
531            let mut download = SegmentedDownload::new(
532                url.to_string(),
533                total_size,
534                save_path.clone(),
535                true,
536                capabilities.etag,
537                capabilities.last_modified,
538            );
539
540            // Restore or initialize segments
541            if let Some(segments) = saved_segments {
542                tracing::debug!("Restoring {} saved segments", segments.len());
543                download.restore_segments(segments);
544            } else {
545                download.init_segments(max_connections, min_segment_size);
546            }
547
548            // Wrap in Arc for sharing
549            let download = Arc::new(download);
550            let download_ref = Arc::clone(&download);
551
552            // Populate shared reference for external access during download (for persistence)
553            if let Some(ref slot) = segmented_ref {
554                *slot.write() = Some(Arc::clone(&download));
555            }
556
557            // Build headers vec
558            let mut all_headers = headers.to_vec();
559            if let Some(r) = referer {
560                all_headers.push(("Referer".to_string(), r.to_string()));
561            }
562            // Add cookies to headers
563            if let Some(cookie_list) = cookies {
564                if !cookie_list.is_empty() {
565                    all_headers.push(("Cookie".to_string(), cookie_list.join("; ")));
566                }
567            }
568
569            // Start download
570            download
571                .start(
572                    self.client(),
573                    ua,
574                    &all_headers,
575                    max_connections,
576                    cancel_token,
577                    progress_callback,
578                )
579                .await?;
580
581            // Verify checksum if provided
582            if let Some(expected) = checksum {
583                let verified = verify_checksum(&save_path, expected).await?;
584                if !verified {
585                    let actual = compute_checksum(&save_path, expected.algorithm).await?;
586                    return Err(checksum::checksum_mismatch_error(&expected.value, &actual));
587                }
588                tracing::debug!("Checksum verified: {} matches expected", expected.algorithm);
589            }
590
591            Ok((save_path, Some(download_ref)))
592        } else {
593            // Fall back to single-connection download
594            let path = self
595                .download(
596                    url,
597                    save_dir,
598                    Some(&final_filename),
599                    user_agent,
600                    referer,
601                    headers,
602                    cookies,
603                    checksum,
604                    cancel_token,
605                    progress_callback,
606                )
607                .await?;
608            Ok((path, None))
609        }
610    }
611}
612
613/// Parse filename from Content-Disposition header
614fn parse_content_disposition(header: &str) -> Option<String> {
615    // Look for filename="..." or filename*=UTF-8''...
616    if let Some(start) = header.find("filename=") {
617        let rest = &header[start + 9..];
618        if let Some(stripped) = rest.strip_prefix('"') {
619            // Quoted filename
620            let end = stripped.find('"')?;
621            return Some(stripped[..end].to_string());
622        } else {
623            // Unquoted filename
624            let end = rest.find(';').unwrap_or(rest.len());
625            return Some(rest[..end].trim().to_string());
626        }
627    }
628
629    if let Some(start) = header.find("filename*=") {
630        let rest = &header[start + 10..];
631        // UTF-8'' prefix
632        if let Some(quote_start) = rest.find("''") {
633            let encoded = &rest[quote_start + 2..];
634            let end = encoded.find(';').unwrap_or(encoded.len());
635            // URL decode
636            if let Ok(decoded) = urlencoding::decode(&encoded[..end]) {
637                return Some(decoded.to_string());
638            }
639        }
640    }
641
642    None
643}
644
645/// Extract filename from URL path
646fn extract_filename_from_url(url: &str) -> Option<String> {
647    url::Url::parse(url)
648        .ok()?
649        .path_segments()?
650        .next_back()
651        .filter(|s| !s.is_empty())
652        .map(|s| {
653            // URL decode the filename
654            urlencoding::decode(s)
655                .map(|d| d.to_string())
656                .unwrap_or_else(|_| s.to_string())
657        })
658}
659
660#[cfg(test)]
661mod tests {
662    use super::*;
663
664    #[test]
665    fn test_parse_content_disposition() {
666        assert_eq!(
667            parse_content_disposition("attachment; filename=\"test.zip\""),
668            Some("test.zip".to_string())
669        );
670
671        assert_eq!(
672            parse_content_disposition("attachment; filename=test.zip"),
673            Some("test.zip".to_string())
674        );
675    }
676
677    #[test]
678    fn test_extract_filename_from_url() {
679        assert_eq!(
680            extract_filename_from_url("https://example.com/path/to/file.zip"),
681            Some("file.zip".to_string())
682        );
683
684        assert_eq!(
685            extract_filename_from_url("https://example.com/path/to/file%20name.zip"),
686            Some("file name.zip".to_string())
687        );
688    }
689}