Skip to main content

datalab_cli/
client.rs

1use crate::error::{DatalabError, Result};
2use crate::output::Progress;
3use reqwest::multipart::{Form, Part};
4use reqwest::{Client, Response};
5use serde::Deserialize;
6use std::path::PathBuf;
7use std::time::Duration;
8use tokio::time::sleep;
9
10const DEFAULT_BASE_URL: &str = "https://www.datalab.to/api/v1";
11const DEFAULT_TIMEOUT_SECS: u64 = 300;
12const INITIAL_POLL_DELAY_MS: u64 = 500;
13const MAX_POLL_DELAY_MS: u64 = 5000;
14const POLL_BACKOFF_MULTIPLIER: f64 = 1.5;
15
16#[derive(Debug, Deserialize)]
17pub struct SubmitResponse {
18    pub success: bool,
19    pub request_id: Option<String>,
20    pub request_check_url: Option<String>,
21    #[serde(flatten)]
22    pub extra: serde_json::Value,
23}
24
25#[derive(Debug, Deserialize)]
26pub struct PollResponse {
27    pub status: String,
28    pub success: Option<bool>,
29    #[serde(flatten)]
30    pub data: serde_json::Value,
31}
32
33pub struct DatalabClient {
34    client: Client,
35    api_key: String,
36    base_url: String,
37    timeout_secs: u64,
38}
39
40impl DatalabClient {
41    pub fn new(timeout_secs: Option<u64>) -> Result<Self> {
42        let api_key = std::env::var("DATALAB_API_KEY").map_err(|_| DatalabError::MissingApiKey)?;
43
44        let base_url =
45            std::env::var("DATALAB_BASE_URL").unwrap_or_else(|_| DEFAULT_BASE_URL.to_string());
46
47        let client = Client::builder()
48            .timeout(Duration::from_secs(
49                timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS),
50            ))
51            .build()?;
52
53        Ok(Self {
54            client,
55            api_key,
56            base_url,
57            timeout_secs: timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS),
58        })
59    }
60
61    fn endpoint(&self, path: &str) -> String {
62        format!(
63            "{}/{}",
64            self.base_url.trim_end_matches('/'),
65            path.trim_start_matches('/')
66        )
67    }
68
69    async fn handle_response(&self, response: Response) -> Result<serde_json::Value> {
70        let status = response.status();
71
72        if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
73            let retry_after = response
74                .headers()
75                .get("retry-after")
76                .and_then(|v| v.to_str().ok())
77                .and_then(|v| v.parse().ok());
78            return Err(DatalabError::RateLimited { retry_after });
79        }
80
81        let body: serde_json::Value = response.json().await?;
82
83        if !status.is_success() {
84            let message = body
85                .get("error")
86                .or_else(|| body.get("message"))
87                .and_then(|v| v.as_str())
88                .unwrap_or("Unknown error")
89                .to_string();
90            return Err(DatalabError::ApiError {
91                status: status.as_u16(),
92                message,
93            });
94        }
95
96        Ok(body)
97    }
98
99    pub async fn get(&self, path: &str) -> Result<serde_json::Value> {
100        let response = self
101            .client
102            .get(self.endpoint(path))
103            .header("X-API-Key", &self.api_key)
104            .send()
105            .await?;
106
107        self.handle_response(response).await
108    }
109
110    pub async fn delete(&self, path: &str) -> Result<serde_json::Value> {
111        let response = self
112            .client
113            .delete(self.endpoint(path))
114            .header("X-API-Key", &self.api_key)
115            .send()
116            .await?;
117
118        self.handle_response(response).await
119    }
120
121    pub async fn post_json(
122        &self,
123        path: &str,
124        body: &serde_json::Value,
125    ) -> Result<serde_json::Value> {
126        let response = self
127            .client
128            .post(self.endpoint(path))
129            .header("X-API-Key", &self.api_key)
130            .json(body)
131            .send()
132            .await?;
133
134        self.handle_response(response).await
135    }
136
137    pub async fn post_form(&self, path: &str, form: Form) -> Result<serde_json::Value> {
138        let response = self
139            .client
140            .post(self.endpoint(path))
141            .header("X-API-Key", &self.api_key)
142            .multipart(form)
143            .send()
144            .await?;
145
146        self.handle_response(response).await
147    }
148
149    pub async fn submit_and_poll(
150        &self,
151        path: &str,
152        form: Form,
153        progress: &Progress,
154    ) -> Result<serde_json::Value> {
155        let submit_response: SubmitResponse =
156            serde_json::from_value(self.post_form(path, form).await?)?;
157
158        if !submit_response.success {
159            let error_msg = submit_response
160                .extra
161                .get("error")
162                .and_then(|v| v.as_str())
163                .unwrap_or("Request submission failed");
164            return Err(DatalabError::ProcessingFailed(error_msg.to_string()));
165        }
166
167        // Emit submit progress event
168        if let Some(ref request_id) = submit_response.request_id {
169            progress.submit(request_id);
170        }
171
172        let check_url = submit_response
173            .request_check_url
174            .ok_or_else(|| DatalabError::ProcessingFailed("No check URL returned".to_string()))?;
175
176        self.poll_until_complete(&check_url, progress).await
177    }
178
179    #[allow(dead_code)]
180    pub async fn submit_json_and_poll(
181        &self,
182        path: &str,
183        body: &serde_json::Value,
184        progress: &Progress,
185    ) -> Result<serde_json::Value> {
186        let submit_response: SubmitResponse =
187            serde_json::from_value(self.post_json(path, body).await?)?;
188
189        if !submit_response.success {
190            let error_msg = submit_response
191                .extra
192                .get("error")
193                .and_then(|v| v.as_str())
194                .unwrap_or("Request submission failed");
195            return Err(DatalabError::ProcessingFailed(error_msg.to_string()));
196        }
197
198        // Emit submit progress event
199        if let Some(ref request_id) = submit_response.request_id {
200            progress.submit(request_id);
201        }
202
203        let check_url = submit_response
204            .request_check_url
205            .ok_or_else(|| DatalabError::ProcessingFailed("No check URL returned".to_string()))?;
206
207        self.poll_until_complete(&check_url, progress).await
208    }
209
210    async fn poll_until_complete(
211        &self,
212        check_url: &str,
213        progress: &Progress,
214    ) -> Result<serde_json::Value> {
215        let mut delay_ms = INITIAL_POLL_DELAY_MS;
216        let start = std::time::Instant::now();
217        let timeout = Duration::from_secs(self.timeout_secs);
218
219        loop {
220            if start.elapsed() > timeout {
221                return Err(DatalabError::Timeout {
222                    seconds: self.timeout_secs,
223                });
224            }
225
226            sleep(Duration::from_millis(delay_ms)).await;
227
228            let response = self
229                .client
230                .get(check_url)
231                .header("X-API-Key", &self.api_key)
232                .send()
233                .await?;
234
235            let poll_response: PollResponse =
236                serde_json::from_value(self.handle_response(response).await?)?;
237
238            // Emit poll progress event
239            progress.poll(&poll_response.status);
240
241            match poll_response.status.as_str() {
242                "complete" => {
243                    if poll_response.success == Some(false) {
244                        let error_msg = poll_response
245                            .data
246                            .get("error")
247                            .and_then(|v| v.as_str())
248                            .unwrap_or("Processing failed");
249                        return Err(DatalabError::ProcessingFailed(error_msg.to_string()));
250                    }
251                    return Ok(poll_response.data);
252                }
253                "failed" => {
254                    let error_msg = poll_response
255                        .data
256                        .get("error")
257                        .and_then(|v| v.as_str())
258                        .unwrap_or("Processing failed");
259                    return Err(DatalabError::ProcessingFailed(error_msg.to_string()));
260                }
261                _ => {
262                    delay_ms = ((delay_ms as f64) * POLL_BACKOFF_MULTIPLIER) as u64;
263                    delay_ms = delay_ms.min(MAX_POLL_DELAY_MS);
264                }
265            }
266        }
267    }
268
269    pub async fn upload_file_to_presigned_url(
270        &self,
271        upload_url: &str,
272        file_path: &PathBuf,
273        content_type: &str,
274        progress: &Progress,
275    ) -> Result<()> {
276        let file_content = tokio::fs::read(file_path).await?;
277        let total_bytes = file_content.len() as u64;
278
279        // Emit upload start progress
280        progress.upload(0, total_bytes);
281
282        let response = self
283            .client
284            .put(upload_url)
285            .header("Content-Type", content_type)
286            .body(file_content)
287            .send()
288            .await?;
289
290        // Emit upload complete progress
291        progress.upload(total_bytes, total_bytes);
292
293        if !response.status().is_success() {
294            return Err(DatalabError::ApiError {
295                status: response.status().as_u16(),
296                message: "Failed to upload file to presigned URL".to_string(),
297            });
298        }
299
300        Ok(())
301    }
302}
303
304pub fn build_form_with_file(file_path: &PathBuf) -> Result<(Form, Vec<u8>)> {
305    let file_content =
306        std::fs::read(file_path).map_err(|_| DatalabError::FileNotFound(file_path.clone()))?;
307
308    let file_name = file_path
309        .file_name()
310        .and_then(|n| n.to_str())
311        .unwrap_or("file")
312        .to_string();
313
314    let mime_type = mime_guess::from_path(file_path)
315        .first_or_octet_stream()
316        .to_string();
317
318    let part = Part::bytes(file_content.clone())
319        .file_name(file_name)
320        .mime_str(&mime_type)
321        .map_err(|e| DatalabError::InvalidInput(e.to_string()))?;
322
323    let form = Form::new().part("file", part);
324
325    Ok((form, file_content))
326}
327
328pub fn add_form_field(form: Form, name: &str, value: &str) -> Form {
329    form.text(name.to_string(), value.to_string())
330}