1use anyhow::{anyhow, Result};
8use base64::{engine::general_purpose, Engine as _};
9use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION};
10use reqwest::{Client as HttpClient, Method};
11use serde_json::Value;
12use std::path::Path;
13use std::time::Instant;
14use tokio::io::AsyncWriteExt as _;
15
16use crate::config::{normalize_romm_origin, AuthConfig, Config};
17use crate::endpoints::Endpoint;
18
19fn http_user_agent() -> String {
22 match std::env::var("ROMM_USER_AGENT") {
23 Ok(s) if !s.trim().is_empty() => s,
24 _ => format!(
25 "Mozilla/5.0 (compatible; romm-cli/{}; +https://github.com/patricksmill/romm-cli)",
26 env!("CARGO_PKG_VERSION")
27 ),
28 }
29}
30
31fn decode_json_response_body(bytes: &[u8]) -> Value {
36 if bytes.is_empty() || bytes.iter().all(|b| b.is_ascii_whitespace()) {
37 return Value::Null;
38 }
39 serde_json::from_slice(bytes).unwrap_or_else(|_| {
40 serde_json::json!({
41 "_non_json_body": String::from_utf8_lossy(bytes).to_string()
42 })
43 })
44}
45
46#[derive(Clone)]
51pub struct RommClient {
52 http: HttpClient,
53 base_url: String,
54 auth: Option<AuthConfig>,
55 verbose: bool,
56}
57
58pub fn api_root_url(base_url: &str) -> String {
60 normalize_romm_origin(base_url)
61}
62
63fn alternate_http_scheme_root(root: &str) -> Option<String> {
64 root.strip_prefix("http://")
65 .map(|rest| format!("https://{}", rest))
66 .or_else(|| {
67 root.strip_prefix("https://")
68 .map(|rest| format!("http://{}", rest))
69 })
70}
71
72pub fn resolve_openapi_root(api_base_url: &str) -> String {
78 if let Ok(s) = std::env::var("ROMM_OPENAPI_BASE_URL") {
79 let t = s.trim();
80 if !t.is_empty() {
81 return normalize_romm_origin(t);
82 }
83 }
84 normalize_romm_origin(api_base_url)
85}
86
87pub fn openapi_spec_urls(api_root: &str) -> Vec<String> {
91 let root = api_root.trim_end_matches('/').to_string();
92 let mut roots = vec![root.clone()];
93 if let Some(alt) = alternate_http_scheme_root(&root) {
94 if alt != root {
95 roots.push(alt);
96 }
97 }
98
99 let mut urls = Vec::new();
100 for r in roots {
101 let b = r.trim_end_matches('/');
102 urls.push(format!("{b}/openapi.json"));
103 urls.push(format!("{b}/api/openapi.json"));
104 }
105 urls
106}
107
108impl RommClient {
109 pub fn new(config: &Config, verbose: bool) -> Result<Self> {
115 let http = HttpClient::builder()
116 .user_agent(http_user_agent())
117 .build()?;
118 Ok(Self {
119 http,
120 base_url: config.base_url.clone(),
121 auth: config.auth.clone(),
122 verbose,
123 })
124 }
125
126 pub fn verbose(&self) -> bool {
127 self.verbose
128 }
129
130 fn build_headers(&self) -> Result<HeaderMap> {
135 let mut headers = HeaderMap::new();
136
137 if let Some(auth) = &self.auth {
138 match auth {
139 AuthConfig::Basic { username, password } => {
140 let creds = format!("{username}:{password}");
141 let encoded = general_purpose::STANDARD.encode(creds.as_bytes());
142 let value = format!("Basic {encoded}");
143 headers.insert(
144 AUTHORIZATION,
145 HeaderValue::from_str(&value)
146 .map_err(|_| anyhow!("invalid basic auth header value"))?,
147 );
148 }
149 AuthConfig::Bearer { token } => {
150 let value = format!("Bearer {token}");
151 headers.insert(
152 AUTHORIZATION,
153 HeaderValue::from_str(&value)
154 .map_err(|_| anyhow!("invalid bearer auth header value"))?,
155 );
156 }
157 AuthConfig::ApiKey { header, key } => {
158 let name = reqwest::header::HeaderName::from_bytes(header.as_bytes()).map_err(
159 |_| anyhow!("invalid API_KEY_HEADER, must be a valid HTTP header name"),
160 )?;
161 headers.insert(
162 name,
163 HeaderValue::from_str(key)
164 .map_err(|_| anyhow!("invalid API_KEY header value"))?,
165 );
166 }
167 }
168 }
169
170 Ok(headers)
171 }
172
173 pub async fn call<E>(&self, ep: &E) -> anyhow::Result<E::Output>
175 where
176 E: Endpoint,
177 E::Output: serde::de::DeserializeOwned,
178 {
179 let method = ep.method();
180 let path = ep.path();
181 let query = ep.query();
182 let body = ep.body();
183
184 let value = self.request_json(method, &path, &query, body).await?;
185 let output = serde_json::from_value(value)
186 .map_err(|e| anyhow!("failed to decode response for {} {}: {}", method, path, e))?;
187
188 Ok(output)
189 }
190
191 pub async fn request_json(
196 &self,
197 method: &str,
198 path: &str,
199 query: &[(String, String)],
200 body: Option<Value>,
201 ) -> Result<Value> {
202 let url = format!(
203 "{}/{}",
204 self.base_url.trim_end_matches('/'),
205 path.trim_start_matches('/')
206 );
207 let headers = self.build_headers()?;
208
209 let http_method = Method::from_bytes(method.as_bytes())
210 .map_err(|_| anyhow!("invalid HTTP method: {method}"))?;
211
212 let query_refs: Vec<(&str, &str)> = query
215 .iter()
216 .map(|(k, v)| (k.as_str(), v.as_str()))
217 .collect();
218
219 let mut req = self
220 .http
221 .request(http_method, &url)
222 .headers(headers)
223 .query(&query_refs);
224
225 if let Some(body) = body {
226 req = req.json(&body);
227 }
228
229 let t0 = Instant::now();
230 let resp = req
231 .send()
232 .await
233 .map_err(|e| anyhow!("request error: {e}"))?;
234
235 let status = resp.status();
236 if self.verbose {
237 let keys: Vec<&str> = query.iter().map(|(k, _)| k.as_str()).collect();
238 tracing::info!(
239 "[romm-cli] {} {} query_keys={:?} -> {} ({}ms)",
240 method,
241 path,
242 keys,
243 status.as_u16(),
244 t0.elapsed().as_millis()
245 );
246 }
247 if !status.is_success() {
248 let body = resp.text().await.unwrap_or_default();
249 return Err(anyhow!(
250 "ROMM API error: {} {} - {}",
251 status.as_u16(),
252 status.canonical_reason().unwrap_or(""),
253 body
254 ));
255 }
256
257 let bytes = resp
258 .bytes()
259 .await
260 .map_err(|e| anyhow!("read response body: {e}"))?;
261
262 Ok(decode_json_response_body(&bytes))
263 }
264
265 pub async fn request_json_unauthenticated(
266 &self,
267 method: &str,
268 path: &str,
269 query: &[(String, String)],
270 body: Option<Value>,
271 ) -> Result<Value> {
272 let url = format!(
273 "{}/{}",
274 self.base_url.trim_end_matches('/'),
275 path.trim_start_matches('/')
276 );
277 let headers = HeaderMap::new();
278
279 let http_method = Method::from_bytes(method.as_bytes())
280 .map_err(|_| anyhow!("invalid HTTP method: {method}"))?;
281
282 let query_refs: Vec<(&str, &str)> = query
285 .iter()
286 .map(|(k, v)| (k.as_str(), v.as_str()))
287 .collect();
288
289 let mut req = self
290 .http
291 .request(http_method, &url)
292 .headers(headers)
293 .query(&query_refs);
294
295 if let Some(body) = body {
296 req = req.json(&body);
297 }
298
299 let t0 = Instant::now();
300 let resp = req
301 .send()
302 .await
303 .map_err(|e| anyhow!("request error: {e}"))?;
304
305 let status = resp.status();
306 if self.verbose {
307 let keys: Vec<&str> = query.iter().map(|(k, _)| k.as_str()).collect();
308 tracing::info!(
309 "[romm-cli] {} {} query_keys={:?} -> {} ({}ms)",
310 method,
311 path,
312 keys,
313 status.as_u16(),
314 t0.elapsed().as_millis()
315 );
316 }
317 if !status.is_success() {
318 let body = resp.text().await.unwrap_or_default();
319 return Err(anyhow!(
320 "ROMM API error: {} {} - {}",
321 status.as_u16(),
322 status.canonical_reason().unwrap_or(""),
323 body
324 ));
325 }
326
327 let bytes = resp
328 .bytes()
329 .await
330 .map_err(|e| anyhow!("read response body: {e}"))?;
331
332 Ok(decode_json_response_body(&bytes))
333 }
334
335 pub async fn fetch_openapi_json(&self) -> Result<String> {
338 let root = resolve_openapi_root(&self.base_url);
339 let urls = openapi_spec_urls(&root);
340 let mut failures = Vec::new();
341 for url in &urls {
342 match self.fetch_openapi_json_once(url).await {
343 Ok(body) => return Ok(body),
344 Err(e) => failures.push(format!("{url}: {e:#}")),
345 }
346 }
347 Err(anyhow!(
348 "could not download OpenAPI ({} attempt(s)): {}",
349 failures.len(),
350 failures.join(" | ")
351 ))
352 }
353
354 async fn fetch_openapi_json_once(&self, url: &str) -> Result<String> {
355 let headers = self.build_headers()?;
356
357 let t0 = Instant::now();
358 let resp = self
359 .http
360 .get(url)
361 .headers(headers)
362 .send()
363 .await
364 .map_err(|e| anyhow!("request failed: {e}"))?;
365
366 let status = resp.status();
367 if self.verbose {
368 tracing::info!(
369 "[romm-cli] GET {} -> {} ({}ms)",
370 url,
371 status.as_u16(),
372 t0.elapsed().as_millis()
373 );
374 }
375 if !status.is_success() {
376 let body = resp.text().await.unwrap_or_default();
377 return Err(anyhow!(
378 "HTTP {} {} - {}",
379 status.as_u16(),
380 status.canonical_reason().unwrap_or(""),
381 body.chars().take(500).collect::<String>()
382 ));
383 }
384
385 resp.text()
386 .await
387 .map_err(|e| anyhow!("read OpenAPI body: {e}"))
388 }
389
390 pub async fn download_rom<F>(
399 &self,
400 rom_id: u64,
401 save_path: &Path,
402 mut on_progress: F,
403 ) -> Result<()>
404 where
405 F: FnMut(u64, u64) + Send,
406 {
407 let path = "/api/roms/download";
408 let url = format!(
409 "{}/{}",
410 self.base_url.trim_end_matches('/'),
411 path.trim_start_matches('/')
412 );
413 let mut headers = self.build_headers()?;
414
415 let filename = save_path
416 .file_name()
417 .and_then(|n| n.to_str())
418 .unwrap_or("download.zip");
419
420 let existing_len = tokio::fs::metadata(save_path)
422 .await
423 .map(|m| m.len())
424 .unwrap_or(0);
425
426 if existing_len > 0 {
427 let range = format!("bytes={existing_len}-");
428 if let Ok(v) = reqwest::header::HeaderValue::from_str(&range) {
429 headers.insert(reqwest::header::RANGE, v);
430 }
431 }
432
433 let t0 = Instant::now();
434 let mut resp = self
435 .http
436 .get(&url)
437 .headers(headers)
438 .query(&[
439 ("rom_ids", rom_id.to_string()),
440 ("filename", filename.to_string()),
441 ])
442 .send()
443 .await
444 .map_err(|e| anyhow!("download request error: {e}"))?;
445
446 let status = resp.status();
447 if self.verbose {
448 tracing::info!(
449 "[romm-cli] GET /api/roms/download rom_id={} filename={:?} -> {} ({}ms)",
450 rom_id,
451 filename,
452 status.as_u16(),
453 t0.elapsed().as_millis()
454 );
455 }
456 if !status.is_success() {
457 let body = resp.text().await.unwrap_or_default();
458 return Err(anyhow!(
459 "ROMM API error: {} {} - {}",
460 status.as_u16(),
461 status.canonical_reason().unwrap_or(""),
462 body
463 ));
464 }
465
466 let (mut received, total, mut file) = if status == reqwest::StatusCode::PARTIAL_CONTENT {
468 let remaining = resp.content_length().unwrap_or(0);
470 let total = existing_len + remaining;
471 let file = tokio::fs::OpenOptions::new()
472 .append(true)
473 .open(save_path)
474 .await
475 .map_err(|e| anyhow!("open file for append {:?}: {e}", save_path))?;
476 (existing_len, total, file)
477 } else {
478 let total = resp.content_length().unwrap_or(0);
480 let file = tokio::fs::File::create(save_path)
481 .await
482 .map_err(|e| anyhow!("create file {:?}: {e}", save_path))?;
483 (0u64, total, file)
484 };
485
486 while let Some(chunk) = resp.chunk().await.map_err(|e| anyhow!("read chunk: {e}"))? {
487 file.write_all(&chunk)
488 .await
489 .map_err(|e| anyhow!("write chunk {:?}: {e}", save_path))?;
490 received += chunk.len() as u64;
491 on_progress(received, total);
492 }
493
494 Ok(())
495 }
496}
497
498#[cfg(test)]
499mod tests {
500 use super::*;
501
502 #[test]
503 fn decode_json_empty_and_whitespace_to_null() {
504 assert_eq!(decode_json_response_body(b""), Value::Null);
505 assert_eq!(decode_json_response_body(b" \n\t "), Value::Null);
506 }
507
508 #[test]
509 fn decode_json_object_roundtrip() {
510 let v = decode_json_response_body(br#"{"a":1}"#);
511 assert_eq!(v["a"], 1);
512 }
513
514 #[test]
515 fn decode_non_json_wrapped() {
516 let v = decode_json_response_body(b"plain text");
517 assert_eq!(v["_non_json_body"], "plain text");
518 }
519
520 #[test]
521 fn api_root_url_strips_trailing_api() {
522 assert_eq!(
523 super::api_root_url("http://localhost:8080/api"),
524 "http://localhost:8080"
525 );
526 assert_eq!(
527 super::api_root_url("http://localhost:8080/api/"),
528 "http://localhost:8080"
529 );
530 assert_eq!(
531 super::api_root_url("http://localhost:8080"),
532 "http://localhost:8080"
533 );
534 }
535
536 #[test]
537 fn openapi_spec_urls_try_primary_scheme_then_alt() {
538 let urls = super::openapi_spec_urls("http://example.test");
539 assert_eq!(urls[0], "http://example.test/openapi.json");
540 assert_eq!(urls[1], "http://example.test/api/openapi.json");
541 assert!(
542 urls.iter()
543 .any(|u| u == "https://example.test/openapi.json"),
544 "{urls:?}"
545 );
546 }
547}