modde_sources/direct/
mod.rs1use std::path::Path;
2
3use anyhow::Result;
4use futures::StreamExt;
5use reqwest::Client;
6use tokio::io::AsyncWriteExt;
7use tracing::debug;
8
9use modde_core::manifest::wabbajack::DownloadDirective;
10
11use crate::common::{ensure_parent, verify_and_wrap, with_retry};
12use crate::traits::{DownloadHandle, DownloadSource, ProgressCallback, VerifiedFile};
13
14pub struct DirectSource {
16 client: Client,
17}
18
19impl DirectSource {
20 pub fn new(client: Client) -> Self {
21 Self { client }
22 }
23}
24
25impl DownloadSource for DirectSource {
26 fn can_handle(&self, directive: &DownloadDirective) -> bool {
27 matches!(directive, DownloadDirective::DirectURL { .. })
28 }
29
30 async fn resolve(&self, directive: &DownloadDirective) -> Result<DownloadHandle> {
31 let DownloadDirective::DirectURL { url, headers, hash } = directive else {
32 anyhow::bail!("not a DirectURL directive");
33 };
34
35 Ok(DownloadHandle {
36 url: url.clone(),
37 headers: headers.clone(),
38 expected_hash: *hash,
39 size_hint: None,
40 })
41 }
42
43 async fn download_with_progress(
44 &self,
45 handle: DownloadHandle,
46 dest: &Path,
47 progress: ProgressCallback,
48 ) -> Result<VerifiedFile> {
49 ensure_parent(dest).await?;
50
51 let client = self.client.clone();
52 let handle_ref = &handle;
53 let dest_ref = dest;
54 let progress_ref = &progress;
55
56 with_retry("direct download", || async {
57 download_with_resume(&client, handle_ref, dest_ref, progress_ref).await
58 })
59 .await?;
60
61 verify_and_wrap(dest, handle.expected_hash).await
62 }
63}
64
65async fn download_with_resume(
66 client: &Client,
67 handle: &DownloadHandle,
68 dest: &Path,
69 progress: &ProgressCallback,
70) -> Result<()> {
71 let existing_len = tokio::fs::metadata(dest).await.map(|m| m.len()).unwrap_or(0);
72
73 let mut req = client.get(&handle.url);
74 for (k, v) in &handle.headers {
75 req = req.header(k.as_str(), v.as_str());
76 }
77
78 if existing_len > 0 {
79 debug!(bytes = existing_len, "attempting range resume");
80 req = req.header("Range", format!("bytes={existing_len}-"));
81 }
82
83 let resp = req.send().await?.error_for_status()?;
84 let status = resp.status();
85 let total = resp.content_length().or(handle.size_hint).unwrap_or(0);
86
87 let (mut file, mut downloaded) = if status == reqwest::StatusCode::PARTIAL_CONTENT {
88 debug!("server returned 206, resuming download");
89 let file = tokio::fs::OpenOptions::new()
90 .append(true)
91 .open(dest)
92 .await?;
93 (file, existing_len)
94 } else {
95 if existing_len > 0 {
96 debug!("server returned {}, restarting download from scratch", status);
97 }
98 let file = tokio::fs::File::create(dest).await?;
99 (file, 0u64)
100 };
101
102 let total_size = if status == reqwest::StatusCode::PARTIAL_CONTENT {
103 total + existing_len
104 } else {
105 total
106 };
107
108 let mut stream = resp.bytes_stream();
109 while let Some(chunk) = stream.next().await {
110 let chunk = chunk?;
111 file.write_all(&chunk).await?;
112 downloaded += chunk.len() as u64;
113 progress(downloaded, total_size);
114 }
115
116 file.flush().await?;
117 Ok(())
118}