1use crate::errors::{Error, Result};
2use crate::models::{Dataset, HealthResponse, ImportStatus, StreamsResponse};
3use reqwest::{
4 Client, Method, Response, Url,
5 header::{AUTHORIZATION, HeaderMap, HeaderName, HeaderValue, USER_AGENT},
6};
7use serde::Serialize;
8use serde::de::DeserializeOwned;
9use serde_json::Value;
10use std::time::Duration;
11
12const REQUEST_SOURCE_HEADER: HeaderName = HeaderName::from_static("x-request-source");
19const DEFAULT_REQUEST_SOURCE: HeaderValue =
20 HeaderValue::from_static(concat!("sdk/rust:", env!("CARGO_PKG_VERSION")));
21
22#[non_exhaustive]
24#[derive(Clone, Debug)]
25pub struct MarpleDB {
26 pub(crate) client: Client,
27 pub(crate) storage_client: Client,
28 pub(crate) base_url: String,
29 auth_header: HeaderValue,
30 request_source: HeaderValue,
31}
32
33impl MarpleDB {
34 pub fn new(url: &str, token: &str) -> Result<Self> {
39 Self::builder().url(url).token(token).build()
40 }
41
42 pub fn builder() -> MarpleDBBuilder {
47 MarpleDBBuilder::default()
48 }
49
50 pub fn storage_client(&self) -> &Client {
55 &self.storage_client
56 }
57
58 fn url(&self, endpoint: &str) -> String {
59 self.base_url.clone() + endpoint.trim_start_matches('/')
60 }
61
62 fn auth(&self, request: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
63 request
64 .header(AUTHORIZATION, self.auth_header.clone())
65 .header(REQUEST_SOURCE_HEADER, self.request_source.clone())
66 }
67
68 async fn send_json<R>(
69 &self,
70 endpoint: &str,
71 method: Method,
72 request: reqwest::RequestBuilder,
73 ) -> Result<R>
74 where
75 R: DeserializeOwned,
76 {
77 let response = request.send().await.map_err(|source| Error::Transport {
78 method: method.clone(),
79 endpoint: endpoint.to_string(),
80 source,
81 })?;
82 self.handle_response(endpoint, method, response).await
83 }
84
85 async fn handle_response<R>(
86 &self,
87 endpoint: &str,
88 method: Method,
89 response: Response,
90 ) -> Result<R>
91 where
92 R: DeserializeOwned,
93 {
94 let status = response.status();
95 let body = response.text().await.map_err(|source| Error::Transport {
96 method: method.clone(),
97 endpoint: endpoint.to_string(),
98 source,
99 })?;
100 if !status.is_success() {
101 return Err(Error::Api {
102 method,
103 endpoint: endpoint.to_string(),
104 status,
105 body,
106 });
107 }
108 Ok(serde_json::from_str(&body)?)
109 }
110
111 #[tracing::instrument(skip_all, fields(endpoint = %endpoint))]
116 pub async fn get<Q, R>(&self, endpoint: &str, query: &Q) -> Result<R>
117 where
118 Q: Serialize + ?Sized,
119 R: DeserializeOwned,
120 {
121 let request = self.auth(self.client.get(self.url(endpoint)).query(query));
122 self.send_json(endpoint, Method::GET, request).await
123 }
124
125 #[tracing::instrument(skip_all, fields(endpoint = %endpoint))]
130 pub async fn post<B, R>(&self, endpoint: &str, body: &B) -> Result<R>
131 where
132 B: Serialize + ?Sized,
133 R: DeserializeOwned,
134 {
135 let request = self.auth(self.client.post(self.url(endpoint)).json(body));
136 self.send_json(endpoint, Method::POST, request).await
137 }
138
139 #[tracing::instrument(skip_all, fields(endpoint = %endpoint))]
144 pub async fn delete<B, R>(&self, endpoint: &str, body: &B) -> Result<R>
145 where
146 B: Serialize + ?Sized,
147 R: DeserializeOwned,
148 {
149 let request = self.auth(self.client.delete(self.url(endpoint)).json(body));
150 self.send_json(endpoint, Method::DELETE, request).await
151 }
152
153 pub(crate) async fn post_json<B, R>(&self, endpoint: &str, body: &B) -> Result<R>
154 where
155 B: Serialize + ?Sized,
156 R: DeserializeOwned,
157 {
158 self.post(endpoint, body).await
159 }
160
161 #[tracing::instrument(skip_all, fields(endpoint = %endpoint))]
162 pub(crate) async fn post_multipart(
163 &self,
164 endpoint: &str,
165 form: reqwest::multipart::Form,
166 ) -> Result<Value> {
167 let request = self.auth(self.client.post(self.url(endpoint)).multipart(form));
168 self.send_json(endpoint, Method::POST, request).await
169 }
170
171 pub(crate) async fn get_json<Q, R>(&self, endpoint: &str, query: &Q) -> Result<R>
172 where
173 Q: Serialize + ?Sized,
174 R: DeserializeOwned,
175 {
176 self.get(endpoint, query).await
177 }
178
179 pub async fn health(&self) -> Result<HealthResponse> {
181 self.get("health", &()).await
182 }
183
184 pub async fn get_streams(&self) -> Result<Vec<crate::Stream>> {
186 let streams_response: StreamsResponse = self.get("streams", &()).await?;
187 Ok(streams_response.streams)
188 }
189
190 pub async fn get_stream(&self, stream_name: &str) -> Result<crate::Stream> {
192 let streams = self.get_streams().await?;
193 streams
194 .into_iter()
195 .find(|s| s.name == stream_name)
196 .ok_or_else(|| Error::StreamNotFound {
197 name: stream_name.to_string(),
198 })
199 }
200
201 pub async fn create_stream<S: Serialize + ?Sized>(
206 &self,
207 stream_name: &str,
208 options: &S,
209 ) -> Result<crate::Stream> {
210 let mut options = match serde_json::to_value(options)? {
211 Value::Object(options) => options,
212 _ => {
213 return Err(Error::Protocol(
214 "create_stream options must serialize to a JSON object".to_string(),
215 ));
216 }
217 };
218 options.insert("name".to_string(), Value::String(stream_name.to_string()));
219 self.post_json::<_, Value>("stream", &options).await?;
220 self.get_stream(stream_name).await
221 }
222
223 pub async fn update_stream<S: Serialize + ?Sized>(
228 &self,
229 stream_id: i32,
230 options: &S,
231 ) -> Result<crate::Stream> {
232 let endpoint = format!("stream/update/{}", stream_id);
233 self.post_json::<_, Value>(&endpoint, options).await?;
234 self.get_streams()
235 .await?
236 .into_iter()
237 .find(|stream| stream.id == stream_id)
238 .ok_or(Error::StreamIdNotFound { id: stream_id })
239 }
240
241 pub async fn get_datasets(&self, stream_id: i32) -> Result<Vec<Dataset>> {
243 self.get(&format!("stream/{}/datasets", stream_id), &())
244 .await
245 }
246
247 pub async fn get_datapool_datasets(&self, pool: &str) -> Result<Vec<Dataset>> {
249 self.get(&format!("datapool/{}/datasets", pool), &()).await
250 }
251
252 pub async fn get_datapool_ingest_queue(&self, pool: &str) -> Result<Vec<Dataset>> {
254 self.get(&format!("datapool/{}/ingest/queue", pool), &())
255 .await
256 }
257
258 pub async fn get_dataset(&self, stream_id: i32, dataset_id: i32) -> Result<Dataset> {
260 self.get(&format!("stream/{}/dataset/{}", stream_id, dataset_id), &())
261 .await
262 }
263
264 pub async fn get_download_link(&self, dataset: &Dataset) -> Result<Url> {
269 if dataset.backup_size.is_none() {
270 return Err(Error::NoBackup { id: dataset.id });
271 }
272 let endpoint = format!(
273 "stream/{}/dataset/{}/backup",
274 dataset.datastream_id, dataset.id
275 );
276 #[derive(serde::Deserialize)]
277 struct DownloadLink {
278 path: String,
279 }
280 let link: DownloadLink = self.get_json(&endpoint, &()).await?;
281 Ok(link.path.parse()?)
282 }
283
284 pub async fn wait_for_import(
289 &self,
290 stream_id: i32,
291 dataset_id: i32,
292 timeout: Duration,
293 ) -> Result<Dataset> {
294 let deadline = std::time::Instant::now() + timeout;
295 let mut last_status = "unknown".to_string();
296
297 while std::time::Instant::now() < deadline {
298 let dataset = self.get_dataset(stream_id, dataset_id).await?;
299 last_status = format!("{:?}", dataset.import_status);
300
301 match dataset.import_status {
302 ImportStatus::Finished | ImportStatus::Live => return Ok(dataset),
303 ImportStatus::Failed | ImportStatus::PostprocessingFailed => {
304 return Err(Error::ImportFailed {
305 id: dataset.id,
306 message: dataset
307 .import_message
308 .clone()
309 .unwrap_or_else(|| format!("{:?}", dataset.import_status)),
310 });
311 }
312 _ => tokio::time::sleep(Duration::from_millis(500)).await,
313 }
314 }
315
316 Err(Error::ImportTimeout {
317 timeout_secs: timeout.as_secs(),
318 last_status,
319 })
320 }
321}
322
323#[non_exhaustive]
325#[derive(Clone, Debug)]
326pub struct MarpleDBBuilder {
327 url: Option<String>,
328 token: Option<String>,
329 client: Option<Client>,
330 storage_client: Option<Client>,
331 timeout: Option<Duration>,
332 user_agent: Option<String>,
333 request_source: Option<String>,
334}
335
336impl Default for MarpleDBBuilder {
337 fn default() -> Self {
338 Self {
339 url: None,
340 token: None,
341 client: None,
342 storage_client: None,
343 timeout: None,
344 user_agent: Some(format!("marple-db/{}", env!("CARGO_PKG_VERSION"))),
345 request_source: None,
346 }
347 }
348}
349
350impl MarpleDBBuilder {
351 pub fn url(mut self, url: impl Into<String>) -> Self {
355 self.url = Some(url.into());
356 self
357 }
358
359 pub fn token(mut self, token: impl Into<String>) -> Self {
363 self.token = Some(token.into());
364 self
365 }
366
367 pub fn timeout(mut self, timeout: Duration) -> Self {
372 self.timeout = Some(timeout);
373 self
374 }
375
376 pub fn user_agent(mut self, user_agent: impl Into<String>) -> Self {
378 self.user_agent = Some(user_agent.into());
379 self
380 }
381
382 pub fn request_source(mut self, request_source: impl Into<String>) -> Self {
389 self.request_source = Some(request_source.into());
390 self
391 }
392
393 pub fn client(mut self, client: Client) -> Self {
397 self.client = Some(client);
398 self
399 }
400
401 pub fn storage_client(mut self, client: Client) -> Self {
406 self.storage_client = Some(client);
407 self
408 }
409
410 pub fn build(self) -> Result<MarpleDB> {
412 let url = self
413 .url
414 .ok_or_else(|| Error::Config("missing MarpleDB API URL".to_string()))?;
415 let token = self
416 .token
417 .ok_or_else(|| Error::Config("missing MarpleDB API token".to_string()))?;
418 let mut auth_header = HeaderValue::from_str(&format!("Bearer {}", token))?;
419 auth_header.set_sensitive(true);
420
421 let request_source = match self.request_source {
422 Some(value) => HeaderValue::from_str(&value)?,
423 None => DEFAULT_REQUEST_SOURCE,
424 };
425
426 let client = match self.client {
427 Some(client) => client,
428 None => build_client(self.timeout, self.user_agent.as_deref())?,
429 };
430 let storage_client = match self.storage_client {
431 Some(client) => client,
432 None => build_client(self.timeout, self.user_agent.as_deref())?,
433 };
434
435 Ok(MarpleDB {
436 client,
437 storage_client,
438 base_url: url.trim_end_matches('/').to_string() + "/",
439 auth_header,
440 request_source,
441 })
442 }
443}
444
445fn build_client(timeout: Option<Duration>, user_agent: Option<&str>) -> Result<Client> {
446 let mut builder = Client::builder();
447 if let Some(timeout) = timeout {
448 builder = builder.timeout(timeout);
449 }
450 if let Some(user_agent) = user_agent {
451 let mut headers = HeaderMap::new();
452 headers.insert(USER_AGENT, HeaderValue::from_str(user_agent)?);
453 builder = builder.default_headers(headers);
454 }
455 builder.build().map_err(|source| Error::Transport {
456 method: Method::GET,
457 endpoint: "client builder".to_string(),
458 source,
459 })
460}