Skip to main content

modde_sources/
common.rs

1//! Shared download helpers reused across sources: streaming a response body to
2//! disk with hash verification, retry/backoff, and parent-directory setup.
3
4use std::path::Path;
5
6use anyhow::Result;
7use futures::StreamExt;
8use tokio::io::AsyncWriteExt;
9use tracing::{debug, warn};
10use xxhash_rust::xxh3::Xxh3;
11use xxhash_rust::xxh64::Xxh64;
12
13use crate::error::{SourceError, SourceResult, status_error};
14use crate::traits::{DownloadHandle, ProgressCallback, VerifiedFile};
15
16pub const MAX_RETRIES: u32 = 3;
17pub const BACKOFF_BASE_MS: u64 = 1000;
18
19/// Stream a response body to a file, reporting progress.
20pub async fn stream_to_file(
21    resp: reqwest::Response,
22    dest: &Path,
23    total_hint: u64,
24    progress: &ProgressCallback,
25) -> Result<u64> {
26    let total = resp.content_length().unwrap_or(total_hint);
27    let mut file = tokio::fs::File::create(dest).await?;
28    let mut downloaded: u64 = 0;
29
30    let mut stream = resp.bytes_stream();
31    while let Some(chunk) = stream.next().await {
32        let chunk = chunk?;
33        file.write_all(&chunk).await?;
34        downloaded += chunk.len() as u64;
35        progress(downloaded, total);
36    }
37
38    file.flush().await?;
39    Ok(downloaded)
40}
41
42/// Stream a response body to a file and verify Wabbajack-compatible hashes
43/// while bytes are still hot in memory.
44pub async fn stream_to_file_verified(
45    resp: reqwest::Response,
46    dest: &Path,
47    expected_hash: u64,
48    total_hint: u64,
49    progress: &ProgressCallback,
50) -> SourceResult<VerifiedFile> {
51    let total = resp.content_length().unwrap_or(total_hint);
52    let mut file = tokio::fs::File::create(dest).await?;
53    let mut downloaded: u64 = 0;
54    let mut xxh64 = Xxh64::new(0);
55    let mut xxh3 = Xxh3::new();
56
57    let mut stream = resp.bytes_stream();
58    while let Some(chunk) = stream.next().await {
59        let chunk = chunk?;
60        xxh64.update(&chunk);
61        xxh3.update(&chunk);
62        file.write_all(&chunk).await?;
63        downloaded += chunk.len() as u64;
64        progress(downloaded, total);
65    }
66    file.flush().await?;
67
68    let h64 = xxh64.digest();
69    if h64 == expected_hash || xxh3.digest() == expected_hash {
70        return Ok(VerifiedFile {
71            path: dest.to_path_buf(),
72            hash: expected_hash,
73        });
74    }
75
76    let _ = tokio::fs::remove_file(dest).await;
77    Err(SourceError::hash_mismatch(dest, expected_hash, h64))
78}
79
80/// Verify hash (xxh64 then xxh3 fallback) and return a `VerifiedFile`.
81pub async fn verify_and_wrap(dest: &Path, expected_hash: u64) -> SourceResult<VerifiedFile> {
82    modde_core::hash::verify_xxhash_compat(dest, expected_hash)
83        .await
84        .map_err(|error| match error {
85            modde_core::CoreError::HashMismatch { .. } => {
86                SourceError::HashMismatch { source: error }
87            }
88            error => SourceError::other(error),
89        })?;
90    Ok(VerifiedFile {
91        path: dest.to_path_buf(),
92        hash: expected_hash,
93    })
94}
95
96/// Ensure parent directory exists.
97pub async fn ensure_parent(dest: &Path) -> SourceResult<()> {
98    if let Some(parent) = dest.parent() {
99        tokio::fs::create_dir_all(parent).await?;
100    }
101    Ok(())
102}
103
104/// Execute an async operation with exponential backoff retries.
105pub async fn with_retry<F, Fut, T>(label: &str, f: F) -> SourceResult<T>
106where
107    F: Fn() -> Fut,
108    Fut: std::future::Future<Output = SourceResult<T>>,
109{
110    for attempt in 0..MAX_RETRIES {
111        match f().await {
112            Ok(val) => return Ok(val),
113            Err(e) => {
114                if !e.is_retryable() {
115                    return Err(e);
116                }
117                if attempt + 1 < MAX_RETRIES {
118                    let base_delay = BACKOFF_BASE_MS * (1 << attempt);
119                    // Add deterministic jitter: up to 50% of base delay, derived from attempt index
120                    let jitter = base_delay / 4
121                        + (base_delay / 2).wrapping_mul(u64::from(attempt) + 1)
122                            % (base_delay / 2 + 1);
123                    let delay = base_delay + jitter;
124                    warn!(attempt = attempt + 1, delay_ms = delay, error = %e, "{label} failed, retrying");
125                    tokio::time::sleep(std::time::Duration::from_millis(delay)).await;
126                } else {
127                    return Err(SourceError::other(anyhow::anyhow!(
128                        "{label} failed after {MAX_RETRIES} attempts: {e:#}"
129                    )));
130                }
131            }
132        }
133    }
134    unreachable!("the 0..MAX_RETRIES loop always returns Ok or the final Err")
135}
136
137/// Simple download: GET, stream to file, verify hash, return `VerifiedFile`.
138///
139/// Handles `ensure_parent` internally — callers should not call it separately.
140pub async fn simple_download(
141    client: &reqwest::Client,
142    handle: &DownloadHandle,
143    dest: &Path,
144    progress: &ProgressCallback,
145) -> SourceResult<VerifiedFile> {
146    ensure_parent(dest).await?;
147
148    let resp = status_error(client.get(&handle.url).send().await?)?;
149
150    let verified = stream_to_file_verified(
151        resp,
152        dest,
153        handle.expected_hash,
154        handle.size_hint.unwrap_or(0),
155        progress,
156    )
157    .await?;
158    debug!("download complete");
159    Ok(verified)
160}