Skip to main content

romm_cli/
client.rs

1//! HTTP client wrapper around the ROMM API.
2//!
3//! `RommClient` owns a configured `reqwest::Client` plus base URL and
4//! authentication settings. Frontends (CLI, TUI, or a future GUI) depend
5//! on this type instead of talking to `reqwest` directly.
6
7use anyhow::{anyhow, Result};
8use base64::{engine::general_purpose, Engine as _};
9use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION};
10use reqwest::{Client as HttpClient, Method};
11use serde_json::Value;
12use std::path::Path;
13use std::time::Instant;
14use tokio::io::AsyncWriteExt as _;
15
16use crate::config::{AuthConfig, Config};
17use crate::endpoints::Endpoint;
18
19/// Map a successful HTTP response body to JSON [`Value`].
20///
21/// Empty or whitespace-only bodies become [`Value::Null`] (e.g. HTTP 204).
22/// Non-JSON UTF-8 bodies are wrapped as `{"_non_json_body": "..."}`.
23fn decode_json_response_body(bytes: &[u8]) -> Value {
24    if bytes.is_empty() || bytes.iter().all(|b| b.is_ascii_whitespace()) {
25        return Value::Null;
26    }
27    serde_json::from_slice(bytes).unwrap_or_else(|_| {
28        serde_json::json!({
29            "_non_json_body": String::from_utf8_lossy(bytes).to_string()
30        })
31    })
32}
33
34/// High-level HTTP client for the ROMM API.
35///
36/// This type hides the details of `reqwest` and authentication headers
37/// behind a small, easy-to-mock interface that all frontends can share.
38#[derive(Clone)]
39pub struct RommClient {
40    http: HttpClient,
41    base_url: String,
42    auth: Option<AuthConfig>,
43    verbose: bool,
44}
45
46impl RommClient {
47    /// Construct a new client from the high-level [`Config`].
48    ///
49    /// `verbose` enables stderr request logging (method, path, query key names, status, timing).
50    /// This is typically done once in `main` and the resulting `RommClient` is shared
51    /// (by reference or cloning) with the chosen frontend.
52    pub fn new(config: &Config, verbose: bool) -> Result<Self> {
53        let http = HttpClient::builder().build()?;
54        Ok(Self {
55            http,
56            base_url: config.base_url.clone(),
57            auth: config.auth.clone(),
58            verbose,
59        })
60    }
61
62    /// Build the HTTP headers for the current authentication mode.
63    ///
64    /// This helper centralises all auth logic so that the rest of the
65    /// code never needs to worry about `Basic` vs `Bearer` vs API key.
66    fn build_headers(&self) -> Result<HeaderMap> {
67        let mut headers = HeaderMap::new();
68
69        if let Some(auth) = &self.auth {
70            match auth {
71                AuthConfig::Basic { username, password } => {
72                    let creds = format!("{username}:{password}");
73                    let encoded = general_purpose::STANDARD.encode(creds.as_bytes());
74                    let value = format!("Basic {encoded}");
75                    headers.insert(
76                        AUTHORIZATION,
77                        HeaderValue::from_str(&value)
78                            .map_err(|_| anyhow!("invalid basic auth header value"))?,
79                    );
80                }
81                AuthConfig::Bearer { token } => {
82                    let value = format!("Bearer {token}");
83                    headers.insert(
84                        AUTHORIZATION,
85                        HeaderValue::from_str(&value)
86                            .map_err(|_| anyhow!("invalid bearer auth header value"))?,
87                    );
88                }
89                AuthConfig::ApiKey { header, key } => {
90                    let name = reqwest::header::HeaderName::from_bytes(header.as_bytes()).map_err(
91                        |_| anyhow!("invalid API_KEY_HEADER, must be a valid HTTP header name"),
92                    )?;
93                    headers.insert(
94                        name,
95                        HeaderValue::from_str(key)
96                            .map_err(|_| anyhow!("invalid API_KEY header value"))?,
97                    );
98                }
99            }
100        }
101
102        Ok(headers)
103    }
104
105    /// Call a typed endpoint using the low-level `request_json` primitive.
106    pub async fn call<E>(&self, ep: &E) -> anyhow::Result<E::Output>
107    where
108        E: Endpoint,
109        E::Output: serde::de::DeserializeOwned,
110    {
111        let method = ep.method();
112        let path = ep.path();
113        let query = ep.query();
114        let body = ep.body();
115
116        let value = self.request_json(method, &path, &query, body).await?;
117        let output = serde_json::from_value(value)
118            .map_err(|e| anyhow!("failed to decode response for {} {}: {}", method, path, e))?;
119
120        Ok(output)
121    }
122
123    /// Low-level helper that issues an HTTP request and returns raw JSON.
124    ///
125    /// Higher-level helpers (such as typed `Endpoint` implementations)
126    /// should prefer [`RommClient::call`] instead of using this directly.
127    pub async fn request_json(
128        &self,
129        method: &str,
130        path: &str,
131        query: &[(String, String)],
132        body: Option<Value>,
133    ) -> Result<Value> {
134        let url = format!(
135            "{}/{}",
136            self.base_url.trim_end_matches('/'),
137            path.trim_start_matches('/')
138        );
139        let headers = self.build_headers()?;
140
141        let http_method = Method::from_bytes(method.as_bytes())
142            .map_err(|_| anyhow!("invalid HTTP method: {method}"))?;
143
144        // Ensure query params serialize as key=value pairs (reqwest/serde_urlencoded
145        // expect sequences of (key, value); using &[(&str, &str)] guarantees correct encoding).
146        let query_refs: Vec<(&str, &str)> = query
147            .iter()
148            .map(|(k, v)| (k.as_str(), v.as_str()))
149            .collect();
150
151        let mut req = self
152            .http
153            .request(http_method, &url)
154            .headers(headers)
155            .query(&query_refs);
156
157        if let Some(body) = body {
158            req = req.json(&body);
159        }
160
161        let t0 = Instant::now();
162        let resp = req
163            .send()
164            .await
165            .map_err(|e| anyhow!("request error: {e}"))?;
166
167        let status = resp.status();
168        if self.verbose {
169            let keys: Vec<&str> = query.iter().map(|(k, _)| k.as_str()).collect();
170            tracing::info!(
171                "[romm-cli] {} {} query_keys={:?} -> {} ({}ms)",
172                method,
173                path,
174                keys,
175                status.as_u16(),
176                t0.elapsed().as_millis()
177            );
178        }
179        if !status.is_success() {
180            let body = resp.text().await.unwrap_or_default();
181            return Err(anyhow!(
182                "ROMM API error: {} {} - {}",
183                status.as_u16(),
184                status.canonical_reason().unwrap_or(""),
185                body
186            ));
187        }
188
189        let bytes = resp
190            .bytes()
191            .await
192            .map_err(|e| anyhow!("read response body: {e}"))?;
193
194        Ok(decode_json_response_body(&bytes))
195    }
196
197    /// Download ROM(s) as a zip file to `save_path`, calling `on_progress(received, total)`.
198    /// Uses GET /api/roms/download?rom_ids={id}&filename=... per RomM OpenAPI.
199    ///
200    /// If `save_path` already exists on disk (e.g. from a previous interrupted
201    /// download), the client sends an HTTP `Range` header to resume from the
202    /// existing byte offset. The server may reply with `206 Partial Content`
203    /// (resume works) or `200 OK` (server doesn't support ranges — restart
204    /// from scratch).
205    pub async fn download_rom<F>(
206        &self,
207        rom_id: u64,
208        save_path: &Path,
209        mut on_progress: F,
210    ) -> Result<()>
211    where
212        F: FnMut(u64, u64) + Send,
213    {
214        let path = "/api/roms/download";
215        let url = format!(
216            "{}/{}",
217            self.base_url.trim_end_matches('/'),
218            path.trim_start_matches('/')
219        );
220        let mut headers = self.build_headers()?;
221
222        let filename = save_path
223            .file_name()
224            .and_then(|n| n.to_str())
225            .unwrap_or("download.zip");
226
227        // Check for an existing partial file to resume from.
228        let existing_len = tokio::fs::metadata(save_path)
229            .await
230            .map(|m| m.len())
231            .unwrap_or(0);
232
233        if existing_len > 0 {
234            let range = format!("bytes={existing_len}-");
235            if let Ok(v) = reqwest::header::HeaderValue::from_str(&range) {
236                headers.insert(reqwest::header::RANGE, v);
237            }
238        }
239
240        let t0 = Instant::now();
241        let mut resp = self
242            .http
243            .get(&url)
244            .headers(headers)
245            .query(&[
246                ("rom_ids", rom_id.to_string()),
247                ("filename", filename.to_string()),
248            ])
249            .send()
250            .await
251            .map_err(|e| anyhow!("download request error: {e}"))?;
252
253        let status = resp.status();
254        if self.verbose {
255            tracing::info!(
256                "[romm-cli] GET /api/roms/download rom_id={} filename={:?} -> {} ({}ms)",
257                rom_id,
258                filename,
259                status.as_u16(),
260                t0.elapsed().as_millis()
261            );
262        }
263        if !status.is_success() {
264            let body = resp.text().await.unwrap_or_default();
265            return Err(anyhow!(
266                "ROMM API error: {} {} - {}",
267                status.as_u16(),
268                status.canonical_reason().unwrap_or(""),
269                body
270            ));
271        }
272
273        // Determine whether the server honoured our Range header.
274        let (mut received, total, mut file) = if status == reqwest::StatusCode::PARTIAL_CONTENT {
275            // 206 — resume: content_length is the *remaining* bytes.
276            let remaining = resp.content_length().unwrap_or(0);
277            let total = existing_len + remaining;
278            let file = tokio::fs::OpenOptions::new()
279                .append(true)
280                .open(save_path)
281                .await
282                .map_err(|e| anyhow!("open file for append {:?}: {e}", save_path))?;
283            (existing_len, total, file)
284        } else {
285            // 200 — server doesn't support ranges; start from scratch.
286            let total = resp.content_length().unwrap_or(0);
287            let file = tokio::fs::File::create(save_path)
288                .await
289                .map_err(|e| anyhow!("create file {:?}: {e}", save_path))?;
290            (0u64, total, file)
291        };
292
293        while let Some(chunk) = resp.chunk().await.map_err(|e| anyhow!("read chunk: {e}"))? {
294            file.write_all(&chunk)
295                .await
296                .map_err(|e| anyhow!("write chunk {:?}: {e}", save_path))?;
297            received += chunk.len() as u64;
298            on_progress(received, total);
299        }
300
301        Ok(())
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308
309    #[test]
310    fn decode_json_empty_and_whitespace_to_null() {
311        assert_eq!(decode_json_response_body(b""), Value::Null);
312        assert_eq!(decode_json_response_body(b"  \n\t "), Value::Null);
313    }
314
315    #[test]
316    fn decode_json_object_roundtrip() {
317        let v = decode_json_response_body(br#"{"a":1}"#);
318        assert_eq!(v["a"], 1);
319    }
320
321    #[test]
322    fn decode_non_json_wrapped() {
323        let v = decode_json_response_body(b"plain text");
324        assert_eq!(v["_non_json_body"], "plain text");
325    }
326}