Skip to main content

modde_sources/direct/
mod.rs

1//! Plain HTTPS download source with `Range`-based resume and HTML mirror
2//! resolution.
3
4use std::path::Path;
5
6use futures::StreamExt;
7use reqwest::Client;
8use tokio::io::{AsyncReadExt as _, AsyncWriteExt};
9use tracing::{debug, info};
10use xxhash_rust::xxh3::Xxh3;
11use xxhash_rust::xxh64::Xxh64;
12
13use modde_core::manifest::wabbajack::DownloadDirective;
14
15use crate::common::{ensure_parent, with_retry};
16use crate::error::{SourceError, SourceResult, status_error};
17use crate::mirror::resolve_html_mirrors;
18use crate::traits::{DownloadHandle, DownloadSource, ProgressCallback, VerifiedFile};
19
20/// Plain HTTPS download source with Range header support for resume.
21pub struct DirectSource {
22    client: Client,
23}
24
25impl DirectSource {
26    /// Create a source that downloads over the given HTTP `client`.
27    #[must_use]
28    pub fn new(client: Client) -> Self {
29        Self { client }
30    }
31}
32
33impl DownloadSource for DirectSource {
34    fn can_handle(&self, directive: &DownloadDirective) -> bool {
35        matches!(directive, DownloadDirective::DirectURL { .. })
36    }
37
38    async fn resolve(&self, directive: &DownloadDirective) -> SourceResult<DownloadHandle> {
39        let DownloadDirective::DirectURL {
40            url,
41            headers,
42            mirror_resolver,
43            hash,
44        } = directive
45        else {
46            return Err(SourceError::other(anyhow::anyhow!(
47                "not a DirectURL directive"
48            )));
49        };
50
51        let candidate_urls = if let Some(resolver) = mirror_resolver {
52            resolve_html_mirrors(&self.client, resolver)
53                .await
54                .map_err(SourceError::other)?
55        } else {
56            Vec::new()
57        };
58        let mut headers = headers.clone();
59        if let Some(resolver) = mirror_resolver
60            && let Some(user_agent) = &resolver.user_agent
61        {
62            headers
63                .entry("User-Agent".to_string())
64                .or_insert_with(|| user_agent.clone());
65        }
66
67        Ok(DownloadHandle {
68            url: url.clone(),
69            candidate_urls,
70            headers,
71            expected_hash: *hash,
72            size_hint: None,
73        })
74    }
75
76    async fn download_with_progress(
77        &self,
78        handle: DownloadHandle,
79        dest: &Path,
80        progress: ProgressCallback,
81    ) -> SourceResult<VerifiedFile> {
82        ensure_parent(dest).await?;
83
84        let client = self.client.clone();
85        let dest_ref = dest;
86        let progress_ref = &progress;
87        let candidates = if handle.candidate_urls.is_empty() {
88            vec![handle.url.clone()]
89        } else {
90            handle.candidate_urls.clone()
91        };
92
93        let mut errors = Vec::new();
94        let mut last_error = None;
95        for (idx, url) in candidates.iter().enumerate() {
96            let mut candidate = handle.clone();
97            candidate.url = url.clone();
98            let result = with_retry("direct download", || async {
99                download_with_resume(&client, &candidate, dest_ref, progress_ref).await
100            })
101            .await;
102            match result {
103                Ok(verified) => {
104                    if idx > 0 {
105                        info!(url = %candidate.url, "direct download mirror succeeded");
106                    }
107                    return Ok(verified);
108                }
109                Err(e) => {
110                    errors.push(format!("{}: {e:#}", candidate.url));
111                    last_error = Some(e);
112                    let _ = tokio::fs::remove_file(dest).await;
113                }
114            }
115        }
116
117        if candidates.len() == 1
118            && let Some(error) = last_error
119        {
120            return Err(error);
121        }
122
123        Err(SourceError::other(anyhow::anyhow!(
124            "all {} direct download candidate(s) failed:\n{}",
125            candidates.len(),
126            errors
127                .iter()
128                .map(|error| format!("  - {error}"))
129                .collect::<Vec<_>>()
130                .join("\n")
131        )))
132    }
133}
134
135async fn download_with_resume(
136    client: &Client,
137    handle: &DownloadHandle,
138    dest: &Path,
139    progress: &ProgressCallback,
140) -> SourceResult<VerifiedFile> {
141    let existing_len = tokio::fs::metadata(dest).await.map_or(0, |m| m.len());
142
143    let mut req = client.get(&handle.url);
144    for (k, v) in &handle.headers {
145        req = req.header(k.as_str(), v.as_str());
146    }
147
148    if existing_len > 0 {
149        debug!(bytes = existing_len, "attempting range resume");
150        req = req.header("Range", format!("bytes={existing_len}-"));
151    }
152
153    let resp = status_error(req.send().await?)?;
154    let status = resp.status();
155    let total = resp.content_length().or(handle.size_hint).unwrap_or(0);
156
157    let (mut file, mut downloaded) = if status == reqwest::StatusCode::PARTIAL_CONTENT {
158        debug!("server returned 206, resuming download");
159        let file = tokio::fs::OpenOptions::new()
160            .append(true)
161            .open(dest)
162            .await?;
163        (file, existing_len)
164    } else {
165        if existing_len > 0 {
166            debug!(
167                "server returned {}, restarting download from scratch",
168                status
169            );
170        }
171        let file = tokio::fs::File::create(dest).await?;
172        (file, 0u64)
173    };
174
175    let total_size = if status == reqwest::StatusCode::PARTIAL_CONTENT {
176        total + existing_len
177    } else {
178        total
179    };
180
181    let mut xxh64 = Xxh64::new(0);
182    let mut xxh3 = Xxh3::new();
183    if status == reqwest::StatusCode::PARTIAL_CONTENT && existing_len > 0 {
184        let mut existing = tokio::fs::File::open(dest).await?;
185        let mut buf = vec![0_u8; 1024 * 1024];
186        loop {
187            let read = existing.read(&mut buf).await?;
188            if read == 0 {
189                break;
190            }
191            xxh64.update(&buf[..read]);
192            xxh3.update(&buf[..read]);
193        }
194    }
195
196    let mut stream = resp.bytes_stream();
197    while let Some(chunk) = stream.next().await {
198        let chunk = chunk?;
199        xxh64.update(&chunk);
200        xxh3.update(&chunk);
201        file.write_all(&chunk).await?;
202        downloaded += chunk.len() as u64;
203        progress(downloaded, total_size);
204    }
205
206    file.flush().await?;
207    let h64 = xxh64.digest();
208    if h64 == handle.expected_hash || xxh3.digest() == handle.expected_hash {
209        return Ok(VerifiedFile {
210            path: dest.to_path_buf(),
211            hash: handle.expected_hash,
212        });
213    }
214
215    let _ = tokio::fs::remove_file(dest).await;
216    Err(SourceError::hash_mismatch(dest, handle.expected_hash, h64))
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use std::collections::HashMap;
223    use std::io::{Read as _, Write as _};
224    use std::net::TcpListener;
225    use std::sync::{
226        Arc,
227        atomic::{AtomicUsize, Ordering},
228    };
229    use std::thread;
230    use xxhash_rust::xxh64::xxh64;
231
232    #[tokio::test]
233    async fn direct_download_tries_candidate_urls_in_order() {
234        let body = b"mirror payload";
235        let expected_hash = xxh64(body, 0);
236        let (base_url, requests) = start_fallback_server(body);
237        let handle = DownloadHandle {
238            url: format!("{base_url}/original"),
239            candidate_urls: vec![format!("{base_url}/bad"), format!("{base_url}/good")],
240            headers: HashMap::new(),
241            expected_hash,
242            size_hint: None,
243        };
244        let temp = tempfile::tempdir().unwrap();
245        let dest = temp.path().join("download.archive");
246        let source = DirectSource::new(Client::new());
247
248        let verified = source.download(handle, &dest).await.unwrap();
249        assert_eq!(verified.hash, expected_hash);
250        assert_eq!(tokio::fs::read(&verified.path).await.unwrap(), body);
251        assert!(requests.load(Ordering::SeqCst) >= 2);
252    }
253
254    fn start_fallback_server(body: &'static [u8]) -> (String, Arc<AtomicUsize>) {
255        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
256        let addr = listener.local_addr().unwrap();
257        let requests = Arc::new(AtomicUsize::new(0));
258        let request_count = Arc::clone(&requests);
259        thread::spawn(move || {
260            for stream in listener.incoming().take(8) {
261                let mut stream = stream.unwrap();
262                let mut buf = [0_u8; 1024];
263                let n = stream.read(&mut buf).unwrap();
264                let request = String::from_utf8_lossy(&buf[..n]);
265                let path = request
266                    .lines()
267                    .next()
268                    .and_then(|line| line.split_whitespace().nth(1))
269                    .unwrap_or("/");
270                request_count.fetch_add(1, Ordering::SeqCst);
271                match path {
272                    "/bad" => write_response(&mut stream, "500 Internal Server Error", b"bad"),
273                    "/good" => write_response(&mut stream, "200 OK", body),
274                    _ => write_response(&mut stream, "404 Not Found", b"missing"),
275                }
276            }
277        });
278        (format!("http://{addr}"), requests)
279    }
280
281    fn write_response(stream: &mut std::net::TcpStream, status: &str, body: &[u8]) {
282        let headers = format!(
283            "HTTP/1.1 {status}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n",
284            body.len()
285        );
286        stream.write_all(headers.as_bytes()).unwrap();
287        stream.write_all(body).unwrap();
288    }
289}