1use crate::error::{DatalabError, Result};
2use crate::output::Progress;
3use reqwest::multipart::{Form, Part};
4use reqwest::{Client, Response};
5use serde::Deserialize;
6use std::path::PathBuf;
7use std::time::Duration;
8use tokio::time::sleep;
9
10const DEFAULT_BASE_URL: &str = "https://www.datalab.to/api/v1";
11const DEFAULT_TIMEOUT_SECS: u64 = 300;
12const INITIAL_POLL_DELAY_MS: u64 = 500;
13const MAX_POLL_DELAY_MS: u64 = 5000;
14const POLL_BACKOFF_MULTIPLIER: f64 = 1.5;
15
16#[derive(Debug, Deserialize)]
17pub struct SubmitResponse {
18 pub success: bool,
19 pub request_id: Option<String>,
20 pub request_check_url: Option<String>,
21 #[serde(flatten)]
22 pub extra: serde_json::Value,
23}
24
25#[derive(Debug, Deserialize)]
26pub struct PollResponse {
27 pub status: String,
28 pub success: Option<bool>,
29 #[serde(flatten)]
30 pub data: serde_json::Value,
31}
32
33pub struct DatalabClient {
34 client: Client,
35 api_key: String,
36 base_url: String,
37 timeout_secs: u64,
38}
39
40impl DatalabClient {
41 pub fn new(timeout_secs: Option<u64>) -> Result<Self> {
42 let api_key = std::env::var("DATALAB_API_KEY").map_err(|_| DatalabError::MissingApiKey)?;
43
44 let base_url =
45 std::env::var("DATALAB_BASE_URL").unwrap_or_else(|_| DEFAULT_BASE_URL.to_string());
46
47 let client = Client::builder()
48 .timeout(Duration::from_secs(
49 timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS),
50 ))
51 .build()?;
52
53 Ok(Self {
54 client,
55 api_key,
56 base_url,
57 timeout_secs: timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS),
58 })
59 }
60
61 fn endpoint(&self, path: &str) -> String {
62 format!(
63 "{}/{}",
64 self.base_url.trim_end_matches('/'),
65 path.trim_start_matches('/')
66 )
67 }
68
69 async fn handle_response(&self, response: Response) -> Result<serde_json::Value> {
70 let status = response.status();
71
72 if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
73 let retry_after = response
74 .headers()
75 .get("retry-after")
76 .and_then(|v| v.to_str().ok())
77 .and_then(|v| v.parse().ok());
78 return Err(DatalabError::RateLimited { retry_after });
79 }
80
81 let body: serde_json::Value = response.json().await?;
82
83 if !status.is_success() {
84 let message = body
85 .get("error")
86 .or_else(|| body.get("message"))
87 .and_then(|v| v.as_str())
88 .unwrap_or("Unknown error")
89 .to_string();
90 return Err(DatalabError::ApiError {
91 status: status.as_u16(),
92 message,
93 });
94 }
95
96 Ok(body)
97 }
98
99 pub async fn get(&self, path: &str) -> Result<serde_json::Value> {
100 let response = self
101 .client
102 .get(self.endpoint(path))
103 .header("X-API-Key", &self.api_key)
104 .send()
105 .await?;
106
107 self.handle_response(response).await
108 }
109
110 pub async fn delete(&self, path: &str) -> Result<serde_json::Value> {
111 let response = self
112 .client
113 .delete(self.endpoint(path))
114 .header("X-API-Key", &self.api_key)
115 .send()
116 .await?;
117
118 self.handle_response(response).await
119 }
120
121 pub async fn post_json(
122 &self,
123 path: &str,
124 body: &serde_json::Value,
125 ) -> Result<serde_json::Value> {
126 let response = self
127 .client
128 .post(self.endpoint(path))
129 .header("X-API-Key", &self.api_key)
130 .json(body)
131 .send()
132 .await?;
133
134 self.handle_response(response).await
135 }
136
137 pub async fn post_form(&self, path: &str, form: Form) -> Result<serde_json::Value> {
138 let response = self
139 .client
140 .post(self.endpoint(path))
141 .header("X-API-Key", &self.api_key)
142 .multipart(form)
143 .send()
144 .await?;
145
146 self.handle_response(response).await
147 }
148
149 pub async fn submit_and_poll(
150 &self,
151 path: &str,
152 form: Form,
153 progress: &Progress,
154 ) -> Result<serde_json::Value> {
155 let submit_response: SubmitResponse =
156 serde_json::from_value(self.post_form(path, form).await?)?;
157
158 if !submit_response.success {
159 let error_msg = submit_response
160 .extra
161 .get("error")
162 .and_then(|v| v.as_str())
163 .unwrap_or("Request submission failed");
164 return Err(DatalabError::ProcessingFailed(error_msg.to_string()));
165 }
166
167 if let Some(ref request_id) = submit_response.request_id {
169 progress.submit(request_id);
170 }
171
172 let check_url = submit_response
173 .request_check_url
174 .ok_or_else(|| DatalabError::ProcessingFailed("No check URL returned".to_string()))?;
175
176 self.poll_until_complete(&check_url, progress).await
177 }
178
179 #[allow(dead_code)]
180 pub async fn submit_json_and_poll(
181 &self,
182 path: &str,
183 body: &serde_json::Value,
184 progress: &Progress,
185 ) -> Result<serde_json::Value> {
186 let submit_response: SubmitResponse =
187 serde_json::from_value(self.post_json(path, body).await?)?;
188
189 if !submit_response.success {
190 let error_msg = submit_response
191 .extra
192 .get("error")
193 .and_then(|v| v.as_str())
194 .unwrap_or("Request submission failed");
195 return Err(DatalabError::ProcessingFailed(error_msg.to_string()));
196 }
197
198 if let Some(ref request_id) = submit_response.request_id {
200 progress.submit(request_id);
201 }
202
203 let check_url = submit_response
204 .request_check_url
205 .ok_or_else(|| DatalabError::ProcessingFailed("No check URL returned".to_string()))?;
206
207 self.poll_until_complete(&check_url, progress).await
208 }
209
210 async fn poll_until_complete(
211 &self,
212 check_url: &str,
213 progress: &Progress,
214 ) -> Result<serde_json::Value> {
215 let mut delay_ms = INITIAL_POLL_DELAY_MS;
216 let start = std::time::Instant::now();
217 let timeout = Duration::from_secs(self.timeout_secs);
218
219 loop {
220 if start.elapsed() > timeout {
221 return Err(DatalabError::Timeout {
222 seconds: self.timeout_secs,
223 });
224 }
225
226 sleep(Duration::from_millis(delay_ms)).await;
227
228 let response = self
229 .client
230 .get(check_url)
231 .header("X-API-Key", &self.api_key)
232 .send()
233 .await?;
234
235 let poll_response: PollResponse =
236 serde_json::from_value(self.handle_response(response).await?)?;
237
238 progress.poll(&poll_response.status);
240
241 match poll_response.status.as_str() {
242 "complete" => {
243 if poll_response.success == Some(false) {
244 let error_msg = poll_response
245 .data
246 .get("error")
247 .and_then(|v| v.as_str())
248 .unwrap_or("Processing failed");
249 return Err(DatalabError::ProcessingFailed(error_msg.to_string()));
250 }
251 return Ok(poll_response.data);
252 }
253 "failed" => {
254 let error_msg = poll_response
255 .data
256 .get("error")
257 .and_then(|v| v.as_str())
258 .unwrap_or("Processing failed");
259 return Err(DatalabError::ProcessingFailed(error_msg.to_string()));
260 }
261 _ => {
262 delay_ms = ((delay_ms as f64) * POLL_BACKOFF_MULTIPLIER) as u64;
263 delay_ms = delay_ms.min(MAX_POLL_DELAY_MS);
264 }
265 }
266 }
267 }
268
269 pub async fn upload_file_to_presigned_url(
270 &self,
271 upload_url: &str,
272 file_path: &PathBuf,
273 content_type: &str,
274 progress: &Progress,
275 ) -> Result<()> {
276 let file_content = tokio::fs::read(file_path).await?;
277 let total_bytes = file_content.len() as u64;
278
279 progress.upload(0, total_bytes);
281
282 let response = self
283 .client
284 .put(upload_url)
285 .header("Content-Type", content_type)
286 .body(file_content)
287 .send()
288 .await?;
289
290 progress.upload(total_bytes, total_bytes);
292
293 if !response.status().is_success() {
294 return Err(DatalabError::ApiError {
295 status: response.status().as_u16(),
296 message: "Failed to upload file to presigned URL".to_string(),
297 });
298 }
299
300 Ok(())
301 }
302}
303
304pub fn build_form_with_file(file_path: &PathBuf) -> Result<(Form, Vec<u8>)> {
305 let file_content =
306 std::fs::read(file_path).map_err(|_| DatalabError::FileNotFound(file_path.clone()))?;
307
308 let file_name = file_path
309 .file_name()
310 .and_then(|n| n.to_str())
311 .unwrap_or("file")
312 .to_string();
313
314 let mime_type = mime_guess::from_path(file_path)
315 .first_or_octet_stream()
316 .to_string();
317
318 let part = Part::bytes(file_content.clone())
319 .file_name(file_name)
320 .mime_str(&mime_type)
321 .map_err(|e| DatalabError::InvalidInput(e.to_string()))?;
322
323 let form = Form::new().part("file", part);
324
325 Ok((form, file_content))
326}
327
328pub fn add_form_field(form: Form, name: &str, value: &str) -> Form {
329 form.text(name.to_string(), value.to_string())
330}