1use reqwest::Url;
2use std::path::Path;
3use std::time::Instant;
4use tokio::io::AsyncWriteExt as _;
5
6use crate::config::normalize_romm_origin;
7use crate::core::interrupt::CancelledByUser;
8use crate::error::{ApiError, DownloadError};
9
10use super::response::{api_error_from_response, read_error_response_text};
11use super::RommClient;
12
13impl RommClient {
14 pub async fn download_rom<F>(
16 &self,
17 rom_id: u64,
18 save_path: &Path,
19 mut on_progress: F,
20 ) -> Result<(), DownloadError>
21 where
22 F: FnMut(u64, u64) + Send,
23 {
24 self.download_rom_with_cancel(rom_id, save_path, |_, _| false, &mut on_progress)
25 .await
26 }
27
28 pub async fn download_rom_with_cancel<F, C>(
29 &self,
30 rom_id: u64,
31 save_path: &Path,
32 is_cancelled: C,
33 on_progress: &mut F,
34 ) -> Result<(), DownloadError>
35 where
36 F: FnMut(u64, u64) + Send,
37 C: FnMut(u64, u64) -> bool + Send,
38 {
39 let filename = filename_hint(save_path);
40 let query = vec![
41 ("rom_ids".to_string(), rom_id.to_string()),
42 ("filename".to_string(), filename),
43 ];
44 self.download_url_with_query_with_cancel(
45 "/api/roms/download",
46 &query,
47 save_path,
48 is_cancelled,
49 on_progress,
50 )
51 .await
52 }
53
54 pub async fn download_url_with_cancel<F, C>(
56 &self,
57 url: &str,
58 save_path: &Path,
59 is_cancelled: C,
60 on_progress: &mut F,
61 ) -> Result<(), DownloadError>
62 where
63 F: FnMut(u64, u64) + Send,
64 C: FnMut(u64, u64) -> bool + Send,
65 {
66 self.download_url_with_query_with_cancel(url, &[], save_path, is_cancelled, on_progress)
67 .await
68 }
69
70 pub async fn download_url_with_query_with_cancel<F, C>(
72 &self,
73 url: &str,
74 query: &[(String, String)],
75 save_path: &Path,
76 mut is_cancelled: C,
77 on_progress: &mut F,
78 ) -> Result<(), DownloadError>
79 where
80 F: FnMut(u64, u64) + Send,
81 C: FnMut(u64, u64) -> bool + Send,
82 {
83 let url = self.resolve_download_url(url)?;
84 let filename = filename_hint(save_path);
85 let mut headers = if self.should_send_auth_to_download_url(&url) {
86 self.build_headers()?
87 } else {
88 reqwest::header::HeaderMap::new()
89 };
90
91 let existing_len = tokio::fs::metadata(save_path)
92 .await
93 .map(|m| m.len())
94 .unwrap_or(0);
95
96 if existing_len > 0 {
97 let range = format!("bytes={existing_len}-");
98 if let Ok(v) = reqwest::header::HeaderValue::from_str(&range) {
99 headers.insert(reqwest::header::RANGE, v);
100 }
101 }
102
103 if let Some(parent) = save_path.parent() {
104 tokio::fs::create_dir_all(parent)
105 .await
106 .map_err(|e| DownloadError::IoContext {
107 context: format!("create download parent dir {parent:?}"),
108 source: e,
109 })?;
110 }
111
112 let t0 = Instant::now();
113 let mut resp = self
114 .http
115 .get(&url)
116 .headers(headers)
117 .query(query)
118 .send()
119 .await?;
120
121 let status = resp.status();
122 if self.verbose {
123 tracing::info!(
124 "[romm-cli] GET {} filename={:?} -> {} ({}ms)",
125 crate::log_redact::redact_url_for_log(&url),
126 filename,
127 status.as_u16(),
128 t0.elapsed().as_millis()
129 );
130 }
131 if !status.is_success() {
132 let body = read_error_response_text(resp).await;
133 return Err(DownloadError::Api(api_error_from_response(status, &body)));
134 }
135
136 let (mut received, total, mut file) = if status == reqwest::StatusCode::PARTIAL_CONTENT {
137 let remaining = resp.content_length().unwrap_or(0);
138 let total = existing_len + remaining;
139 let file = tokio::fs::OpenOptions::new()
140 .append(true)
141 .open(save_path)
142 .await
143 .map_err(|e| DownloadError::IoContext {
144 context: format!("open file for append {save_path:?}"),
145 source: e,
146 })?;
147 (existing_len, total, file)
148 } else {
149 let total = resp.content_length().unwrap_or(0);
150 let file =
151 tokio::fs::File::create(save_path)
152 .await
153 .map_err(|e| DownloadError::IoContext {
154 context: format!("create file {save_path:?}"),
155 source: e,
156 })?;
157 (0u64, total, file)
158 };
159
160 if is_cancelled(received, total) {
161 return Err(DownloadError::Cancelled(CancelledByUser));
162 }
163
164 while let Some(chunk) = resp.chunk().await? {
165 if is_cancelled(received, total) {
166 return Err(DownloadError::Cancelled(CancelledByUser));
167 }
168 file.write_all(&chunk)
169 .await
170 .map_err(|e| DownloadError::IoContext {
171 context: format!("write chunk {save_path:?}"),
172 source: e,
173 })?;
174 received += chunk.len() as u64;
175 on_progress(received, total);
176 }
177
178 Ok(())
179 }
180
181 fn resolve_download_url(&self, url: &str) -> Result<String, DownloadError> {
182 let trimmed = url.trim();
183 if trimmed.is_empty() {
184 return Err(DownloadError::Api(ApiError::UnexpectedResponse(
185 "download URL cannot be empty".into(),
186 )));
187 }
188 if let Ok(parsed) = Url::parse(trimmed) {
189 return Ok(parsed.to_string());
190 }
191
192 let base = Url::parse(&normalize_romm_origin(&self.base_url)).map_err(|e| {
193 DownloadError::Api(ApiError::UnexpectedResponse(format!(
194 "invalid RomM base URL: {e}"
195 )))
196 })?;
197 let joined = base.join(trimmed).map_err(|e| {
198 DownloadError::Api(ApiError::UnexpectedResponse(format!(
199 "could not resolve download URL {trimmed:?}: {e}"
200 )))
201 })?;
202 Ok(joined.to_string())
203 }
204
205 fn should_send_auth_to_download_url(&self, url: &str) -> bool {
206 let Ok(download_url) = Url::parse(url) else {
207 return true;
208 };
209 let Ok(base_url) = Url::parse(&normalize_romm_origin(&self.base_url)) else {
210 return false;
211 };
212
213 download_url.scheme() == base_url.scheme()
214 && download_url.host_str() == base_url.host_str()
215 && download_url.port_or_known_default() == base_url.port_or_known_default()
216 }
217}
218
219fn filename_hint(save_path: &Path) -> String {
220 save_path
221 .file_name()
222 .and_then(|n| n.to_str())
223 .unwrap_or("download.bin")
224 .to_string()
225}
226
227#[cfg(test)]
228mod tests {
229 use crate::config::{AuthConfig, Config, ExtrasDefaults};
230
231 use super::*;
232
233 fn client_for(base_url: &str) -> RommClient {
234 RommClient::new(
235 &Config {
236 base_url: base_url.to_string(),
237 download_dir: ".".to_string(),
238 use_https: true,
239 auth: Some(AuthConfig::Bearer {
240 token: "secret".to_string(),
241 }),
242 extras_defaults: ExtrasDefaults::default(),
243 save_sync: Default::default(),
244 roms_layout: Default::default(),
245 theme: crate::config::default_theme_id(),
246 tui_layout: Default::default(),
247 },
248 false,
249 )
250 .expect("client")
251 }
252
253 #[test]
254 fn download_auth_allowed_for_same_origin_absolute_url() {
255 let client = client_for("https://romm.example:8443/api");
256 assert!(client.should_send_auth_to_download_url("https://romm.example:8443/files/a.zip"));
257 }
258
259 #[test]
260 fn download_auth_blocked_for_off_origin_absolute_url() {
261 let client = client_for("https://romm.example/api");
262 assert!(!client.should_send_auth_to_download_url("https://cdn.example/files/a.zip"));
263 }
264}