1use anyhow::{anyhow, Result};
8use base64::{engine::general_purpose, Engine as _};
9use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION};
10use reqwest::{Client as HttpClient, Method};
11use serde_json::Value;
12use std::path::Path;
13use std::time::Instant;
14use tokio::io::AsyncWriteExt as _;
15
16use crate::config::{AuthConfig, Config};
17use crate::endpoints::Endpoint;
18
19fn decode_json_response_body(bytes: &[u8]) -> Value {
24 if bytes.is_empty() || bytes.iter().all(|b| b.is_ascii_whitespace()) {
25 return Value::Null;
26 }
27 serde_json::from_slice(bytes).unwrap_or_else(|_| {
28 serde_json::json!({
29 "_non_json_body": String::from_utf8_lossy(bytes).to_string()
30 })
31 })
32}
33
34#[derive(Clone)]
39pub struct RommClient {
40 http: HttpClient,
41 base_url: String,
42 auth: Option<AuthConfig>,
43 verbose: bool,
44}
45
46impl RommClient {
47 pub fn new(config: &Config, verbose: bool) -> Result<Self> {
53 let http = HttpClient::builder().build()?;
54 Ok(Self {
55 http,
56 base_url: config.base_url.clone(),
57 auth: config.auth.clone(),
58 verbose,
59 })
60 }
61
62 fn build_headers(&self) -> Result<HeaderMap> {
67 let mut headers = HeaderMap::new();
68
69 if let Some(auth) = &self.auth {
70 match auth {
71 AuthConfig::Basic { username, password } => {
72 let creds = format!("{username}:{password}");
73 let encoded = general_purpose::STANDARD.encode(creds.as_bytes());
74 let value = format!("Basic {encoded}");
75 headers.insert(
76 AUTHORIZATION,
77 HeaderValue::from_str(&value)
78 .map_err(|_| anyhow!("invalid basic auth header value"))?,
79 );
80 }
81 AuthConfig::Bearer { token } => {
82 let value = format!("Bearer {token}");
83 headers.insert(
84 AUTHORIZATION,
85 HeaderValue::from_str(&value)
86 .map_err(|_| anyhow!("invalid bearer auth header value"))?,
87 );
88 }
89 AuthConfig::ApiKey { header, key } => {
90 let name = reqwest::header::HeaderName::from_bytes(header.as_bytes()).map_err(
91 |_| anyhow!("invalid API_KEY_HEADER, must be a valid HTTP header name"),
92 )?;
93 headers.insert(
94 name,
95 HeaderValue::from_str(key)
96 .map_err(|_| anyhow!("invalid API_KEY header value"))?,
97 );
98 }
99 }
100 }
101
102 Ok(headers)
103 }
104
105 pub async fn call<E>(&self, ep: &E) -> anyhow::Result<E::Output>
107 where
108 E: Endpoint,
109 E::Output: serde::de::DeserializeOwned,
110 {
111 let method = ep.method();
112 let path = ep.path();
113 let query = ep.query();
114 let body = ep.body();
115
116 let value = self.request_json(method, &path, &query, body).await?;
117 let output = serde_json::from_value(value)
118 .map_err(|e| anyhow!("failed to decode response for {} {}: {}", method, path, e))?;
119
120 Ok(output)
121 }
122
123 pub async fn request_json(
128 &self,
129 method: &str,
130 path: &str,
131 query: &[(String, String)],
132 body: Option<Value>,
133 ) -> Result<Value> {
134 let url = format!(
135 "{}/{}",
136 self.base_url.trim_end_matches('/'),
137 path.trim_start_matches('/')
138 );
139 let headers = self.build_headers()?;
140
141 let http_method = Method::from_bytes(method.as_bytes())
142 .map_err(|_| anyhow!("invalid HTTP method: {method}"))?;
143
144 let query_refs: Vec<(&str, &str)> = query
147 .iter()
148 .map(|(k, v)| (k.as_str(), v.as_str()))
149 .collect();
150
151 let mut req = self
152 .http
153 .request(http_method, &url)
154 .headers(headers)
155 .query(&query_refs);
156
157 if let Some(body) = body {
158 req = req.json(&body);
159 }
160
161 let t0 = Instant::now();
162 let resp = req
163 .send()
164 .await
165 .map_err(|e| anyhow!("request error: {e}"))?;
166
167 let status = resp.status();
168 if self.verbose {
169 let keys: Vec<&str> = query.iter().map(|(k, _)| k.as_str()).collect();
170 tracing::info!(
171 "[romm-cli] {} {} query_keys={:?} -> {} ({}ms)",
172 method,
173 path,
174 keys,
175 status.as_u16(),
176 t0.elapsed().as_millis()
177 );
178 }
179 if !status.is_success() {
180 let body = resp.text().await.unwrap_or_default();
181 return Err(anyhow!(
182 "ROMM API error: {} {} - {}",
183 status.as_u16(),
184 status.canonical_reason().unwrap_or(""),
185 body
186 ));
187 }
188
189 let bytes = resp
190 .bytes()
191 .await
192 .map_err(|e| anyhow!("read response body: {e}"))?;
193
194 Ok(decode_json_response_body(&bytes))
195 }
196
197 pub async fn download_rom<F>(
206 &self,
207 rom_id: u64,
208 save_path: &Path,
209 mut on_progress: F,
210 ) -> Result<()>
211 where
212 F: FnMut(u64, u64) + Send,
213 {
214 let path = "/api/roms/download";
215 let url = format!(
216 "{}/{}",
217 self.base_url.trim_end_matches('/'),
218 path.trim_start_matches('/')
219 );
220 let mut headers = self.build_headers()?;
221
222 let filename = save_path
223 .file_name()
224 .and_then(|n| n.to_str())
225 .unwrap_or("download.zip");
226
227 let existing_len = tokio::fs::metadata(save_path)
229 .await
230 .map(|m| m.len())
231 .unwrap_or(0);
232
233 if existing_len > 0 {
234 let range = format!("bytes={existing_len}-");
235 if let Ok(v) = reqwest::header::HeaderValue::from_str(&range) {
236 headers.insert(reqwest::header::RANGE, v);
237 }
238 }
239
240 let t0 = Instant::now();
241 let mut resp = self
242 .http
243 .get(&url)
244 .headers(headers)
245 .query(&[
246 ("rom_ids", rom_id.to_string()),
247 ("filename", filename.to_string()),
248 ])
249 .send()
250 .await
251 .map_err(|e| anyhow!("download request error: {e}"))?;
252
253 let status = resp.status();
254 if self.verbose {
255 tracing::info!(
256 "[romm-cli] GET /api/roms/download rom_id={} filename={:?} -> {} ({}ms)",
257 rom_id,
258 filename,
259 status.as_u16(),
260 t0.elapsed().as_millis()
261 );
262 }
263 if !status.is_success() {
264 let body = resp.text().await.unwrap_or_default();
265 return Err(anyhow!(
266 "ROMM API error: {} {} - {}",
267 status.as_u16(),
268 status.canonical_reason().unwrap_or(""),
269 body
270 ));
271 }
272
273 let (mut received, total, mut file) = if status == reqwest::StatusCode::PARTIAL_CONTENT {
275 let remaining = resp.content_length().unwrap_or(0);
277 let total = existing_len + remaining;
278 let file = tokio::fs::OpenOptions::new()
279 .append(true)
280 .open(save_path)
281 .await
282 .map_err(|e| anyhow!("open file for append {:?}: {e}", save_path))?;
283 (existing_len, total, file)
284 } else {
285 let total = resp.content_length().unwrap_or(0);
287 let file = tokio::fs::File::create(save_path)
288 .await
289 .map_err(|e| anyhow!("create file {:?}: {e}", save_path))?;
290 (0u64, total, file)
291 };
292
293 while let Some(chunk) = resp.chunk().await.map_err(|e| anyhow!("read chunk: {e}"))? {
294 file.write_all(&chunk)
295 .await
296 .map_err(|e| anyhow!("write chunk {:?}: {e}", save_path))?;
297 received += chunk.len() as u64;
298 on_progress(received, total);
299 }
300
301 Ok(())
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308
309 #[test]
310 fn decode_json_empty_and_whitespace_to_null() {
311 assert_eq!(decode_json_response_body(b""), Value::Null);
312 assert_eq!(decode_json_response_body(b" \n\t "), Value::Null);
313 }
314
315 #[test]
316 fn decode_json_object_roundtrip() {
317 let v = decode_json_response_body(br#"{"a":1}"#);
318 assert_eq!(v["a"], 1);
319 }
320
321 #[test]
322 fn decode_non_json_wrapped() {
323 let v = decode_json_response_body(b"plain text");
324 assert_eq!(v["_non_json_body"], "plain text");
325 }
326}