1use std::path::Path;
5
6use anyhow::Result;
7use futures::StreamExt;
8use tokio::io::AsyncWriteExt;
9use tracing::{debug, warn};
10use xxhash_rust::xxh3::Xxh3;
11use xxhash_rust::xxh64::Xxh64;
12
13use crate::error::{SourceError, SourceResult, status_error};
14use crate::traits::{DownloadHandle, ProgressCallback, VerifiedFile};
15
16pub const MAX_RETRIES: u32 = 3;
17pub const BACKOFF_BASE_MS: u64 = 1000;
18
19pub async fn stream_to_file(
21 resp: reqwest::Response,
22 dest: &Path,
23 total_hint: u64,
24 progress: &ProgressCallback,
25) -> Result<u64> {
26 let total = resp.content_length().unwrap_or(total_hint);
27 let mut file = tokio::fs::File::create(dest).await?;
28 let mut downloaded: u64 = 0;
29
30 let mut stream = resp.bytes_stream();
31 while let Some(chunk) = stream.next().await {
32 let chunk = chunk?;
33 file.write_all(&chunk).await?;
34 downloaded += chunk.len() as u64;
35 progress(downloaded, total);
36 }
37
38 file.flush().await?;
39 Ok(downloaded)
40}
41
42pub async fn stream_to_file_verified(
45 resp: reqwest::Response,
46 dest: &Path,
47 expected_hash: u64,
48 total_hint: u64,
49 progress: &ProgressCallback,
50) -> SourceResult<VerifiedFile> {
51 let total = resp.content_length().unwrap_or(total_hint);
52 let mut file = tokio::fs::File::create(dest).await?;
53 let mut downloaded: u64 = 0;
54 let mut xxh64 = Xxh64::new(0);
55 let mut xxh3 = Xxh3::new();
56
57 let mut stream = resp.bytes_stream();
58 while let Some(chunk) = stream.next().await {
59 let chunk = chunk?;
60 xxh64.update(&chunk);
61 xxh3.update(&chunk);
62 file.write_all(&chunk).await?;
63 downloaded += chunk.len() as u64;
64 progress(downloaded, total);
65 }
66 file.flush().await?;
67
68 let h64 = xxh64.digest();
69 if h64 == expected_hash || xxh3.digest() == expected_hash {
70 return Ok(VerifiedFile {
71 path: dest.to_path_buf(),
72 hash: expected_hash,
73 });
74 }
75
76 let _ = tokio::fs::remove_file(dest).await;
77 Err(SourceError::hash_mismatch(dest, expected_hash, h64))
78}
79
80pub async fn verify_and_wrap(dest: &Path, expected_hash: u64) -> SourceResult<VerifiedFile> {
82 modde_core::hash::verify_xxhash_compat(dest, expected_hash)
83 .await
84 .map_err(|error| match error {
85 modde_core::CoreError::HashMismatch { .. } => {
86 SourceError::HashMismatch { source: error }
87 }
88 error => SourceError::other(error),
89 })?;
90 Ok(VerifiedFile {
91 path: dest.to_path_buf(),
92 hash: expected_hash,
93 })
94}
95
96pub async fn ensure_parent(dest: &Path) -> SourceResult<()> {
98 if let Some(parent) = dest.parent() {
99 tokio::fs::create_dir_all(parent).await?;
100 }
101 Ok(())
102}
103
104pub async fn with_retry<F, Fut, T>(label: &str, f: F) -> SourceResult<T>
106where
107 F: Fn() -> Fut,
108 Fut: std::future::Future<Output = SourceResult<T>>,
109{
110 for attempt in 0..MAX_RETRIES {
111 match f().await {
112 Ok(val) => return Ok(val),
113 Err(e) => {
114 if !e.is_retryable() {
115 return Err(e);
116 }
117 if attempt + 1 < MAX_RETRIES {
118 let base_delay = BACKOFF_BASE_MS * (1 << attempt);
119 let jitter = base_delay / 4
121 + (base_delay / 2).wrapping_mul(u64::from(attempt) + 1)
122 % (base_delay / 2 + 1);
123 let delay = base_delay + jitter;
124 warn!(attempt = attempt + 1, delay_ms = delay, error = %e, "{label} failed, retrying");
125 tokio::time::sleep(std::time::Duration::from_millis(delay)).await;
126 } else {
127 return Err(SourceError::other(anyhow::anyhow!(
128 "{label} failed after {MAX_RETRIES} attempts: {e:#}"
129 )));
130 }
131 }
132 }
133 }
134 unreachable!("the 0..MAX_RETRIES loop always returns Ok or the final Err")
135}
136
137pub async fn simple_download(
141 client: &reqwest::Client,
142 handle: &DownloadHandle,
143 dest: &Path,
144 progress: &ProgressCallback,
145) -> SourceResult<VerifiedFile> {
146 ensure_parent(dest).await?;
147
148 let resp = status_error(client.get(&handle.url).send().await?)?;
149
150 let verified = stream_to_file_verified(
151 resp,
152 dest,
153 handle.expected_hash,
154 handle.size_hint.unwrap_or(0),
155 progress,
156 )
157 .await?;
158 debug!("download complete");
159 Ok(verified)
160}