Skip to main content

paperless_api/
client.rs

1use std::{collections::HashMap, path::Path, str::FromStr, sync::Arc};
2
3use reqwest::{
4    Method, StatusCode,
5    header::{ACCEPT, HeaderMap, HeaderName, InvalidHeaderValue},
6    multipart,
7};
8use serde::Deserialize;
9use tracing::{debug, trace};
10
11use crate::{
12    Error, Result,
13    correspondent::{Correspondent, CorrespondentId},
14    custom_field::{CustomField, CustomFieldId},
15    document::{Document, DocumentData, DocumentId},
16    document_type::{DocumentType, DocumentTypeId},
17    tag::{Tag, TagId},
18    task::{Task, TaskId},
19};
20
21/// Selects which cached metadata to refresh.
22#[derive(Copy, Clone, Debug, PartialEq, Eq)]
23pub enum RefreshData {
24    Tags,
25    CustomFields,
26    Correspondents,
27    DocumentTypes,
28}
29
30/// Client to interact with Paperless.
31#[derive(Debug, Clone)]
32pub struct PaperlessClient {
33    client: reqwest::Client,
34    base_url: String,
35
36    correspondents: HashMap<CorrespondentId, Correspondent>,
37    document_types: HashMap<DocumentTypeId, DocumentType>,
38    tags: HashMap<TagId, Tag>,
39    custom_fields: HashMap<CustomFieldId, CustomField>,
40}
41
42#[derive(Debug, Deserialize)]
43struct PaginatedResponse<T> {
44    results: Vec<T>,
45    next: Option<String>,
46}
47
48impl PaperlessClient {
49    /// Create a new Paperless client.
50    pub fn new(
51        base_url: &str,
52        token: &str,
53        headers: Option<&HashMap<String, String>>,
54    ) -> std::result::Result<Self, String> {
55        let mut headers_map = HeaderMap::new();
56
57        // Add additional headers if provided
58        if let Some(headers) = headers {
59            for (key, value) in headers {
60                headers_map.insert(
61                    HeaderName::from_str(key).map_err(|err| err.to_string())?,
62                    value
63                        .parse()
64                        .map_err(|err: InvalidHeaderValue| err.to_string())?,
65                );
66            }
67        }
68
69        // Add the Paperless token header
70        headers_map.insert(
71            HeaderName::from_str("Authorization").map_err(|err| err.to_string())?,
72            format!("Token {token}")
73                .parse()
74                .map_err(|err: InvalidHeaderValue| err.to_string())?,
75        );
76
77        Ok(Self {
78            base_url: base_url.to_string(),
79            client: reqwest::Client::builder()
80                .default_headers(headers_map)
81                .build()
82                .map_err(|err| err.to_string())?,
83            tags: HashMap::new(),
84            custom_fields: HashMap::new(),
85            correspondents: HashMap::new(),
86            document_types: HashMap::new(),
87        })
88    }
89
90    async fn load_tags(&self) -> Result<HashMap<TagId, Tag>> {
91        debug!("loading tags");
92        let tags: Vec<Tag> = self.fetch_all_pages("/api/tags/").await?;
93        Ok(tags.into_iter().map(|tag| (tag.id, tag)).collect())
94    }
95
96    async fn load_custom_fields(&self) -> Result<HashMap<CustomFieldId, CustomField>> {
97        debug!("loading custom fields");
98        let custom_fields: Vec<CustomField> = self.fetch_all_pages("/api/custom_fields/").await?;
99        Ok(custom_fields
100            .into_iter()
101            .map(|custom_field| (custom_field.id, custom_field))
102            .collect())
103    }
104
105    async fn load_correspondents(&self) -> Result<HashMap<CorrespondentId, Correspondent>> {
106        debug!("loading correspondents");
107        let correspondents: Vec<Correspondent> =
108            self.fetch_all_pages("/api/correspondents/").await?;
109        Ok(correspondents
110            .into_iter()
111            .map(|correspondent| (correspondent.id, correspondent))
112            .collect())
113    }
114
115    async fn load_document_types(&self) -> Result<HashMap<DocumentTypeId, DocumentType>> {
116        debug!("loading document types");
117        let document_types: Vec<DocumentType> =
118            self.fetch_all_pages("/api/document_types/").await?;
119        Ok(document_types
120            .into_iter()
121            .map(|document_type| (document_type.id, document_type))
122            .collect())
123    }
124
125    /// Refresh selected cached metadata concurrently.
126    pub async fn refresh(&mut self, data: impl IntoIterator<Item = RefreshData>) -> Result<()> {
127        let mut refresh_tags = false;
128        let mut refresh_custom_fields = false;
129        let mut refresh_correspondents = false;
130        let mut refresh_document_types = false;
131
132        for item in data {
133            match item {
134                RefreshData::Tags => refresh_tags = true,
135                RefreshData::CustomFields => refresh_custom_fields = true,
136                RefreshData::Correspondents => refresh_correspondents = true,
137                RefreshData::DocumentTypes => refresh_document_types = true,
138            }
139        }
140
141        let (tags, custom_fields, correspondents, document_types) = futures_util::try_join!(
142            async {
143                if refresh_tags {
144                    Ok::<Option<HashMap<TagId, Tag>>, Error>(Some(self.load_tags().await?))
145                } else {
146                    Ok::<Option<HashMap<TagId, Tag>>, Error>(None)
147                }
148            },
149            async {
150                if refresh_custom_fields {
151                    Ok::<Option<HashMap<CustomFieldId, CustomField>>, Error>(Some(
152                        self.load_custom_fields().await?,
153                    ))
154                } else {
155                    Ok::<Option<HashMap<CustomFieldId, CustomField>>, Error>(None)
156                }
157            },
158            async {
159                if refresh_correspondents {
160                    Ok::<Option<HashMap<CorrespondentId, Correspondent>>, Error>(Some(
161                        self.load_correspondents().await?,
162                    ))
163                } else {
164                    Ok::<Option<HashMap<CorrespondentId, Correspondent>>, Error>(None)
165                }
166            },
167            async {
168                if refresh_document_types {
169                    Ok::<Option<HashMap<DocumentTypeId, DocumentType>>, Error>(Some(
170                        self.load_document_types().await?,
171                    ))
172                } else {
173                    Ok::<Option<HashMap<DocumentTypeId, DocumentType>>, Error>(None)
174                }
175            },
176        )?;
177
178        if let Some(tags) = tags {
179            self.tags = tags;
180        }
181
182        if let Some(custom_fields) = custom_fields {
183            self.custom_fields = custom_fields;
184        }
185
186        if let Some(correspondents) = correspondents {
187            self.correspondents = correspondents;
188        }
189
190        if let Some(document_types) = document_types {
191            self.document_types = document_types;
192        }
193
194        Ok(())
195    }
196
197    /// List all tags.
198    #[inline]
199    pub async fn refresh_tags(&mut self) -> Result<()> {
200        self.refresh([RefreshData::Tags]).await
201    }
202
203    /// List all custom fields.
204    #[inline]
205    pub async fn refresh_custom_fields(&mut self) -> Result<()> {
206        self.refresh([RefreshData::CustomFields]).await
207    }
208
209    /// List all correspondents.
210    #[inline]
211    pub async fn refresh_correspondents(&mut self) -> Result<()> {
212        self.refresh([RefreshData::Correspondents]).await
213    }
214
215    /// List all document types.
216    #[inline]
217    pub async fn refresh_document_types(&mut self) -> Result<()> {
218        self.refresh([RefreshData::DocumentTypes]).await
219    }
220
221    /// Get all documents with any of the given tags.
222    pub async fn get_documents_by_tags(
223        &self,
224        tag_ids: &[TagId],
225        truncate_content: bool,
226    ) -> Result<Vec<Document>> {
227        let tag_id_str = tag_ids
228            .iter()
229            .map(|tag_id| tag_id.0.to_string())
230            .collect::<Vec<_>>()
231            .join(",");
232        let documents: Vec<_> = self
233            .fetch_all_pages::<DocumentData>(&format!(
234                "/api/documents/?truncate_content={truncate_content}&tags__id__in={tag_id_str}"
235            ))
236            .await?
237            .into_iter()
238            .map(|data| Document::new(data, Arc::new(self.clone()), truncate_content))
239            .collect();
240
241        Ok(documents)
242    }
243
244    pub(crate) async fn get_document_data_by_id(&self, id: DocumentId) -> Result<DocumentData> {
245        let resp = self
246            .request(Method::GET, &format!("/api/documents/{}/", id.0), None)
247            .await?;
248
249        let document_data: DocumentData = resp
250            .json()
251            .await
252            .map_err(|e| Error::Other(format!("Failed to parse document: {e}")))?;
253
254        Ok(document_data)
255    }
256
257    /// Get a document by its ID.
258    pub async fn get_document_by_id(&self, id: DocumentId) -> Result<Document> {
259        Ok(Document::new(
260            self.get_document_data_by_id(id).await?,
261            Arc::new(self.clone()),
262            false,
263        ))
264    }
265
266    pub(crate) async fn request(
267        &self,
268        method: Method,
269        endpoint: &str,
270        body: Option<&serde_json::Value>,
271    ) -> Result<reqwest::Response> {
272        let mut req = self
273            .client
274            .request(method, format!("{}{endpoint}", self.base_url))
275            .header(ACCEPT, "application/json");
276
277        if let Some(json_body) = body {
278            req = req.json(json_body);
279        }
280
281        let resp = req
282            .send()
283            .await
284            .map_err(|e| Error::Other(format!("Failed to send request: {e}")))?;
285
286        if resp.status() == StatusCode::NOT_FOUND {
287            return Err(Error::NotFound);
288        }
289
290        if !resp.status().is_success() {
291            return Err(Error::Response {
292                status_code: resp.status().as_u16(),
293                body: resp.text().await.unwrap_or_default(),
294            });
295        }
296
297        Ok(resp)
298    }
299
300    pub(crate) async fn fetch_all_pages<T: for<'de> Deserialize<'de>>(
301        &self,
302        endpoint: &str,
303    ) -> Result<Vec<T>> {
304        let mut results = Vec::new();
305        let mut current_url = Some(endpoint.to_string());
306
307        while let Some(url) = current_url {
308            let resp = self.request(Method::GET, &url, None).await?;
309
310            let page: PaginatedResponse<T> = resp.json().await.map_err(|e| {
311                Error::Other(format!(
312                    "Failed to parse paginated response for {endpoint}: {e}"
313                ))
314            })?;
315
316            results.extend(page.results);
317
318            current_url = page.next.and_then(|next_url| {
319                // Extract just the path from the full URL
320                next_url
321                    .trim_start_matches(&self.base_url)
322                    .to_string()
323                    .into()
324            });
325        }
326
327        Ok(results)
328    }
329
330    /// Get all tasks with optional filtering by ID, name, or acknowledged status.
331    pub async fn get_task_status(
332        &self,
333        task_id: Option<&TaskId>,
334        task_name: Option<&str>,
335        acknowledged: Option<bool>,
336    ) -> Result<Vec<Task>> {
337        let mut query = Vec::new();
338
339        if let Some(id) = task_id {
340            query.push(("task_id", id.to_string()));
341        }
342
343        if let Some(name) = task_name {
344            query.push(("task_name", name.to_string()));
345        }
346
347        if let Some(ack) = acknowledged {
348            query.push(("acknowledged", ack.to_string()));
349        }
350
351        let resp = self
352            .request(
353                Method::GET,
354                &format!(
355                    "/api/tasks/?{}",
356                    serde_urlencoded::to_string(&query)
357                        .map_err(|e| Error::Other(format!("Failed to serialize query: {e}")))?
358                ),
359                None,
360            )
361            .await?;
362
363        trace!("get_task_status response: {:?}", resp);
364
365        let body = resp
366            .text()
367            .await
368            .map_err(|e| Error::Other(format!("Failed to read response body: {e}")))?;
369
370        let tasks: Vec<Task> = match serde_json::from_str(&body) {
371            Ok(t) => t,
372            Err(e) => {
373                return Err(Error::InvalidJson(format!(
374                    "Failed to parse response body: {e}"
375                )));
376            }
377        };
378
379        if tasks.is_empty() {
380            return Err(Error::NotFound);
381        }
382
383        Ok(tasks)
384    }
385
386    /// Upload a document to Paperless.
387    ///
388    /// Returns the task ID on success.
389    pub async fn upload_document(&self, file_path: &Path, filename: &str) -> Result<TaskId> {
390        let file_bytes = std::fs::read(file_path)
391            .map_err(|e| Error::Other(format!("Failed to read file: {e}")))?;
392
393        let form = multipart::Form::new().part(
394            "document",
395            multipart::Part::bytes(file_bytes).file_name(filename.to_string()),
396        );
397
398        let url = format!("{}/api/documents/post_document/", self.base_url);
399
400        let resp = self
401            .client
402            .post(&url)
403            .multipart(form)
404            .send()
405            .await
406            .map_err(|e| Error::Other(format!("Failed to send request: {e}")))?;
407
408        let status = resp.status();
409        if !resp.status().is_success() {
410            return Err(Error::Response {
411                status_code: status.as_u16(),
412                body: resp.text().await.unwrap_or_default(),
413            });
414        }
415
416        let task_id: String = resp
417            .json()
418            .await
419            .map_err(|e| Error::Other(format!("Failed to parse task ID: {e}")))?;
420        Ok(TaskId(task_id))
421    }
422
423    #[inline]
424    #[must_use]
425    pub fn tags(&self) -> &HashMap<TagId, Tag> {
426        &self.tags
427    }
428
429    #[must_use]
430    pub fn find_tag_by_name(&self, name: &str) -> Option<&Tag> {
431        self.tags.values().find(|tag| tag.name == name)
432    }
433
434    #[inline]
435    #[must_use]
436    pub fn document_types(&self) -> &HashMap<DocumentTypeId, DocumentType> {
437        &self.document_types
438    }
439
440    #[must_use]
441    pub fn find_document_type_by_name(&self, name: &str) -> Option<&DocumentType> {
442        self.document_types.values().find(|dt| dt.name == name)
443    }
444
445    #[inline]
446    #[must_use]
447    pub fn correspondents(&self) -> &HashMap<CorrespondentId, Correspondent> {
448        &self.correspondents
449    }
450
451    #[inline]
452    #[must_use]
453    pub fn custom_fields(&self) -> &HashMap<CustomFieldId, CustomField> {
454        &self.custom_fields
455    }
456
457    #[must_use]
458    pub fn find_custom_field_by_name(&self, name: &str) -> Option<&CustomField> {
459        self.custom_fields.values().find(|field| field.name == name)
460    }
461}