modde_sources/direct/
mod.rs1use std::path::Path;
5
6use futures::StreamExt;
7use reqwest::Client;
8use tokio::io::{AsyncReadExt as _, AsyncWriteExt};
9use tracing::{debug, info};
10use xxhash_rust::xxh3::Xxh3;
11use xxhash_rust::xxh64::Xxh64;
12
13use modde_core::manifest::wabbajack::DownloadDirective;
14
15use crate::common::{ensure_parent, with_retry};
16use crate::error::{SourceError, SourceResult, status_error};
17use crate::mirror::resolve_html_mirrors;
18use crate::traits::{DownloadHandle, DownloadSource, ProgressCallback, VerifiedFile};
19
20pub struct DirectSource {
22 client: Client,
23}
24
25impl DirectSource {
26 #[must_use]
28 pub fn new(client: Client) -> Self {
29 Self { client }
30 }
31}
32
33impl DownloadSource for DirectSource {
34 fn can_handle(&self, directive: &DownloadDirective) -> bool {
35 matches!(directive, DownloadDirective::DirectURL { .. })
36 }
37
38 async fn resolve(&self, directive: &DownloadDirective) -> SourceResult<DownloadHandle> {
39 let DownloadDirective::DirectURL {
40 url,
41 headers,
42 mirror_resolver,
43 hash,
44 } = directive
45 else {
46 return Err(SourceError::other(anyhow::anyhow!(
47 "not a DirectURL directive"
48 )));
49 };
50
51 let candidate_urls = if let Some(resolver) = mirror_resolver {
52 resolve_html_mirrors(&self.client, resolver)
53 .await
54 .map_err(SourceError::other)?
55 } else {
56 Vec::new()
57 };
58 let mut headers = headers.clone();
59 if let Some(resolver) = mirror_resolver
60 && let Some(user_agent) = &resolver.user_agent
61 {
62 headers
63 .entry("User-Agent".to_string())
64 .or_insert_with(|| user_agent.clone());
65 }
66
67 Ok(DownloadHandle {
68 url: url.clone(),
69 candidate_urls,
70 headers,
71 expected_hash: *hash,
72 size_hint: None,
73 })
74 }
75
76 async fn download_with_progress(
77 &self,
78 handle: DownloadHandle,
79 dest: &Path,
80 progress: ProgressCallback,
81 ) -> SourceResult<VerifiedFile> {
82 ensure_parent(dest).await?;
83
84 let client = self.client.clone();
85 let dest_ref = dest;
86 let progress_ref = &progress;
87 let candidates = if handle.candidate_urls.is_empty() {
88 vec![handle.url.clone()]
89 } else {
90 handle.candidate_urls.clone()
91 };
92
93 let mut errors = Vec::new();
94 let mut last_error = None;
95 for (idx, url) in candidates.iter().enumerate() {
96 let mut candidate = handle.clone();
97 candidate.url = url.clone();
98 let result = with_retry("direct download", || async {
99 download_with_resume(&client, &candidate, dest_ref, progress_ref).await
100 })
101 .await;
102 match result {
103 Ok(verified) => {
104 if idx > 0 {
105 info!(url = %candidate.url, "direct download mirror succeeded");
106 }
107 return Ok(verified);
108 }
109 Err(e) => {
110 errors.push(format!("{}: {e:#}", candidate.url));
111 last_error = Some(e);
112 let _ = tokio::fs::remove_file(dest).await;
113 }
114 }
115 }
116
117 if candidates.len() == 1
118 && let Some(error) = last_error
119 {
120 return Err(error);
121 }
122
123 Err(SourceError::other(anyhow::anyhow!(
124 "all {} direct download candidate(s) failed:\n{}",
125 candidates.len(),
126 errors
127 .iter()
128 .map(|error| format!(" - {error}"))
129 .collect::<Vec<_>>()
130 .join("\n")
131 )))
132 }
133}
134
135async fn download_with_resume(
136 client: &Client,
137 handle: &DownloadHandle,
138 dest: &Path,
139 progress: &ProgressCallback,
140) -> SourceResult<VerifiedFile> {
141 let existing_len = tokio::fs::metadata(dest).await.map_or(0, |m| m.len());
142
143 let mut req = client.get(&handle.url);
144 for (k, v) in &handle.headers {
145 req = req.header(k.as_str(), v.as_str());
146 }
147
148 if existing_len > 0 {
149 debug!(bytes = existing_len, "attempting range resume");
150 req = req.header("Range", format!("bytes={existing_len}-"));
151 }
152
153 let resp = status_error(req.send().await?)?;
154 let status = resp.status();
155 let total = resp.content_length().or(handle.size_hint).unwrap_or(0);
156
157 let (mut file, mut downloaded) = if status == reqwest::StatusCode::PARTIAL_CONTENT {
158 debug!("server returned 206, resuming download");
159 let file = tokio::fs::OpenOptions::new()
160 .append(true)
161 .open(dest)
162 .await?;
163 (file, existing_len)
164 } else {
165 if existing_len > 0 {
166 debug!(
167 "server returned {}, restarting download from scratch",
168 status
169 );
170 }
171 let file = tokio::fs::File::create(dest).await?;
172 (file, 0u64)
173 };
174
175 let total_size = if status == reqwest::StatusCode::PARTIAL_CONTENT {
176 total + existing_len
177 } else {
178 total
179 };
180
181 let mut xxh64 = Xxh64::new(0);
182 let mut xxh3 = Xxh3::new();
183 if status == reqwest::StatusCode::PARTIAL_CONTENT && existing_len > 0 {
184 let mut existing = tokio::fs::File::open(dest).await?;
185 let mut buf = vec![0_u8; 1024 * 1024];
186 loop {
187 let read = existing.read(&mut buf).await?;
188 if read == 0 {
189 break;
190 }
191 xxh64.update(&buf[..read]);
192 xxh3.update(&buf[..read]);
193 }
194 }
195
196 let mut stream = resp.bytes_stream();
197 while let Some(chunk) = stream.next().await {
198 let chunk = chunk?;
199 xxh64.update(&chunk);
200 xxh3.update(&chunk);
201 file.write_all(&chunk).await?;
202 downloaded += chunk.len() as u64;
203 progress(downloaded, total_size);
204 }
205
206 file.flush().await?;
207 let h64 = xxh64.digest();
208 if h64 == handle.expected_hash || xxh3.digest() == handle.expected_hash {
209 return Ok(VerifiedFile {
210 path: dest.to_path_buf(),
211 hash: handle.expected_hash,
212 });
213 }
214
215 let _ = tokio::fs::remove_file(dest).await;
216 Err(SourceError::hash_mismatch(dest, handle.expected_hash, h64))
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222 use std::collections::HashMap;
223 use std::io::{Read as _, Write as _};
224 use std::net::TcpListener;
225 use std::sync::{
226 Arc,
227 atomic::{AtomicUsize, Ordering},
228 };
229 use std::thread;
230 use xxhash_rust::xxh64::xxh64;
231
232 #[tokio::test]
233 async fn direct_download_tries_candidate_urls_in_order() {
234 let body = b"mirror payload";
235 let expected_hash = xxh64(body, 0);
236 let (base_url, requests) = start_fallback_server(body);
237 let handle = DownloadHandle {
238 url: format!("{base_url}/original"),
239 candidate_urls: vec![format!("{base_url}/bad"), format!("{base_url}/good")],
240 headers: HashMap::new(),
241 expected_hash,
242 size_hint: None,
243 };
244 let temp = tempfile::tempdir().unwrap();
245 let dest = temp.path().join("download.archive");
246 let source = DirectSource::new(Client::new());
247
248 let verified = source.download(handle, &dest).await.unwrap();
249 assert_eq!(verified.hash, expected_hash);
250 assert_eq!(tokio::fs::read(&verified.path).await.unwrap(), body);
251 assert!(requests.load(Ordering::SeqCst) >= 2);
252 }
253
254 fn start_fallback_server(body: &'static [u8]) -> (String, Arc<AtomicUsize>) {
255 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
256 let addr = listener.local_addr().unwrap();
257 let requests = Arc::new(AtomicUsize::new(0));
258 let request_count = Arc::clone(&requests);
259 thread::spawn(move || {
260 for stream in listener.incoming().take(8) {
261 let mut stream = stream.unwrap();
262 let mut buf = [0_u8; 1024];
263 let n = stream.read(&mut buf).unwrap();
264 let request = String::from_utf8_lossy(&buf[..n]);
265 let path = request
266 .lines()
267 .next()
268 .and_then(|line| line.split_whitespace().nth(1))
269 .unwrap_or("/");
270 request_count.fetch_add(1, Ordering::SeqCst);
271 match path {
272 "/bad" => write_response(&mut stream, "500 Internal Server Error", b"bad"),
273 "/good" => write_response(&mut stream, "200 OK", body),
274 _ => write_response(&mut stream, "404 Not Found", b"missing"),
275 }
276 }
277 });
278 (format!("http://{addr}"), requests)
279 }
280
281 fn write_response(stream: &mut std::net::TcpStream, status: &str, body: &[u8]) {
282 let headers = format!(
283 "HTTP/1.1 {status}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n",
284 body.len()
285 );
286 stream.write_all(headers.as_bytes()).unwrap();
287 stream.write_all(body).unwrap();
288 }
289}