Skip to main content

datapress_client/
client.rs

1//! Asynchronous DataPress client.
2
3use serde_json::Value as JsonValue;
4
5use crate::error::{ClientError, Result};
6use crate::models::{QueryRequest, QueryResponse, SqlRequest, SqlResponse};
7
8const ARROW_IPC_MIME: &str = "application/vnd.apache.arrow.stream";
9
10/// Builder for [`Client`].
11#[derive(Debug)]
12pub struct ClientBuilder {
13    base_url: String,
14    api_base: String,
15    admin_token: Option<String>,
16    bearer_token: Option<String>,
17    inner: reqwest::ClientBuilder,
18}
19
20impl ClientBuilder {
21    /// Start building a client for the given server base URL, e.g.
22    /// `http://127.0.0.1:8000`. A configured server prefix (e.g.
23    /// `/datapress`) should be included here.
24    pub fn new(base_url: impl Into<String>) -> Self {
25        Self {
26            base_url: base_url.into(),
27            api_base: "/api/v1".into(),
28            admin_token: None,
29            bearer_token: None,
30            inner: reqwest::Client::builder(),
31        }
32    }
33
34    /// Override the versioned API mount path. Defaults to `/api/v1`; pass
35    /// `/api` to target the legacy unversioned alias.
36    pub fn api_base(mut self, base: impl Into<String>) -> Self {
37        self.api_base = base.into();
38        self
39    }
40
41    /// Set the admin token sent as `X-Admin-Token` on mutating endpoints
42    /// (currently [`Client::reload`]).
43    pub fn admin_token(mut self, token: impl Into<String>) -> Self {
44        self.admin_token = Some(token.into());
45        self
46    }
47
48    /// Set an OAuth2 bearer token, attached as `Authorization: Bearer …`
49    /// to every request (for servers with `auth` enabled).
50    pub fn bearer_token(mut self, token: impl Into<String>) -> Self {
51        self.bearer_token = Some(token.into());
52        self
53    }
54
55    /// Set the per-request timeout.
56    pub fn timeout(mut self, dur: std::time::Duration) -> Self {
57        self.inner = self.inner.timeout(dur);
58        self
59    }
60
61    /// Accept invalid or untrusted TLS certificates.
62    ///
63    /// **Dangerous**: disables certificate verification, leaving the
64    /// connection open to man-in-the-middle attacks. Intended only for
65    /// development against servers using self-signed certificates.
66    pub fn danger_accept_invalid_certs(mut self, accept: bool) -> Self {
67        self.inner = self.inner.danger_accept_invalid_certs(accept);
68        self
69    }
70
71    /// Provide a pre-configured [`reqwest::ClientBuilder`] to customise
72    /// the underlying HTTP client (proxies, TLS, pools, …).
73    pub fn reqwest_builder(mut self, b: reqwest::ClientBuilder) -> Self {
74        self.inner = b;
75        self
76    }
77
78    /// Finish building.
79    pub fn build(self) -> Result<Client> {
80        let base_url = self.base_url.trim_end_matches('/').to_string();
81        if !base_url.starts_with("http://") && !base_url.starts_with("https://") {
82            return Err(ClientError::InvalidBaseUrl(self.base_url));
83        }
84        let http = self.inner.build()?;
85        Ok(Client {
86            http,
87            base_url,
88            api_base: self.api_base.trim_end_matches('/').to_string(),
89            admin_token: self.admin_token,
90            bearer_token: self.bearer_token,
91        })
92    }
93}
94
95/// Asynchronous client for a running DataPress server.
96///
97/// Cheap to clone (wraps an `Arc` internally via [`reqwest::Client`]);
98/// share one instance across tasks.
99#[derive(Clone, Debug)]
100pub struct Client {
101    http: reqwest::Client,
102    base_url: String,
103    api_base: String,
104    admin_token: Option<String>,
105    bearer_token: Option<String>,
106}
107
108impl Client {
109    /// Construct a client with defaults for `base_url`.
110    pub fn new(base_url: impl Into<String>) -> Result<Self> {
111        ClientBuilder::new(base_url).build()
112    }
113
114    /// Start a [`ClientBuilder`].
115    pub fn builder(base_url: impl Into<String>) -> ClientBuilder {
116        ClientBuilder::new(base_url)
117    }
118
119    // ----------------------------------------------------------- urls --
120
121    fn api_url(&self, path: &str) -> String {
122        format!("{}{}{}", self.base_url, self.api_base, path)
123    }
124
125    fn root_url(&self, path: &str) -> String {
126        // /healthz and /readyz live at the host root, outside any prefix.
127        // Strip everything after the authority from base_url.
128        let without_scheme = self
129            .base_url
130            .split_once("://")
131            .unwrap_or(("http", self.base_url.as_str()));
132        let (scheme, rest) = without_scheme;
133        let authority = rest.split('/').next().unwrap_or(rest);
134        format!("{scheme}://{authority}{path}")
135    }
136
137    fn apply_headers(&self, req: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
138        let mut req = req;
139        if let Some(t) = &self.admin_token {
140            req = req.header("X-Admin-Token", t);
141        }
142        if let Some(t) = &self.bearer_token {
143            req = req.bearer_auth(t);
144        }
145        req
146    }
147
148    // ------------------------------------------------------- requests --
149
150    async fn get_json(&self, url: String) -> Result<JsonValue> {
151        let req = self.apply_headers(self.http.get(&url).header("Accept", "application/json"));
152        Self::json_response(req.send().await?).await
153    }
154
155    async fn post_json<B: serde::Serialize>(&self, url: String, body: &B) -> Result<JsonValue> {
156        let req = self
157            .apply_headers(self.http.post(&url).header("Accept", "application/json"))
158            .json(body);
159        Self::json_response(req.send().await?).await
160    }
161
162    async fn json_response(resp: reqwest::Response) -> Result<JsonValue> {
163        let status = resp.status();
164        let body = resp.text().await?;
165        if !status.is_success() {
166            return Err(ClientError::from_response(status.as_u16(), body));
167        }
168        if body.is_empty() {
169            return Ok(JsonValue::Null);
170        }
171        serde_json::from_str(&body).map_err(|e| ClientError::Decode(e.to_string()))
172    }
173
174    // --------------------------------------------------------- probes --
175
176    /// Liveness probe — `GET /healthz` (always at the host root).
177    pub async fn healthz(&self) -> Result<JsonValue> {
178        self.get_json(self.root_url("/healthz")).await
179    }
180
181    /// Readiness probe — `GET /readyz`. Returns a `503` error while the
182    /// server is still loading datasets.
183    pub async fn readyz(&self) -> Result<JsonValue> {
184        self.get_json(self.root_url("/readyz")).await
185    }
186
187    // ------------------------------------------------------- metadata --
188
189    /// List registered dataset names.
190    pub async fn datasets(&self) -> Result<Vec<String>> {
191        let v = self.get_json(self.api_url("/datasets")).await?;
192        // Newer servers return `{"datasets": [ {name, …}, … ]}`; tolerate a
193        // bare array and a list of strings too.
194        let arr = match &v {
195            JsonValue::Object(map) => map.get("datasets").cloned().unwrap_or(JsonValue::Null),
196            other => other.clone(),
197        };
198        let names = match arr {
199            JsonValue::Array(items) => items
200                .into_iter()
201                .filter_map(|it| match it {
202                    JsonValue::String(s) => Some(s),
203                    JsonValue::Object(o) => o
204                        .get("name")
205                        .and_then(|n| n.as_str())
206                        .map(str::to_owned),
207                    _ => None,
208                })
209                .collect(),
210            _ => Vec::new(),
211        };
212        Ok(names)
213    }
214
215    /// Fetch the schema description for `dataset`.
216    pub async fn schema(&self, dataset: &str) -> Result<JsonValue> {
217        self.get_json(self.api_url(&format!("/datasets/{dataset}/schema")))
218            .await
219    }
220
221    /// Count matching rows. `predicates` is the same predicate shape used
222    /// by [`QueryRequest`]; `None`/empty = unfiltered.
223    pub async fn count(
224        &self,
225        dataset: &str,
226        predicates: &[crate::models::Predicate],
227    ) -> Result<u64> {
228        let body = serde_json::json!({ "predicates": predicates });
229        let out = self
230            .post_json(self.api_url(&format!("/datasets/{dataset}/count")), &body)
231            .await?;
232        out.get("count")
233            .and_then(JsonValue::as_u64)
234            .ok_or_else(|| ClientError::Decode("count: missing `count` field".into()))
235    }
236
237    // -------------------------------------------------------- queries --
238
239    /// Run a structured query and return the decoded JSON envelope.
240    pub async fn query_json(&self, dataset: &str, request: &QueryRequest) -> Result<QueryResponse> {
241        let v = self
242            .post_json(self.api_url(&format!("/datasets/{dataset}/query")), request)
243            .await?;
244        serde_json::from_value(v).map_err(|e| ClientError::Decode(e.to_string()))
245    }
246
247    /// Run a raw read-only SQL statement (`POST /sql`). The endpoint must
248    /// be enabled server-side (`[sql].enabled = true`), else a `404` is
249    /// returned.
250    pub async fn sql(&self, sql: impl Into<String>, max_rows: Option<u64>) -> Result<SqlResponse> {
251        let body = SqlRequest {
252            sql: sql.into(),
253            max_rows,
254        };
255        let v = self.post_json(self.api_url("/sql"), &body).await?;
256        serde_json::from_value(v).map_err(|e| ClientError::Decode(e.to_string()))
257    }
258
259    /// Trigger an in-place reload of `dataset` (requires `admin_token` or
260    /// the configured reload scopes).
261    pub async fn reload(&self, dataset: &str) -> Result<JsonValue> {
262        self.post_json(
263            self.api_url(&format!("/datasets/{dataset}/reload")),
264            &serde_json::json!({}),
265        )
266        .await
267    }
268
269    // ----------------------------------------------------------- arrow --
270
271    /// Run a structured query against the Arrow IPC streaming endpoint
272    /// (`POST /datasets/{name}/query/stream`), returning the raw IPC
273    /// stream bytes. Use [`Client::query_arrow`] to decode them into
274    /// record batches.
275    pub async fn query_arrow_bytes(
276        &self,
277        dataset: &str,
278        request: &QueryRequest,
279    ) -> Result<bytes::Bytes> {
280        let url = self.api_url(&format!("/datasets/{dataset}/query/stream"));
281        let req = self
282            .apply_headers(self.http.post(&url).header("Accept", ARROW_IPC_MIME))
283            .json(request);
284        let resp = req.send().await?;
285        let status = resp.status();
286        let ctype = resp
287            .headers()
288            .get(reqwest::header::CONTENT_TYPE)
289            .and_then(|v| v.to_str().ok())
290            .unwrap_or("")
291            .to_ascii_lowercase();
292        let body = resp.bytes().await?;
293        if !status.is_success() {
294            let text = String::from_utf8_lossy(&body).into_owned();
295            return Err(ClientError::from_response(status.as_u16(), text));
296        }
297        if !ctype.contains("arrow") {
298            return Err(ClientError::UnexpectedContentType(ctype));
299        }
300        Ok(body)
301    }
302
303    /// Run a structured query and decode the Arrow IPC response into a
304    /// vector of [`arrow::record_batch::RecordBatch`].
305    #[cfg(feature = "arrow")]
306    pub async fn query_arrow(
307        &self,
308        dataset: &str,
309        request: &QueryRequest,
310    ) -> Result<Vec<arrow::record_batch::RecordBatch>> {
311        let bytes = self.query_arrow_bytes(dataset, request).await?;
312        decode_ipc_stream(&bytes)
313    }
314}
315
316/// Decode an Arrow IPC stream into its record batches.
317#[cfg(feature = "arrow")]
318pub fn decode_ipc_stream(bytes: &[u8]) -> Result<Vec<arrow::record_batch::RecordBatch>> {
319    use arrow::ipc::reader::StreamReader;
320    let reader = StreamReader::try_new(std::io::Cursor::new(bytes), None)?;
321    let mut batches = Vec::new();
322    for batch in reader {
323        batches.push(batch?);
324    }
325    Ok(batches)
326}