1use std::time::Duration;
2
3use crate::error::{DatalabError, Result};
4use crate::types::{MarkerPollResponse, MarkerRequest, MarkerStatus, MarkerSubmitResponse, StepTypesResponse};
5
6const DEFAULT_BASE_URL: &str = "https://www.datalab.to";
7
8#[derive(Clone)]
38pub struct DatalabClient {
39 http: reqwest::Client,
40 api_key: String,
41 base_url: String,
42}
43
44impl DatalabClient {
45 pub fn new(api_key: impl Into<String>) -> Self {
47 Self {
48 http: reqwest::Client::new(),
49 api_key: api_key.into(),
50 base_url: DEFAULT_BASE_URL.to_string(),
51 }
52 }
53
54 pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
56 self.base_url = url.into();
57 self
58 }
59
60 pub fn from_env() -> Result<Self> {
64 let key = std::env::var("DATALAB_API_KEY").map_err(|_| DatalabError::MissingApiKey)?;
65 Ok(Self::new(key))
66 }
67
68 pub async fn convert_document(&self, req: MarkerRequest) -> Result<MarkerPollResponse> {
74 let submit = self.submit_marker(req).await?;
75 let request_id = submit.request_id;
76
77 loop {
78 tokio::time::sleep(Duration::from_secs(2)).await;
79 let poll = self.get_marker_result(&request_id).await?;
80 match poll.status {
81 MarkerStatus::Complete => return Ok(poll),
82 MarkerStatus::Failed => {
83 return Err(DatalabError::Processing(
84 poll.error.unwrap_or_else(|| "unknown processing error".to_string()),
85 ));
86 }
87 MarkerStatus::Processing => continue,
88 }
89 }
90 }
91
92 pub async fn submit_marker(&self, req: MarkerRequest) -> Result<MarkerSubmitResponse> {
97 if req.file.is_none() && req.file_url.is_none() {
99 return Err(DatalabError::InvalidRequest);
100 }
101
102 let mut form = reqwest::multipart::Form::new();
103
104 if let Some(bytes) = req.file {
106 let filename = req.filename.unwrap_or_else(|| "document.pdf".to_string());
107 let part = reqwest::multipart::Part::bytes(bytes)
108 .file_name(filename)
109 .mime_str("application/pdf")
110 .map_err(|e| DatalabError::Http(e))?;
111 form = form.part("file", part);
112 } else if let Some(url) = req.file_url {
113 form = form.text("file_url", url);
114 }
115
116 let fmt = req.output_format.iter().map(|f| match f {
118 crate::types::OutputFormat::Markdown => "markdown",
119 crate::types::OutputFormat::Html => "html",
120 crate::types::OutputFormat::Json => "json",
121 crate::types::OutputFormat::Chunks => "chunks",
122 }).collect::<Vec<_>>().join(",");
123 form = form.text("output_format", fmt);
124
125 let mode = match req.mode {
127 crate::types::ProcessingMode::Fast => "fast",
128 crate::types::ProcessingMode::Balanced => "balanced",
129 crate::types::ProcessingMode::Accurate => "accurate",
130 };
131 form = form.text("mode", mode);
132
133 if let Some(max_pages) = req.max_pages {
135 form = form.text("max_pages", max_pages.to_string());
136 }
137 if let Some(page_range) = req.page_range {
138 form = form.text("page_range", page_range);
139 }
140 if req.paginate {
141 form = form.text("paginate", "true");
142 }
143 if req.skip_cache {
144 form = form.text("skip_cache", "true");
145 }
146 if req.disable_image_extraction {
147 form = form.text("disable_image_extraction", "true");
148 }
149 if req.disable_image_captions {
150 form = form.text("disable_image_captions", "true");
151 }
152 if req.save_checkpoint {
153 form = form.text("save_checkpoint", "true");
154 }
155 if req.add_block_ids {
156 form = form.text("add_block_ids", "true");
157 }
158 if req.include_markdown_in_chunks {
159 form = form.text("include_markdown_in_chunks", "true");
160 }
161 if req.keep_spreadsheet_formatting {
162 form = form.text("keep_spreadsheet_formatting", "true");
163 }
164 if req.fence_synthetic_captions {
165 form = form.text("fence_synthetic_captions", "true");
166 }
167 if let Some(schema) = req.page_schema {
168 form = form.text("page_schema", schema.to_string());
169 }
170 if let Some(seg_schema) = req.segmentation_schema {
171 form = form.text("segmentation_schema", seg_schema);
172 }
173 if let Some(config) = req.additional_config {
174 form = form.text("additional_config", config.to_string());
175 }
176 if let Some(extras) = req.extras {
177 form = form.text("extras", extras);
178 }
179 if let Some(webhook) = req.webhook_url {
180 form = form.text("webhook_url", webhook);
181 }
182
183 let url = format!("{}/api/v1/marker", self.base_url);
184 let resp = self
185 .http
186 .post(&url)
187 .header("X-API-Key", &self.api_key)
188 .multipart(form)
189 .send()
190 .await?;
191
192 let status = resp.status();
193 if !status.is_success() {
194 let message = resp.text().await.unwrap_or_default();
195 return Err(DatalabError::Api {
196 status: status.as_u16(),
197 message,
198 });
199 }
200
201 let body = resp.text().await?;
202 let submit = serde_json::from_str::<MarkerSubmitResponse>(&body)
203 .map_err(|e| DatalabError::Api { status: 0, message: format!("JSON parse error: {e}") })?;
204 if submit.success == Some(false) {
205 return Err(DatalabError::Api {
206 status: 0,
207 message: "submit returned success=false".to_string(),
208 });
209 }
210 Ok(submit)
211 }
212
213 pub async fn get_marker_result(&self, request_id: &str) -> Result<MarkerPollResponse> {
218 let url = format!("{}/api/v1/marker/{}", self.base_url, request_id);
219 let resp = self
220 .http
221 .get(&url)
222 .header("X-API-Key", &self.api_key)
223 .send()
224 .await?;
225
226 let status = resp.status();
227 if !status.is_success() {
228 let message = resp.text().await.unwrap_or_default();
229 return Err(DatalabError::Api {
230 status: status.as_u16(),
231 message,
232 });
233 }
234
235 let body = resp.text().await?;
236 serde_json::from_str::<MarkerPollResponse>(&body)
237 .map_err(|e| DatalabError::Api { status: 0, message: format!("JSON parse error: {e}") })
238 }
239
240 pub async fn list_step_types(&self) -> Result<StepTypesResponse> {
242 let url = format!("{}/api/v1/workflows/step-types", self.base_url);
243 let resp = self
244 .http
245 .get(&url)
246 .header("X-API-Key", &self.api_key)
247 .send()
248 .await?;
249
250 let status = resp.status();
251 if !status.is_success() {
252 let message = resp.text().await.unwrap_or_default();
253 return Err(DatalabError::Api {
254 status: status.as_u16(),
255 message,
256 });
257 }
258
259 Ok(resp.json::<StepTypesResponse>().await?)
260 }
261}