Skip to main content

romm_cli/
client.rs

1//! HTTP client wrapper around the ROMM API.
2//!
3//! `RommClient` owns a configured `reqwest::Client` plus base URL and
4//! authentication settings. Frontends (CLI, TUI, or a future GUI) depend
5//! on this type instead of talking to `reqwest` directly.
6
7use anyhow::{anyhow, Result};
8use base64::{engine::general_purpose, Engine as _};
9use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION};
10use reqwest::{Client as HttpClient, Method};
11use serde_json::Value;
12use std::path::Path;
13use std::time::Instant;
14use tokio::io::AsyncWriteExt as _;
15
16use crate::config::{normalize_romm_origin, AuthConfig, Config};
17use crate::endpoints::Endpoint;
18
19/// Default `User-Agent` for every request. The stock `reqwest` UA is sometimes blocked at the HTTP
20/// layer (403, etc.) by reverse proxies; override with env `ROMM_USER_AGENT` if needed.
21fn http_user_agent() -> String {
22    match std::env::var("ROMM_USER_AGENT") {
23        Ok(s) if !s.trim().is_empty() => s,
24        _ => format!(
25            "Mozilla/5.0 (compatible; romm-cli/{}; +https://github.com/patricksmill/romm-cli)",
26            env!("CARGO_PKG_VERSION")
27        ),
28    }
29}
30
31/// Map a successful HTTP response body to JSON [`Value`].
32///
33/// Empty or whitespace-only bodies become [`Value::Null`] (e.g. HTTP 204).
34/// Non-JSON UTF-8 bodies are wrapped as `{"_non_json_body": "..."}`.
35fn decode_json_response_body(bytes: &[u8]) -> Value {
36    if bytes.is_empty() || bytes.iter().all(|b| b.is_ascii_whitespace()) {
37        return Value::Null;
38    }
39    serde_json::from_slice(bytes).unwrap_or_else(|_| {
40        serde_json::json!({
41            "_non_json_body": String::from_utf8_lossy(bytes).to_string()
42        })
43    })
44}
45
46/// High-level HTTP client for the ROMM API.
47///
48/// This type hides the details of `reqwest` and authentication headers
49/// behind a small, easy-to-mock interface that all frontends can share.
50#[derive(Clone)]
51pub struct RommClient {
52    http: HttpClient,
53    base_url: String,
54    auth: Option<AuthConfig>,
55    verbose: bool,
56}
57
58/// Same as [`crate::config::normalize_romm_origin`]: browser-style origin for RomM (no `/api` suffix).
59pub fn api_root_url(base_url: &str) -> String {
60    normalize_romm_origin(base_url)
61}
62
63fn alternate_http_scheme_root(root: &str) -> Option<String> {
64    root.strip_prefix("http://")
65        .map(|rest| format!("https://{}", rest))
66        .or_else(|| {
67            root.strip_prefix("https://")
68                .map(|rest| format!("http://{}", rest))
69        })
70}
71
72/// Origin used to fetch `/openapi.json` (same as the RomM website). Normally equals
73/// [`normalize_romm_origin`] applied to `API_BASE_URL`.
74///
75/// Set `ROMM_OPENAPI_BASE_URL` only when that origin differs (wrong host in `API_BASE_URL`, split
76/// DNS, etc.).
77pub fn resolve_openapi_root(api_base_url: &str) -> String {
78    if let Ok(s) = std::env::var("ROMM_OPENAPI_BASE_URL") {
79        let t = s.trim();
80        if !t.is_empty() {
81            return normalize_romm_origin(t);
82        }
83    }
84    normalize_romm_origin(api_base_url)
85}
86
87/// URLs to try for the OpenAPI JSON document (scheme fallback and alternate paths).
88///
89/// `api_root` is an origin such as `https://example.com` (see [`resolve_openapi_root`]).
90pub fn openapi_spec_urls(api_root: &str) -> Vec<String> {
91    let root = api_root.trim_end_matches('/').to_string();
92    let mut roots = vec![root.clone()];
93    if let Some(alt) = alternate_http_scheme_root(&root) {
94        if alt != root {
95            roots.push(alt);
96        }
97    }
98
99    let mut urls = Vec::new();
100    for r in roots {
101        let b = r.trim_end_matches('/');
102        urls.push(format!("{b}/openapi.json"));
103        urls.push(format!("{b}/api/openapi.json"));
104    }
105    urls
106}
107
108impl RommClient {
109    /// Construct a new client from the high-level [`Config`].
110    ///
111    /// `verbose` enables stderr request logging (method, path, query key names, status, timing).
112    /// This is typically done once in `main` and the resulting `RommClient` is shared
113    /// (by reference or cloning) with the chosen frontend.
114    pub fn new(config: &Config, verbose: bool) -> Result<Self> {
115        let http = HttpClient::builder()
116            .user_agent(http_user_agent())
117            .build()?;
118        Ok(Self {
119            http,
120            base_url: config.base_url.clone(),
121            auth: config.auth.clone(),
122            verbose,
123        })
124    }
125
126    pub fn verbose(&self) -> bool {
127        self.verbose
128    }
129
130    /// Build the HTTP headers for the current authentication mode.
131    ///
132    /// This helper centralises all auth logic so that the rest of the
133    /// code never needs to worry about `Basic` vs `Bearer` vs API key.
134    fn build_headers(&self) -> Result<HeaderMap> {
135        let mut headers = HeaderMap::new();
136
137        if let Some(auth) = &self.auth {
138            match auth {
139                AuthConfig::Basic { username, password } => {
140                    let creds = format!("{username}:{password}");
141                    let encoded = general_purpose::STANDARD.encode(creds.as_bytes());
142                    let value = format!("Basic {encoded}");
143                    headers.insert(
144                        AUTHORIZATION,
145                        HeaderValue::from_str(&value)
146                            .map_err(|_| anyhow!("invalid basic auth header value"))?,
147                    );
148                }
149                AuthConfig::Bearer { token } => {
150                    let value = format!("Bearer {token}");
151                    headers.insert(
152                        AUTHORIZATION,
153                        HeaderValue::from_str(&value)
154                            .map_err(|_| anyhow!("invalid bearer auth header value"))?,
155                    );
156                }
157                AuthConfig::ApiKey { header, key } => {
158                    let name = reqwest::header::HeaderName::from_bytes(header.as_bytes()).map_err(
159                        |_| anyhow!("invalid API_KEY_HEADER, must be a valid HTTP header name"),
160                    )?;
161                    headers.insert(
162                        name,
163                        HeaderValue::from_str(key)
164                            .map_err(|_| anyhow!("invalid API_KEY header value"))?,
165                    );
166                }
167            }
168        }
169
170        Ok(headers)
171    }
172
173    /// Call a typed endpoint using the low-level `request_json` primitive.
174    pub async fn call<E>(&self, ep: &E) -> anyhow::Result<E::Output>
175    where
176        E: Endpoint,
177        E::Output: serde::de::DeserializeOwned,
178    {
179        let method = ep.method();
180        let path = ep.path();
181        let query = ep.query();
182        let body = ep.body();
183
184        let value = self.request_json(method, &path, &query, body).await?;
185        let output = serde_json::from_value(value)
186            .map_err(|e| anyhow!("failed to decode response for {} {}: {}", method, path, e))?;
187
188        Ok(output)
189    }
190
191    /// Low-level helper that issues an HTTP request and returns raw JSON.
192    ///
193    /// Higher-level helpers (such as typed `Endpoint` implementations)
194    /// should prefer [`RommClient::call`] instead of using this directly.
195    pub async fn request_json(
196        &self,
197        method: &str,
198        path: &str,
199        query: &[(String, String)],
200        body: Option<Value>,
201    ) -> Result<Value> {
202        let url = format!(
203            "{}/{}",
204            self.base_url.trim_end_matches('/'),
205            path.trim_start_matches('/')
206        );
207        let headers = self.build_headers()?;
208
209        let http_method = Method::from_bytes(method.as_bytes())
210            .map_err(|_| anyhow!("invalid HTTP method: {method}"))?;
211
212        // Ensure query params serialize as key=value pairs (reqwest/serde_urlencoded
213        // expect sequences of (key, value); using &[(&str, &str)] guarantees correct encoding).
214        let query_refs: Vec<(&str, &str)> = query
215            .iter()
216            .map(|(k, v)| (k.as_str(), v.as_str()))
217            .collect();
218
219        let mut req = self
220            .http
221            .request(http_method, &url)
222            .headers(headers)
223            .query(&query_refs);
224
225        if let Some(body) = body {
226            req = req.json(&body);
227        }
228
229        let t0 = Instant::now();
230        let resp = req
231            .send()
232            .await
233            .map_err(|e| anyhow!("request error: {e}"))?;
234
235        let status = resp.status();
236        if self.verbose {
237            let keys: Vec<&str> = query.iter().map(|(k, _)| k.as_str()).collect();
238            tracing::info!(
239                "[romm-cli] {} {} query_keys={:?} -> {} ({}ms)",
240                method,
241                path,
242                keys,
243                status.as_u16(),
244                t0.elapsed().as_millis()
245            );
246        }
247        if !status.is_success() {
248            let body = resp.text().await.unwrap_or_default();
249            return Err(anyhow!(
250                "ROMM API error: {} {} - {}",
251                status.as_u16(),
252                status.canonical_reason().unwrap_or(""),
253                body
254            ));
255        }
256
257        let bytes = resp
258            .bytes()
259            .await
260            .map_err(|e| anyhow!("read response body: {e}"))?;
261
262        Ok(decode_json_response_body(&bytes))
263    }
264
265    pub async fn request_json_unauthenticated(
266        &self,
267        method: &str,
268        path: &str,
269        query: &[(String, String)],
270        body: Option<Value>,
271    ) -> Result<Value> {
272        let url = format!(
273            "{}/{}",
274            self.base_url.trim_end_matches('/'),
275            path.trim_start_matches('/')
276        );
277        let headers = HeaderMap::new();
278
279        let http_method = Method::from_bytes(method.as_bytes())
280            .map_err(|_| anyhow!("invalid HTTP method: {method}"))?;
281
282        // Ensure query params serialize as key=value pairs (reqwest/serde_urlencoded
283        // expect sequences of (key, value); using &[(&str, &str)] guarantees correct encoding).
284        let query_refs: Vec<(&str, &str)> = query
285            .iter()
286            .map(|(k, v)| (k.as_str(), v.as_str()))
287            .collect();
288
289        let mut req = self
290            .http
291            .request(http_method, &url)
292            .headers(headers)
293            .query(&query_refs);
294
295        if let Some(body) = body {
296            req = req.json(&body);
297        }
298
299        let t0 = Instant::now();
300        let resp = req
301            .send()
302            .await
303            .map_err(|e| anyhow!("request error: {e}"))?;
304
305        let status = resp.status();
306        if self.verbose {
307            let keys: Vec<&str> = query.iter().map(|(k, _)| k.as_str()).collect();
308            tracing::info!(
309                "[romm-cli] {} {} query_keys={:?} -> {} ({}ms)",
310                method,
311                path,
312                keys,
313                status.as_u16(),
314                t0.elapsed().as_millis()
315            );
316        }
317        if !status.is_success() {
318            let body = resp.text().await.unwrap_or_default();
319            return Err(anyhow!(
320                "ROMM API error: {} {} - {}",
321                status.as_u16(),
322                status.canonical_reason().unwrap_or(""),
323                body
324            ));
325        }
326
327        let bytes = resp
328            .bytes()
329            .await
330            .map_err(|e| anyhow!("read response body: {e}"))?;
331
332        Ok(decode_json_response_body(&bytes))
333    }
334
335    /// GET the OpenAPI spec from the server. Tries [`openapi_spec_urls`] in order (HTTP/HTTPS and
336    /// `/openapi.json` vs `/api/openapi.json`). Uses [`resolve_openapi_root`] for the origin.
337    pub async fn fetch_openapi_json(&self) -> Result<String> {
338        let root = resolve_openapi_root(&self.base_url);
339        let urls = openapi_spec_urls(&root);
340        let mut failures = Vec::new();
341        for url in &urls {
342            match self.fetch_openapi_json_once(url).await {
343                Ok(body) => return Ok(body),
344                Err(e) => failures.push(format!("{url}: {e:#}")),
345            }
346        }
347        Err(anyhow!(
348            "could not download OpenAPI ({} attempt(s)): {}",
349            failures.len(),
350            failures.join(" | ")
351        ))
352    }
353
354    async fn fetch_openapi_json_once(&self, url: &str) -> Result<String> {
355        let headers = self.build_headers()?;
356
357        let t0 = Instant::now();
358        let resp = self
359            .http
360            .get(url)
361            .headers(headers)
362            .send()
363            .await
364            .map_err(|e| anyhow!("request failed: {e}"))?;
365
366        let status = resp.status();
367        if self.verbose {
368            tracing::info!(
369                "[romm-cli] GET {} -> {} ({}ms)",
370                url,
371                status.as_u16(),
372                t0.elapsed().as_millis()
373            );
374        }
375        if !status.is_success() {
376            let body = resp.text().await.unwrap_or_default();
377            return Err(anyhow!(
378                "HTTP {} {} - {}",
379                status.as_u16(),
380                status.canonical_reason().unwrap_or(""),
381                body.chars().take(500).collect::<String>()
382            ));
383        }
384
385        resp.text()
386            .await
387            .map_err(|e| anyhow!("read OpenAPI body: {e}"))
388    }
389
390    /// Download ROM(s) as a zip file to `save_path`, calling `on_progress(received, total)`.
391    /// Uses GET /api/roms/download?rom_ids={id}&filename=... per RomM OpenAPI.
392    ///
393    /// If `save_path` already exists on disk (e.g. from a previous interrupted
394    /// download), the client sends an HTTP `Range` header to resume from the
395    /// existing byte offset. The server may reply with `206 Partial Content`
396    /// (resume works) or `200 OK` (server doesn't support ranges — restart
397    /// from scratch).
398    pub async fn download_rom<F>(
399        &self,
400        rom_id: u64,
401        save_path: &Path,
402        mut on_progress: F,
403    ) -> Result<()>
404    where
405        F: FnMut(u64, u64) + Send,
406    {
407        let path = "/api/roms/download";
408        let url = format!(
409            "{}/{}",
410            self.base_url.trim_end_matches('/'),
411            path.trim_start_matches('/')
412        );
413        let mut headers = self.build_headers()?;
414
415        let filename = save_path
416            .file_name()
417            .and_then(|n| n.to_str())
418            .unwrap_or("download.zip");
419
420        // Check for an existing partial file to resume from.
421        let existing_len = tokio::fs::metadata(save_path)
422            .await
423            .map(|m| m.len())
424            .unwrap_or(0);
425
426        if existing_len > 0 {
427            let range = format!("bytes={existing_len}-");
428            if let Ok(v) = reqwest::header::HeaderValue::from_str(&range) {
429                headers.insert(reqwest::header::RANGE, v);
430            }
431        }
432
433        let t0 = Instant::now();
434        let mut resp = self
435            .http
436            .get(&url)
437            .headers(headers)
438            .query(&[
439                ("rom_ids", rom_id.to_string()),
440                ("filename", filename.to_string()),
441            ])
442            .send()
443            .await
444            .map_err(|e| anyhow!("download request error: {e}"))?;
445
446        let status = resp.status();
447        if self.verbose {
448            tracing::info!(
449                "[romm-cli] GET /api/roms/download rom_id={} filename={:?} -> {} ({}ms)",
450                rom_id,
451                filename,
452                status.as_u16(),
453                t0.elapsed().as_millis()
454            );
455        }
456        if !status.is_success() {
457            let body = resp.text().await.unwrap_or_default();
458            return Err(anyhow!(
459                "ROMM API error: {} {} - {}",
460                status.as_u16(),
461                status.canonical_reason().unwrap_or(""),
462                body
463            ));
464        }
465
466        // Determine whether the server honoured our Range header.
467        let (mut received, total, mut file) = if status == reqwest::StatusCode::PARTIAL_CONTENT {
468            // 206 — resume: content_length is the *remaining* bytes.
469            let remaining = resp.content_length().unwrap_or(0);
470            let total = existing_len + remaining;
471            let file = tokio::fs::OpenOptions::new()
472                .append(true)
473                .open(save_path)
474                .await
475                .map_err(|e| anyhow!("open file for append {:?}: {e}", save_path))?;
476            (existing_len, total, file)
477        } else {
478            // 200 — server doesn't support ranges; start from scratch.
479            let total = resp.content_length().unwrap_or(0);
480            let file = tokio::fs::File::create(save_path)
481                .await
482                .map_err(|e| anyhow!("create file {:?}: {e}", save_path))?;
483            (0u64, total, file)
484        };
485
486        while let Some(chunk) = resp.chunk().await.map_err(|e| anyhow!("read chunk: {e}"))? {
487            file.write_all(&chunk)
488                .await
489                .map_err(|e| anyhow!("write chunk {:?}: {e}", save_path))?;
490            received += chunk.len() as u64;
491            on_progress(received, total);
492        }
493
494        Ok(())
495    }
496}
497
498#[cfg(test)]
499mod tests {
500    use super::*;
501
502    #[test]
503    fn decode_json_empty_and_whitespace_to_null() {
504        assert_eq!(decode_json_response_body(b""), Value::Null);
505        assert_eq!(decode_json_response_body(b"  \n\t "), Value::Null);
506    }
507
508    #[test]
509    fn decode_json_object_roundtrip() {
510        let v = decode_json_response_body(br#"{"a":1}"#);
511        assert_eq!(v["a"], 1);
512    }
513
514    #[test]
515    fn decode_non_json_wrapped() {
516        let v = decode_json_response_body(b"plain text");
517        assert_eq!(v["_non_json_body"], "plain text");
518    }
519
520    #[test]
521    fn api_root_url_strips_trailing_api() {
522        assert_eq!(
523            super::api_root_url("http://localhost:8080/api"),
524            "http://localhost:8080"
525        );
526        assert_eq!(
527            super::api_root_url("http://localhost:8080/api/"),
528            "http://localhost:8080"
529        );
530        assert_eq!(
531            super::api_root_url("http://localhost:8080"),
532            "http://localhost:8080"
533        );
534    }
535
536    #[test]
537    fn openapi_spec_urls_try_primary_scheme_then_alt() {
538        let urls = super::openapi_spec_urls("http://example.test");
539        assert_eq!(urls[0], "http://example.test/openapi.json");
540        assert_eq!(urls[1], "http://example.test/api/openapi.json");
541        assert!(
542            urls.iter()
543                .any(|u| u == "https://example.test/openapi.json"),
544            "{urls:?}"
545        );
546    }
547}