Skip to main content

pulith_fetch/fetch/
segmented.rs

1//! Segmented download functionality.
2//!
3//! This module provides the ability to download files in parallel
4//! segments for improved performance.
5
6use std::path::{Path, PathBuf};
7use std::sync::Arc;
8
9use futures_util::StreamExt;
10use futures_util::stream::FuturesUnordered;
11use pulith_fs::workflow::Workspace;
12use pulith_verify::{Hasher, Sha256Hasher};
13use tokio::io::{AsyncReadExt, AsyncWriteExt};
14use tokio::sync::Semaphore;
15
16use crate::config::{FetchOptions, FetchPhase};
17use crate::error::{Error, Result};
18use crate::net::http::HttpClient;
19use crate::progress::Progress;
20use crate::segment::{Segment, calculate_segments};
21
22/// Configuration for segmented downloads.
23#[derive(Debug, Clone)]
24pub struct SegmentedOptions {
25    /// Number of segments to download in parallel
26    pub num_segments: u32,
27    /// Maximum concurrent downloads
28    pub max_concurrent: usize,
29}
30
31impl Default for SegmentedOptions {
32    fn default() -> Self {
33        Self {
34            num_segments: 4,
35            max_concurrent: 4,
36        }
37    }
38}
39
40/// Segmented fetcher implementation.
41pub struct SegmentedFetcher<C: HttpClient> {
42    client: Arc<C>,
43    workspace_root: PathBuf,
44}
45
46impl<C: HttpClient + 'static> SegmentedFetcher<C> {
47    /// Create a new segmented fetcher.
48    pub fn new(client: C, workspace_root: impl Into<PathBuf>) -> Self {
49        Self {
50            client: Arc::new(client),
51            workspace_root: workspace_root.into(),
52        }
53    }
54
55    /// Fetch a file using segmented downloads.
56    pub async fn fetch_segmented(
57        &self,
58        url: &str,
59        destination: &Path,
60        options: SegmentedOptions,
61        fetch_options: FetchOptions,
62    ) -> Result<PathBuf> {
63        // Get file size first
64        let total_bytes = self
65            .client
66            .head(url)
67            .await
68            .map_err(|e| Error::Network(e.to_string()))?;
69
70        // Calculate segments
71        let segments = calculate_segments(total_bytes.unwrap_or(0), options.num_segments)?;
72
73        // Create workspace for staging
74        let staging_dir = self.workspace_root.join("staging");
75        let workspace = Workspace::new(
76            &staging_dir,
77            destination.parent().unwrap_or_else(|| Path::new(".")),
78        )?;
79
80        // Download segments in parallel
81        let segment_files = self
82            .download_segments(
83                url,
84                &segments,
85                &workspace,
86                &fetch_options,
87                options.max_concurrent,
88            )
89            .await?;
90
91        // Reassemble segments and commit workspace
92        self.reassemble_segments(
93            &segment_files,
94            destination,
95            workspace,
96            &fetch_options,
97            total_bytes,
98        )
99        .await?;
100
101        Ok(destination.to_path_buf())
102    }
103
104    /// Download all segments in parallel.
105    async fn download_segments(
106        &self,
107        url: &str,
108        segments: &[Segment],
109        workspace: &Workspace,
110        options: &FetchOptions,
111        max_concurrent: usize,
112    ) -> Result<Vec<PathBuf>> {
113        let semaphore = Arc::new(Semaphore::new(max_concurrent));
114        let mut futures = FuturesUnordered::new();
115
116        for segment in segments {
117            let permit = semaphore
118                .clone()
119                .acquire_owned()
120                .await
121                .map_err(|e| Error::Network(e.to_string()))?;
122            let client = self.client.clone();
123            let url = url.to_string();
124            let workspace_path = workspace.path().to_path_buf();
125            let segment_clone = segment.clone();
126            let options_clone = options.clone();
127
128            let future = tokio::spawn(async move {
129                let _permit = permit;
130                let segment_path = workspace_path.join(format!("segment_{}", segment_clone.index));
131
132                // Create Range header for this segment
133                let range_header =
134                    format!("bytes={}-{}", segment_clone.start, segment_clone.end - 1);
135                let mut segment_options = options_clone;
136                let mut headers: Vec<_> = segment_options.headers.iter().cloned().collect();
137                headers.push(("Range".to_string(), range_header));
138                segment_options.headers = Arc::from(headers);
139
140                // Download the segment
141                let mut stream = client
142                    .stream(&url, &segment_options.headers)
143                    .await
144                    .map_err(|e| Error::Network(e.to_string()))?;
145                let mut file = tokio::fs::File::create(&segment_path)
146                    .await
147                    .map_err(|e| Error::Network(e.to_string()))?;
148
149                while let Some(chunk_result) = stream.next().await {
150                    let chunk = chunk_result.map_err(|e| Error::Network(e.to_string()))?;
151                    file.write_all(&chunk)
152                        .await
153                        .map_err(|e| Error::Network(e.to_string()))?;
154                }
155
156                Ok::<PathBuf, Error>(segment_path)
157            });
158
159            futures.push(future);
160        }
161
162        // Wait for all downloads to complete
163        let mut segment_files = Vec::with_capacity(segments.len());
164        while let Some(result) = futures.next().await {
165            match result {
166                Ok(segment_result) => match segment_result {
167                    Ok(path) => segment_files.push(path),
168                    Err(e) => return Err(e),
169                },
170                Err(e) => return Err(Error::Network(e.to_string())),
171            }
172        }
173
174        // Sort by segment index to ensure correct order
175        segment_files.sort_by_key(|path| {
176            let filename = path.file_name().unwrap().to_str().unwrap();
177            filename
178                .split('_')
179                .next_back()
180                .unwrap()
181                .parse::<u32>()
182                .unwrap()
183        });
184
185        Ok(segment_files)
186    }
187
188    /// Reassemble segments into the final file.
189    async fn reassemble_segments(
190        &self,
191        segment_files: &[PathBuf],
192        destination: &Path,
193        workspace: Workspace,
194        options: &FetchOptions,
195        total_bytes: Option<u64>,
196    ) -> Result<()> {
197        let staging_file_path = workspace.path().join(
198            destination
199                .file_name()
200                .unwrap_or_else(|| std::ffi::OsStr::new("download")),
201        );
202        let mut output_file = tokio::fs::File::create(&staging_file_path)
203            .await
204            .map_err(|e| Error::Network(e.to_string()))?;
205        let mut hasher = Sha256Hasher::new();
206        let mut bytes_downloaded = 0u64;
207
208        // Report initial progress
209        self.report_progress(
210            options,
211            Progress {
212                phase: FetchPhase::Downloading,
213                bytes_downloaded: 0,
214                total_bytes,
215                retry_count: 0,
216                performance_metrics: None,
217            },
218        );
219
220        // Copy segments in order
221        for segment_path in segment_files {
222            let mut segment_file = tokio::fs::File::open(segment_path)
223                .await
224                .map_err(|e| Error::Network(e.to_string()))?;
225
226            let mut buffer = vec![0u8; 65536]; // 64KB buffer for better I/O performance
227            loop {
228                let n = segment_file
229                    .read(&mut buffer)
230                    .await
231                    .map_err(|e| Error::Network(e.to_string()))?;
232                if n == 0 {
233                    break;
234                }
235
236                hasher.update(&buffer[..n]);
237                output_file
238                    .write_all(&buffer[..n])
239                    .await
240                    .map_err(|e| Error::Network(e.to_string()))?;
241                bytes_downloaded += n as u64;
242
243                // Report progress
244                self.report_progress(
245                    options,
246                    Progress {
247                        phase: FetchPhase::Downloading,
248                        bytes_downloaded,
249                        total_bytes,
250                        retry_count: 0,
251                        performance_metrics: None,
252                    },
253                );
254            }
255
256            // Clean up segment file
257            tokio::fs::remove_file(segment_path)
258                .await
259                .map_err(|e| Error::Network(e.to_string()))?;
260        }
261
262        // Verify checksum if provided
263        if let Some(expected_checksum) = options.checksum {
264            self.report_progress(
265                options,
266                Progress {
267                    phase: FetchPhase::Verifying,
268                    bytes_downloaded,
269                    total_bytes,
270                    retry_count: 0,
271                    performance_metrics: None,
272                },
273            );
274
275            let actual_checksum = hasher.finalize();
276            if actual_checksum != expected_checksum {
277                return Err(Error::ChecksumMismatch {
278                    expected: hex::encode(expected_checksum),
279                    actual: hex::encode(actual_checksum),
280                });
281            }
282        }
283
284        // Move to final destination
285        self.report_progress(
286            options,
287            Progress {
288                phase: FetchPhase::Committing,
289                bytes_downloaded,
290                total_bytes,
291                retry_count: 0,
292                performance_metrics: None,
293            },
294        );
295
296        // Move the file to the final destination
297        tokio::fs::rename(&staging_file_path, destination)
298            .await
299            .map_err(|e| Error::Network(e.to_string()))?;
300        workspace
301            .commit()
302            .map_err(|e| Error::Network(e.to_string()))?;
303
304        self.report_progress(
305            options,
306            Progress {
307                phase: FetchPhase::Completed,
308                bytes_downloaded,
309                total_bytes,
310                retry_count: 0,
311                performance_metrics: None,
312            },
313        );
314
315        tokio::fs::rename(&staging_file_path, destination)
316            .await
317            .map_err(|e| Error::Network(e.to_string()))?;
318
319        self.report_progress(
320            options,
321            Progress {
322                phase: FetchPhase::Completed,
323                bytes_downloaded,
324                total_bytes,
325                retry_count: 0,
326                performance_metrics: None,
327            },
328        );
329
330        Ok(())
331    }
332
333    /// Report progress if callback is configured.
334    fn report_progress(&self, options: &FetchOptions, progress: Progress) {
335        if let Some(ref callback) = options.on_progress {
336            callback(&progress);
337        }
338    }
339}
340
341#[cfg(test)]
342mod tests {
343    use crate::calculate_segments;
344
345    #[test]
346    fn test_segment_calculation() {
347        // Test basic segment calculation
348        let segments = calculate_segments(100, 4).unwrap();
349        assert_eq!(segments.len(), 4);
350        assert_eq!(segments[0].start, 0);
351        assert_eq!(segments[0].end, 25);
352        assert_eq!(segments[3].start, 75);
353        assert_eq!(segments[3].end, 100);
354
355        // Test with remainder
356        let segments = calculate_segments(10, 3).unwrap();
357        assert_eq!(segments.len(), 3);
358        assert_eq!(segments[0].end, 4); // First segment gets extra byte
359        assert_eq!(segments[1].end, 7); // Second segment gets extra byte
360        assert_eq!(segments[2].end, 10);
361
362        // Test zero file size
363        let segments = calculate_segments(0, 4).unwrap();
364        assert_eq!(segments.len(), 1);
365        assert_eq!(segments[0].start, 0);
366        assert_eq!(segments[0].end, 0);
367    }
368
369    #[test]
370    fn test_segment_calculation_errors() {
371        // Test zero segments
372        let result = calculate_segments(100, 0);
373        assert!(result.is_err());
374    }
375}