gosh_dl/http/
segment.rs

1//! Segmented Download Support
2//!
3//! This module provides multi-connection segmented downloads for faster
4//! HTTP/HTTPS transfers. It splits files into segments and downloads
5//! them in parallel using multiple connections.
6
7use crate::error::{EngineError, NetworkErrorKind, Result, StorageErrorKind};
8use crate::storage::Segment;
9use crate::types::DownloadProgress;
10
11use bytes::Bytes;
12use futures::stream::StreamExt;
13use parking_lot::RwLock;
14use reqwest::Client;
15use std::path::PathBuf;
16use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
17use std::sync::Arc;
18use std::time::{Duration, Instant};
19use tokio::fs::{File, OpenOptions};
20use tokio::io::{AsyncSeekExt, AsyncWriteExt, SeekFrom};
21use tokio::sync::Semaphore;
22use tokio_util::sync::CancellationToken;
23
24/// Minimum segment size (1 MiB)
25pub const MIN_SEGMENT_SIZE: u64 = 1024 * 1024;
26
27/// Default number of connections per download
28pub const DEFAULT_CONNECTIONS: usize = 16;
29
30/// Progress update interval
31const PROGRESS_INTERVAL: Duration = Duration::from_millis(250);
32
33/// Persistence interval for segment state
34const PERSISTENCE_INTERVAL: Duration = Duration::from_secs(5);
35
36/// Shared state for a segmented download
37struct SharedState {
38    /// Total bytes downloaded across all segments
39    downloaded: AtomicU64,
40    /// Current download speed (bytes/sec)
41    speed: AtomicU64,
42    /// Number of active connections
43    active_connections: AtomicU64,
44    /// Whether download is paused
45    paused: AtomicBool,
46    /// Per-segment downloaded bytes (for tracking progress)
47    segment_progress: RwLock<Vec<u64>>,
48    /// Last persistence time
49    last_persistence: RwLock<Instant>,
50}
51
52/// Segmented download manager
53pub struct SegmentedDownload {
54    /// URL to download from
55    url: String,
56    /// Total file size
57    total_size: u64,
58    /// Path to save the file
59    save_path: PathBuf,
60    /// Segments
61    segments: Vec<Segment>,
62    /// Whether server supports range requests (stored for resume validation)
63    #[allow(dead_code)]
64    supports_range: bool,
65    /// ETag for validation
66    etag: Option<String>,
67    /// Last-Modified for validation (stored for resume validation)
68    #[allow(dead_code)]
69    last_modified: Option<String>,
70    /// Shared state (wrapped in Arc for task sharing)
71    state: Arc<SharedState>,
72}
73
74/// Server capabilities determined from HEAD request
75#[derive(Debug, Clone)]
76pub struct ServerCapabilities {
77    /// Content-Length header value
78    pub content_length: Option<u64>,
79    /// Whether server supports Range requests
80    pub supports_range: bool,
81    /// ETag header for validation
82    pub etag: Option<String>,
83    /// Last-Modified header for validation
84    pub last_modified: Option<String>,
85    /// Suggested filename from Content-Disposition
86    pub suggested_filename: Option<String>,
87}
88
89impl SegmentedDownload {
90    /// Create a new segmented download
91    pub fn new(
92        url: String,
93        total_size: u64,
94        save_path: PathBuf,
95        supports_range: bool,
96        etag: Option<String>,
97        last_modified: Option<String>,
98    ) -> Self {
99        Self {
100            url,
101            total_size,
102            save_path,
103            segments: Vec::new(),
104            supports_range,
105            etag,
106            last_modified,
107            state: Arc::new(SharedState {
108                downloaded: AtomicU64::new(0),
109                speed: AtomicU64::new(0),
110                active_connections: AtomicU64::new(0),
111                paused: AtomicBool::new(false),
112                segment_progress: RwLock::new(Vec::new()),
113                last_persistence: RwLock::new(Instant::now()),
114            }),
115        }
116    }
117
118    /// Initialize segments for a new download
119    pub fn init_segments(&mut self, max_connections: usize, min_segment_size: u64) {
120        let num_segments =
121            calculate_segment_count(self.total_size, max_connections, min_segment_size);
122        let segment_size = self.total_size / num_segments as u64;
123
124        let mut segments = Vec::with_capacity(num_segments);
125        for i in 0..num_segments {
126            let start = i as u64 * segment_size;
127            let end = if i == num_segments - 1 {
128                self.total_size - 1
129            } else {
130                (i as u64 + 1) * segment_size - 1
131            };
132            segments.push(Segment::new(i, start, end));
133        }
134
135        // Initialize segment progress tracking
136        *self.state.segment_progress.write() = vec![0u64; num_segments];
137
138        self.segments = segments;
139    }
140
141    /// Restore segments from saved state
142    pub fn restore_segments(&mut self, saved_segments: Vec<Segment>) {
143        // Calculate total already downloaded
144        let downloaded: u64 = saved_segments.iter().map(|s| s.downloaded).sum();
145        self.state.downloaded.store(downloaded, Ordering::Relaxed);
146
147        // Initialize segment progress tracking with saved values
148        let progress: Vec<u64> = saved_segments.iter().map(|s| s.downloaded).collect();
149        *self.state.segment_progress.write() = progress;
150
151        self.segments = saved_segments;
152    }
153
154    /// Get current segments
155    pub fn segments(&self) -> &[Segment] {
156        &self.segments
157    }
158
159    /// Get segments with current progress updated
160    ///
161    /// This creates a snapshot of the current segment state for persistence.
162    pub fn segments_with_progress(&self) -> Vec<Segment> {
163        let progress = self.state.segment_progress.read();
164        self.segments
165            .iter()
166            .enumerate()
167            .map(|(idx, s)| {
168                let mut segment = s.clone();
169                if let Some(&downloaded) = progress.get(idx) {
170                    segment.downloaded = downloaded;
171                    if segment.downloaded >= segment.size() {
172                        segment.state = crate::storage::SegmentState::Completed;
173                    } else if segment.downloaded > 0 {
174                        segment.state = crate::storage::SegmentState::Downloading;
175                    }
176                }
177                segment
178            })
179            .collect()
180    }
181
182    /// Start the segmented download
183    pub async fn start<F>(
184        &self,
185        client: &Client,
186        user_agent: &str,
187        headers: &[(String, String)],
188        max_connections: usize,
189        cancel_token: CancellationToken,
190        progress_callback: F,
191    ) -> Result<()>
192    where
193        F: Fn(DownloadProgress) + Send + Sync + 'static,
194    {
195        // Create/open the file and pre-allocate space
196        let file = self.prepare_file().await?;
197        let file = Arc::new(tokio::sync::Mutex::new(file));
198
199        // Create semaphore for connection limiting
200        let semaphore = Arc::new(Semaphore::new(max_connections));
201
202        // Shared state for progress tracking
203        let progress_callback = Arc::new(progress_callback);
204        let last_progress = Arc::new(RwLock::new(Instant::now()));
205        let bytes_since_progress = Arc::new(AtomicU64::new(0));
206
207        // Clone segments data for tasks
208        let segments_data: Vec<_> = self
209            .segments
210            .iter()
211            .enumerate()
212            .filter(|(_, s)| !s.is_complete())
213            .map(|(idx, s)| (idx, s.start, s.end, s.downloaded))
214            .collect();
215
216        // Spawn tasks for each pending segment
217        let mut handles = Vec::new();
218
219        for (segment_idx, start, end, already_downloaded) in segments_data {
220            let client = client.clone();
221            let url = self.url.clone();
222            let user_agent = user_agent.to_string();
223            let headers = headers.to_vec();
224            let file = Arc::clone(&file);
225            let semaphore = Arc::clone(&semaphore);
226            let cancel_token = cancel_token.clone();
227            let etag = self.etag.clone();
228            let state = Arc::clone(&self.state);
229            let progress_callback = Arc::clone(&progress_callback);
230            let last_progress = Arc::clone(&last_progress);
231            let bytes_since_progress = Arc::clone(&bytes_since_progress);
232            let total_size = self.total_size;
233
234            let handle = tokio::spawn(async move {
235                // Acquire permit
236                let _permit = semaphore
237                    .acquire()
238                    .await
239                    .map_err(|_| EngineError::Shutdown)?;
240
241                // Check cancellation
242                if cancel_token.is_cancelled() {
243                    return Ok(());
244                }
245
246                // Check if paused
247                if state.paused.load(Ordering::Relaxed) {
248                    return Ok(());
249                }
250
251                state.active_connections.fetch_add(1, Ordering::Relaxed);
252
253                // Adjusted start position for resume
254                let resume_start = start + already_downloaded;
255                if resume_start > end {
256                    // Already complete
257                    state.active_connections.fetch_sub(1, Ordering::Relaxed);
258                    return Ok(());
259                }
260
261                // Build request with Range header
262                let mut request = client.get(&url);
263                request = request.header("User-Agent", &user_agent);
264                request = request.header("Range", format!("bytes={}-{}", resume_start, end));
265
266                // Add ETag for validation if available
267                if let Some(ref etag_val) = etag {
268                    request = request.header("If-Range", etag_val);
269                }
270
271                // Add custom headers
272                for (name, value) in &headers {
273                    request = request.header(name.as_str(), value.as_str());
274                }
275
276                // Send request
277                let response = request.send().await.map_err(|e| {
278                    EngineError::network(
279                        NetworkErrorKind::Other,
280                        format!("Segment {} request failed: {}", segment_idx, e),
281                    )
282                })?;
283
284                let status = response.status();
285
286                // Handle 416 Range Not Satisfiable - file may have changed on server
287                if status == reqwest::StatusCode::RANGE_NOT_SATISFIABLE {
288                    state.active_connections.fetch_sub(1, Ordering::Relaxed);
289                    return Err(EngineError::network(
290                        NetworkErrorKind::HttpStatus(416),
291                        format!(
292                            "Segment {} range not satisfiable (file may have changed on server)",
293                            segment_idx
294                        ),
295                    ));
296                }
297
298                if !status.is_success() && status != reqwest::StatusCode::PARTIAL_CONTENT {
299                    state.active_connections.fetch_sub(1, Ordering::Relaxed);
300                    return Err(EngineError::network(
301                        NetworkErrorKind::HttpStatus(status.as_u16()),
302                        format!("Segment {} HTTP error: {}", segment_idx, status),
303                    ));
304                }
305
306                // Validate Content-Range header matches our request (security check)
307                if status == reqwest::StatusCode::PARTIAL_CONTENT {
308                    if let Some(content_range) = response.headers().get("content-range") {
309                        if let Ok(range_str) = content_range.to_str() {
310                            // Expected format: "bytes START-END/TOTAL" or "bytes START-END/*"
311                            if let Some(range_part) = range_str.strip_prefix("bytes ") {
312                                if let Some((range, _)) = range_part.split_once('/') {
313                                    if let Some((start_str, end_str)) = range.split_once('-') {
314                                        let range_start: u64 = start_str.parse().unwrap_or(0);
315                                        let range_end: u64 = end_str.parse().unwrap_or(0);
316
317                                        // Verify the server is sending the range we requested
318                                        if range_start != resume_start || range_end != end {
319                                            state
320                                                .active_connections
321                                                .fetch_sub(1, Ordering::Relaxed);
322                                            return Err(EngineError::network(
323                                                NetworkErrorKind::Other,
324                                                format!(
325                                                    "Segment {} Content-Range mismatch: requested {}-{}, got {}-{}",
326                                                    segment_idx, resume_start, end, range_start, range_end
327                                                ),
328                                            ));
329                                        }
330                                    }
331                                }
332                            }
333                        }
334                    }
335                }
336
337                // Stream data to file
338                let mut stream = response.bytes_stream();
339                let mut segment_bytes: u64 = already_downloaded;
340                let mut last_speed_update = Instant::now();
341                let mut bytes_for_speed: u64 = 0;
342
343                while let Some(chunk_result) = tokio::select! {
344                    chunk = stream.next() => chunk,
345                    _ = cancel_token.cancelled() => None,
346                } {
347                    // Check pause
348                    if state.paused.load(Ordering::Relaxed) {
349                        break;
350                    }
351
352                    let chunk: Bytes = match chunk_result {
353                        Ok(c) => c,
354                        Err(e) => {
355                            state.active_connections.fetch_sub(1, Ordering::Relaxed);
356                            return Err(EngineError::network(
357                                NetworkErrorKind::Other,
358                                format!("Segment {} stream error: {}", segment_idx, e),
359                            ));
360                        }
361                    };
362
363                    let chunk_len = chunk.len() as u64;
364
365                    // Write to file at correct offset
366                    {
367                        let mut file = file.lock().await;
368                        file.seek(SeekFrom::Start(start + segment_bytes))
369                            .await
370                            .map_err(|e| {
371                                EngineError::storage(
372                                    StorageErrorKind::Io,
373                                    PathBuf::new(),
374                                    format!("Seek failed: {}", e),
375                                )
376                            })?;
377                        file.write_all(&chunk).await.map_err(|e| {
378                            EngineError::storage(
379                                StorageErrorKind::Io,
380                                PathBuf::new(),
381                                format!("Write failed: {}", e),
382                            )
383                        })?;
384                    }
385
386                    segment_bytes += chunk_len;
387
388                    // Update segment progress for persistence
389                    {
390                        let mut progress = state.segment_progress.write();
391                        if let Some(p) = progress.get_mut(segment_idx) {
392                            *p = segment_bytes;
393                        }
394                    }
395
396                    // Update global counters
397                    state.downloaded.fetch_add(chunk_len, Ordering::Relaxed);
398                    bytes_since_progress.fetch_add(chunk_len, Ordering::Relaxed);
399                    bytes_for_speed += chunk_len;
400
401                    // Update speed calculation
402                    let now = Instant::now();
403                    let speed_elapsed = now.duration_since(last_speed_update);
404                    if speed_elapsed >= Duration::from_millis(500) {
405                        let current_speed =
406                            (bytes_for_speed as f64 / speed_elapsed.as_secs_f64()) as u64;
407                        state.speed.store(current_speed, Ordering::Relaxed);
408                        bytes_for_speed = 0;
409                        last_speed_update = now;
410                    }
411
412                    // Emit progress at intervals
413                    // Calculate values and check if we should emit, then release lock before callback
414                    let should_emit = {
415                        let mut last = last_progress.write();
416                        if now.duration_since(*last) >= PROGRESS_INTERVAL {
417                            *last = now;
418                            bytes_since_progress.store(0, Ordering::Relaxed);
419                            true
420                        } else {
421                            false
422                        }
423                    };
424
425                    if should_emit {
426                        let total_downloaded = state.downloaded.load(Ordering::Relaxed);
427                        let current_speed = state.speed.load(Ordering::Relaxed);
428                        let connections = state.active_connections.load(Ordering::Relaxed) as u32;
429
430                        progress_callback(DownloadProgress {
431                            total_size: Some(total_size),
432                            completed_size: total_downloaded,
433                            download_speed: current_speed,
434                            upload_speed: 0,
435                            connections,
436                            seeders: 0,
437                            peers: 0,
438                            eta_seconds: if current_speed > 0 {
439                                Some((total_size.saturating_sub(total_downloaded)) / current_speed)
440                            } else {
441                                None
442                            },
443                        });
444                    }
445                }
446
447                state.active_connections.fetch_sub(1, Ordering::Relaxed);
448
449                // Segment task completed (either fully or paused/cancelled)
450                Result::<()>::Ok(())
451            });
452
453            handles.push(handle);
454        }
455
456        // Wait for all segment tasks to complete and collect errors
457        let mut segment_errors: Vec<String> = Vec::new();
458        for (idx, handle) in handles.into_iter().enumerate() {
459            match handle.await {
460                Err(e) => {
461                    // Task panicked
462                    tracing::error!("Segment {} task panicked: {:?}", idx, e);
463                    segment_errors.push(format!("Segment {} panicked: {:?}", idx, e));
464                }
465                Ok(Err(e)) => {
466                    // Task returned an error
467                    tracing::error!("Segment {} failed: {:?}", idx, e);
468                    segment_errors.push(format!("Segment {} failed: {}", idx, e));
469                }
470                Ok(Ok(())) => {
471                    // Task completed successfully
472                }
473            }
474        }
475
476        // If any segments failed, return error
477        if !segment_errors.is_empty() {
478            return Err(EngineError::network(
479                NetworkErrorKind::Other,
480                format!(
481                    "Download failed: {} segment(s) failed: {}",
482                    segment_errors.len(),
483                    segment_errors.join("; ")
484                ),
485            ));
486        }
487
488        // Sync file to disk
489        {
490            let mut file = file.lock().await;
491            file.flush().await.map_err(|e| {
492                EngineError::storage(
493                    StorageErrorKind::Io,
494                    &self.save_path,
495                    format!("Flush failed: {}", e),
496                )
497            })?;
498            file.sync_all().await.map_err(|e| {
499                EngineError::storage(
500                    StorageErrorKind::Io,
501                    &self.save_path,
502                    format!("Sync failed: {}", e),
503                )
504            })?;
505        }
506
507        // Final progress update
508        let total_downloaded = self.state.downloaded.load(Ordering::Relaxed);
509        progress_callback(DownloadProgress {
510            total_size: Some(self.total_size),
511            completed_size: total_downloaded,
512            download_speed: 0,
513            upload_speed: 0,
514            connections: 0,
515            seeders: 0,
516            peers: 0,
517            eta_seconds: None,
518        });
519
520        // Check if complete
521        if total_downloaded >= self.total_size {
522            // Rename from .part to final name
523            self.finalize().await?;
524        }
525
526        Ok(())
527    }
528
529    /// Check if persistence is due based on the time interval.
530    ///
531    /// Returns true if enough time has passed since the last persistence,
532    /// and resets the timer if so.
533    pub fn should_persist(&self) -> bool {
534        let mut last = self.state.last_persistence.write();
535        let now = Instant::now();
536        if now.duration_since(*last) >= PERSISTENCE_INTERVAL {
537            *last = now;
538            true
539        } else {
540            false
541        }
542    }
543
544    /// Force mark persistence as done (call after successful save).
545    pub fn mark_persisted(&self) {
546        *self.state.last_persistence.write() = Instant::now();
547    }
548
549    /// Prepare the output file
550    async fn prepare_file(&self) -> Result<File> {
551        // Use .part extension during download
552        let part_path = self.part_path();
553
554        // Ensure parent directory exists
555        if let Some(parent) = part_path.parent() {
556            tokio::fs::create_dir_all(parent).await.map_err(|e| {
557                EngineError::storage(
558                    StorageErrorKind::Io,
559                    parent,
560                    format!("Create dir failed: {}", e),
561                )
562            })?;
563        }
564
565        // Check if file exists (for resume)
566        let file = if part_path.exists() {
567            OpenOptions::new()
568                .write(true)
569                .read(true)
570                .open(&part_path)
571                .await
572                .map_err(|e| {
573                    EngineError::storage(
574                        StorageErrorKind::Io,
575                        &part_path,
576                        format!("Open failed: {}", e),
577                    )
578                })?
579        } else {
580            // Create new file and pre-allocate
581            let file = File::create(&part_path).await.map_err(|e| {
582                EngineError::storage(
583                    StorageErrorKind::Io,
584                    &part_path,
585                    format!("Create failed: {}", e),
586                )
587            })?;
588
589            // Pre-allocate space
590            file.set_len(self.total_size).await.map_err(|e| {
591                EngineError::storage(
592                    StorageErrorKind::Io,
593                    &part_path,
594                    format!("Pre-allocate failed: {}", e),
595                )
596            })?;
597
598            file
599        };
600
601        Ok(file)
602    }
603
604    /// Get the .part file path
605    fn part_path(&self) -> PathBuf {
606        let ext = self
607            .save_path
608            .extension()
609            .map(|e| format!("{}.part", e.to_string_lossy()))
610            .unwrap_or_else(|| "part".to_string());
611        self.save_path.with_extension(ext)
612    }
613
614    /// Rename .part file to final name
615    async fn finalize(&self) -> Result<()> {
616        let part_path = self.part_path();
617        if part_path.exists() {
618            tokio::fs::rename(&part_path, &self.save_path)
619                .await
620                .map_err(|e| {
621                    EngineError::storage(
622                        StorageErrorKind::Io,
623                        &self.save_path,
624                        format!("Rename failed: {}", e),
625                    )
626                })?;
627        }
628        Ok(())
629    }
630
631    /// Pause the download
632    pub fn pause(&self) {
633        self.state.paused.store(true, Ordering::Relaxed);
634    }
635
636    /// Check if download is complete
637    pub fn is_complete(&self) -> bool {
638        self.state.downloaded.load(Ordering::Relaxed) >= self.total_size
639    }
640
641    /// Get current progress
642    pub fn progress(&self) -> DownloadProgress {
643        DownloadProgress {
644            total_size: Some(self.total_size),
645            completed_size: self.state.downloaded.load(Ordering::Relaxed),
646            download_speed: self.state.speed.load(Ordering::Relaxed),
647            upload_speed: 0,
648            connections: self.state.active_connections.load(Ordering::Relaxed) as u32,
649            seeders: 0,
650            peers: 0,
651            eta_seconds: {
652                let speed = self.state.speed.load(Ordering::Relaxed);
653                let remaining = self
654                    .total_size
655                    .saturating_sub(self.state.downloaded.load(Ordering::Relaxed));
656                if speed > 0 {
657                    Some(remaining / speed)
658                } else {
659                    None
660                }
661            },
662        }
663    }
664}
665
666/// Calculate optimal number of segments based on file size and constraints
667pub fn calculate_segment_count(
668    total_size: u64,
669    max_connections: usize,
670    min_segment_size: u64,
671) -> usize {
672    if total_size == 0 {
673        return 1;
674    }
675
676    // Calculate maximum segments based on min_segment_size
677    let max_segments_by_size = (total_size / min_segment_size) as usize;
678
679    // Use the smaller of max_connections and max_segments_by_size
680    let num_segments = max_connections.min(max_segments_by_size.max(1));
681
682    // Ensure at least 1 segment
683    num_segments.max(1)
684}
685
686/// Probe server capabilities with a HEAD request
687pub async fn probe_server(
688    client: &Client,
689    url: &str,
690    user_agent: &str,
691) -> Result<ServerCapabilities> {
692    let response = client
693        .head(url)
694        .header("User-Agent", user_agent)
695        .send()
696        .await
697        .map_err(|e| {
698            EngineError::network(
699                NetworkErrorKind::Other,
700                format!("HEAD request failed: {}", e),
701            )
702        })?;
703
704    if !response.status().is_success() {
705        return Err(EngineError::network(
706            NetworkErrorKind::HttpStatus(response.status().as_u16()),
707            format!("HEAD request returned: {}", response.status()),
708        ));
709    }
710
711    let headers = response.headers();
712
713    let content_length = headers
714        .get("content-length")
715        .and_then(|v| v.to_str().ok())
716        .and_then(|s| s.parse::<u64>().ok());
717
718    let supports_range = headers
719        .get("accept-ranges")
720        .and_then(|v| v.to_str().ok())
721        .map(|v| v.contains("bytes"))
722        .unwrap_or(false);
723
724    let etag = headers
725        .get("etag")
726        .and_then(|v| v.to_str().ok())
727        .map(|s| s.to_string());
728
729    let last_modified = headers
730        .get("last-modified")
731        .and_then(|v| v.to_str().ok())
732        .map(|s| s.to_string());
733
734    let suggested_filename = headers
735        .get("content-disposition")
736        .and_then(|v| v.to_str().ok())
737        .and_then(parse_content_disposition);
738
739    Ok(ServerCapabilities {
740        content_length,
741        supports_range,
742        etag,
743        last_modified,
744        suggested_filename,
745    })
746}
747
748/// Parse filename from Content-Disposition header
749fn parse_content_disposition(header: &str) -> Option<String> {
750    // Look for filename="..." or filename*=UTF-8''...
751    if let Some(start) = header.find("filename=") {
752        let rest = &header[start + 9..];
753        if let Some(stripped) = rest.strip_prefix('"') {
754            let end = stripped.find('"')?;
755            return Some(stripped[..end].to_string());
756        } else {
757            let end = rest.find(';').unwrap_or(rest.len());
758            return Some(rest[..end].trim().to_string());
759        }
760    }
761
762    if let Some(start) = header.find("filename*=") {
763        let rest = &header[start + 10..];
764        if let Some(quote_start) = rest.find("''") {
765            let encoded = &rest[quote_start + 2..];
766            let end = encoded.find(';').unwrap_or(encoded.len());
767            if let Ok(decoded) = urlencoding::decode(&encoded[..end]) {
768                return Some(decoded.to_string());
769            }
770        }
771    }
772
773    None
774}
775
776#[cfg(test)]
777mod tests {
778    use super::*;
779
780    #[test]
781    fn test_calculate_segment_count() {
782        // 100MB file, 16 connections, 1MB min
783        assert_eq!(
784            calculate_segment_count(100 * 1024 * 1024, 16, 1024 * 1024),
785            16
786        );
787
788        // 10MB file, 16 connections, 1MB min -> only 10 segments
789        assert_eq!(
790            calculate_segment_count(10 * 1024 * 1024, 16, 1024 * 1024),
791            10
792        );
793
794        // 500KB file, 16 connections, 1MB min -> 1 segment
795        assert_eq!(calculate_segment_count(512 * 1024, 16, 1024 * 1024), 1);
796
797        // Empty file
798        assert_eq!(calculate_segment_count(0, 16, 1024 * 1024), 1);
799
800        // Very large file
801        assert_eq!(
802            calculate_segment_count(10 * 1024 * 1024 * 1024, 16, 1024 * 1024),
803            16
804        );
805    }
806
807    #[test]
808    fn test_segment_init() {
809        let mut download = SegmentedDownload::new(
810            "https://example.com/file.zip".to_string(),
811            100 * 1024 * 1024, // 100MB
812            PathBuf::from("/tmp/file.zip"),
813            true,
814            None,
815            None,
816        );
817
818        download.init_segments(16, 1024 * 1024);
819
820        let segments = download.segments();
821        assert_eq!(segments.len(), 16);
822
823        // Check segment boundaries
824        assert_eq!(segments[0].start, 0);
825        assert_eq!(segments[15].end, 100 * 1024 * 1024 - 1);
826
827        // Check segments are contiguous
828        for i in 0..15 {
829            assert_eq!(segments[i].end + 1, segments[i + 1].start);
830        }
831    }
832
833    #[test]
834    fn test_parse_content_disposition() {
835        assert_eq!(
836            parse_content_disposition("attachment; filename=\"test.zip\""),
837            Some("test.zip".to_string())
838        );
839
840        assert_eq!(
841            parse_content_disposition("attachment; filename=test.zip"),
842            Some("test.zip".to_string())
843        );
844
845        assert_eq!(
846            parse_content_disposition("attachment; filename*=UTF-8''test%20file.zip"),
847            Some("test file.zip".to_string())
848        );
849    }
850}