1use std::path::{Path, PathBuf};
7use std::sync::Arc;
8
9use futures_util::StreamExt;
10use futures_util::stream::FuturesUnordered;
11use pulith_fs::workflow::Workspace;
12use pulith_verify::{Hasher, Sha256Hasher};
13use tokio::io::{AsyncReadExt, AsyncWriteExt};
14use tokio::sync::Semaphore;
15
16use crate::config::{FetchOptions, FetchPhase};
17use crate::error::{Error, Result};
18use crate::net::http::HttpClient;
19use crate::progress::Progress;
20use crate::segment::{Segment, calculate_segments};
21
22#[derive(Debug, Clone)]
24pub struct SegmentedOptions {
25 pub num_segments: u32,
27 pub max_concurrent: usize,
29}
30
31impl Default for SegmentedOptions {
32 fn default() -> Self {
33 Self {
34 num_segments: 4,
35 max_concurrent: 4,
36 }
37 }
38}
39
40pub struct SegmentedFetcher<C: HttpClient> {
42 client: Arc<C>,
43 workspace_root: PathBuf,
44}
45
46impl<C: HttpClient + 'static> SegmentedFetcher<C> {
47 pub fn new(client: C, workspace_root: impl Into<PathBuf>) -> Self {
49 Self {
50 client: Arc::new(client),
51 workspace_root: workspace_root.into(),
52 }
53 }
54
55 pub async fn fetch_segmented(
57 &self,
58 url: &str,
59 destination: &Path,
60 options: SegmentedOptions,
61 fetch_options: FetchOptions,
62 ) -> Result<PathBuf> {
63 let total_bytes = self
65 .client
66 .head(url)
67 .await
68 .map_err(|e| Error::Network(e.to_string()))?;
69
70 let segments = calculate_segments(total_bytes.unwrap_or(0), options.num_segments)?;
72
73 let staging_dir = self.workspace_root.join("staging");
75 let workspace = Workspace::new(
76 &staging_dir,
77 destination.parent().unwrap_or_else(|| Path::new(".")),
78 )?;
79
80 let segment_files = self
82 .download_segments(
83 url,
84 &segments,
85 &workspace,
86 &fetch_options,
87 options.max_concurrent,
88 )
89 .await?;
90
91 self.reassemble_segments(
93 &segment_files,
94 destination,
95 workspace,
96 &fetch_options,
97 total_bytes,
98 )
99 .await?;
100
101 Ok(destination.to_path_buf())
102 }
103
104 async fn download_segments(
106 &self,
107 url: &str,
108 segments: &[Segment],
109 workspace: &Workspace,
110 options: &FetchOptions,
111 max_concurrent: usize,
112 ) -> Result<Vec<PathBuf>> {
113 let semaphore = Arc::new(Semaphore::new(max_concurrent));
114 let mut futures = FuturesUnordered::new();
115
116 for segment in segments {
117 let permit = semaphore
118 .clone()
119 .acquire_owned()
120 .await
121 .map_err(|e| Error::Network(e.to_string()))?;
122 let client = self.client.clone();
123 let url = url.to_string();
124 let workspace_path = workspace.path().to_path_buf();
125 let segment_clone = segment.clone();
126 let options_clone = options.clone();
127
128 let future = tokio::spawn(async move {
129 let _permit = permit;
130 let segment_path = workspace_path.join(format!("segment_{}", segment_clone.index));
131
132 let range_header =
134 format!("bytes={}-{}", segment_clone.start, segment_clone.end - 1);
135 let mut segment_options = options_clone;
136 let mut headers: Vec<_> = segment_options.headers.iter().cloned().collect();
137 headers.push(("Range".to_string(), range_header));
138 segment_options.headers = Arc::from(headers);
139
140 let mut stream = client
142 .stream(&url, &segment_options.headers)
143 .await
144 .map_err(|e| Error::Network(e.to_string()))?;
145 let mut file = tokio::fs::File::create(&segment_path)
146 .await
147 .map_err(|e| Error::Network(e.to_string()))?;
148
149 while let Some(chunk_result) = stream.next().await {
150 let chunk = chunk_result.map_err(|e| Error::Network(e.to_string()))?;
151 file.write_all(&chunk)
152 .await
153 .map_err(|e| Error::Network(e.to_string()))?;
154 }
155
156 Ok::<PathBuf, Error>(segment_path)
157 });
158
159 futures.push(future);
160 }
161
162 let mut segment_files = Vec::with_capacity(segments.len());
164 while let Some(result) = futures.next().await {
165 match result {
166 Ok(segment_result) => match segment_result {
167 Ok(path) => segment_files.push(path),
168 Err(e) => return Err(e),
169 },
170 Err(e) => return Err(Error::Network(e.to_string())),
171 }
172 }
173
174 segment_files.sort_by_key(|path| {
176 let filename = path.file_name().unwrap().to_str().unwrap();
177 filename
178 .split('_')
179 .next_back()
180 .unwrap()
181 .parse::<u32>()
182 .unwrap()
183 });
184
185 Ok(segment_files)
186 }
187
188 async fn reassemble_segments(
190 &self,
191 segment_files: &[PathBuf],
192 destination: &Path,
193 workspace: Workspace,
194 options: &FetchOptions,
195 total_bytes: Option<u64>,
196 ) -> Result<()> {
197 let staging_file_path = workspace.path().join(
198 destination
199 .file_name()
200 .unwrap_or_else(|| std::ffi::OsStr::new("download")),
201 );
202 let mut output_file = tokio::fs::File::create(&staging_file_path)
203 .await
204 .map_err(|e| Error::Network(e.to_string()))?;
205 let mut hasher = Sha256Hasher::new();
206 let mut bytes_downloaded = 0u64;
207
208 self.report_progress(
210 options,
211 Progress {
212 phase: FetchPhase::Downloading,
213 bytes_downloaded: 0,
214 total_bytes,
215 retry_count: 0,
216 performance_metrics: None,
217 },
218 );
219
220 for segment_path in segment_files {
222 let mut segment_file = tokio::fs::File::open(segment_path)
223 .await
224 .map_err(|e| Error::Network(e.to_string()))?;
225
226 let mut buffer = vec![0u8; 65536]; loop {
228 let n = segment_file
229 .read(&mut buffer)
230 .await
231 .map_err(|e| Error::Network(e.to_string()))?;
232 if n == 0 {
233 break;
234 }
235
236 hasher.update(&buffer[..n]);
237 output_file
238 .write_all(&buffer[..n])
239 .await
240 .map_err(|e| Error::Network(e.to_string()))?;
241 bytes_downloaded += n as u64;
242
243 self.report_progress(
245 options,
246 Progress {
247 phase: FetchPhase::Downloading,
248 bytes_downloaded,
249 total_bytes,
250 retry_count: 0,
251 performance_metrics: None,
252 },
253 );
254 }
255
256 tokio::fs::remove_file(segment_path)
258 .await
259 .map_err(|e| Error::Network(e.to_string()))?;
260 }
261
262 if let Some(expected_checksum) = options.checksum {
264 self.report_progress(
265 options,
266 Progress {
267 phase: FetchPhase::Verifying,
268 bytes_downloaded,
269 total_bytes,
270 retry_count: 0,
271 performance_metrics: None,
272 },
273 );
274
275 let actual_checksum = hasher.finalize();
276 if actual_checksum != expected_checksum {
277 return Err(Error::ChecksumMismatch {
278 expected: hex::encode(expected_checksum),
279 actual: hex::encode(actual_checksum),
280 });
281 }
282 }
283
284 self.report_progress(
286 options,
287 Progress {
288 phase: FetchPhase::Committing,
289 bytes_downloaded,
290 total_bytes,
291 retry_count: 0,
292 performance_metrics: None,
293 },
294 );
295
296 tokio::fs::rename(&staging_file_path, destination)
298 .await
299 .map_err(|e| Error::Network(e.to_string()))?;
300 workspace
301 .commit()
302 .map_err(|e| Error::Network(e.to_string()))?;
303
304 self.report_progress(
305 options,
306 Progress {
307 phase: FetchPhase::Completed,
308 bytes_downloaded,
309 total_bytes,
310 retry_count: 0,
311 performance_metrics: None,
312 },
313 );
314
315 tokio::fs::rename(&staging_file_path, destination)
316 .await
317 .map_err(|e| Error::Network(e.to_string()))?;
318
319 self.report_progress(
320 options,
321 Progress {
322 phase: FetchPhase::Completed,
323 bytes_downloaded,
324 total_bytes,
325 retry_count: 0,
326 performance_metrics: None,
327 },
328 );
329
330 Ok(())
331 }
332
333 fn report_progress(&self, options: &FetchOptions, progress: Progress) {
335 if let Some(ref callback) = options.on_progress {
336 callback(&progress);
337 }
338 }
339}
340
341#[cfg(test)]
342mod tests {
343 use crate::calculate_segments;
344
345 #[test]
346 fn test_segment_calculation() {
347 let segments = calculate_segments(100, 4).unwrap();
349 assert_eq!(segments.len(), 4);
350 assert_eq!(segments[0].start, 0);
351 assert_eq!(segments[0].end, 25);
352 assert_eq!(segments[3].start, 75);
353 assert_eq!(segments[3].end, 100);
354
355 let segments = calculate_segments(10, 3).unwrap();
357 assert_eq!(segments.len(), 3);
358 assert_eq!(segments[0].end, 4); assert_eq!(segments[1].end, 7); assert_eq!(segments[2].end, 10);
361
362 let segments = calculate_segments(0, 4).unwrap();
364 assert_eq!(segments.len(), 1);
365 assert_eq!(segments[0].start, 0);
366 assert_eq!(segments[0].end, 0);
367 }
368
369 #[test]
370 fn test_segment_calculation_errors() {
371 let result = calculate_segments(100, 0);
373 assert!(result.is_err());
374 }
375}