1use crate::error::{EngineError, NetworkErrorKind, Result, StorageErrorKind};
8use crate::storage::Segment;
9use crate::types::DownloadProgress;
10
11use bytes::Bytes;
12use futures::stream::StreamExt;
13use parking_lot::RwLock;
14use reqwest::Client;
15use std::path::PathBuf;
16use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
17use std::sync::Arc;
18use std::time::{Duration, Instant};
19use tokio::fs::{File, OpenOptions};
20use tokio::io::{AsyncSeekExt, AsyncWriteExt, SeekFrom};
21use tokio::sync::Semaphore;
22use tokio_util::sync::CancellationToken;
23
24pub const MIN_SEGMENT_SIZE: u64 = 1024 * 1024;
26
27pub const DEFAULT_CONNECTIONS: usize = 16;
29
30const PROGRESS_INTERVAL: Duration = Duration::from_millis(250);
32
33const PERSISTENCE_INTERVAL: Duration = Duration::from_secs(5);
35
36struct SharedState {
38 downloaded: AtomicU64,
40 speed: AtomicU64,
42 active_connections: AtomicU64,
44 paused: AtomicBool,
46 segment_progress: RwLock<Vec<u64>>,
48 last_persistence: RwLock<Instant>,
50}
51
52pub struct SegmentedDownload {
54 url: String,
56 total_size: u64,
58 save_path: PathBuf,
60 segments: Vec<Segment>,
62 #[allow(dead_code)]
64 supports_range: bool,
65 etag: Option<String>,
67 #[allow(dead_code)]
69 last_modified: Option<String>,
70 state: Arc<SharedState>,
72}
73
74#[derive(Debug, Clone)]
76pub struct ServerCapabilities {
77 pub content_length: Option<u64>,
79 pub supports_range: bool,
81 pub etag: Option<String>,
83 pub last_modified: Option<String>,
85 pub suggested_filename: Option<String>,
87}
88
89impl SegmentedDownload {
90 pub fn new(
92 url: String,
93 total_size: u64,
94 save_path: PathBuf,
95 supports_range: bool,
96 etag: Option<String>,
97 last_modified: Option<String>,
98 ) -> Self {
99 Self {
100 url,
101 total_size,
102 save_path,
103 segments: Vec::new(),
104 supports_range,
105 etag,
106 last_modified,
107 state: Arc::new(SharedState {
108 downloaded: AtomicU64::new(0),
109 speed: AtomicU64::new(0),
110 active_connections: AtomicU64::new(0),
111 paused: AtomicBool::new(false),
112 segment_progress: RwLock::new(Vec::new()),
113 last_persistence: RwLock::new(Instant::now()),
114 }),
115 }
116 }
117
118 pub fn init_segments(&mut self, max_connections: usize, min_segment_size: u64) {
120 let num_segments =
121 calculate_segment_count(self.total_size, max_connections, min_segment_size);
122 let segment_size = self.total_size / num_segments as u64;
123
124 let mut segments = Vec::with_capacity(num_segments);
125 for i in 0..num_segments {
126 let start = i as u64 * segment_size;
127 let end = if i == num_segments - 1 {
128 self.total_size - 1
129 } else {
130 (i as u64 + 1) * segment_size - 1
131 };
132 segments.push(Segment::new(i, start, end));
133 }
134
135 *self.state.segment_progress.write() = vec![0u64; num_segments];
137
138 self.segments = segments;
139 }
140
141 pub fn restore_segments(&mut self, saved_segments: Vec<Segment>) {
143 let downloaded: u64 = saved_segments.iter().map(|s| s.downloaded).sum();
145 self.state.downloaded.store(downloaded, Ordering::Relaxed);
146
147 let progress: Vec<u64> = saved_segments.iter().map(|s| s.downloaded).collect();
149 *self.state.segment_progress.write() = progress;
150
151 self.segments = saved_segments;
152 }
153
154 pub fn segments(&self) -> &[Segment] {
156 &self.segments
157 }
158
159 pub fn segments_with_progress(&self) -> Vec<Segment> {
163 let progress = self.state.segment_progress.read();
164 self.segments
165 .iter()
166 .enumerate()
167 .map(|(idx, s)| {
168 let mut segment = s.clone();
169 if let Some(&downloaded) = progress.get(idx) {
170 segment.downloaded = downloaded;
171 if segment.downloaded >= segment.size() {
172 segment.state = crate::storage::SegmentState::Completed;
173 } else if segment.downloaded > 0 {
174 segment.state = crate::storage::SegmentState::Downloading;
175 }
176 }
177 segment
178 })
179 .collect()
180 }
181
182 pub async fn start<F>(
184 &self,
185 client: &Client,
186 user_agent: &str,
187 headers: &[(String, String)],
188 max_connections: usize,
189 cancel_token: CancellationToken,
190 progress_callback: F,
191 ) -> Result<()>
192 where
193 F: Fn(DownloadProgress) + Send + Sync + 'static,
194 {
195 let file = self.prepare_file().await?;
197 let file = Arc::new(tokio::sync::Mutex::new(file));
198
199 let semaphore = Arc::new(Semaphore::new(max_connections));
201
202 let progress_callback = Arc::new(progress_callback);
204 let last_progress = Arc::new(RwLock::new(Instant::now()));
205 let bytes_since_progress = Arc::new(AtomicU64::new(0));
206
207 let segments_data: Vec<_> = self
209 .segments
210 .iter()
211 .enumerate()
212 .filter(|(_, s)| !s.is_complete())
213 .map(|(idx, s)| (idx, s.start, s.end, s.downloaded))
214 .collect();
215
216 let mut handles = Vec::new();
218
219 for (segment_idx, start, end, already_downloaded) in segments_data {
220 let client = client.clone();
221 let url = self.url.clone();
222 let user_agent = user_agent.to_string();
223 let headers = headers.to_vec();
224 let file = Arc::clone(&file);
225 let semaphore = Arc::clone(&semaphore);
226 let cancel_token = cancel_token.clone();
227 let etag = self.etag.clone();
228 let state = Arc::clone(&self.state);
229 let progress_callback = Arc::clone(&progress_callback);
230 let last_progress = Arc::clone(&last_progress);
231 let bytes_since_progress = Arc::clone(&bytes_since_progress);
232 let total_size = self.total_size;
233
234 let handle = tokio::spawn(async move {
235 let _permit = semaphore
237 .acquire()
238 .await
239 .map_err(|_| EngineError::Shutdown)?;
240
241 if cancel_token.is_cancelled() {
243 return Ok(());
244 }
245
246 if state.paused.load(Ordering::Relaxed) {
248 return Ok(());
249 }
250
251 state.active_connections.fetch_add(1, Ordering::Relaxed);
252
253 let resume_start = start + already_downloaded;
255 if resume_start > end {
256 state.active_connections.fetch_sub(1, Ordering::Relaxed);
258 return Ok(());
259 }
260
261 let mut request = client.get(&url);
263 request = request.header("User-Agent", &user_agent);
264 request = request.header("Range", format!("bytes={}-{}", resume_start, end));
265
266 if let Some(ref etag_val) = etag {
268 request = request.header("If-Range", etag_val);
269 }
270
271 for (name, value) in &headers {
273 request = request.header(name.as_str(), value.as_str());
274 }
275
276 let response = request.send().await.map_err(|e| {
278 EngineError::network(
279 NetworkErrorKind::Other,
280 format!("Segment {} request failed: {}", segment_idx, e),
281 )
282 })?;
283
284 let status = response.status();
285
286 if status == reqwest::StatusCode::RANGE_NOT_SATISFIABLE {
288 state.active_connections.fetch_sub(1, Ordering::Relaxed);
289 return Err(EngineError::network(
290 NetworkErrorKind::HttpStatus(416),
291 format!(
292 "Segment {} range not satisfiable (file may have changed on server)",
293 segment_idx
294 ),
295 ));
296 }
297
298 if !status.is_success() && status != reqwest::StatusCode::PARTIAL_CONTENT {
299 state.active_connections.fetch_sub(1, Ordering::Relaxed);
300 return Err(EngineError::network(
301 NetworkErrorKind::HttpStatus(status.as_u16()),
302 format!("Segment {} HTTP error: {}", segment_idx, status),
303 ));
304 }
305
306 if status == reqwest::StatusCode::PARTIAL_CONTENT {
308 if let Some(content_range) = response.headers().get("content-range") {
309 if let Ok(range_str) = content_range.to_str() {
310 if let Some(range_part) = range_str.strip_prefix("bytes ") {
312 if let Some((range, _)) = range_part.split_once('/') {
313 if let Some((start_str, end_str)) = range.split_once('-') {
314 let range_start: u64 = start_str.parse().unwrap_or(0);
315 let range_end: u64 = end_str.parse().unwrap_or(0);
316
317 if range_start != resume_start || range_end != end {
319 state
320 .active_connections
321 .fetch_sub(1, Ordering::Relaxed);
322 return Err(EngineError::network(
323 NetworkErrorKind::Other,
324 format!(
325 "Segment {} Content-Range mismatch: requested {}-{}, got {}-{}",
326 segment_idx, resume_start, end, range_start, range_end
327 ),
328 ));
329 }
330 }
331 }
332 }
333 }
334 }
335 }
336
337 let mut stream = response.bytes_stream();
339 let mut segment_bytes: u64 = already_downloaded;
340 let mut last_speed_update = Instant::now();
341 let mut bytes_for_speed: u64 = 0;
342
343 while let Some(chunk_result) = tokio::select! {
344 chunk = stream.next() => chunk,
345 _ = cancel_token.cancelled() => None,
346 } {
347 if state.paused.load(Ordering::Relaxed) {
349 break;
350 }
351
352 let chunk: Bytes = match chunk_result {
353 Ok(c) => c,
354 Err(e) => {
355 state.active_connections.fetch_sub(1, Ordering::Relaxed);
356 return Err(EngineError::network(
357 NetworkErrorKind::Other,
358 format!("Segment {} stream error: {}", segment_idx, e),
359 ));
360 }
361 };
362
363 let chunk_len = chunk.len() as u64;
364
365 {
367 let mut file = file.lock().await;
368 file.seek(SeekFrom::Start(start + segment_bytes))
369 .await
370 .map_err(|e| {
371 EngineError::storage(
372 StorageErrorKind::Io,
373 PathBuf::new(),
374 format!("Seek failed: {}", e),
375 )
376 })?;
377 file.write_all(&chunk).await.map_err(|e| {
378 EngineError::storage(
379 StorageErrorKind::Io,
380 PathBuf::new(),
381 format!("Write failed: {}", e),
382 )
383 })?;
384 }
385
386 segment_bytes += chunk_len;
387
388 {
390 let mut progress = state.segment_progress.write();
391 if let Some(p) = progress.get_mut(segment_idx) {
392 *p = segment_bytes;
393 }
394 }
395
396 state.downloaded.fetch_add(chunk_len, Ordering::Relaxed);
398 bytes_since_progress.fetch_add(chunk_len, Ordering::Relaxed);
399 bytes_for_speed += chunk_len;
400
401 let now = Instant::now();
403 let speed_elapsed = now.duration_since(last_speed_update);
404 if speed_elapsed >= Duration::from_millis(500) {
405 let current_speed =
406 (bytes_for_speed as f64 / speed_elapsed.as_secs_f64()) as u64;
407 state.speed.store(current_speed, Ordering::Relaxed);
408 bytes_for_speed = 0;
409 last_speed_update = now;
410 }
411
412 let should_emit = {
415 let mut last = last_progress.write();
416 if now.duration_since(*last) >= PROGRESS_INTERVAL {
417 *last = now;
418 bytes_since_progress.store(0, Ordering::Relaxed);
419 true
420 } else {
421 false
422 }
423 };
424
425 if should_emit {
426 let total_downloaded = state.downloaded.load(Ordering::Relaxed);
427 let current_speed = state.speed.load(Ordering::Relaxed);
428 let connections = state.active_connections.load(Ordering::Relaxed) as u32;
429
430 progress_callback(DownloadProgress {
431 total_size: Some(total_size),
432 completed_size: total_downloaded,
433 download_speed: current_speed,
434 upload_speed: 0,
435 connections,
436 seeders: 0,
437 peers: 0,
438 eta_seconds: if current_speed > 0 {
439 Some((total_size.saturating_sub(total_downloaded)) / current_speed)
440 } else {
441 None
442 },
443 });
444 }
445 }
446
447 state.active_connections.fetch_sub(1, Ordering::Relaxed);
448
449 Result::<()>::Ok(())
451 });
452
453 handles.push(handle);
454 }
455
456 let mut segment_errors: Vec<String> = Vec::new();
458 for (idx, handle) in handles.into_iter().enumerate() {
459 match handle.await {
460 Err(e) => {
461 tracing::error!("Segment {} task panicked: {:?}", idx, e);
463 segment_errors.push(format!("Segment {} panicked: {:?}", idx, e));
464 }
465 Ok(Err(e)) => {
466 tracing::error!("Segment {} failed: {:?}", idx, e);
468 segment_errors.push(format!("Segment {} failed: {}", idx, e));
469 }
470 Ok(Ok(())) => {
471 }
473 }
474 }
475
476 if !segment_errors.is_empty() {
478 return Err(EngineError::network(
479 NetworkErrorKind::Other,
480 format!(
481 "Download failed: {} segment(s) failed: {}",
482 segment_errors.len(),
483 segment_errors.join("; ")
484 ),
485 ));
486 }
487
488 {
490 let mut file = file.lock().await;
491 file.flush().await.map_err(|e| {
492 EngineError::storage(
493 StorageErrorKind::Io,
494 &self.save_path,
495 format!("Flush failed: {}", e),
496 )
497 })?;
498 file.sync_all().await.map_err(|e| {
499 EngineError::storage(
500 StorageErrorKind::Io,
501 &self.save_path,
502 format!("Sync failed: {}", e),
503 )
504 })?;
505 }
506
507 let total_downloaded = self.state.downloaded.load(Ordering::Relaxed);
509 progress_callback(DownloadProgress {
510 total_size: Some(self.total_size),
511 completed_size: total_downloaded,
512 download_speed: 0,
513 upload_speed: 0,
514 connections: 0,
515 seeders: 0,
516 peers: 0,
517 eta_seconds: None,
518 });
519
520 if total_downloaded >= self.total_size {
522 self.finalize().await?;
524 }
525
526 Ok(())
527 }
528
529 pub fn should_persist(&self) -> bool {
534 let mut last = self.state.last_persistence.write();
535 let now = Instant::now();
536 if now.duration_since(*last) >= PERSISTENCE_INTERVAL {
537 *last = now;
538 true
539 } else {
540 false
541 }
542 }
543
544 pub fn mark_persisted(&self) {
546 *self.state.last_persistence.write() = Instant::now();
547 }
548
549 async fn prepare_file(&self) -> Result<File> {
551 let part_path = self.part_path();
553
554 if let Some(parent) = part_path.parent() {
556 tokio::fs::create_dir_all(parent).await.map_err(|e| {
557 EngineError::storage(
558 StorageErrorKind::Io,
559 parent,
560 format!("Create dir failed: {}", e),
561 )
562 })?;
563 }
564
565 let file = if part_path.exists() {
567 OpenOptions::new()
568 .write(true)
569 .read(true)
570 .open(&part_path)
571 .await
572 .map_err(|e| {
573 EngineError::storage(
574 StorageErrorKind::Io,
575 &part_path,
576 format!("Open failed: {}", e),
577 )
578 })?
579 } else {
580 let file = File::create(&part_path).await.map_err(|e| {
582 EngineError::storage(
583 StorageErrorKind::Io,
584 &part_path,
585 format!("Create failed: {}", e),
586 )
587 })?;
588
589 file.set_len(self.total_size).await.map_err(|e| {
591 EngineError::storage(
592 StorageErrorKind::Io,
593 &part_path,
594 format!("Pre-allocate failed: {}", e),
595 )
596 })?;
597
598 file
599 };
600
601 Ok(file)
602 }
603
604 fn part_path(&self) -> PathBuf {
606 let ext = self
607 .save_path
608 .extension()
609 .map(|e| format!("{}.part", e.to_string_lossy()))
610 .unwrap_or_else(|| "part".to_string());
611 self.save_path.with_extension(ext)
612 }
613
614 async fn finalize(&self) -> Result<()> {
616 let part_path = self.part_path();
617 if part_path.exists() {
618 tokio::fs::rename(&part_path, &self.save_path)
619 .await
620 .map_err(|e| {
621 EngineError::storage(
622 StorageErrorKind::Io,
623 &self.save_path,
624 format!("Rename failed: {}", e),
625 )
626 })?;
627 }
628 Ok(())
629 }
630
631 pub fn pause(&self) {
633 self.state.paused.store(true, Ordering::Relaxed);
634 }
635
636 pub fn is_complete(&self) -> bool {
638 self.state.downloaded.load(Ordering::Relaxed) >= self.total_size
639 }
640
641 pub fn progress(&self) -> DownloadProgress {
643 DownloadProgress {
644 total_size: Some(self.total_size),
645 completed_size: self.state.downloaded.load(Ordering::Relaxed),
646 download_speed: self.state.speed.load(Ordering::Relaxed),
647 upload_speed: 0,
648 connections: self.state.active_connections.load(Ordering::Relaxed) as u32,
649 seeders: 0,
650 peers: 0,
651 eta_seconds: {
652 let speed = self.state.speed.load(Ordering::Relaxed);
653 let remaining = self
654 .total_size
655 .saturating_sub(self.state.downloaded.load(Ordering::Relaxed));
656 if speed > 0 {
657 Some(remaining / speed)
658 } else {
659 None
660 }
661 },
662 }
663 }
664}
665
666pub fn calculate_segment_count(
668 total_size: u64,
669 max_connections: usize,
670 min_segment_size: u64,
671) -> usize {
672 if total_size == 0 {
673 return 1;
674 }
675
676 let max_segments_by_size = (total_size / min_segment_size) as usize;
678
679 let num_segments = max_connections.min(max_segments_by_size.max(1));
681
682 num_segments.max(1)
684}
685
686pub async fn probe_server(
688 client: &Client,
689 url: &str,
690 user_agent: &str,
691) -> Result<ServerCapabilities> {
692 let response = client
693 .head(url)
694 .header("User-Agent", user_agent)
695 .send()
696 .await
697 .map_err(|e| {
698 EngineError::network(
699 NetworkErrorKind::Other,
700 format!("HEAD request failed: {}", e),
701 )
702 })?;
703
704 if !response.status().is_success() {
705 return Err(EngineError::network(
706 NetworkErrorKind::HttpStatus(response.status().as_u16()),
707 format!("HEAD request returned: {}", response.status()),
708 ));
709 }
710
711 let headers = response.headers();
712
713 let content_length = headers
714 .get("content-length")
715 .and_then(|v| v.to_str().ok())
716 .and_then(|s| s.parse::<u64>().ok());
717
718 let supports_range = headers
719 .get("accept-ranges")
720 .and_then(|v| v.to_str().ok())
721 .map(|v| v.contains("bytes"))
722 .unwrap_or(false);
723
724 let etag = headers
725 .get("etag")
726 .and_then(|v| v.to_str().ok())
727 .map(|s| s.to_string());
728
729 let last_modified = headers
730 .get("last-modified")
731 .and_then(|v| v.to_str().ok())
732 .map(|s| s.to_string());
733
734 let suggested_filename = headers
735 .get("content-disposition")
736 .and_then(|v| v.to_str().ok())
737 .and_then(parse_content_disposition);
738
739 Ok(ServerCapabilities {
740 content_length,
741 supports_range,
742 etag,
743 last_modified,
744 suggested_filename,
745 })
746}
747
748fn parse_content_disposition(header: &str) -> Option<String> {
750 if let Some(start) = header.find("filename=") {
752 let rest = &header[start + 9..];
753 if let Some(stripped) = rest.strip_prefix('"') {
754 let end = stripped.find('"')?;
755 return Some(stripped[..end].to_string());
756 } else {
757 let end = rest.find(';').unwrap_or(rest.len());
758 return Some(rest[..end].trim().to_string());
759 }
760 }
761
762 if let Some(start) = header.find("filename*=") {
763 let rest = &header[start + 10..];
764 if let Some(quote_start) = rest.find("''") {
765 let encoded = &rest[quote_start + 2..];
766 let end = encoded.find(';').unwrap_or(encoded.len());
767 if let Ok(decoded) = urlencoding::decode(&encoded[..end]) {
768 return Some(decoded.to_string());
769 }
770 }
771 }
772
773 None
774}
775
776#[cfg(test)]
777mod tests {
778 use super::*;
779
780 #[test]
781 fn test_calculate_segment_count() {
782 assert_eq!(
784 calculate_segment_count(100 * 1024 * 1024, 16, 1024 * 1024),
785 16
786 );
787
788 assert_eq!(
790 calculate_segment_count(10 * 1024 * 1024, 16, 1024 * 1024),
791 10
792 );
793
794 assert_eq!(calculate_segment_count(512 * 1024, 16, 1024 * 1024), 1);
796
797 assert_eq!(calculate_segment_count(0, 16, 1024 * 1024), 1);
799
800 assert_eq!(
802 calculate_segment_count(10 * 1024 * 1024 * 1024, 16, 1024 * 1024),
803 16
804 );
805 }
806
807 #[test]
808 fn test_segment_init() {
809 let mut download = SegmentedDownload::new(
810 "https://example.com/file.zip".to_string(),
811 100 * 1024 * 1024, PathBuf::from("/tmp/file.zip"),
813 true,
814 None,
815 None,
816 );
817
818 download.init_segments(16, 1024 * 1024);
819
820 let segments = download.segments();
821 assert_eq!(segments.len(), 16);
822
823 assert_eq!(segments[0].start, 0);
825 assert_eq!(segments[15].end, 100 * 1024 * 1024 - 1);
826
827 for i in 0..15 {
829 assert_eq!(segments[i].end + 1, segments[i + 1].start);
830 }
831 }
832
833 #[test]
834 fn test_parse_content_disposition() {
835 assert_eq!(
836 parse_content_disposition("attachment; filename=\"test.zip\""),
837 Some("test.zip".to_string())
838 );
839
840 assert_eq!(
841 parse_content_disposition("attachment; filename=test.zip"),
842 Some("test.zip".to_string())
843 );
844
845 assert_eq!(
846 parse_content_disposition("attachment; filename*=UTF-8''test%20file.zip"),
847 Some("test file.zip".to_string())
848 );
849 }
850}