Skip to main content

romm_cli/
client.rs

1//! HTTP client wrapper around the ROMM API.
2//!
3//! `RommClient` owns a configured `reqwest::Client` plus base URL and
4//! authentication settings. Frontends (CLI, TUI, or a future GUI) depend
5//! on this type instead of talking to `reqwest` directly.
6
7use anyhow::{anyhow, Result};
8use base64::{engine::general_purpose, Engine as _};
9use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION};
10use reqwest::{Client as HttpClient, Method};
11use serde_json::Value;
12use std::path::Path;
13use std::time::Instant;
14use tokio::io::AsyncWriteExt as _;
15
16use crate::config::{normalize_romm_origin, AuthConfig, Config};
17use crate::endpoints::Endpoint;
18
19/// Default `User-Agent` for every request. The stock `reqwest` UA is sometimes blocked at the HTTP
20/// layer (403, etc.) by reverse proxies; override with env `ROMM_USER_AGENT` if needed.
21fn http_user_agent() -> String {
22    match std::env::var("ROMM_USER_AGENT") {
23        Ok(s) if !s.trim().is_empty() => s,
24        _ => format!(
25            "Mozilla/5.0 (compatible; romm-cli/{}; +https://github.com/patricksmill/romm-cli)",
26            env!("CARGO_PKG_VERSION")
27        ),
28    }
29}
30
31/// Map a successful HTTP response body to JSON [`Value`].
32///
33/// Empty or whitespace-only bodies become [`Value::Null`] (e.g. HTTP 204).
34/// Non-JSON UTF-8 bodies are wrapped as `{"_non_json_body": "..."}`.
35fn decode_json_response_body(bytes: &[u8]) -> Value {
36    if bytes.is_empty() || bytes.iter().all(|b| b.is_ascii_whitespace()) {
37        return Value::Null;
38    }
39    serde_json::from_slice(bytes).unwrap_or_else(|_| {
40        serde_json::json!({
41            "_non_json_body": String::from_utf8_lossy(bytes).to_string()
42        })
43    })
44}
45
46fn version_from_heartbeat_json(v: &Value) -> Option<String> {
47    v.get("SYSTEM")?.get("VERSION")?.as_str().map(String::from)
48}
49
50/// High-level HTTP client for the ROMM API.
51///
52/// This type hides the details of `reqwest` and authentication headers
53/// behind a small, easy-to-mock interface that all frontends can share.
54#[derive(Clone)]
55pub struct RommClient {
56    http: HttpClient,
57    base_url: String,
58    auth: Option<AuthConfig>,
59    verbose: bool,
60}
61
62/// Same as [`crate::config::normalize_romm_origin`]: browser-style origin for RomM (no `/api` suffix).
63pub fn api_root_url(base_url: &str) -> String {
64    normalize_romm_origin(base_url)
65}
66
67fn alternate_http_scheme_root(root: &str) -> Option<String> {
68    root.strip_prefix("http://")
69        .map(|rest| format!("https://{}", rest))
70        .or_else(|| {
71            root.strip_prefix("https://")
72                .map(|rest| format!("http://{}", rest))
73        })
74}
75
76/// Origin used to fetch `/openapi.json` (same as the RomM website). Normally equals
77/// [`normalize_romm_origin`] applied to `API_BASE_URL`.
78///
79/// Set `ROMM_OPENAPI_BASE_URL` only when that origin differs (wrong host in `API_BASE_URL`, split
80/// DNS, etc.).
81pub fn resolve_openapi_root(api_base_url: &str) -> String {
82    if let Ok(s) = std::env::var("ROMM_OPENAPI_BASE_URL") {
83        let t = s.trim();
84        if !t.is_empty() {
85            return normalize_romm_origin(t);
86        }
87    }
88    normalize_romm_origin(api_base_url)
89}
90
91/// URLs to try for the OpenAPI JSON document (scheme fallback and alternate paths).
92///
93/// `api_root` is an origin such as `https://example.com` (see [`resolve_openapi_root`]).
94pub fn openapi_spec_urls(api_root: &str) -> Vec<String> {
95    let root = api_root.trim_end_matches('/').to_string();
96    let mut roots = vec![root.clone()];
97    if let Some(alt) = alternate_http_scheme_root(&root) {
98        if alt != root {
99            roots.push(alt);
100        }
101    }
102
103    let mut urls = Vec::new();
104    for r in roots {
105        let b = r.trim_end_matches('/');
106        urls.push(format!("{b}/openapi.json"));
107        urls.push(format!("{b}/api/openapi.json"));
108    }
109    urls
110}
111
112impl RommClient {
113    /// Construct a new client from the high-level [`Config`].
114    ///
115    /// `verbose` enables stderr request logging (method, path, query key names, status, timing).
116    /// This is typically done once in `main` and the resulting `RommClient` is shared
117    /// (by reference or cloning) with the chosen frontend.
118    pub fn new(config: &Config, verbose: bool) -> Result<Self> {
119        let http = HttpClient::builder()
120            .user_agent(http_user_agent())
121            .build()?;
122        Ok(Self {
123            http,
124            base_url: config.base_url.clone(),
125            auth: config.auth.clone(),
126            verbose,
127        })
128    }
129
130    pub fn verbose(&self) -> bool {
131        self.verbose
132    }
133
134    /// Build the HTTP headers for the current authentication mode.
135    ///
136    /// This helper centralises all auth logic so that the rest of the
137    /// code never needs to worry about `Basic` vs `Bearer` vs API key.
138    fn build_headers(&self) -> Result<HeaderMap> {
139        let mut headers = HeaderMap::new();
140
141        if let Some(auth) = &self.auth {
142            match auth {
143                AuthConfig::Basic { username, password } => {
144                    let creds = format!("{username}:{password}");
145                    let encoded = general_purpose::STANDARD.encode(creds.as_bytes());
146                    let value = format!("Basic {encoded}");
147                    headers.insert(
148                        AUTHORIZATION,
149                        HeaderValue::from_str(&value)
150                            .map_err(|_| anyhow!("invalid basic auth header value"))?,
151                    );
152                }
153                AuthConfig::Bearer { token } => {
154                    let value = format!("Bearer {token}");
155                    headers.insert(
156                        AUTHORIZATION,
157                        HeaderValue::from_str(&value)
158                            .map_err(|_| anyhow!("invalid bearer auth header value"))?,
159                    );
160                }
161                AuthConfig::ApiKey { header, key } => {
162                    let name = reqwest::header::HeaderName::from_bytes(header.as_bytes()).map_err(
163                        |_| anyhow!("invalid API_KEY_HEADER, must be a valid HTTP header name"),
164                    )?;
165                    headers.insert(
166                        name,
167                        HeaderValue::from_str(key)
168                            .map_err(|_| anyhow!("invalid API_KEY header value"))?,
169                    );
170                }
171            }
172        }
173
174        Ok(headers)
175    }
176
177    /// Call a typed endpoint using the low-level `request_json` primitive.
178    pub async fn call<E>(&self, ep: &E) -> anyhow::Result<E::Output>
179    where
180        E: Endpoint,
181        E::Output: serde::de::DeserializeOwned,
182    {
183        let method = ep.method();
184        let path = ep.path();
185        let query = ep.query();
186        let body = ep.body();
187
188        let value = self.request_json(method, &path, &query, body).await?;
189        let output = serde_json::from_value(value)
190            .map_err(|e| anyhow!("failed to decode response for {} {}: {}", method, path, e))?;
191
192        Ok(output)
193    }
194
195    /// Low-level helper that issues an HTTP request and returns raw JSON.
196    ///
197    /// Higher-level helpers (such as typed `Endpoint` implementations)
198    /// should prefer [`RommClient::call`] instead of using this directly.
199    pub async fn request_json(
200        &self,
201        method: &str,
202        path: &str,
203        query: &[(String, String)],
204        body: Option<Value>,
205    ) -> Result<Value> {
206        let url = format!(
207            "{}/{}",
208            self.base_url.trim_end_matches('/'),
209            path.trim_start_matches('/')
210        );
211        let headers = self.build_headers()?;
212
213        let http_method = Method::from_bytes(method.as_bytes())
214            .map_err(|_| anyhow!("invalid HTTP method: {method}"))?;
215
216        // Ensure query params serialize as key=value pairs (reqwest/serde_urlencoded
217        // expect sequences of (key, value); using &[(&str, &str)] guarantees correct encoding).
218        let query_refs: Vec<(&str, &str)> = query
219            .iter()
220            .map(|(k, v)| (k.as_str(), v.as_str()))
221            .collect();
222
223        let mut req = self
224            .http
225            .request(http_method, &url)
226            .headers(headers)
227            .query(&query_refs);
228
229        if let Some(body) = body {
230            req = req.json(&body);
231        }
232
233        let t0 = Instant::now();
234        let resp = req
235            .send()
236            .await
237            .map_err(|e| anyhow!("request error: {e}"))?;
238
239        let status = resp.status();
240        if self.verbose {
241            let keys: Vec<&str> = query.iter().map(|(k, _)| k.as_str()).collect();
242            tracing::info!(
243                "[romm-cli] {} {} query_keys={:?} -> {} ({}ms)",
244                method,
245                path,
246                keys,
247                status.as_u16(),
248                t0.elapsed().as_millis()
249            );
250        }
251        if !status.is_success() {
252            let body = resp.text().await.unwrap_or_default();
253            return Err(anyhow!(
254                "ROMM API error: {} {} - {}",
255                status.as_u16(),
256                status.canonical_reason().unwrap_or(""),
257                body
258            ));
259        }
260
261        let bytes = resp
262            .bytes()
263            .await
264            .map_err(|e| anyhow!("read response body: {e}"))?;
265
266        Ok(decode_json_response_body(&bytes))
267    }
268
269    pub async fn request_json_unauthenticated(
270        &self,
271        method: &str,
272        path: &str,
273        query: &[(String, String)],
274        body: Option<Value>,
275    ) -> Result<Value> {
276        let url = format!(
277            "{}/{}",
278            self.base_url.trim_end_matches('/'),
279            path.trim_start_matches('/')
280        );
281        let headers = HeaderMap::new();
282
283        let http_method = Method::from_bytes(method.as_bytes())
284            .map_err(|_| anyhow!("invalid HTTP method: {method}"))?;
285
286        // Ensure query params serialize as key=value pairs (reqwest/serde_urlencoded
287        // expect sequences of (key, value); using &[(&str, &str)] guarantees correct encoding).
288        let query_refs: Vec<(&str, &str)> = query
289            .iter()
290            .map(|(k, v)| (k.as_str(), v.as_str()))
291            .collect();
292
293        let mut req = self
294            .http
295            .request(http_method, &url)
296            .headers(headers)
297            .query(&query_refs);
298
299        if let Some(body) = body {
300            req = req.json(&body);
301        }
302
303        let t0 = Instant::now();
304        let resp = req
305            .send()
306            .await
307            .map_err(|e| anyhow!("request error: {e}"))?;
308
309        let status = resp.status();
310        if self.verbose {
311            let keys: Vec<&str> = query.iter().map(|(k, _)| k.as_str()).collect();
312            tracing::info!(
313                "[romm-cli] {} {} query_keys={:?} -> {} ({}ms)",
314                method,
315                path,
316                keys,
317                status.as_u16(),
318                t0.elapsed().as_millis()
319            );
320        }
321        if !status.is_success() {
322            let body = resp.text().await.unwrap_or_default();
323            return Err(anyhow!(
324                "ROMM API error: {} {} - {}",
325                status.as_u16(),
326                status.canonical_reason().unwrap_or(""),
327                body
328            ));
329        }
330
331        let bytes = resp
332            .bytes()
333            .await
334            .map_err(|e| anyhow!("read response body: {e}"))?;
335
336        Ok(decode_json_response_body(&bytes))
337    }
338
339    /// RomM application version from `GET /api/heartbeat` (`SYSTEM.VERSION`), if the endpoint succeeds.
340    pub async fn rom_server_version_from_heartbeat(&self) -> Option<String> {
341        let v = self
342            .request_json_unauthenticated("GET", "/api/heartbeat", &[], None)
343            .await
344            .ok()?;
345        version_from_heartbeat_json(&v)
346    }
347
348    /// GET the OpenAPI spec from the server. Tries [`openapi_spec_urls`] in order (HTTP/HTTPS and
349    /// `/openapi.json` vs `/api/openapi.json`). Uses [`resolve_openapi_root`] for the origin.
350    pub async fn fetch_openapi_json(&self) -> Result<String> {
351        let root = resolve_openapi_root(&self.base_url);
352        let urls = openapi_spec_urls(&root);
353        let mut failures = Vec::new();
354        for url in &urls {
355            match self.fetch_openapi_json_once(url).await {
356                Ok(body) => return Ok(body),
357                Err(e) => failures.push(format!("{url}: {e:#}")),
358            }
359        }
360        Err(anyhow!(
361            "could not download OpenAPI ({} attempt(s)): {}",
362            failures.len(),
363            failures.join(" | ")
364        ))
365    }
366
367    async fn fetch_openapi_json_once(&self, url: &str) -> Result<String> {
368        let headers = self.build_headers()?;
369
370        let t0 = Instant::now();
371        let resp = self
372            .http
373            .get(url)
374            .headers(headers)
375            .send()
376            .await
377            .map_err(|e| anyhow!("request failed: {e}"))?;
378
379        let status = resp.status();
380        if self.verbose {
381            tracing::info!(
382                "[romm-cli] GET {} -> {} ({}ms)",
383                url,
384                status.as_u16(),
385                t0.elapsed().as_millis()
386            );
387        }
388        if !status.is_success() {
389            let body = resp.text().await.unwrap_or_default();
390            return Err(anyhow!(
391                "HTTP {} {} - {}",
392                status.as_u16(),
393                status.canonical_reason().unwrap_or(""),
394                body.chars().take(500).collect::<String>()
395            ));
396        }
397
398        resp.text()
399            .await
400            .map_err(|e| anyhow!("read OpenAPI body: {e}"))
401    }
402
403    /// Download ROM(s) as a zip file to `save_path`, calling `on_progress(received, total)`.
404    /// Uses GET /api/roms/download?rom_ids={id}&filename=... per RomM OpenAPI.
405    ///
406    /// If `save_path` already exists on disk (e.g. from a previous interrupted
407    /// download), the client sends an HTTP `Range` header to resume from the
408    /// existing byte offset. The server may reply with `206 Partial Content`
409    /// (resume works) or `200 OK` (server doesn't support ranges — restart
410    /// from scratch).
411    pub async fn download_rom<F>(
412        &self,
413        rom_id: u64,
414        save_path: &Path,
415        mut on_progress: F,
416    ) -> Result<()>
417    where
418        F: FnMut(u64, u64) + Send,
419    {
420        let path = "/api/roms/download";
421        let url = format!(
422            "{}/{}",
423            self.base_url.trim_end_matches('/'),
424            path.trim_start_matches('/')
425        );
426        let mut headers = self.build_headers()?;
427
428        let filename = save_path
429            .file_name()
430            .and_then(|n| n.to_str())
431            .unwrap_or("download.zip");
432
433        // Check for an existing partial file to resume from.
434        let existing_len = tokio::fs::metadata(save_path)
435            .await
436            .map(|m| m.len())
437            .unwrap_or(0);
438
439        if existing_len > 0 {
440            let range = format!("bytes={existing_len}-");
441            if let Ok(v) = reqwest::header::HeaderValue::from_str(&range) {
442                headers.insert(reqwest::header::RANGE, v);
443            }
444        }
445
446        let t0 = Instant::now();
447        let mut resp = self
448            .http
449            .get(&url)
450            .headers(headers)
451            .query(&[
452                ("rom_ids", rom_id.to_string()),
453                ("filename", filename.to_string()),
454            ])
455            .send()
456            .await
457            .map_err(|e| anyhow!("download request error: {e}"))?;
458
459        let status = resp.status();
460        if self.verbose {
461            tracing::info!(
462                "[romm-cli] GET /api/roms/download rom_id={} filename={:?} -> {} ({}ms)",
463                rom_id,
464                filename,
465                status.as_u16(),
466                t0.elapsed().as_millis()
467            );
468        }
469        if !status.is_success() {
470            let body = resp.text().await.unwrap_or_default();
471            return Err(anyhow!(
472                "ROMM API error: {} {} - {}",
473                status.as_u16(),
474                status.canonical_reason().unwrap_or(""),
475                body
476            ));
477        }
478
479        // Determine whether the server honoured our Range header.
480        let (mut received, total, mut file) = if status == reqwest::StatusCode::PARTIAL_CONTENT {
481            // 206 — resume: content_length is the *remaining* bytes.
482            let remaining = resp.content_length().unwrap_or(0);
483            let total = existing_len + remaining;
484            let file = tokio::fs::OpenOptions::new()
485                .append(true)
486                .open(save_path)
487                .await
488                .map_err(|e| anyhow!("open file for append {:?}: {e}", save_path))?;
489            (existing_len, total, file)
490        } else {
491            // 200 — server doesn't support ranges; start from scratch.
492            let total = resp.content_length().unwrap_or(0);
493            let file = tokio::fs::File::create(save_path)
494                .await
495                .map_err(|e| anyhow!("create file {:?}: {e}", save_path))?;
496            (0u64, total, file)
497        };
498
499        while let Some(chunk) = resp.chunk().await.map_err(|e| anyhow!("read chunk: {e}"))? {
500            file.write_all(&chunk)
501                .await
502                .map_err(|e| anyhow!("write chunk {:?}: {e}", save_path))?;
503            received += chunk.len() as u64;
504            on_progress(received, total);
505        }
506
507        Ok(())
508    }
509}
510
511#[cfg(test)]
512mod tests {
513    use super::*;
514
515    #[test]
516    fn decode_json_empty_and_whitespace_to_null() {
517        assert_eq!(decode_json_response_body(b""), Value::Null);
518        assert_eq!(decode_json_response_body(b"  \n\t "), Value::Null);
519    }
520
521    #[test]
522    fn decode_json_object_roundtrip() {
523        let v = decode_json_response_body(br#"{"a":1}"#);
524        assert_eq!(v["a"], 1);
525    }
526
527    #[test]
528    fn decode_non_json_wrapped() {
529        let v = decode_json_response_body(b"plain text");
530        assert_eq!(v["_non_json_body"], "plain text");
531    }
532
533    #[test]
534    fn api_root_url_strips_trailing_api() {
535        assert_eq!(
536            super::api_root_url("http://localhost:8080/api"),
537            "http://localhost:8080"
538        );
539        assert_eq!(
540            super::api_root_url("http://localhost:8080/api/"),
541            "http://localhost:8080"
542        );
543        assert_eq!(
544            super::api_root_url("http://localhost:8080"),
545            "http://localhost:8080"
546        );
547    }
548
549    #[test]
550    fn openapi_spec_urls_try_primary_scheme_then_alt() {
551        let urls = super::openapi_spec_urls("http://example.test");
552        assert_eq!(urls[0], "http://example.test/openapi.json");
553        assert_eq!(urls[1], "http://example.test/api/openapi.json");
554        assert!(
555            urls.iter()
556                .any(|u| u == "https://example.test/openapi.json"),
557            "{urls:?}"
558        );
559    }
560}