Skip to main content

paperless_api/
client.rs

1use std::{collections::HashMap, path::Path, str::FromStr, sync::Arc};
2
3use enum_iterator::Sequence;
4use reqwest::{
5    Method, StatusCode,
6    header::{ACCEPT, HeaderMap, HeaderName, InvalidHeaderValue},
7    multipart,
8};
9use serde::Deserialize;
10use tracing::{debug, trace};
11
12use crate::{
13    Error, Result, User,
14    correspondent::Correspondent,
15    custom_field::CustomField,
16    document::{Document, DocumentData},
17    document_type::DocumentType,
18    id::{
19        CorrespondentId, CustomFieldId, DocumentId, DocumentTypeId, StoragePathId, TagId, TaskId,
20        UserId,
21    },
22    storage_path::StoragePath,
23    tag::Tag,
24    task::Task,
25    workflow::Workflow,
26};
27
28/// Selects which cached metadata to refresh.
29#[derive(Copy, Clone, Debug, PartialEq, Eq, Sequence)]
30#[non_exhaustive]
31pub enum RefreshData {
32    Tags,
33    CustomFields,
34    Correspondents,
35    DocumentTypes,
36    Users,
37    StoragePaths,
38}
39
40/// Client to interact with Paperless.
41#[derive(Debug, Clone)]
42pub struct PaperlessClient {
43    client: reqwest::Client,
44    pub(crate) base_url: Box<str>,
45    cached_data: Arc<CachedData>,
46}
47
48#[derive(Debug, Clone)]
49struct CachedData {
50    correspondents: HashMap<CorrespondentId, Correspondent>,
51    document_types: HashMap<DocumentTypeId, DocumentType>,
52    tags: HashMap<TagId, Tag>,
53    custom_fields: HashMap<CustomFieldId, CustomField>,
54    users: HashMap<UserId, User>,
55    storage_paths: HashMap<StoragePathId, StoragePath>,
56}
57
58#[derive(Debug, Deserialize)]
59struct PaginatedResponse<T> {
60    results: Vec<T>,
61    next: Option<String>,
62}
63
64impl PaperlessClient {
65    /// Create a new Paperless client.
66    pub fn new(
67        base_url: &str,
68        token: &str,
69        headers: Option<&HashMap<String, String>>,
70    ) -> std::result::Result<Self, String> {
71        let mut headers_map = HeaderMap::new();
72
73        // Add additional headers if provided
74        if let Some(headers) = headers {
75            for (key, value) in headers {
76                headers_map.insert(
77                    HeaderName::from_str(key).map_err(|err| err.to_string())?,
78                    value
79                        .parse()
80                        .map_err(|err: InvalidHeaderValue| err.to_string())?,
81                );
82            }
83        }
84
85        // Add the Paperless token header
86        headers_map.insert(
87            HeaderName::from_str("Authorization").map_err(|err| err.to_string())?,
88            format!("Token {token}")
89                .parse()
90                .map_err(|err: InvalidHeaderValue| err.to_string())?,
91        );
92
93        Ok(Self {
94            base_url: base_url.into(),
95            client: reqwest::Client::builder()
96                .default_headers(headers_map)
97                .zstd(true)
98                .build()
99                .map_err(|err| err.to_string())?,
100            cached_data: Arc::new(CachedData {
101                tags: HashMap::new(),
102                custom_fields: HashMap::new(),
103                correspondents: HashMap::new(),
104                document_types: HashMap::new(),
105                users: HashMap::new(),
106                storage_paths: HashMap::new(),
107            }),
108        })
109    }
110
111    async fn load_tags(&self) -> Result<HashMap<TagId, Tag>> {
112        debug!("loading tags");
113        let tags: Vec<Tag> = self.fetch_all_pages("/api/tags/").await?;
114        Ok(tags.into_iter().map(|tag| (tag.id, tag)).collect())
115    }
116
117    async fn load_custom_fields(&self) -> Result<HashMap<CustomFieldId, CustomField>> {
118        debug!("loading custom fields");
119        let custom_fields: Vec<CustomField> = self.fetch_all_pages("/api/custom_fields/").await?;
120        Ok(custom_fields
121            .into_iter()
122            .map(|custom_field| (custom_field.id, custom_field))
123            .collect())
124    }
125
126    async fn load_correspondents(&self) -> Result<HashMap<CorrespondentId, Correspondent>> {
127        debug!("loading correspondents");
128        let correspondents: Vec<Correspondent> =
129            self.fetch_all_pages("/api/correspondents/").await?;
130        Ok(correspondents
131            .into_iter()
132            .map(|correspondent| (correspondent.id, correspondent))
133            .collect())
134    }
135
136    async fn load_document_types(&self) -> Result<HashMap<DocumentTypeId, DocumentType>> {
137        debug!("loading document types");
138        let document_types: Vec<DocumentType> =
139            self.fetch_all_pages("/api/document_types/").await?;
140        Ok(document_types
141            .into_iter()
142            .map(|document_type| (document_type.id, document_type))
143            .collect())
144    }
145
146    async fn load_users(&self) -> Result<HashMap<UserId, User>> {
147        debug!("loading users");
148        let users: Vec<User> = self.fetch_all_pages("/api/users/").await?;
149        Ok(users.into_iter().map(|user| (user.id, user)).collect())
150    }
151
152    async fn load_storage_paths(&self) -> Result<HashMap<StoragePathId, StoragePath>> {
153        debug!("loading storage paths");
154        let storage_paths: Vec<StoragePath> = self.fetch_all_pages("/api/storage_paths/").await?;
155        Ok(storage_paths
156            .into_iter()
157            .map(|storage_path| (storage_path.id, storage_path))
158            .collect())
159    }
160
161    pub async fn refresh_all(&mut self) -> Result<()> {
162        self.refresh(enum_iterator::all::<RefreshData>()).await
163    }
164
165    /// Refresh selected cached metadata concurrently.
166    pub async fn refresh(&mut self, data: impl IntoIterator<Item = RefreshData>) -> Result<()> {
167        let mut refresh_tags = false;
168        let mut refresh_custom_fields = false;
169        let mut refresh_correspondents = false;
170        let mut refresh_document_types = false;
171        let mut refresh_users = false;
172        let mut refresh_storage_paths = false;
173
174        for item in data {
175            match item {
176                RefreshData::Tags => refresh_tags = true,
177                RefreshData::CustomFields => refresh_custom_fields = true,
178                RefreshData::Correspondents => refresh_correspondents = true,
179                RefreshData::DocumentTypes => refresh_document_types = true,
180                RefreshData::Users => refresh_users = true,
181                RefreshData::StoragePaths => refresh_storage_paths = true,
182            }
183        }
184
185        let (tags, custom_fields, correspondents, document_types, users, storage_paths) = futures_util::try_join!(
186            async {
187                if refresh_tags {
188                    Ok::<Option<HashMap<TagId, Tag>>, Error>(Some(self.load_tags().await?))
189                } else {
190                    Ok::<Option<HashMap<TagId, Tag>>, Error>(None)
191                }
192            },
193            async {
194                if refresh_custom_fields {
195                    Ok::<Option<HashMap<CustomFieldId, CustomField>>, Error>(Some(
196                        self.load_custom_fields().await?,
197                    ))
198                } else {
199                    Ok::<Option<HashMap<CustomFieldId, CustomField>>, Error>(None)
200                }
201            },
202            async {
203                if refresh_correspondents {
204                    Ok::<Option<HashMap<CorrespondentId, Correspondent>>, Error>(Some(
205                        self.load_correspondents().await?,
206                    ))
207                } else {
208                    Ok::<Option<HashMap<CorrespondentId, Correspondent>>, Error>(None)
209                }
210            },
211            async {
212                if refresh_document_types {
213                    Ok::<Option<HashMap<DocumentTypeId, DocumentType>>, Error>(Some(
214                        self.load_document_types().await?,
215                    ))
216                } else {
217                    Ok::<Option<HashMap<DocumentTypeId, DocumentType>>, Error>(None)
218                }
219            },
220            async {
221                if refresh_users {
222                    Ok::<Option<HashMap<UserId, User>>, Error>(Some(self.load_users().await?))
223                } else {
224                    Ok::<Option<HashMap<UserId, User>>, Error>(None)
225                }
226            },
227            async {
228                if refresh_storage_paths {
229                    Ok::<Option<HashMap<StoragePathId, StoragePath>>, Error>(Some(
230                        self.load_storage_paths().await?,
231                    ))
232                } else {
233                    Ok::<Option<HashMap<StoragePathId, StoragePath>>, Error>(None)
234                }
235            },
236        )?;
237
238        // Try to get a mutable reference to the cached data and update it
239        // If the cache is still referenced only this client will see the changes
240        let cached_data = Arc::make_mut(&mut self.cached_data);
241
242        if let Some(correspondents) = correspondents {
243            cached_data.correspondents = correspondents;
244        }
245        if let Some(document_types) = document_types {
246            cached_data.document_types = document_types;
247        }
248        if let Some(tags) = tags {
249            cached_data.tags = tags;
250        }
251        if let Some(custom_fields) = custom_fields {
252            cached_data.custom_fields = custom_fields;
253        }
254        if let Some(users) = users {
255            cached_data.users = users;
256        }
257        if let Some(storage_paths) = storage_paths {
258            cached_data.storage_paths = storage_paths;
259        }
260
261        Ok(())
262    }
263
264    /// Get all documents with any of the given tags.
265    pub async fn get_documents_by_tags(
266        &self,
267        tag_ids: &[TagId],
268        truncate_content: bool,
269    ) -> Result<Vec<Document>> {
270        let tag_id_str = tag_ids
271            .iter()
272            .map(|tag_id| tag_id.0.to_string())
273            .collect::<Vec<_>>()
274            .join(",");
275        let documents: Vec<_> = self
276            .fetch_all_pages::<DocumentData>(&format!(
277                "/api/documents/?truncate_content={truncate_content}&tags__id__in={tag_id_str}"
278            ))
279            .await?
280            .into_iter()
281            .map(|data| Document::new(data, Arc::new(self.clone()), truncate_content))
282            .collect();
283
284        Ok(documents)
285    }
286
287    pub(crate) async fn get_document_data_by_id(&self, id: DocumentId) -> Result<DocumentData> {
288        let resp = self
289            .request(Method::GET, &format!("/api/documents/{}/", id.0), None)
290            .await?;
291
292        let document_data: DocumentData = resp
293            .json()
294            .await
295            .map_err(|e| Error::Other(format!("Failed to parse document: {e}")))?;
296
297        Ok(document_data)
298    }
299
300    /// Get a document by its ID.
301    pub async fn get_document_by_id(&self, id: DocumentId) -> Result<Document> {
302        Ok(Document::new(
303            self.get_document_data_by_id(id).await?,
304            Arc::new(self.clone()),
305            false,
306        ))
307    }
308
309    pub(crate) async fn request(
310        &self,
311        method: Method,
312        endpoint: &str,
313        body: Option<&serde_json::Value>,
314    ) -> Result<reqwest::Response> {
315        let mut req = self
316            .client
317            .request(method, format!("{}{endpoint}", self.base_url))
318            .header(ACCEPT, "application/json");
319
320        // Set payload body if provided
321        if let Some(json_body) = body {
322            req = req.json(json_body);
323        }
324
325        let resp = req
326            .send()
327            .await
328            .map_err(|e| Error::Other(format!("Failed to send request: {e}")))?;
329
330        if resp.status() == StatusCode::NOT_FOUND {
331            return Err(Error::NotFound);
332        }
333
334        if !resp.status().is_success() {
335            return Err(Error::Response {
336                status_code: resp.status().as_u16(),
337                body: resp.text().await.unwrap_or_default(),
338            });
339        }
340
341        Ok(resp)
342    }
343
344    pub(crate) async fn fetch_all_pages<T: for<'de> Deserialize<'de>>(
345        &self,
346        endpoint: &str,
347    ) -> Result<Vec<T>> {
348        let mut results = Vec::new();
349        let mut current_url = Some(endpoint.to_string());
350
351        while let Some(url) = current_url {
352            let resp = self.request(Method::GET, &url, None).await?;
353
354            let page: PaginatedResponse<T> = resp.json().await.map_err(|e| {
355                Error::InvalidJson(format!(
356                    "Failed to parse paginated response for {endpoint}: {e}"
357                ))
358            })?;
359
360            results.extend(page.results);
361
362            current_url = page.next.and_then(|next_url| {
363                // Extract just the path from the full URL
364                next_url
365                    .trim_start_matches(&*self.base_url)
366                    .to_string()
367                    .into()
368            });
369        }
370
371        Ok(results)
372    }
373
374    /// Get all tasks with optional filtering by ID, name, or acknowledged status.
375    pub async fn get_task_status(
376        &self,
377        task_id: Option<&TaskId>,
378        task_name: Option<&str>,
379        acknowledged: Option<bool>,
380    ) -> Result<Vec<Task>> {
381        let mut query = Vec::new();
382
383        if let Some(id) = task_id {
384            query.push(("task_id", id.to_string()));
385        }
386
387        if let Some(name) = task_name {
388            query.push(("task_name", name.to_string()));
389        }
390
391        if let Some(ack) = acknowledged {
392            query.push(("acknowledged", ack.to_string()));
393        }
394
395        let resp = self
396            .request(
397                Method::GET,
398                &format!(
399                    "/api/tasks/?{}",
400                    serde_urlencoded::to_string(&query)
401                        .map_err(|e| Error::Other(format!("Failed to serialize query: {e}")))?
402                ),
403                None,
404            )
405            .await?;
406
407        trace!("get_task_status response: {:?}", resp);
408
409        let body = resp
410            .text()
411            .await
412            .map_err(|e| Error::Other(format!("Failed to read response body: {e}")))?;
413
414        let tasks: Vec<Task> = match serde_json::from_str(&body) {
415            Ok(t) => t,
416            Err(e) => {
417                return Err(Error::InvalidJson(format!(
418                    "Failed to parse response body: {e}"
419                )));
420            }
421        };
422
423        if tasks.is_empty() {
424            return Err(Error::NotFound);
425        }
426
427        Ok(tasks)
428    }
429
430    pub fn get_workflows(&self) -> impl Future<Output = Result<Vec<Workflow>>> {
431        self.fetch_all_pages("/api/workflows/")
432    }
433
434    /// Upload a document to Paperless.
435    ///
436    /// Returns the task ID on success.
437    pub async fn upload_document(&self, file_path: &Path, filename: &str) -> Result<TaskId> {
438        let file_bytes = std::fs::read(file_path)
439            .map_err(|e| Error::Other(format!("Failed to read file: {e}")))?;
440
441        let form = multipart::Form::new().part(
442            "document",
443            multipart::Part::bytes(file_bytes).file_name(filename.to_string()),
444        );
445
446        let url = format!("{}/api/documents/post_document/", self.base_url);
447
448        let resp = self
449            .client
450            .post(&url)
451            .multipart(form)
452            .send()
453            .await
454            .map_err(|e| Error::Other(format!("Failed to send request: {e}")))?;
455
456        let status = resp.status();
457        if !resp.status().is_success() {
458            return Err(Error::Response {
459                status_code: status.as_u16(),
460                body: resp.text().await.unwrap_or_default(),
461            });
462        }
463
464        let task_id: String = resp
465            .json()
466            .await
467            .map_err(|e| Error::Other(format!("Failed to parse task ID: {e}")))?;
468        Ok(TaskId(task_id))
469    }
470
471    #[inline]
472    #[must_use]
473    pub fn tags(&self) -> &HashMap<TagId, Tag> {
474        &self.cached_data.tags
475    }
476
477    #[inline]
478    #[must_use]
479    pub fn storage_paths(&self) -> &HashMap<StoragePathId, StoragePath> {
480        &self.cached_data.storage_paths
481    }
482
483    #[must_use]
484    pub fn find_tag_by_name(&self, name: &str) -> Option<&Tag> {
485        self.cached_data.tags.values().find(|tag| tag.name == name)
486    }
487
488    #[inline]
489    #[must_use]
490    pub fn document_types(&self) -> &HashMap<DocumentTypeId, DocumentType> {
491        &self.cached_data.document_types
492    }
493
494    #[must_use]
495    pub fn find_document_type_by_name(&self, name: &str) -> Option<&DocumentType> {
496        self.cached_data
497            .document_types
498            .values()
499            .find(|dt| dt.name == name)
500    }
501
502    #[inline]
503    #[must_use]
504    pub fn correspondents(&self) -> &HashMap<CorrespondentId, Correspondent> {
505        &self.cached_data.correspondents
506    }
507
508    #[inline]
509    #[must_use]
510    pub fn custom_fields(&self) -> &HashMap<CustomFieldId, CustomField> {
511        &self.cached_data.custom_fields
512    }
513
514    #[must_use]
515    pub fn find_custom_field_by_name(&self, name: &str) -> Option<&CustomField> {
516        self.cached_data
517            .custom_fields
518            .values()
519            .find(|field| field.name == name)
520    }
521
522    #[inline]
523    #[must_use]
524    pub fn users(&self) -> &HashMap<UserId, User> {
525        &self.cached_data.users
526    }
527}