Skip to main content

datapress_client/
client.rs

1//! Asynchronous DataPress client.
2
3use serde_json::Value as JsonValue;
4
5use crate::error::{ClientError, Result};
6use crate::models::{QueryRequest, QueryResponse, SqlRequest, SqlResponse};
7
8const ARROW_IPC_MIME: &str = "application/vnd.apache.arrow.stream";
9
10/// Builder for [`Client`].
11#[derive(Debug)]
12pub struct ClientBuilder {
13    base_url: String,
14    api_base: String,
15    admin_token: Option<String>,
16    bearer_token: Option<String>,
17    inner: reqwest::ClientBuilder,
18}
19
20impl ClientBuilder {
21    /// Start building a client for the given server base URL, e.g.
22    /// `http://127.0.0.1:8000`. A configured server prefix (e.g.
23    /// `/datapress`) should be included here.
24    pub fn new(base_url: impl Into<String>) -> Self {
25        Self {
26            base_url: base_url.into(),
27            api_base: "/api/v1".into(),
28            admin_token: None,
29            bearer_token: None,
30            inner: reqwest::Client::builder(),
31        }
32    }
33
34    /// Override the versioned API mount path. Defaults to `/api/v1`; pass
35    /// `/api` to target the legacy unversioned alias.
36    pub fn api_base(mut self, base: impl Into<String>) -> Self {
37        self.api_base = base.into();
38        self
39    }
40
41    /// Set the admin token sent as `X-Admin-Token` on mutating endpoints
42    /// (currently [`Client::reload`]).
43    pub fn admin_token(mut self, token: impl Into<String>) -> Self {
44        self.admin_token = Some(token.into());
45        self
46    }
47
48    /// Set an OAuth2 bearer token, attached as `Authorization: Bearer …`
49    /// to every request (for servers with `auth` enabled).
50    pub fn bearer_token(mut self, token: impl Into<String>) -> Self {
51        self.bearer_token = Some(token.into());
52        self
53    }
54
55    /// Set the per-request timeout.
56    pub fn timeout(mut self, dur: std::time::Duration) -> Self {
57        self.inner = self.inner.timeout(dur);
58        self
59    }
60
61    /// Provide a pre-configured [`reqwest::ClientBuilder`] to customise
62    /// the underlying HTTP client (proxies, TLS, pools, …).
63    pub fn reqwest_builder(mut self, b: reqwest::ClientBuilder) -> Self {
64        self.inner = b;
65        self
66    }
67
68    /// Finish building.
69    pub fn build(self) -> Result<Client> {
70        let base_url = self.base_url.trim_end_matches('/').to_string();
71        if !base_url.starts_with("http://") && !base_url.starts_with("https://") {
72            return Err(ClientError::InvalidBaseUrl(self.base_url));
73        }
74        let http = self.inner.build()?;
75        Ok(Client {
76            http,
77            base_url,
78            api_base: self.api_base.trim_end_matches('/').to_string(),
79            admin_token: self.admin_token,
80            bearer_token: self.bearer_token,
81        })
82    }
83}
84
85/// Asynchronous client for a running DataPress server.
86///
87/// Cheap to clone (wraps an `Arc` internally via [`reqwest::Client`]);
88/// share one instance across tasks.
89#[derive(Clone, Debug)]
90pub struct Client {
91    http: reqwest::Client,
92    base_url: String,
93    api_base: String,
94    admin_token: Option<String>,
95    bearer_token: Option<String>,
96}
97
98impl Client {
99    /// Construct a client with defaults for `base_url`.
100    pub fn new(base_url: impl Into<String>) -> Result<Self> {
101        ClientBuilder::new(base_url).build()
102    }
103
104    /// Start a [`ClientBuilder`].
105    pub fn builder(base_url: impl Into<String>) -> ClientBuilder {
106        ClientBuilder::new(base_url)
107    }
108
109    // ----------------------------------------------------------- urls --
110
111    fn api_url(&self, path: &str) -> String {
112        format!("{}{}{}", self.base_url, self.api_base, path)
113    }
114
115    fn root_url(&self, path: &str) -> String {
116        // /healthz and /readyz live at the host root, outside any prefix.
117        // Strip everything after the authority from base_url.
118        let without_scheme = self
119            .base_url
120            .split_once("://")
121            .unwrap_or(("http", self.base_url.as_str()));
122        let (scheme, rest) = without_scheme;
123        let authority = rest.split('/').next().unwrap_or(rest);
124        format!("{scheme}://{authority}{path}")
125    }
126
127    fn apply_headers(&self, req: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
128        let mut req = req;
129        if let Some(t) = &self.admin_token {
130            req = req.header("X-Admin-Token", t);
131        }
132        if let Some(t) = &self.bearer_token {
133            req = req.bearer_auth(t);
134        }
135        req
136    }
137
138    // ------------------------------------------------------- requests --
139
140    async fn get_json(&self, url: String) -> Result<JsonValue> {
141        let req = self.apply_headers(self.http.get(&url).header("Accept", "application/json"));
142        Self::json_response(req.send().await?).await
143    }
144
145    async fn post_json<B: serde::Serialize>(&self, url: String, body: &B) -> Result<JsonValue> {
146        let req = self
147            .apply_headers(self.http.post(&url).header("Accept", "application/json"))
148            .json(body);
149        Self::json_response(req.send().await?).await
150    }
151
152    async fn json_response(resp: reqwest::Response) -> Result<JsonValue> {
153        let status = resp.status();
154        let body = resp.text().await?;
155        if !status.is_success() {
156            return Err(ClientError::from_response(status.as_u16(), body));
157        }
158        if body.is_empty() {
159            return Ok(JsonValue::Null);
160        }
161        serde_json::from_str(&body).map_err(|e| ClientError::Decode(e.to_string()))
162    }
163
164    // --------------------------------------------------------- probes --
165
166    /// Liveness probe — `GET /healthz` (always at the host root).
167    pub async fn healthz(&self) -> Result<JsonValue> {
168        self.get_json(self.root_url("/healthz")).await
169    }
170
171    /// Readiness probe — `GET /readyz`. Returns a `503` error while the
172    /// server is still loading datasets.
173    pub async fn readyz(&self) -> Result<JsonValue> {
174        self.get_json(self.root_url("/readyz")).await
175    }
176
177    // ------------------------------------------------------- metadata --
178
179    /// List registered dataset names.
180    pub async fn datasets(&self) -> Result<Vec<String>> {
181        let v = self.get_json(self.api_url("/datasets")).await?;
182        // Newer servers return `{"datasets": [ {name, …}, … ]}`; tolerate a
183        // bare array and a list of strings too.
184        let arr = match &v {
185            JsonValue::Object(map) => map.get("datasets").cloned().unwrap_or(JsonValue::Null),
186            other => other.clone(),
187        };
188        let names = match arr {
189            JsonValue::Array(items) => items
190                .into_iter()
191                .filter_map(|it| match it {
192                    JsonValue::String(s) => Some(s),
193                    JsonValue::Object(o) => o
194                        .get("name")
195                        .and_then(|n| n.as_str())
196                        .map(str::to_owned),
197                    _ => None,
198                })
199                .collect(),
200            _ => Vec::new(),
201        };
202        Ok(names)
203    }
204
205    /// Fetch the schema description for `dataset`.
206    pub async fn schema(&self, dataset: &str) -> Result<JsonValue> {
207        self.get_json(self.api_url(&format!("/datasets/{dataset}/schema")))
208            .await
209    }
210
211    /// Count matching rows. `predicates` is the same predicate shape used
212    /// by [`QueryRequest`]; `None`/empty = unfiltered.
213    pub async fn count(
214        &self,
215        dataset: &str,
216        predicates: &[crate::models::Predicate],
217    ) -> Result<u64> {
218        let body = serde_json::json!({ "predicates": predicates });
219        let out = self
220            .post_json(self.api_url(&format!("/datasets/{dataset}/count")), &body)
221            .await?;
222        out.get("count")
223            .and_then(JsonValue::as_u64)
224            .ok_or_else(|| ClientError::Decode("count: missing `count` field".into()))
225    }
226
227    // -------------------------------------------------------- queries --
228
229    /// Run a structured query and return the decoded JSON envelope.
230    pub async fn query_json(&self, dataset: &str, request: &QueryRequest) -> Result<QueryResponse> {
231        let v = self
232            .post_json(self.api_url(&format!("/datasets/{dataset}/query")), request)
233            .await?;
234        serde_json::from_value(v).map_err(|e| ClientError::Decode(e.to_string()))
235    }
236
237    /// Run a raw read-only SQL statement (`POST /sql`). The endpoint must
238    /// be enabled server-side (`[sql].enabled = true`), else a `404` is
239    /// returned.
240    pub async fn sql(&self, sql: impl Into<String>, max_rows: Option<u64>) -> Result<SqlResponse> {
241        let body = SqlRequest {
242            sql: sql.into(),
243            max_rows,
244        };
245        let v = self.post_json(self.api_url("/sql"), &body).await?;
246        serde_json::from_value(v).map_err(|e| ClientError::Decode(e.to_string()))
247    }
248
249    /// Trigger an in-place reload of `dataset` (requires `admin_token` or
250    /// the configured reload scopes).
251    pub async fn reload(&self, dataset: &str) -> Result<JsonValue> {
252        self.post_json(
253            self.api_url(&format!("/datasets/{dataset}/reload")),
254            &serde_json::json!({}),
255        )
256        .await
257    }
258
259    // ----------------------------------------------------------- arrow --
260
261    /// Run a structured query against the Arrow IPC streaming endpoint
262    /// (`POST /datasets/{name}/query/stream`), returning the raw IPC
263    /// stream bytes. Use [`Client::query_arrow`] to decode them into
264    /// record batches.
265    pub async fn query_arrow_bytes(
266        &self,
267        dataset: &str,
268        request: &QueryRequest,
269    ) -> Result<bytes::Bytes> {
270        let url = self.api_url(&format!("/datasets/{dataset}/query/stream"));
271        let req = self
272            .apply_headers(self.http.post(&url).header("Accept", ARROW_IPC_MIME))
273            .json(request);
274        let resp = req.send().await?;
275        let status = resp.status();
276        let ctype = resp
277            .headers()
278            .get(reqwest::header::CONTENT_TYPE)
279            .and_then(|v| v.to_str().ok())
280            .unwrap_or("")
281            .to_ascii_lowercase();
282        let body = resp.bytes().await?;
283        if !status.is_success() {
284            let text = String::from_utf8_lossy(&body).into_owned();
285            return Err(ClientError::from_response(status.as_u16(), text));
286        }
287        if !ctype.contains("arrow") {
288            return Err(ClientError::UnexpectedContentType(ctype));
289        }
290        Ok(body)
291    }
292
293    /// Run a structured query and decode the Arrow IPC response into a
294    /// vector of [`arrow::record_batch::RecordBatch`].
295    #[cfg(feature = "arrow")]
296    pub async fn query_arrow(
297        &self,
298        dataset: &str,
299        request: &QueryRequest,
300    ) -> Result<Vec<arrow::record_batch::RecordBatch>> {
301        let bytes = self.query_arrow_bytes(dataset, request).await?;
302        decode_ipc_stream(&bytes)
303    }
304}
305
306/// Decode an Arrow IPC stream into its record batches.
307#[cfg(feature = "arrow")]
308pub fn decode_ipc_stream(bytes: &[u8]) -> Result<Vec<arrow::record_batch::RecordBatch>> {
309    use arrow::ipc::reader::StreamReader;
310    let reader = StreamReader::try_new(std::io::Cursor::new(bytes), None)?;
311    let mut batches = Vec::new();
312    for batch in reader {
313        batches.push(batch?);
314    }
315    Ok(batches)
316}