datapress_client/
client.rs1use serde_json::Value as JsonValue;
4
5use crate::error::{ClientError, Result};
6use crate::models::{QueryRequest, QueryResponse, SqlRequest, SqlResponse};
7
8const ARROW_IPC_MIME: &str = "application/vnd.apache.arrow.stream";
9
10#[derive(Debug)]
12pub struct ClientBuilder {
13 base_url: String,
14 api_base: String,
15 admin_token: Option<String>,
16 bearer_token: Option<String>,
17 inner: reqwest::ClientBuilder,
18}
19
20impl ClientBuilder {
21 pub fn new(base_url: impl Into<String>) -> Self {
25 Self {
26 base_url: base_url.into(),
27 api_base: "/api/v1".into(),
28 admin_token: None,
29 bearer_token: None,
30 inner: reqwest::Client::builder(),
31 }
32 }
33
34 pub fn api_base(mut self, base: impl Into<String>) -> Self {
37 self.api_base = base.into();
38 self
39 }
40
41 pub fn admin_token(mut self, token: impl Into<String>) -> Self {
44 self.admin_token = Some(token.into());
45 self
46 }
47
48 pub fn bearer_token(mut self, token: impl Into<String>) -> Self {
51 self.bearer_token = Some(token.into());
52 self
53 }
54
55 pub fn timeout(mut self, dur: std::time::Duration) -> Self {
57 self.inner = self.inner.timeout(dur);
58 self
59 }
60
61 pub fn reqwest_builder(mut self, b: reqwest::ClientBuilder) -> Self {
64 self.inner = b;
65 self
66 }
67
68 pub fn build(self) -> Result<Client> {
70 let base_url = self.base_url.trim_end_matches('/').to_string();
71 if !base_url.starts_with("http://") && !base_url.starts_with("https://") {
72 return Err(ClientError::InvalidBaseUrl(self.base_url));
73 }
74 let http = self.inner.build()?;
75 Ok(Client {
76 http,
77 base_url,
78 api_base: self.api_base.trim_end_matches('/').to_string(),
79 admin_token: self.admin_token,
80 bearer_token: self.bearer_token,
81 })
82 }
83}
84
85#[derive(Clone, Debug)]
90pub struct Client {
91 http: reqwest::Client,
92 base_url: String,
93 api_base: String,
94 admin_token: Option<String>,
95 bearer_token: Option<String>,
96}
97
98impl Client {
99 pub fn new(base_url: impl Into<String>) -> Result<Self> {
101 ClientBuilder::new(base_url).build()
102 }
103
104 pub fn builder(base_url: impl Into<String>) -> ClientBuilder {
106 ClientBuilder::new(base_url)
107 }
108
109 fn api_url(&self, path: &str) -> String {
112 format!("{}{}{}", self.base_url, self.api_base, path)
113 }
114
115 fn root_url(&self, path: &str) -> String {
116 let without_scheme = self
119 .base_url
120 .split_once("://")
121 .unwrap_or(("http", self.base_url.as_str()));
122 let (scheme, rest) = without_scheme;
123 let authority = rest.split('/').next().unwrap_or(rest);
124 format!("{scheme}://{authority}{path}")
125 }
126
127 fn apply_headers(&self, req: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
128 let mut req = req;
129 if let Some(t) = &self.admin_token {
130 req = req.header("X-Admin-Token", t);
131 }
132 if let Some(t) = &self.bearer_token {
133 req = req.bearer_auth(t);
134 }
135 req
136 }
137
138 async fn get_json(&self, url: String) -> Result<JsonValue> {
141 let req = self.apply_headers(self.http.get(&url).header("Accept", "application/json"));
142 Self::json_response(req.send().await?).await
143 }
144
145 async fn post_json<B: serde::Serialize>(&self, url: String, body: &B) -> Result<JsonValue> {
146 let req = self
147 .apply_headers(self.http.post(&url).header("Accept", "application/json"))
148 .json(body);
149 Self::json_response(req.send().await?).await
150 }
151
152 async fn json_response(resp: reqwest::Response) -> Result<JsonValue> {
153 let status = resp.status();
154 let body = resp.text().await?;
155 if !status.is_success() {
156 return Err(ClientError::from_response(status.as_u16(), body));
157 }
158 if body.is_empty() {
159 return Ok(JsonValue::Null);
160 }
161 serde_json::from_str(&body).map_err(|e| ClientError::Decode(e.to_string()))
162 }
163
164 pub async fn healthz(&self) -> Result<JsonValue> {
168 self.get_json(self.root_url("/healthz")).await
169 }
170
171 pub async fn readyz(&self) -> Result<JsonValue> {
174 self.get_json(self.root_url("/readyz")).await
175 }
176
177 pub async fn datasets(&self) -> Result<Vec<String>> {
181 let v = self.get_json(self.api_url("/datasets")).await?;
182 let arr = match &v {
185 JsonValue::Object(map) => map.get("datasets").cloned().unwrap_or(JsonValue::Null),
186 other => other.clone(),
187 };
188 let names = match arr {
189 JsonValue::Array(items) => items
190 .into_iter()
191 .filter_map(|it| match it {
192 JsonValue::String(s) => Some(s),
193 JsonValue::Object(o) => o
194 .get("name")
195 .and_then(|n| n.as_str())
196 .map(str::to_owned),
197 _ => None,
198 })
199 .collect(),
200 _ => Vec::new(),
201 };
202 Ok(names)
203 }
204
205 pub async fn schema(&self, dataset: &str) -> Result<JsonValue> {
207 self.get_json(self.api_url(&format!("/datasets/{dataset}/schema")))
208 .await
209 }
210
211 pub async fn count(
214 &self,
215 dataset: &str,
216 predicates: &[crate::models::Predicate],
217 ) -> Result<u64> {
218 let body = serde_json::json!({ "predicates": predicates });
219 let out = self
220 .post_json(self.api_url(&format!("/datasets/{dataset}/count")), &body)
221 .await?;
222 out.get("count")
223 .and_then(JsonValue::as_u64)
224 .ok_or_else(|| ClientError::Decode("count: missing `count` field".into()))
225 }
226
227 pub async fn query_json(&self, dataset: &str, request: &QueryRequest) -> Result<QueryResponse> {
231 let v = self
232 .post_json(self.api_url(&format!("/datasets/{dataset}/query")), request)
233 .await?;
234 serde_json::from_value(v).map_err(|e| ClientError::Decode(e.to_string()))
235 }
236
237 pub async fn sql(&self, sql: impl Into<String>, max_rows: Option<u64>) -> Result<SqlResponse> {
241 let body = SqlRequest {
242 sql: sql.into(),
243 max_rows,
244 };
245 let v = self.post_json(self.api_url("/sql"), &body).await?;
246 serde_json::from_value(v).map_err(|e| ClientError::Decode(e.to_string()))
247 }
248
249 pub async fn reload(&self, dataset: &str) -> Result<JsonValue> {
252 self.post_json(
253 self.api_url(&format!("/datasets/{dataset}/reload")),
254 &serde_json::json!({}),
255 )
256 .await
257 }
258
259 pub async fn query_arrow_bytes(
266 &self,
267 dataset: &str,
268 request: &QueryRequest,
269 ) -> Result<bytes::Bytes> {
270 let url = self.api_url(&format!("/datasets/{dataset}/query/stream"));
271 let req = self
272 .apply_headers(self.http.post(&url).header("Accept", ARROW_IPC_MIME))
273 .json(request);
274 let resp = req.send().await?;
275 let status = resp.status();
276 let ctype = resp
277 .headers()
278 .get(reqwest::header::CONTENT_TYPE)
279 .and_then(|v| v.to_str().ok())
280 .unwrap_or("")
281 .to_ascii_lowercase();
282 let body = resp.bytes().await?;
283 if !status.is_success() {
284 let text = String::from_utf8_lossy(&body).into_owned();
285 return Err(ClientError::from_response(status.as_u16(), text));
286 }
287 if !ctype.contains("arrow") {
288 return Err(ClientError::UnexpectedContentType(ctype));
289 }
290 Ok(body)
291 }
292
293 #[cfg(feature = "arrow")]
296 pub async fn query_arrow(
297 &self,
298 dataset: &str,
299 request: &QueryRequest,
300 ) -> Result<Vec<arrow::record_batch::RecordBatch>> {
301 let bytes = self.query_arrow_bytes(dataset, request).await?;
302 decode_ipc_stream(&bytes)
303 }
304}
305
306#[cfg(feature = "arrow")]
308pub fn decode_ipc_stream(bytes: &[u8]) -> Result<Vec<arrow::record_batch::RecordBatch>> {
309 use arrow::ipc::reader::StreamReader;
310 let reader = StreamReader::try_new(std::io::Cursor::new(bytes), None)?;
311 let mut batches = Vec::new();
312 for batch in reader {
313 batches.push(batch?);
314 }
315 Ok(batches)
316}