Skip to main content

paperless_api/
client.rs

1use std::{collections::HashMap, path::Path, str::FromStr, sync::Arc};
2
3use reqwest::{
4    Method, StatusCode,
5    header::{ACCEPT, HeaderMap, HeaderName, InvalidHeaderValue},
6    multipart,
7};
8use serde::Deserialize;
9use tracing::{debug, trace};
10
11use crate::{
12    Error, Result, User,
13    correspondent::Correspondent,
14    custom_field::CustomField,
15    document::{Document, DocumentData},
16    document_type::DocumentType,
17    id::{CorrespondentId, CustomFieldId, DocumentId, DocumentTypeId, TagId, TaskId, UserId},
18    tag::Tag,
19    task::Task,
20    workflow::Workflow,
21};
22
23/// Selects which cached metadata to refresh.
24#[derive(Copy, Clone, Debug, PartialEq, Eq)]
25pub enum RefreshData {
26    Tags,
27    CustomFields,
28    Correspondents,
29    DocumentTypes,
30    Users,
31}
32
33/// Client to interact with Paperless.
34#[derive(Debug, Clone)]
35pub struct PaperlessClient {
36    client: reqwest::Client,
37    pub(crate) base_url: String,
38
39    correspondents: HashMap<CorrespondentId, Correspondent>,
40    document_types: HashMap<DocumentTypeId, DocumentType>,
41    tags: HashMap<TagId, Tag>,
42    custom_fields: HashMap<CustomFieldId, CustomField>,
43    users: HashMap<UserId, User>,
44}
45
46#[derive(Debug, Deserialize)]
47struct PaginatedResponse<T> {
48    results: Vec<T>,
49    next: Option<String>,
50}
51
52impl PaperlessClient {
53    /// Create a new Paperless client.
54    pub fn new(
55        base_url: &str,
56        token: &str,
57        headers: Option<&HashMap<String, String>>,
58    ) -> std::result::Result<Self, String> {
59        let mut headers_map = HeaderMap::new();
60
61        // Add additional headers if provided
62        if let Some(headers) = headers {
63            for (key, value) in headers {
64                headers_map.insert(
65                    HeaderName::from_str(key).map_err(|err| err.to_string())?,
66                    value
67                        .parse()
68                        .map_err(|err: InvalidHeaderValue| err.to_string())?,
69                );
70            }
71        }
72
73        // Add the Paperless token header
74        headers_map.insert(
75            HeaderName::from_str("Authorization").map_err(|err| err.to_string())?,
76            format!("Token {token}")
77                .parse()
78                .map_err(|err: InvalidHeaderValue| err.to_string())?,
79        );
80
81        Ok(Self {
82            base_url: base_url.to_string(),
83            client: reqwest::Client::builder()
84                .default_headers(headers_map)
85                .zstd(true)
86                .build()
87                .map_err(|err| err.to_string())?,
88            tags: HashMap::new(),
89            custom_fields: HashMap::new(),
90            correspondents: HashMap::new(),
91            document_types: HashMap::new(),
92            users: HashMap::new(),
93        })
94    }
95
96    async fn load_tags(&self) -> Result<HashMap<TagId, Tag>> {
97        debug!("loading tags");
98        let tags: Vec<Tag> = self.fetch_all_pages("/api/tags/").await?;
99        Ok(tags.into_iter().map(|tag| (tag.id, tag)).collect())
100    }
101
102    async fn load_custom_fields(&self) -> Result<HashMap<CustomFieldId, CustomField>> {
103        debug!("loading custom fields");
104        let custom_fields: Vec<CustomField> = self.fetch_all_pages("/api/custom_fields/").await?;
105        Ok(custom_fields
106            .into_iter()
107            .map(|custom_field| (custom_field.id, custom_field))
108            .collect())
109    }
110
111    async fn load_correspondents(&self) -> Result<HashMap<CorrespondentId, Correspondent>> {
112        debug!("loading correspondents");
113        let correspondents: Vec<Correspondent> =
114            self.fetch_all_pages("/api/correspondents/").await?;
115        Ok(correspondents
116            .into_iter()
117            .map(|correspondent| (correspondent.id, correspondent))
118            .collect())
119    }
120
121    async fn load_document_types(&self) -> Result<HashMap<DocumentTypeId, DocumentType>> {
122        debug!("loading document types");
123        let document_types: Vec<DocumentType> =
124            self.fetch_all_pages("/api/document_types/").await?;
125        Ok(document_types
126            .into_iter()
127            .map(|document_type| (document_type.id, document_type))
128            .collect())
129    }
130
131    async fn load_users(&self) -> Result<HashMap<UserId, User>> {
132        debug!("loading users");
133        let users: Vec<User> = self.fetch_all_pages("/api/users/").await?;
134        Ok(users.into_iter().map(|user| (user.id, user)).collect())
135    }
136
137    pub async fn refresh_all(&mut self) -> Result<()> {
138        self.refresh([
139            RefreshData::Tags,
140            RefreshData::CustomFields,
141            RefreshData::Correspondents,
142            RefreshData::DocumentTypes,
143            RefreshData::Users,
144        ])
145        .await
146    }
147
148    /// Refresh selected cached metadata concurrently.
149    pub async fn refresh(&mut self, data: impl IntoIterator<Item = RefreshData>) -> Result<()> {
150        let mut refresh_tags = false;
151        let mut refresh_custom_fields = false;
152        let mut refresh_correspondents = false;
153        let mut refresh_document_types = false;
154        let mut refresh_users = false;
155
156        for item in data {
157            match item {
158                RefreshData::Tags => refresh_tags = true,
159                RefreshData::CustomFields => refresh_custom_fields = true,
160                RefreshData::Correspondents => refresh_correspondents = true,
161                RefreshData::DocumentTypes => refresh_document_types = true,
162                RefreshData::Users => refresh_users = true,
163            }
164        }
165
166        let (tags, custom_fields, correspondents, document_types, users) = futures_util::try_join!(
167            async {
168                if refresh_tags {
169                    Ok::<Option<HashMap<TagId, Tag>>, Error>(Some(self.load_tags().await?))
170                } else {
171                    Ok::<Option<HashMap<TagId, Tag>>, Error>(None)
172                }
173            },
174            async {
175                if refresh_custom_fields {
176                    Ok::<Option<HashMap<CustomFieldId, CustomField>>, Error>(Some(
177                        self.load_custom_fields().await?,
178                    ))
179                } else {
180                    Ok::<Option<HashMap<CustomFieldId, CustomField>>, Error>(None)
181                }
182            },
183            async {
184                if refresh_correspondents {
185                    Ok::<Option<HashMap<CorrespondentId, Correspondent>>, Error>(Some(
186                        self.load_correspondents().await?,
187                    ))
188                } else {
189                    Ok::<Option<HashMap<CorrespondentId, Correspondent>>, Error>(None)
190                }
191            },
192            async {
193                if refresh_document_types {
194                    Ok::<Option<HashMap<DocumentTypeId, DocumentType>>, Error>(Some(
195                        self.load_document_types().await?,
196                    ))
197                } else {
198                    Ok::<Option<HashMap<DocumentTypeId, DocumentType>>, Error>(None)
199                }
200            },
201            async {
202                if refresh_users {
203                    Ok::<Option<HashMap<UserId, User>>, Error>(Some(self.load_users().await?))
204                } else {
205                    Ok::<Option<HashMap<UserId, User>>, Error>(None)
206                }
207            },
208        )?;
209
210        if let Some(tags) = tags {
211            self.tags = tags;
212        }
213
214        if let Some(custom_fields) = custom_fields {
215            self.custom_fields = custom_fields;
216        }
217
218        if let Some(correspondents) = correspondents {
219            self.correspondents = correspondents;
220        }
221
222        if let Some(document_types) = document_types {
223            self.document_types = document_types;
224        }
225
226        if let Some(users) = users {
227            self.users = users;
228        }
229
230        Ok(())
231    }
232
233    /// Refresh tags.
234    #[inline]
235    pub async fn refresh_tags(&mut self) -> Result<()> {
236        self.refresh([RefreshData::Tags]).await
237    }
238
239    /// Refresh custom fields.
240    #[inline]
241    pub async fn refresh_custom_fields(&mut self) -> Result<()> {
242        self.refresh([RefreshData::CustomFields]).await
243    }
244
245    /// Refresh correspondents.
246    #[inline]
247    pub async fn refresh_correspondents(&mut self) -> Result<()> {
248        self.refresh([RefreshData::Correspondents]).await
249    }
250
251    /// Refresh document types.
252    #[inline]
253    pub async fn refresh_document_types(&mut self) -> Result<()> {
254        self.refresh([RefreshData::DocumentTypes]).await
255    }
256
257    /// Refresh users.
258    #[inline]
259    pub async fn refresh_users(&mut self) -> Result<()> {
260        self.refresh([RefreshData::Users]).await
261    }
262
263    /// Get all documents with any of the given tags.
264    pub async fn get_documents_by_tags(
265        &self,
266        tag_ids: &[TagId],
267        truncate_content: bool,
268    ) -> Result<Vec<Document>> {
269        let tag_id_str = tag_ids
270            .iter()
271            .map(|tag_id| tag_id.0.to_string())
272            .collect::<Vec<_>>()
273            .join(",");
274        let documents: Vec<_> = self
275            .fetch_all_pages::<DocumentData>(&format!(
276                "/api/documents/?truncate_content={truncate_content}&tags__id__in={tag_id_str}"
277            ))
278            .await?
279            .into_iter()
280            .map(|data| Document::new(data, Arc::new(self.clone()), truncate_content))
281            .collect();
282
283        Ok(documents)
284    }
285
286    pub(crate) async fn get_document_data_by_id(&self, id: DocumentId) -> Result<DocumentData> {
287        let resp = self
288            .request(Method::GET, &format!("/api/documents/{}/", id.0), None)
289            .await?;
290
291        let document_data: DocumentData = resp
292            .json()
293            .await
294            .map_err(|e| Error::Other(format!("Failed to parse document: {e}")))?;
295
296        Ok(document_data)
297    }
298
299    /// Get a document by its ID.
300    pub async fn get_document_by_id(&self, id: DocumentId) -> Result<Document> {
301        Ok(Document::new(
302            self.get_document_data_by_id(id).await?,
303            Arc::new(self.clone()),
304            false,
305        ))
306    }
307
308    pub(crate) async fn request(
309        &self,
310        method: Method,
311        endpoint: &str,
312        body: Option<&serde_json::Value>,
313    ) -> Result<reqwest::Response> {
314        let mut req = self
315            .client
316            .request(method, format!("{}{endpoint}", self.base_url))
317            .header(ACCEPT, "application/json");
318
319        // Set payload body if provided
320        if let Some(json_body) = body {
321            req = req.json(json_body);
322        }
323
324        let resp = req
325            .send()
326            .await
327            .map_err(|e| Error::Other(format!("Failed to send request: {e}")))?;
328
329        if resp.status() == StatusCode::NOT_FOUND {
330            return Err(Error::NotFound);
331        }
332
333        if !resp.status().is_success() {
334            return Err(Error::Response {
335                status_code: resp.status().as_u16(),
336                body: resp.text().await.unwrap_or_default(),
337            });
338        }
339
340        Ok(resp)
341    }
342
343    pub(crate) async fn fetch_all_pages<T: for<'de> Deserialize<'de>>(
344        &self,
345        endpoint: &str,
346    ) -> Result<Vec<T>> {
347        let mut results = Vec::new();
348        let mut current_url = Some(endpoint.to_string());
349
350        while let Some(url) = current_url {
351            let resp = self.request(Method::GET, &url, None).await?;
352
353            let page: PaginatedResponse<T> = resp.json().await.map_err(|e| {
354                Error::InvalidJson(format!(
355                    "Failed to parse paginated response for {endpoint}: {e}"
356                ))
357            })?;
358
359            results.extend(page.results);
360
361            current_url = page.next.and_then(|next_url| {
362                // Extract just the path from the full URL
363                next_url
364                    .trim_start_matches(&self.base_url)
365                    .to_string()
366                    .into()
367            });
368        }
369
370        Ok(results)
371    }
372
373    /// Get all tasks with optional filtering by ID, name, or acknowledged status.
374    pub async fn get_task_status(
375        &self,
376        task_id: Option<&TaskId>,
377        task_name: Option<&str>,
378        acknowledged: Option<bool>,
379    ) -> Result<Vec<Task>> {
380        let mut query = Vec::new();
381
382        if let Some(id) = task_id {
383            query.push(("task_id", id.to_string()));
384        }
385
386        if let Some(name) = task_name {
387            query.push(("task_name", name.to_string()));
388        }
389
390        if let Some(ack) = acknowledged {
391            query.push(("acknowledged", ack.to_string()));
392        }
393
394        let resp = self
395            .request(
396                Method::GET,
397                &format!(
398                    "/api/tasks/?{}",
399                    serde_urlencoded::to_string(&query)
400                        .map_err(|e| Error::Other(format!("Failed to serialize query: {e}")))?
401                ),
402                None,
403            )
404            .await?;
405
406        trace!("get_task_status response: {:?}", resp);
407
408        let body = resp
409            .text()
410            .await
411            .map_err(|e| Error::Other(format!("Failed to read response body: {e}")))?;
412
413        let tasks: Vec<Task> = match serde_json::from_str(&body) {
414            Ok(t) => t,
415            Err(e) => {
416                return Err(Error::InvalidJson(format!(
417                    "Failed to parse response body: {e}"
418                )));
419            }
420        };
421
422        if tasks.is_empty() {
423            return Err(Error::NotFound);
424        }
425
426        Ok(tasks)
427    }
428
429    pub async fn get_workflows(&self) -> Result<Vec<Workflow>> {
430        self.fetch_all_pages("/api/workflows/").await
431    }
432
433    /// Upload a document to Paperless.
434    ///
435    /// Returns the task ID on success.
436    pub async fn upload_document(&self, file_path: &Path, filename: &str) -> Result<TaskId> {
437        let file_bytes = std::fs::read(file_path)
438            .map_err(|e| Error::Other(format!("Failed to read file: {e}")))?;
439
440        let form = multipart::Form::new().part(
441            "document",
442            multipart::Part::bytes(file_bytes).file_name(filename.to_string()),
443        );
444
445        let url = format!("{}/api/documents/post_document/", self.base_url);
446
447        let resp = self
448            .client
449            .post(&url)
450            .multipart(form)
451            .send()
452            .await
453            .map_err(|e| Error::Other(format!("Failed to send request: {e}")))?;
454
455        let status = resp.status();
456        if !resp.status().is_success() {
457            return Err(Error::Response {
458                status_code: status.as_u16(),
459                body: resp.text().await.unwrap_or_default(),
460            });
461        }
462
463        let task_id: String = resp
464            .json()
465            .await
466            .map_err(|e| Error::Other(format!("Failed to parse task ID: {e}")))?;
467        Ok(TaskId(task_id))
468    }
469
470    #[inline]
471    #[must_use]
472    pub fn tags(&self) -> &HashMap<TagId, Tag> {
473        &self.tags
474    }
475
476    #[must_use]
477    pub fn find_tag_by_name(&self, name: &str) -> Option<&Tag> {
478        self.tags.values().find(|tag| tag.name == name)
479    }
480
481    #[inline]
482    #[must_use]
483    pub fn document_types(&self) -> &HashMap<DocumentTypeId, DocumentType> {
484        &self.document_types
485    }
486
487    #[must_use]
488    pub fn find_document_type_by_name(&self, name: &str) -> Option<&DocumentType> {
489        self.document_types.values().find(|dt| dt.name == name)
490    }
491
492    #[inline]
493    #[must_use]
494    pub fn correspondents(&self) -> &HashMap<CorrespondentId, Correspondent> {
495        &self.correspondents
496    }
497
498    #[inline]
499    #[must_use]
500    pub fn custom_fields(&self) -> &HashMap<CustomFieldId, CustomField> {
501        &self.custom_fields
502    }
503
504    #[must_use]
505    pub fn find_custom_field_by_name(&self, name: &str) -> Option<&CustomField> {
506        self.custom_fields.values().find(|field| field.name == name)
507    }
508
509    #[inline]
510    #[must_use]
511    pub fn users(&self) -> &HashMap<UserId, User> {
512        &self.users
513    }
514}