use crate::error::{DatalabError, Result};
use crate::output::Progress;
use reqwest::multipart::{Form, Part};
use reqwest::{Client, Response};
use serde::Deserialize;
use std::path::PathBuf;
use std::time::Duration;
use tokio::time::sleep;
const DEFAULT_BASE_URL: &str = "https://www.datalab.to/api/v1";
const DEFAULT_TIMEOUT_SECS: u64 = 300;
const INITIAL_POLL_DELAY_MS: u64 = 500;
const MAX_POLL_DELAY_MS: u64 = 5000;
const POLL_BACKOFF_MULTIPLIER: f64 = 1.5;
#[derive(Debug, Deserialize)]
pub struct SubmitResponse {
pub success: bool,
pub request_id: Option<String>,
pub request_check_url: Option<String>,
#[serde(flatten)]
pub extra: serde_json::Value,
}
#[derive(Debug, Deserialize)]
pub struct PollResponse {
pub status: String,
pub success: Option<bool>,
#[serde(flatten)]
pub data: serde_json::Value,
}
pub struct DatalabClient {
client: Client,
api_key: String,
base_url: String,
timeout_secs: u64,
}
impl DatalabClient {
pub fn new(timeout_secs: Option<u64>) -> Result<Self> {
let api_key = std::env::var("DATALAB_API_KEY").map_err(|_| DatalabError::MissingApiKey)?;
let base_url =
std::env::var("DATALAB_BASE_URL").unwrap_or_else(|_| DEFAULT_BASE_URL.to_string());
let client = Client::builder()
.timeout(Duration::from_secs(
timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS),
))
.build()?;
Ok(Self {
client,
api_key,
base_url,
timeout_secs: timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS),
})
}
fn endpoint(&self, path: &str) -> String {
format!(
"{}/{}",
self.base_url.trim_end_matches('/'),
path.trim_start_matches('/')
)
}
async fn handle_response(&self, response: Response) -> Result<serde_json::Value> {
let status = response.status();
if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
let retry_after = response
.headers()
.get("retry-after")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse().ok());
return Err(DatalabError::RateLimited { retry_after });
}
let body: serde_json::Value = response.json().await?;
if !status.is_success() {
let message = body
.get("error")
.or_else(|| body.get("message"))
.and_then(|v| v.as_str())
.unwrap_or("Unknown error")
.to_string();
return Err(DatalabError::ApiError {
status: status.as_u16(),
message,
});
}
Ok(body)
}
pub async fn get(&self, path: &str) -> Result<serde_json::Value> {
let response = self
.client
.get(self.endpoint(path))
.header("X-API-Key", &self.api_key)
.send()
.await?;
self.handle_response(response).await
}
pub async fn delete(&self, path: &str) -> Result<serde_json::Value> {
let response = self
.client
.delete(self.endpoint(path))
.header("X-API-Key", &self.api_key)
.send()
.await?;
self.handle_response(response).await
}
pub async fn post_json(
&self,
path: &str,
body: &serde_json::Value,
) -> Result<serde_json::Value> {
let response = self
.client
.post(self.endpoint(path))
.header("X-API-Key", &self.api_key)
.json(body)
.send()
.await?;
self.handle_response(response).await
}
pub async fn post_form(&self, path: &str, form: Form) -> Result<serde_json::Value> {
let response = self
.client
.post(self.endpoint(path))
.header("X-API-Key", &self.api_key)
.multipart(form)
.send()
.await?;
self.handle_response(response).await
}
pub async fn submit_and_poll(
&self,
path: &str,
form: Form,
progress: &Progress,
) -> Result<serde_json::Value> {
let submit_response: SubmitResponse =
serde_json::from_value(self.post_form(path, form).await?)?;
if !submit_response.success {
let error_msg = submit_response
.extra
.get("error")
.and_then(|v| v.as_str())
.unwrap_or("Request submission failed");
return Err(DatalabError::ProcessingFailed(error_msg.to_string()));
}
if let Some(ref request_id) = submit_response.request_id {
progress.submit(request_id);
}
let check_url = submit_response
.request_check_url
.ok_or_else(|| DatalabError::ProcessingFailed("No check URL returned".to_string()))?;
self.poll_until_complete(&check_url, progress).await
}
#[allow(dead_code)]
pub async fn submit_json_and_poll(
&self,
path: &str,
body: &serde_json::Value,
progress: &Progress,
) -> Result<serde_json::Value> {
let submit_response: SubmitResponse =
serde_json::from_value(self.post_json(path, body).await?)?;
if !submit_response.success {
let error_msg = submit_response
.extra
.get("error")
.and_then(|v| v.as_str())
.unwrap_or("Request submission failed");
return Err(DatalabError::ProcessingFailed(error_msg.to_string()));
}
if let Some(ref request_id) = submit_response.request_id {
progress.submit(request_id);
}
let check_url = submit_response
.request_check_url
.ok_or_else(|| DatalabError::ProcessingFailed("No check URL returned".to_string()))?;
self.poll_until_complete(&check_url, progress).await
}
async fn poll_until_complete(
&self,
check_url: &str,
progress: &Progress,
) -> Result<serde_json::Value> {
let mut delay_ms = INITIAL_POLL_DELAY_MS;
let start = std::time::Instant::now();
let timeout = Duration::from_secs(self.timeout_secs);
loop {
if start.elapsed() > timeout {
return Err(DatalabError::Timeout {
seconds: self.timeout_secs,
});
}
sleep(Duration::from_millis(delay_ms)).await;
let response = self
.client
.get(check_url)
.header("X-API-Key", &self.api_key)
.send()
.await?;
let poll_response: PollResponse =
serde_json::from_value(self.handle_response(response).await?)?;
progress.poll(&poll_response.status);
match poll_response.status.as_str() {
"complete" => {
if poll_response.success == Some(false) {
let error_msg = poll_response
.data
.get("error")
.and_then(|v| v.as_str())
.unwrap_or("Processing failed");
return Err(DatalabError::ProcessingFailed(error_msg.to_string()));
}
return Ok(poll_response.data);
}
"failed" => {
let error_msg = poll_response
.data
.get("error")
.and_then(|v| v.as_str())
.unwrap_or("Processing failed");
return Err(DatalabError::ProcessingFailed(error_msg.to_string()));
}
_ => {
delay_ms = ((delay_ms as f64) * POLL_BACKOFF_MULTIPLIER) as u64;
delay_ms = delay_ms.min(MAX_POLL_DELAY_MS);
}
}
}
}
pub async fn upload_file_to_presigned_url(
&self,
upload_url: &str,
file_path: &PathBuf,
content_type: &str,
progress: &Progress,
) -> Result<()> {
let file_content = tokio::fs::read(file_path).await?;
let total_bytes = file_content.len() as u64;
progress.upload(0, total_bytes);
let response = self
.client
.put(upload_url)
.header("Content-Type", content_type)
.body(file_content)
.send()
.await?;
progress.upload(total_bytes, total_bytes);
if !response.status().is_success() {
return Err(DatalabError::ApiError {
status: response.status().as_u16(),
message: "Failed to upload file to presigned URL".to_string(),
});
}
Ok(())
}
}
pub fn build_form_with_file(file_path: &PathBuf) -> Result<(Form, Vec<u8>)> {
let file_content =
std::fs::read(file_path).map_err(|_| DatalabError::FileNotFound(file_path.clone()))?;
let file_name = file_path
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("file")
.to_string();
let mime_type = mime_guess::from_path(file_path)
.first_or_octet_stream()
.to_string();
let part = Part::bytes(file_content.clone())
.file_name(file_name)
.mime_str(&mime_type)
.map_err(|e| DatalabError::InvalidInput(e.to_string()))?;
let form = Form::new().part("file", part);
Ok((form, file_content))
}
pub fn add_form_field(form: Form, name: &str, value: &str) -> Form {
form.text(name.to_string(), value.to_string())
}