Skip to main content

chartml_core/resolver/
builtin.rs

1//! Built-in providers: `InlineProvider` and `HttpProvider`.
2//!
3//! Both are pre-registered on `ChartML::new()` under their respective dispatch
4//! keys (`"inline"`, `"http"`). Hosts can override either by calling
5//! `register_provider("inline", ...)` / `register_provider("http", ...)` with
6//! their own implementation — e.g., an HTTP provider that adds OAuth refresh.
7
8use std::collections::HashMap;
9
10use async_trait::async_trait;
11
12use crate::data::{DataTable, Row};
13
14use super::{DataSourceProvider, FetchError, FetchRequest, FetchResult};
15
16/// Provider for `data: { rows: [...] }` shapes. Materializes the inline rows
17/// into a `DataTable` via `DataTable::from_rows`. Empty rows produce an
18/// empty table — the same as the legacy inline path.
19pub struct InlineProvider;
20
21impl InlineProvider {
22    pub fn new() -> Self {
23        Self
24    }
25}
26
27impl Default for InlineProvider {
28    fn default() -> Self {
29        Self::new()
30    }
31}
32
33#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
34#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
35impl DataSourceProvider for InlineProvider {
36    async fn fetch(&self, request: FetchRequest) -> Result<FetchResult, FetchError> {
37        let rows_value = request.spec.rows.unwrap_or_default();
38        let rows: Vec<Row> = rows_value
39            .into_iter()
40            .map(|value| match value {
41                serde_json::Value::Object(map) => Ok(map.into_iter().collect::<Row>()),
42                other => Err(FetchError::DecodeFailed(format!(
43                    "InlineProvider expects each row to be a JSON object, got: {other}"
44                ))),
45            })
46            .collect::<Result<Vec<Row>, FetchError>>()?;
47
48        let data = DataTable::from_rows(&rows)
49            .map_err(|e| FetchError::DecodeFailed(format!("from_rows failed: {e}")))?;
50        Ok(FetchResult {
51            data,
52            metadata: HashMap::new(),
53        })
54    }
55}
56
57/// Provider for `data: { url: "..." }` shapes. Issues a GET via `reqwest`
58/// (works on both native and WASM with no feature-flag branching).
59///
60/// Header handling:
61/// - `with_default_headers(...)` sets defaults that apply to every request.
62/// - `FetchRequest.headers` overrides any default with the same name on a
63///   per-request basis (matches HTTP intuition: per-call wins over default).
64///
65/// Decode rule:
66/// - `Content-Type: application/vnd.apache.arrow.*` → decode as Arrow IPC
67///   bytes via `DataTable::from_ipc_bytes`.
68/// - anything else → parse as JSON. JSON arrays of objects flow through
69///   `DataTable::from_rows`; JSON objects with a top-level array key (`rows`,
70///   `data`, or `results`) are unwrapped automatically to match the most
71///   common API conventions.
72pub struct HttpProvider {
73    client: reqwest::Client,
74    default_headers: HashMap<String, String>,
75}
76
77impl HttpProvider {
78    /// New provider with no default headers. Convenience for the default
79    /// `register_provider("http", HttpProvider::new())` registration.
80    pub fn new() -> Self {
81        Self {
82            client: reqwest::Client::new(),
83            default_headers: HashMap::new(),
84        }
85    }
86
87    /// Builder: attach default headers (e.g. `Authorization`, `User-Agent`)
88    /// applied to every request unless overridden by `FetchRequest.headers`.
89    pub fn with_default_headers(mut self, headers: HashMap<String, String>) -> Self {
90        self.default_headers = headers;
91        self
92    }
93}
94
95impl Default for HttpProvider {
96    fn default() -> Self {
97        Self::new()
98    }
99}
100
101#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
102#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
103impl DataSourceProvider for HttpProvider {
104    async fn fetch(&self, request: FetchRequest) -> Result<FetchResult, FetchError> {
105        let url = request
106            .spec
107            .url
108            .as_deref()
109            .ok_or_else(|| FetchError::Other(
110                "HttpProvider requires `url` in the data spec".to_string(),
111            ))?;
112
113        // Merge headers: defaults first, then per-request overrides.
114        let mut merged: HashMap<String, String> = self.default_headers.clone();
115        for (k, v) in &request.headers {
116            merged.insert(k.clone(), v.clone());
117        }
118
119        let mut req = self.client.get(url);
120        for (name, value) in &merged {
121            req = req.header(name, value);
122        }
123
124        let response = req
125            .send()
126            .await
127            .map_err(|e| FetchError::QueryFailed(format!("HTTP GET {url} failed: {e}")))?;
128
129        let status = response.status();
130        if !status.is_success() {
131            return Err(FetchError::QueryFailed(format!(
132                "HTTP GET {url} returned status {status}"
133            )));
134        }
135
136        let content_type = response
137            .headers()
138            .get(reqwest::header::CONTENT_TYPE)
139            .and_then(|v| v.to_str().ok())
140            .map(|s| s.to_lowercase())
141            .unwrap_or_default();
142
143        let bytes = response
144            .bytes()
145            .await
146            .map_err(|e| FetchError::DecodeFailed(format!("body read failed: {e}")))?;
147
148        let data = if content_type.starts_with("application/vnd.apache.arrow") {
149            DataTable::from_ipc_bytes(&bytes)
150                .map_err(|e| FetchError::DecodeFailed(format!("Arrow IPC decode failed: {e}")))?
151        } else {
152            decode_json_to_table(&bytes)?
153        };
154
155        Ok(FetchResult {
156            data,
157            metadata: HashMap::new(),
158        })
159    }
160}
161
162/// Decode an HTTP response body as JSON and convert to a `DataTable`.
163///
164/// Accepts three top-level shapes for compatibility with common API
165/// conventions:
166/// - `[ {...}, {...} ]` → array of objects (the canonical chartml shape).
167/// - `{ "rows": [ ... ] }`, `{ "data": [ ... ] }`, `{ "results": [ ... ] }` →
168///   unwrap the array key and treat as the canonical shape.
169/// - anything else → `DecodeFailed` with the discovered shape in the error.
170fn decode_json_to_table(bytes: &[u8]) -> Result<DataTable, FetchError> {
171    let value: serde_json::Value = serde_json::from_slice(bytes)
172        .map_err(|e| FetchError::DecodeFailed(format!("JSON parse failed: {e}")))?;
173
174    let array = match value {
175        serde_json::Value::Array(arr) => arr,
176        serde_json::Value::Object(mut obj) => {
177            // Common top-level wrapper conventions, in order of preference.
178            const ARRAY_KEYS: [&str; 3] = ["rows", "data", "results"];
179            let mut found: Option<Vec<serde_json::Value>> = None;
180            for key in ARRAY_KEYS {
181                if let Some(serde_json::Value::Array(arr)) = obj.remove(key) {
182                    found = Some(arr);
183                    break;
184                }
185            }
186            found.ok_or_else(|| {
187                FetchError::DecodeFailed(
188                    "JSON object response must have a top-level `rows`, `data`, or `results` array key"
189                        .to_string(),
190                )
191            })?
192        }
193        other => {
194            return Err(FetchError::DecodeFailed(format!(
195                "JSON response must be an array of objects or an object with a `rows`/`data`/`results` array; got: {}",
196                discriminant_name(&other),
197            )));
198        }
199    };
200
201    let rows: Vec<Row> = array
202        .into_iter()
203        .map(|v| match v {
204            serde_json::Value::Object(map) => Ok(map.into_iter().collect::<Row>()),
205            other => Err(FetchError::DecodeFailed(format!(
206                "JSON array entries must be objects, got: {}",
207                discriminant_name(&other),
208            ))),
209        })
210        .collect::<Result<Vec<Row>, FetchError>>()?;
211
212    DataTable::from_rows(&rows)
213        .map_err(|e| FetchError::DecodeFailed(format!("from_rows failed: {e}")))
214}
215
216/// Pretty-print a JSON value's variant for error messages.
217fn discriminant_name(value: &serde_json::Value) -> &'static str {
218    match value {
219        serde_json::Value::Null => "null",
220        serde_json::Value::Bool(_) => "bool",
221        serde_json::Value::Number(_) => "number",
222        serde_json::Value::String(_) => "string",
223        serde_json::Value::Array(_) => "array",
224        serde_json::Value::Object(_) => "object",
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231    use crate::resolver::FetchRequest;
232    use crate::spec::InlineData;
233    use serde_json::json;
234
235    fn empty_request(spec: InlineData) -> FetchRequest {
236        FetchRequest {
237            source_name: None,
238            spec,
239            cache: None,
240            headers: HashMap::new(),
241            namespace: None,
242            cancel_token: None,
243        }
244    }
245
246    #[tokio::test]
247    async fn inline_provider_basic_rows() {
248        let provider = InlineProvider::new();
249        let spec = InlineData {
250            provider: Some("inline".into()),
251            rows: Some(vec![
252                json!({"x": "A", "y": 1}),
253                json!({"x": "B", "y": 2}),
254            ]),
255            url: None,
256            endpoint: None,
257            cache: None,
258            datasource: None,
259            query: None,
260        };
261        let result = provider.fetch(empty_request(spec)).await.unwrap();
262        assert_eq!(result.data.num_rows(), 2);
263    }
264
265    #[tokio::test]
266    async fn inline_provider_empty_rows() {
267        let provider = InlineProvider::new();
268        let spec = InlineData {
269            provider: Some("inline".into()),
270            rows: Some(vec![]),
271            url: None,
272            endpoint: None,
273            cache: None,
274            datasource: None,
275            query: None,
276        };
277        let result = provider.fetch(empty_request(spec)).await.unwrap();
278        assert_eq!(result.data.num_rows(), 0);
279    }
280
281    #[tokio::test]
282    async fn inline_provider_rejects_non_object_rows() {
283        let provider = InlineProvider::new();
284        let spec = InlineData {
285            provider: Some("inline".into()),
286            rows: Some(vec![json!(42)]),
287            url: None,
288            endpoint: None,
289            cache: None,
290            datasource: None,
291            query: None,
292        };
293        let err = provider.fetch(empty_request(spec)).await.unwrap_err();
294        assert!(matches!(err, FetchError::DecodeFailed(_)));
295    }
296
297    #[test]
298    fn json_decode_array_of_objects() {
299        let body = b"[{\"x\":1},{\"x\":2}]";
300        let table = decode_json_to_table(body).unwrap();
301        assert_eq!(table.num_rows(), 2);
302    }
303
304    #[test]
305    fn json_decode_rows_wrapper() {
306        let body = b"{\"rows\":[{\"x\":1}]}";
307        let table = decode_json_to_table(body).unwrap();
308        assert_eq!(table.num_rows(), 1);
309    }
310
311    #[test]
312    fn json_decode_data_wrapper() {
313        let body = b"{\"data\":[{\"x\":1},{\"x\":2}]}";
314        let table = decode_json_to_table(body).unwrap();
315        assert_eq!(table.num_rows(), 2);
316    }
317
318    #[test]
319    fn json_decode_rejects_bare_object() {
320        let body = b"{\"foo\":\"bar\"}";
321        let err = decode_json_to_table(body).unwrap_err();
322        assert!(matches!(err, FetchError::DecodeFailed(_)));
323    }
324
325    #[test]
326    fn json_decode_rejects_scalar() {
327        let body = b"42";
328        let err = decode_json_to_table(body).unwrap_err();
329        assert!(matches!(err, FetchError::DecodeFailed(_)));
330    }
331}