Skip to main content

papers_datalab/
client.rs

1use std::time::Duration;
2
3use crate::error::{DatalabError, Result};
4use crate::types::{MarkerPollResponse, MarkerRequest, MarkerStatus, MarkerSubmitResponse, StepTypesResponse};
5
6const DEFAULT_BASE_URL: &str = "https://www.datalab.to";
7
8/// Async client for the DataLab Marker REST API.
9///
10/// # Authentication
11///
12/// All requests require an API key sent via the `X-API-Key` header.
13/// Create the client with [`DatalabClient::new`] or load from the
14/// `DATALAB_API_KEY` environment variable with [`DatalabClient::from_env`].
15///
16/// # Usage
17///
18/// ```no_run
19/// # async fn example() -> papers_datalab::Result<()> {
20/// use papers_datalab::{DatalabClient, MarkerRequest, OutputFormat, ProcessingMode};
21///
22/// let client = DatalabClient::from_env()?;
23/// let pdf_bytes = std::fs::read("paper.pdf").unwrap();
24///
25/// let result = client.convert_document(MarkerRequest {
26///     file: Some(pdf_bytes),
27///     filename: Some("paper.pdf".into()),
28///     output_format: vec![OutputFormat::Markdown],
29///     mode: ProcessingMode::Accurate,
30///     ..Default::default()
31/// }).await?;
32///
33/// println!("{}", result.markdown.unwrap_or_default());
34/// # Ok(())
35/// # }
36/// ```
37#[derive(Clone)]
38pub struct DatalabClient {
39    http: reqwest::Client,
40    api_key: String,
41    base_url: String,
42}
43
44impl DatalabClient {
45    /// Create a new client with an explicit API key.
46    pub fn new(api_key: impl Into<String>) -> Self {
47        Self {
48            http: reqwest::Client::new(),
49            api_key: api_key.into(),
50            base_url: DEFAULT_BASE_URL.to_string(),
51        }
52    }
53
54    /// Override the base URL. Useful for testing with a mock server.
55    pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
56        self.base_url = url.into();
57        self
58    }
59
60    /// Create a client from the `DATALAB_API_KEY` environment variable.
61    ///
62    /// Returns [`DatalabError::MissingApiKey`] if the variable is not set.
63    pub fn from_env() -> Result<Self> {
64        let key = std::env::var("DATALAB_API_KEY").map_err(|_| DatalabError::MissingApiKey)?;
65        Ok(Self::new(key))
66    }
67
68    /// High-level: submit a document and poll until conversion is complete.
69    ///
70    /// Uses a 2-second poll interval. Returns the completed [`MarkerPollResponse`]
71    /// or an error if the job fails. No timeout is applied — the caller is
72    /// responsible for cancellation if needed.
73    pub async fn convert_document(&self, req: MarkerRequest) -> Result<MarkerPollResponse> {
74        let submit = self.submit_marker(req).await?;
75        let request_id = submit.request_id;
76
77        loop {
78            tokio::time::sleep(Duration::from_secs(2)).await;
79            let poll = self.get_marker_result(&request_id).await?;
80            match poll.status {
81                MarkerStatus::Complete => return Ok(poll),
82                MarkerStatus::Failed => {
83                    return Err(DatalabError::Processing(
84                        poll.error.unwrap_or_else(|| "unknown processing error".to_string()),
85                    ));
86                }
87                MarkerStatus::Processing => continue,
88            }
89        }
90    }
91
92    /// POST /api/v1/marker — submit a conversion job.
93    ///
94    /// Returns immediately with a `request_id`. Use [`get_marker_result`](Self::get_marker_result)
95    /// to poll for the result, or call [`convert_document`](Self::convert_document) to do both.
96    pub async fn submit_marker(&self, req: MarkerRequest) -> Result<MarkerSubmitResponse> {
97        // Validate: exactly one of file or file_url must be provided
98        if req.file.is_none() && req.file_url.is_none() {
99            return Err(DatalabError::InvalidRequest);
100        }
101
102        let mut form = reqwest::multipart::Form::new();
103
104        // File source
105        if let Some(bytes) = req.file {
106            let filename = req.filename.unwrap_or_else(|| "document.pdf".to_string());
107            let part = reqwest::multipart::Part::bytes(bytes)
108                .file_name(filename)
109                .mime_str("application/pdf")
110                .map_err(|e| DatalabError::Http(e))?;
111            form = form.part("file", part);
112        } else if let Some(url) = req.file_url {
113            form = form.text("file_url", url);
114        }
115
116        // Output format (serialize to comma-joined string)
117        let fmt = req.output_format.iter().map(|f| match f {
118            crate::types::OutputFormat::Markdown => "markdown",
119            crate::types::OutputFormat::Html => "html",
120            crate::types::OutputFormat::Json => "json",
121            crate::types::OutputFormat::Chunks => "chunks",
122        }).collect::<Vec<_>>().join(",");
123        form = form.text("output_format", fmt);
124
125        // Processing mode
126        let mode = match req.mode {
127            crate::types::ProcessingMode::Fast => "fast",
128            crate::types::ProcessingMode::Balanced => "balanced",
129            crate::types::ProcessingMode::Accurate => "accurate",
130        };
131        form = form.text("mode", mode);
132
133        // Optional scalar fields
134        if let Some(max_pages) = req.max_pages {
135            form = form.text("max_pages", max_pages.to_string());
136        }
137        if let Some(page_range) = req.page_range {
138            form = form.text("page_range", page_range);
139        }
140        if req.paginate {
141            form = form.text("paginate", "true");
142        }
143        if req.skip_cache {
144            form = form.text("skip_cache", "true");
145        }
146        if req.disable_image_extraction {
147            form = form.text("disable_image_extraction", "true");
148        }
149        if req.disable_image_captions {
150            form = form.text("disable_image_captions", "true");
151        }
152        if req.save_checkpoint {
153            form = form.text("save_checkpoint", "true");
154        }
155        if req.add_block_ids {
156            form = form.text("add_block_ids", "true");
157        }
158        if req.include_markdown_in_chunks {
159            form = form.text("include_markdown_in_chunks", "true");
160        }
161        if req.keep_spreadsheet_formatting {
162            form = form.text("keep_spreadsheet_formatting", "true");
163        }
164        if req.fence_synthetic_captions {
165            form = form.text("fence_synthetic_captions", "true");
166        }
167        if let Some(schema) = req.page_schema {
168            form = form.text("page_schema", schema.to_string());
169        }
170        if let Some(seg_schema) = req.segmentation_schema {
171            form = form.text("segmentation_schema", seg_schema);
172        }
173        if let Some(config) = req.additional_config {
174            form = form.text("additional_config", config.to_string());
175        }
176        if let Some(extras) = req.extras {
177            form = form.text("extras", extras);
178        }
179        if let Some(webhook) = req.webhook_url {
180            form = form.text("webhook_url", webhook);
181        }
182
183        let url = format!("{}/api/v1/marker", self.base_url);
184        let resp = self
185            .http
186            .post(&url)
187            .header("X-API-Key", &self.api_key)
188            .multipart(form)
189            .send()
190            .await?;
191
192        let status = resp.status();
193        if !status.is_success() {
194            let message = resp.text().await.unwrap_or_default();
195            return Err(DatalabError::Api {
196                status: status.as_u16(),
197                message,
198            });
199        }
200
201        let body = resp.text().await?;
202        let submit = serde_json::from_str::<MarkerSubmitResponse>(&body)
203            .map_err(|e| DatalabError::Api { status: 0, message: format!("JSON parse error: {e}") })?;
204        if submit.success == Some(false) {
205            return Err(DatalabError::Api {
206                status: 0,
207                message: "submit returned success=false".to_string(),
208            });
209        }
210        Ok(submit)
211    }
212
213    /// GET /api/v1/marker/{request_id} — poll for a single conversion result.
214    ///
215    /// Returns the current state of the job. `status` will be `processing`,
216    /// `complete`, or `failed`. Poll every 2 seconds until `complete` or `failed`.
217    pub async fn get_marker_result(&self, request_id: &str) -> Result<MarkerPollResponse> {
218        let url = format!("{}/api/v1/marker/{}", self.base_url, request_id);
219        let resp = self
220            .http
221            .get(&url)
222            .header("X-API-Key", &self.api_key)
223            .send()
224            .await?;
225
226        let status = resp.status();
227        if !status.is_success() {
228            let message = resp.text().await.unwrap_or_default();
229            return Err(DatalabError::Api {
230                status: status.as_u16(),
231                message,
232            });
233        }
234
235        let body = resp.text().await?;
236        serde_json::from_str::<MarkerPollResponse>(&body)
237            .map_err(|e| DatalabError::Api { status: 0, message: format!("JSON parse error: {e}") })
238    }
239
240    /// GET /api/v1/workflows/step-types — list available workflow step types.
241    pub async fn list_step_types(&self) -> Result<StepTypesResponse> {
242        let url = format!("{}/api/v1/workflows/step-types", self.base_url);
243        let resp = self
244            .http
245            .get(&url)
246            .header("X-API-Key", &self.api_key)
247            .send()
248            .await?;
249
250        let status = resp.status();
251        if !status.is_success() {
252            let message = resp.text().await.unwrap_or_default();
253            return Err(DatalabError::Api {
254                status: status.as_u16(),
255                message,
256            });
257        }
258
259        Ok(resp.json::<StepTypesResponse>().await?)
260    }
261}