Skip to main content

romm_cli/
client.rs

1//! HTTP client wrapper around the ROMM API.
2//!
3//! `RommClient` owns a configured `reqwest::Client` plus base URL and
4//! authentication settings. Frontends (CLI, TUI, or a future GUI) depend
5//! on this type instead of talking to `reqwest` directly.
6
7use anyhow::{anyhow, Result};
8use base64::{engine::general_purpose, Engine as _};
9use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION};
10use reqwest::{Client as HttpClient, Method};
11use serde_json::Value;
12use std::path::Path;
13use std::time::Instant;
14use tokio::io::AsyncWriteExt as _;
15
16use crate::config::{normalize_romm_origin, AuthConfig, Config};
17use crate::endpoints::Endpoint;
18
19/// Default `User-Agent` for every request. The stock `reqwest` UA is sometimes blocked at the HTTP
20/// layer (403, etc.) by reverse proxies; override with env `ROMM_USER_AGENT` if needed.
21fn http_user_agent() -> String {
22    match std::env::var("ROMM_USER_AGENT") {
23        Ok(s) if !s.trim().is_empty() => s,
24        _ => format!(
25            "Mozilla/5.0 (compatible; romm-cli/{}; +https://github.com/patricksmill/romm-cli)",
26            env!("CARGO_PKG_VERSION")
27        ),
28    }
29}
30
31/// Map a successful HTTP response body to JSON [`Value`].
32///
33/// Empty or whitespace-only bodies become [`Value::Null`] (e.g. HTTP 204).
34/// Non-JSON UTF-8 bodies are wrapped as `{"_non_json_body": "..."}`.
35fn decode_json_response_body(bytes: &[u8]) -> Value {
36    if bytes.is_empty() || bytes.iter().all(|b| b.is_ascii_whitespace()) {
37        return Value::Null;
38    }
39    serde_json::from_slice(bytes).unwrap_or_else(|_| {
40        serde_json::json!({
41            "_non_json_body": String::from_utf8_lossy(bytes).to_string()
42        })
43    })
44}
45
46/// High-level HTTP client for the ROMM API.
47///
48/// This type hides the details of `reqwest` and authentication headers
49/// behind a small, easy-to-mock interface that all frontends can share.
50#[derive(Clone)]
51pub struct RommClient {
52    http: HttpClient,
53    base_url: String,
54    auth: Option<AuthConfig>,
55    verbose: bool,
56}
57
58/// Same as [`crate::config::normalize_romm_origin`]: browser-style origin for RomM (no `/api` suffix).
59pub fn api_root_url(base_url: &str) -> String {
60    normalize_romm_origin(base_url)
61}
62
63fn alternate_http_scheme_root(root: &str) -> Option<String> {
64    root.strip_prefix("http://")
65        .map(|rest| format!("https://{}", rest))
66        .or_else(|| {
67            root.strip_prefix("https://")
68                .map(|rest| format!("http://{}", rest))
69        })
70}
71
72/// Origin used to fetch `/openapi.json` (same as the RomM website). Normally equals
73/// [`normalize_romm_origin`] applied to `API_BASE_URL`.
74///
75/// Set `ROMM_OPENAPI_BASE_URL` only when that origin differs (wrong host in `API_BASE_URL`, split
76/// DNS, etc.).
77pub fn resolve_openapi_root(api_base_url: &str) -> String {
78    if let Ok(s) = std::env::var("ROMM_OPENAPI_BASE_URL") {
79        let t = s.trim();
80        if !t.is_empty() {
81            return normalize_romm_origin(t);
82        }
83    }
84    normalize_romm_origin(api_base_url)
85}
86
87/// URLs to try for the OpenAPI JSON document (scheme fallback and alternate paths).
88///
89/// `api_root` is an origin such as `https://example.com` (see [`resolve_openapi_root`]).
90pub fn openapi_spec_urls(api_root: &str) -> Vec<String> {
91    let root = api_root.trim_end_matches('/').to_string();
92    let mut roots = vec![root.clone()];
93    if let Some(alt) = alternate_http_scheme_root(&root) {
94        if alt != root {
95            roots.push(alt);
96        }
97    }
98
99    let mut urls = Vec::new();
100    for r in roots {
101        let b = r.trim_end_matches('/');
102        urls.push(format!("{b}/openapi.json"));
103        urls.push(format!("{b}/api/openapi.json"));
104    }
105    urls
106}
107
108impl RommClient {
109    /// Construct a new client from the high-level [`Config`].
110    ///
111    /// `verbose` enables stderr request logging (method, path, query key names, status, timing).
112    /// This is typically done once in `main` and the resulting `RommClient` is shared
113    /// (by reference or cloning) with the chosen frontend.
114    pub fn new(config: &Config, verbose: bool) -> Result<Self> {
115        let http = HttpClient::builder()
116            .user_agent(http_user_agent())
117            .build()?;
118        Ok(Self {
119            http,
120            base_url: config.base_url.clone(),
121            auth: config.auth.clone(),
122            verbose,
123        })
124    }
125
126    /// Build the HTTP headers for the current authentication mode.
127    ///
128    /// This helper centralises all auth logic so that the rest of the
129    /// code never needs to worry about `Basic` vs `Bearer` vs API key.
130    fn build_headers(&self) -> Result<HeaderMap> {
131        let mut headers = HeaderMap::new();
132
133        if let Some(auth) = &self.auth {
134            match auth {
135                AuthConfig::Basic { username, password } => {
136                    let creds = format!("{username}:{password}");
137                    let encoded = general_purpose::STANDARD.encode(creds.as_bytes());
138                    let value = format!("Basic {encoded}");
139                    headers.insert(
140                        AUTHORIZATION,
141                        HeaderValue::from_str(&value)
142                            .map_err(|_| anyhow!("invalid basic auth header value"))?,
143                    );
144                }
145                AuthConfig::Bearer { token } => {
146                    let value = format!("Bearer {token}");
147                    headers.insert(
148                        AUTHORIZATION,
149                        HeaderValue::from_str(&value)
150                            .map_err(|_| anyhow!("invalid bearer auth header value"))?,
151                    );
152                }
153                AuthConfig::ApiKey { header, key } => {
154                    let name = reqwest::header::HeaderName::from_bytes(header.as_bytes()).map_err(
155                        |_| anyhow!("invalid API_KEY_HEADER, must be a valid HTTP header name"),
156                    )?;
157                    headers.insert(
158                        name,
159                        HeaderValue::from_str(key)
160                            .map_err(|_| anyhow!("invalid API_KEY header value"))?,
161                    );
162                }
163            }
164        }
165
166        Ok(headers)
167    }
168
169    /// Call a typed endpoint using the low-level `request_json` primitive.
170    pub async fn call<E>(&self, ep: &E) -> anyhow::Result<E::Output>
171    where
172        E: Endpoint,
173        E::Output: serde::de::DeserializeOwned,
174    {
175        let method = ep.method();
176        let path = ep.path();
177        let query = ep.query();
178        let body = ep.body();
179
180        let value = self.request_json(method, &path, &query, body).await?;
181        let output = serde_json::from_value(value)
182            .map_err(|e| anyhow!("failed to decode response for {} {}: {}", method, path, e))?;
183
184        Ok(output)
185    }
186
187    /// Low-level helper that issues an HTTP request and returns raw JSON.
188    ///
189    /// Higher-level helpers (such as typed `Endpoint` implementations)
190    /// should prefer [`RommClient::call`] instead of using this directly.
191    pub async fn request_json(
192        &self,
193        method: &str,
194        path: &str,
195        query: &[(String, String)],
196        body: Option<Value>,
197    ) -> Result<Value> {
198        let url = format!(
199            "{}/{}",
200            self.base_url.trim_end_matches('/'),
201            path.trim_start_matches('/')
202        );
203        let headers = self.build_headers()?;
204
205        let http_method = Method::from_bytes(method.as_bytes())
206            .map_err(|_| anyhow!("invalid HTTP method: {method}"))?;
207
208        // Ensure query params serialize as key=value pairs (reqwest/serde_urlencoded
209        // expect sequences of (key, value); using &[(&str, &str)] guarantees correct encoding).
210        let query_refs: Vec<(&str, &str)> = query
211            .iter()
212            .map(|(k, v)| (k.as_str(), v.as_str()))
213            .collect();
214
215        let mut req = self
216            .http
217            .request(http_method, &url)
218            .headers(headers)
219            .query(&query_refs);
220
221        if let Some(body) = body {
222            req = req.json(&body);
223        }
224
225        let t0 = Instant::now();
226        let resp = req
227            .send()
228            .await
229            .map_err(|e| anyhow!("request error: {e}"))?;
230
231        let status = resp.status();
232        if self.verbose {
233            let keys: Vec<&str> = query.iter().map(|(k, _)| k.as_str()).collect();
234            tracing::info!(
235                "[romm-cli] {} {} query_keys={:?} -> {} ({}ms)",
236                method,
237                path,
238                keys,
239                status.as_u16(),
240                t0.elapsed().as_millis()
241            );
242        }
243        if !status.is_success() {
244            let body = resp.text().await.unwrap_or_default();
245            return Err(anyhow!(
246                "ROMM API error: {} {} - {}",
247                status.as_u16(),
248                status.canonical_reason().unwrap_or(""),
249                body
250            ));
251        }
252
253        let bytes = resp
254            .bytes()
255            .await
256            .map_err(|e| anyhow!("read response body: {e}"))?;
257
258        Ok(decode_json_response_body(&bytes))
259    }
260
261    /// GET the OpenAPI spec from the server. Tries [`openapi_spec_urls`] in order (HTTP/HTTPS and
262    /// `/openapi.json` vs `/api/openapi.json`). Uses [`resolve_openapi_root`] for the origin.
263    pub async fn fetch_openapi_json(&self) -> Result<String> {
264        let root = resolve_openapi_root(&self.base_url);
265        let urls = openapi_spec_urls(&root);
266        let mut failures = Vec::new();
267        for url in &urls {
268            match self.fetch_openapi_json_once(url).await {
269                Ok(body) => return Ok(body),
270                Err(e) => failures.push(format!("{url}: {e:#}")),
271            }
272        }
273        Err(anyhow!(
274            "could not download OpenAPI ({} attempt(s)): {}",
275            failures.len(),
276            failures.join(" | ")
277        ))
278    }
279
280    async fn fetch_openapi_json_once(&self, url: &str) -> Result<String> {
281        let headers = self.build_headers()?;
282
283        let t0 = Instant::now();
284        let resp = self
285            .http
286            .get(url)
287            .headers(headers)
288            .send()
289            .await
290            .map_err(|e| anyhow!("request failed: {e}"))?;
291
292        let status = resp.status();
293        if self.verbose {
294            tracing::info!(
295                "[romm-cli] GET {} -> {} ({}ms)",
296                url,
297                status.as_u16(),
298                t0.elapsed().as_millis()
299            );
300        }
301        if !status.is_success() {
302            let body = resp.text().await.unwrap_or_default();
303            return Err(anyhow!(
304                "HTTP {} {} - {}",
305                status.as_u16(),
306                status.canonical_reason().unwrap_or(""),
307                body.chars().take(500).collect::<String>()
308            ));
309        }
310
311        resp.text()
312            .await
313            .map_err(|e| anyhow!("read OpenAPI body: {e}"))
314    }
315
316    /// Download ROM(s) as a zip file to `save_path`, calling `on_progress(received, total)`.
317    /// Uses GET /api/roms/download?rom_ids={id}&filename=... per RomM OpenAPI.
318    ///
319    /// If `save_path` already exists on disk (e.g. from a previous interrupted
320    /// download), the client sends an HTTP `Range` header to resume from the
321    /// existing byte offset. The server may reply with `206 Partial Content`
322    /// (resume works) or `200 OK` (server doesn't support ranges — restart
323    /// from scratch).
324    pub async fn download_rom<F>(
325        &self,
326        rom_id: u64,
327        save_path: &Path,
328        mut on_progress: F,
329    ) -> Result<()>
330    where
331        F: FnMut(u64, u64) + Send,
332    {
333        let path = "/api/roms/download";
334        let url = format!(
335            "{}/{}",
336            self.base_url.trim_end_matches('/'),
337            path.trim_start_matches('/')
338        );
339        let mut headers = self.build_headers()?;
340
341        let filename = save_path
342            .file_name()
343            .and_then(|n| n.to_str())
344            .unwrap_or("download.zip");
345
346        // Check for an existing partial file to resume from.
347        let existing_len = tokio::fs::metadata(save_path)
348            .await
349            .map(|m| m.len())
350            .unwrap_or(0);
351
352        if existing_len > 0 {
353            let range = format!("bytes={existing_len}-");
354            if let Ok(v) = reqwest::header::HeaderValue::from_str(&range) {
355                headers.insert(reqwest::header::RANGE, v);
356            }
357        }
358
359        let t0 = Instant::now();
360        let mut resp = self
361            .http
362            .get(&url)
363            .headers(headers)
364            .query(&[
365                ("rom_ids", rom_id.to_string()),
366                ("filename", filename.to_string()),
367            ])
368            .send()
369            .await
370            .map_err(|e| anyhow!("download request error: {e}"))?;
371
372        let status = resp.status();
373        if self.verbose {
374            tracing::info!(
375                "[romm-cli] GET /api/roms/download rom_id={} filename={:?} -> {} ({}ms)",
376                rom_id,
377                filename,
378                status.as_u16(),
379                t0.elapsed().as_millis()
380            );
381        }
382        if !status.is_success() {
383            let body = resp.text().await.unwrap_or_default();
384            return Err(anyhow!(
385                "ROMM API error: {} {} - {}",
386                status.as_u16(),
387                status.canonical_reason().unwrap_or(""),
388                body
389            ));
390        }
391
392        // Determine whether the server honoured our Range header.
393        let (mut received, total, mut file) = if status == reqwest::StatusCode::PARTIAL_CONTENT {
394            // 206 — resume: content_length is the *remaining* bytes.
395            let remaining = resp.content_length().unwrap_or(0);
396            let total = existing_len + remaining;
397            let file = tokio::fs::OpenOptions::new()
398                .append(true)
399                .open(save_path)
400                .await
401                .map_err(|e| anyhow!("open file for append {:?}: {e}", save_path))?;
402            (existing_len, total, file)
403        } else {
404            // 200 — server doesn't support ranges; start from scratch.
405            let total = resp.content_length().unwrap_or(0);
406            let file = tokio::fs::File::create(save_path)
407                .await
408                .map_err(|e| anyhow!("create file {:?}: {e}", save_path))?;
409            (0u64, total, file)
410        };
411
412        while let Some(chunk) = resp.chunk().await.map_err(|e| anyhow!("read chunk: {e}"))? {
413            file.write_all(&chunk)
414                .await
415                .map_err(|e| anyhow!("write chunk {:?}: {e}", save_path))?;
416            received += chunk.len() as u64;
417            on_progress(received, total);
418        }
419
420        Ok(())
421    }
422}
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427
428    #[test]
429    fn decode_json_empty_and_whitespace_to_null() {
430        assert_eq!(decode_json_response_body(b""), Value::Null);
431        assert_eq!(decode_json_response_body(b"  \n\t "), Value::Null);
432    }
433
434    #[test]
435    fn decode_json_object_roundtrip() {
436        let v = decode_json_response_body(br#"{"a":1}"#);
437        assert_eq!(v["a"], 1);
438    }
439
440    #[test]
441    fn decode_non_json_wrapped() {
442        let v = decode_json_response_body(b"plain text");
443        assert_eq!(v["_non_json_body"], "plain text");
444    }
445
446    #[test]
447    fn api_root_url_strips_trailing_api() {
448        assert_eq!(
449            super::api_root_url("http://localhost:8080/api"),
450            "http://localhost:8080"
451        );
452        assert_eq!(
453            super::api_root_url("http://localhost:8080/api/"),
454            "http://localhost:8080"
455        );
456        assert_eq!(
457            super::api_root_url("http://localhost:8080"),
458            "http://localhost:8080"
459        );
460    }
461
462    #[test]
463    fn openapi_spec_urls_try_primary_scheme_then_alt() {
464        let urls = super::openapi_spec_urls("http://example.test");
465        assert_eq!(urls[0], "http://example.test/openapi.json");
466        assert_eq!(urls[1], "http://example.test/api/openapi.json");
467        assert!(
468            urls.iter()
469                .any(|u| u == "https://example.test/openapi.json"),
470            "{urls:?}"
471        );
472    }
473}