Skip to main content

modde_sources/direct/
mod.rs

1use std::path::Path;
2
3use anyhow::Result;
4use futures::StreamExt;
5use reqwest::Client;
6use tokio::io::AsyncWriteExt;
7use tracing::debug;
8
9use modde_core::manifest::wabbajack::DownloadDirective;
10
11use crate::common::{ensure_parent, verify_and_wrap, with_retry};
12use crate::traits::{DownloadHandle, DownloadSource, ProgressCallback, VerifiedFile};
13
14/// Plain HTTPS download source with Range header support for resume.
15pub struct DirectSource {
16    client: Client,
17}
18
19impl DirectSource {
20    pub fn new(client: Client) -> Self {
21        Self { client }
22    }
23}
24
25impl DownloadSource for DirectSource {
26    fn can_handle(&self, directive: &DownloadDirective) -> bool {
27        matches!(directive, DownloadDirective::DirectURL { .. })
28    }
29
30    async fn resolve(&self, directive: &DownloadDirective) -> Result<DownloadHandle> {
31        let DownloadDirective::DirectURL { url, headers, hash } = directive else {
32            anyhow::bail!("not a DirectURL directive");
33        };
34
35        Ok(DownloadHandle {
36            url: url.clone(),
37            headers: headers.clone(),
38            expected_hash: *hash,
39            size_hint: None,
40        })
41    }
42
43    async fn download_with_progress(
44        &self,
45        handle: DownloadHandle,
46        dest: &Path,
47        progress: ProgressCallback,
48    ) -> Result<VerifiedFile> {
49        ensure_parent(dest).await?;
50
51        let client = self.client.clone();
52        let handle_ref = &handle;
53        let dest_ref = dest;
54        let progress_ref = &progress;
55
56        with_retry("direct download", || async {
57            download_with_resume(&client, handle_ref, dest_ref, progress_ref).await
58        })
59        .await?;
60
61        verify_and_wrap(dest, handle.expected_hash).await
62    }
63}
64
65async fn download_with_resume(
66    client: &Client,
67    handle: &DownloadHandle,
68    dest: &Path,
69    progress: &ProgressCallback,
70) -> Result<()> {
71    let existing_len = tokio::fs::metadata(dest).await.map(|m| m.len()).unwrap_or(0);
72
73    let mut req = client.get(&handle.url);
74    for (k, v) in &handle.headers {
75        req = req.header(k.as_str(), v.as_str());
76    }
77
78    if existing_len > 0 {
79        debug!(bytes = existing_len, "attempting range resume");
80        req = req.header("Range", format!("bytes={existing_len}-"));
81    }
82
83    let resp = req.send().await?.error_for_status()?;
84    let status = resp.status();
85    let total = resp.content_length().or(handle.size_hint).unwrap_or(0);
86
87    let (mut file, mut downloaded) = if status == reqwest::StatusCode::PARTIAL_CONTENT {
88        debug!("server returned 206, resuming download");
89        let file = tokio::fs::OpenOptions::new()
90            .append(true)
91            .open(dest)
92            .await?;
93        (file, existing_len)
94    } else {
95        if existing_len > 0 {
96            debug!("server returned {}, restarting download from scratch", status);
97        }
98        let file = tokio::fs::File::create(dest).await?;
99        (file, 0u64)
100    };
101
102    let total_size = if status == reqwest::StatusCode::PARTIAL_CONTENT {
103        total + existing_len
104    } else {
105        total
106    };
107
108    let mut stream = resp.bytes_stream();
109    while let Some(chunk) = stream.next().await {
110        let chunk = chunk?;
111        file.write_all(&chunk).await?;
112        downloaded += chunk.len() as u64;
113        progress(downloaded, total_size);
114    }
115
116    file.flush().await?;
117    Ok(())
118}