1use crate::errors::{Error, Result};
2use crate::models::{Dataset, HealthResponse, ImportStatus, StreamsResponse};
3use reqwest::{
4 Client, Method, Response, Url,
5 header::{AUTHORIZATION, HeaderMap, HeaderValue, USER_AGENT},
6};
7use serde::Serialize;
8use serde::de::DeserializeOwned;
9use serde_json::Value;
10use std::time::Duration;
11
12#[non_exhaustive]
14#[derive(Clone, Debug)]
15pub struct MarpleDB {
16 pub(crate) client: Client,
17 pub(crate) storage_client: Client,
18 pub(crate) base_url: String,
19 auth_header: HeaderValue,
20}
21
22impl MarpleDB {
23 pub fn new(url: &str, token: &str) -> Result<Self> {
28 Self::builder().url(url).token(token).build()
29 }
30
31 pub fn builder() -> MarpleDBBuilder {
36 MarpleDBBuilder::default()
37 }
38
39 pub fn storage_client(&self) -> &Client {
44 &self.storage_client
45 }
46
47 fn url(&self, endpoint: &str) -> String {
48 self.base_url.clone() + endpoint.trim_start_matches('/')
49 }
50
51 fn auth(&self, request: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
52 request.header(AUTHORIZATION, self.auth_header.clone())
53 }
54
55 async fn send_json<R>(
56 &self,
57 endpoint: &str,
58 method: Method,
59 request: reqwest::RequestBuilder,
60 ) -> Result<R>
61 where
62 R: DeserializeOwned,
63 {
64 let response = request.send().await.map_err(|source| Error::Transport {
65 method: method.clone(),
66 endpoint: endpoint.to_string(),
67 source,
68 })?;
69 self.handle_response(endpoint, method, response).await
70 }
71
72 async fn handle_response<R>(
73 &self,
74 endpoint: &str,
75 method: Method,
76 response: Response,
77 ) -> Result<R>
78 where
79 R: DeserializeOwned,
80 {
81 let status = response.status();
82 let body = response.text().await.map_err(|source| Error::Transport {
83 method: method.clone(),
84 endpoint: endpoint.to_string(),
85 source,
86 })?;
87 if !status.is_success() {
88 return Err(Error::Api {
89 method,
90 endpoint: endpoint.to_string(),
91 status,
92 body,
93 });
94 }
95 Ok(serde_json::from_str(&body)?)
96 }
97
98 #[tracing::instrument(skip_all, fields(endpoint = %endpoint))]
103 pub async fn get<Q, R>(&self, endpoint: &str, query: &Q) -> Result<R>
104 where
105 Q: Serialize + ?Sized,
106 R: DeserializeOwned,
107 {
108 let request = self.auth(self.client.get(self.url(endpoint)).query(query));
109 self.send_json(endpoint, Method::GET, request).await
110 }
111
112 #[tracing::instrument(skip_all, fields(endpoint = %endpoint))]
117 pub async fn post<B, R>(&self, endpoint: &str, body: &B) -> Result<R>
118 where
119 B: Serialize + ?Sized,
120 R: DeserializeOwned,
121 {
122 let request = self.auth(self.client.post(self.url(endpoint)).json(body));
123 self.send_json(endpoint, Method::POST, request).await
124 }
125
126 #[tracing::instrument(skip_all, fields(endpoint = %endpoint))]
131 pub async fn delete<B, R>(&self, endpoint: &str, body: &B) -> Result<R>
132 where
133 B: Serialize + ?Sized,
134 R: DeserializeOwned,
135 {
136 let request = self.auth(self.client.delete(self.url(endpoint)).json(body));
137 self.send_json(endpoint, Method::DELETE, request).await
138 }
139
140 pub(crate) async fn post_json<B, R>(&self, endpoint: &str, body: &B) -> Result<R>
141 where
142 B: Serialize + ?Sized,
143 R: DeserializeOwned,
144 {
145 self.post(endpoint, body).await
146 }
147
148 #[tracing::instrument(skip_all, fields(endpoint = %endpoint))]
149 pub(crate) async fn post_multipart(
150 &self,
151 endpoint: &str,
152 form: reqwest::multipart::Form,
153 ) -> Result<Value> {
154 let request = self.auth(self.client.post(self.url(endpoint)).multipart(form));
155 self.send_json(endpoint, Method::POST, request).await
156 }
157
158 pub(crate) async fn get_json<Q, R>(&self, endpoint: &str, query: &Q) -> Result<R>
159 where
160 Q: Serialize + ?Sized,
161 R: DeserializeOwned,
162 {
163 self.get(endpoint, query).await
164 }
165
166 pub async fn health(&self) -> Result<HealthResponse> {
168 self.get("health", &()).await
169 }
170
171 pub async fn get_streams(&self) -> Result<Vec<crate::Stream>> {
173 let streams_response: StreamsResponse = self.get("streams", &()).await?;
174 Ok(streams_response.streams)
175 }
176
177 pub async fn get_stream(&self, stream_name: &str) -> Result<crate::Stream> {
179 let streams = self.get_streams().await?;
180 streams
181 .into_iter()
182 .find(|s| s.name == stream_name)
183 .ok_or_else(|| Error::StreamNotFound {
184 name: stream_name.to_string(),
185 })
186 }
187
188 pub async fn create_stream<S: Serialize + ?Sized>(
193 &self,
194 stream_name: &str,
195 options: &S,
196 ) -> Result<crate::Stream> {
197 let mut options = match serde_json::to_value(options)? {
198 Value::Object(options) => options,
199 _ => {
200 return Err(Error::Protocol(
201 "create_stream options must serialize to a JSON object".to_string(),
202 ));
203 }
204 };
205 options.insert("name".to_string(), Value::String(stream_name.to_string()));
206 self.post_json::<_, Value>("stream", &options).await?;
207 self.get_stream(stream_name).await
208 }
209
210 pub async fn update_stream<S: Serialize + ?Sized>(
215 &self,
216 stream_id: i32,
217 options: &S,
218 ) -> Result<crate::Stream> {
219 let endpoint = format!("stream/update/{}", stream_id);
220 self.post_json::<_, Value>(&endpoint, options).await?;
221 self.get_streams()
222 .await?
223 .into_iter()
224 .find(|stream| stream.id == stream_id)
225 .ok_or(Error::StreamIdNotFound { id: stream_id })
226 }
227
228 pub async fn get_datasets(&self, stream_id: i32) -> Result<Vec<Dataset>> {
230 self.get(&format!("stream/{}/datasets", stream_id), &())
231 .await
232 }
233
234 pub async fn get_datapool_datasets(&self, pool: &str) -> Result<Vec<Dataset>> {
236 self.get(&format!("datapool/{}/datasets", pool), &()).await
237 }
238
239 pub async fn get_datapool_ingest_queue(&self, pool: &str) -> Result<Vec<Dataset>> {
241 self.get(&format!("datapool/{}/ingest/queue", pool), &())
242 .await
243 }
244
245 pub async fn get_dataset(&self, stream_id: i32, dataset_id: i32) -> Result<Dataset> {
247 self.get(&format!("stream/{}/dataset/{}", stream_id, dataset_id), &())
248 .await
249 }
250
251 pub async fn get_download_link(&self, dataset: &Dataset) -> Result<Url> {
256 if dataset.backup_size.is_none() {
257 return Err(Error::NoBackup { id: dataset.id });
258 }
259 let endpoint = format!(
260 "stream/{}/dataset/{}/backup",
261 dataset.datastream_id, dataset.id
262 );
263 #[derive(serde::Deserialize)]
264 struct DownloadLink {
265 path: String,
266 }
267 let link: DownloadLink = self.get_json(&endpoint, &()).await?;
268 Ok(link.path.parse()?)
269 }
270
271 pub async fn wait_for_import(
276 &self,
277 stream_id: i32,
278 dataset_id: i32,
279 timeout: Duration,
280 ) -> Result<Dataset> {
281 let deadline = std::time::Instant::now() + timeout;
282 let mut last_status = "unknown".to_string();
283
284 while std::time::Instant::now() < deadline {
285 let dataset = self.get_dataset(stream_id, dataset_id).await?;
286 last_status = format!("{:?}", dataset.import_status);
287
288 match dataset.import_status {
289 ImportStatus::Finished | ImportStatus::Live => return Ok(dataset),
290 ImportStatus::Failed | ImportStatus::PostprocessingFailed => {
291 return Err(Error::ImportFailed {
292 id: dataset.id,
293 message: dataset
294 .import_message
295 .clone()
296 .unwrap_or_else(|| format!("{:?}", dataset.import_status)),
297 });
298 }
299 _ => tokio::time::sleep(Duration::from_millis(500)).await,
300 }
301 }
302
303 Err(Error::ImportTimeout {
304 timeout_secs: timeout.as_secs(),
305 last_status,
306 })
307 }
308}
309
310#[non_exhaustive]
312#[derive(Clone, Debug)]
313pub struct MarpleDBBuilder {
314 url: Option<String>,
315 token: Option<String>,
316 client: Option<Client>,
317 storage_client: Option<Client>,
318 timeout: Option<Duration>,
319 user_agent: Option<String>,
320}
321
322impl Default for MarpleDBBuilder {
323 fn default() -> Self {
324 Self {
325 url: None,
326 token: None,
327 client: None,
328 storage_client: None,
329 timeout: None,
330 user_agent: Some(format!("marple-db/{}", env!("CARGO_PKG_VERSION"))),
331 }
332 }
333}
334
335impl MarpleDBBuilder {
336 pub fn url(mut self, url: impl Into<String>) -> Self {
340 self.url = Some(url.into());
341 self
342 }
343
344 pub fn token(mut self, token: impl Into<String>) -> Self {
348 self.token = Some(token.into());
349 self
350 }
351
352 pub fn timeout(mut self, timeout: Duration) -> Self {
357 self.timeout = Some(timeout);
358 self
359 }
360
361 pub fn user_agent(mut self, user_agent: impl Into<String>) -> Self {
363 self.user_agent = Some(user_agent.into());
364 self
365 }
366
367 pub fn client(mut self, client: Client) -> Self {
371 self.client = Some(client);
372 self
373 }
374
375 pub fn storage_client(mut self, client: Client) -> Self {
380 self.storage_client = Some(client);
381 self
382 }
383
384 pub fn build(self) -> Result<MarpleDB> {
386 let url = self
387 .url
388 .ok_or_else(|| Error::Config("missing MarpleDB API URL".to_string()))?;
389 let token = self
390 .token
391 .ok_or_else(|| Error::Config("missing MarpleDB API token".to_string()))?;
392 let mut auth_header = HeaderValue::from_str(&format!("Bearer {}", token))?;
393 auth_header.set_sensitive(true);
394
395 let client = match self.client {
396 Some(client) => client,
397 None => build_client(self.timeout, self.user_agent.as_deref())?,
398 };
399 let storage_client = match self.storage_client {
400 Some(client) => client,
401 None => build_client(self.timeout, self.user_agent.as_deref())?,
402 };
403
404 Ok(MarpleDB {
405 client,
406 storage_client,
407 base_url: url.trim_end_matches('/').to_string() + "/",
408 auth_header,
409 })
410 }
411}
412
413fn build_client(timeout: Option<Duration>, user_agent: Option<&str>) -> Result<Client> {
414 let mut builder = Client::builder();
415 if let Some(timeout) = timeout {
416 builder = builder.timeout(timeout);
417 }
418 if let Some(user_agent) = user_agent {
419 let mut headers = HeaderMap::new();
420 headers.insert(USER_AGENT, HeaderValue::from_str(user_agent)?);
421 builder = builder.default_headers(headers);
422 }
423 builder.build().map_err(|source| Error::Transport {
424 method: Method::GET,
425 endpoint: "client builder".to_string(),
426 source,
427 })
428}