Skip to main content

paperless_api/
client.rs

1use std::{collections::HashMap, path::Path, str::FromStr, sync::Arc};
2
3use enum_iterator::Sequence;
4use reqwest::{
5    Method, StatusCode,
6    header::{ACCEPT, HeaderMap, HeaderName, InvalidHeaderValue},
7    multipart,
8};
9use serde::Deserialize;
10use tracing::{debug, trace};
11
12use crate::{
13    Error, Result, User,
14    correspondent::Correspondent,
15    custom_field::CustomField,
16    document::{Document, DocumentData},
17    document_type::DocumentType,
18    id::{
19        CorrespondentId, CustomFieldId, DocumentId, DocumentTypeId, StoragePathId, TagId, TaskId,
20        UserId,
21    },
22    storage_path::StoragePath,
23    tag::Tag,
24    task::Task,
25    workflow::Workflow,
26};
27
28/// Selects which cached metadata to refresh.
29#[derive(Copy, Clone, Debug, PartialEq, Eq, Sequence)]
30#[non_exhaustive]
31pub enum RefreshData {
32    Tags,
33    CustomFields,
34    Correspondents,
35    DocumentTypes,
36    Users,
37    StoragePaths,
38}
39
40/// Client to interact with Paperless.
41#[derive(Debug, Clone)]
42pub struct PaperlessClient {
43    client: reqwest::Client,
44    pub(crate) base_url: Box<str>,
45    cached_data: Arc<CachedData>,
46}
47
48#[derive(Debug)]
49struct CachedData {
50    correspondents: HashMap<CorrespondentId, Correspondent>,
51    document_types: HashMap<DocumentTypeId, DocumentType>,
52    tags: HashMap<TagId, Tag>,
53    custom_fields: HashMap<CustomFieldId, CustomField>,
54    users: HashMap<UserId, User>,
55    storage_paths: HashMap<StoragePathId, StoragePath>,
56}
57
58#[derive(Debug, Deserialize)]
59struct PaginatedResponse<T> {
60    results: Vec<T>,
61    next: Option<String>,
62}
63
64impl PaperlessClient {
65    /// Create a new Paperless client.
66    pub fn new(
67        base_url: &str,
68        token: &str,
69        headers: Option<&HashMap<String, String>>,
70    ) -> std::result::Result<Self, String> {
71        let mut headers_map = HeaderMap::new();
72
73        // Add additional headers if provided
74        if let Some(headers) = headers {
75            for (key, value) in headers {
76                headers_map.insert(
77                    HeaderName::from_str(key).map_err(|err| err.to_string())?,
78                    value
79                        .parse()
80                        .map_err(|err: InvalidHeaderValue| err.to_string())?,
81                );
82            }
83        }
84
85        // Add the Paperless token header
86        headers_map.insert(
87            HeaderName::from_str("Authorization").map_err(|err| err.to_string())?,
88            format!("Token {token}")
89                .parse()
90                .map_err(|err: InvalidHeaderValue| err.to_string())?,
91        );
92
93        Ok(Self {
94            base_url: base_url.into(),
95            client: reqwest::Client::builder()
96                .default_headers(headers_map)
97                .zstd(true)
98                .build()
99                .map_err(|err| err.to_string())?,
100            cached_data: Arc::new(CachedData {
101                tags: HashMap::new(),
102                custom_fields: HashMap::new(),
103                correspondents: HashMap::new(),
104                document_types: HashMap::new(),
105                users: HashMap::new(),
106                storage_paths: HashMap::new(),
107            }),
108        })
109    }
110
111    async fn load_tags(&self) -> Result<HashMap<TagId, Tag>> {
112        debug!("loading tags");
113        let tags: Vec<Tag> = self.fetch_all_pages("/api/tags/").await?;
114        Ok(tags.into_iter().map(|tag| (tag.id, tag)).collect())
115    }
116
117    async fn load_custom_fields(&self) -> Result<HashMap<CustomFieldId, CustomField>> {
118        debug!("loading custom fields");
119        let custom_fields: Vec<CustomField> = self.fetch_all_pages("/api/custom_fields/").await?;
120        Ok(custom_fields
121            .into_iter()
122            .map(|custom_field| (custom_field.id, custom_field))
123            .collect())
124    }
125
126    async fn load_correspondents(&self) -> Result<HashMap<CorrespondentId, Correspondent>> {
127        debug!("loading correspondents");
128        let correspondents: Vec<Correspondent> =
129            self.fetch_all_pages("/api/correspondents/").await?;
130        Ok(correspondents
131            .into_iter()
132            .map(|correspondent| (correspondent.id, correspondent))
133            .collect())
134    }
135
136    async fn load_document_types(&self) -> Result<HashMap<DocumentTypeId, DocumentType>> {
137        debug!("loading document types");
138        let document_types: Vec<DocumentType> =
139            self.fetch_all_pages("/api/document_types/").await?;
140        Ok(document_types
141            .into_iter()
142            .map(|document_type| (document_type.id, document_type))
143            .collect())
144    }
145
146    async fn load_users(&self) -> Result<HashMap<UserId, User>> {
147        debug!("loading users");
148        let users: Vec<User> = self.fetch_all_pages("/api/users/").await?;
149        Ok(users.into_iter().map(|user| (user.id, user)).collect())
150    }
151
152    async fn load_storage_paths(&self) -> Result<HashMap<StoragePathId, StoragePath>> {
153        debug!("loading storage paths");
154        let storage_paths: Vec<StoragePath> = self.fetch_all_pages("/api/storage_paths/").await?;
155        Ok(storage_paths
156            .into_iter()
157            .map(|storage_path| (storage_path.id, storage_path))
158            .collect())
159    }
160
161    pub async fn refresh_all(&mut self) -> Result<()> {
162        self.refresh(enum_iterator::all::<RefreshData>()).await
163    }
164
165    /// Refresh selected cached metadata concurrently.
166    pub async fn refresh(&mut self, data: impl IntoIterator<Item = RefreshData>) -> Result<()> {
167        let mut refresh_tags = false;
168        let mut refresh_custom_fields = false;
169        let mut refresh_correspondents = false;
170        let mut refresh_document_types = false;
171        let mut refresh_users = false;
172        let mut refresh_storage_paths = false;
173
174        for item in data {
175            match item {
176                RefreshData::Tags => refresh_tags = true,
177                RefreshData::CustomFields => refresh_custom_fields = true,
178                RefreshData::Correspondents => refresh_correspondents = true,
179                RefreshData::DocumentTypes => refresh_document_types = true,
180                RefreshData::Users => refresh_users = true,
181                RefreshData::StoragePaths => refresh_storage_paths = true,
182            }
183        }
184
185        let (tags, custom_fields, correspondents, document_types, users, storage_paths) = futures_util::try_join!(
186            async {
187                if refresh_tags {
188                    Ok::<Option<HashMap<TagId, Tag>>, Error>(Some(self.load_tags().await?))
189                } else {
190                    Ok::<Option<HashMap<TagId, Tag>>, Error>(None)
191                }
192            },
193            async {
194                if refresh_custom_fields {
195                    Ok::<Option<HashMap<CustomFieldId, CustomField>>, Error>(Some(
196                        self.load_custom_fields().await?,
197                    ))
198                } else {
199                    Ok::<Option<HashMap<CustomFieldId, CustomField>>, Error>(None)
200                }
201            },
202            async {
203                if refresh_correspondents {
204                    Ok::<Option<HashMap<CorrespondentId, Correspondent>>, Error>(Some(
205                        self.load_correspondents().await?,
206                    ))
207                } else {
208                    Ok::<Option<HashMap<CorrespondentId, Correspondent>>, Error>(None)
209                }
210            },
211            async {
212                if refresh_document_types {
213                    Ok::<Option<HashMap<DocumentTypeId, DocumentType>>, Error>(Some(
214                        self.load_document_types().await?,
215                    ))
216                } else {
217                    Ok::<Option<HashMap<DocumentTypeId, DocumentType>>, Error>(None)
218                }
219            },
220            async {
221                if refresh_users {
222                    Ok::<Option<HashMap<UserId, User>>, Error>(Some(self.load_users().await?))
223                } else {
224                    Ok::<Option<HashMap<UserId, User>>, Error>(None)
225                }
226            },
227            async {
228                if refresh_storage_paths {
229                    Ok::<Option<HashMap<StoragePathId, StoragePath>>, Error>(Some(
230                        self.load_storage_paths().await?,
231                    ))
232                } else {
233                    Ok::<Option<HashMap<StoragePathId, StoragePath>>, Error>(None)
234                }
235            },
236        )?;
237
238        self.cached_data = Arc::new(CachedData {
239            correspondents: correspondents.unwrap_or_default(),
240            document_types: document_types.unwrap_or_default(),
241            tags: tags.unwrap_or_default(),
242            custom_fields: custom_fields.unwrap_or_default(),
243            users: users.unwrap_or_default(),
244            storage_paths: storage_paths.unwrap_or_default(),
245        });
246
247        Ok(())
248    }
249
250    /// Get all documents with any of the given tags.
251    pub async fn get_documents_by_tags(
252        &self,
253        tag_ids: &[TagId],
254        truncate_content: bool,
255    ) -> Result<Vec<Document>> {
256        let tag_id_str = tag_ids
257            .iter()
258            .map(|tag_id| tag_id.0.to_string())
259            .collect::<Vec<_>>()
260            .join(",");
261        let documents: Vec<_> = self
262            .fetch_all_pages::<DocumentData>(&format!(
263                "/api/documents/?truncate_content={truncate_content}&tags__id__in={tag_id_str}"
264            ))
265            .await?
266            .into_iter()
267            .map(|data| Document::new(data, Arc::new(self.clone()), truncate_content))
268            .collect();
269
270        Ok(documents)
271    }
272
273    pub(crate) async fn get_document_data_by_id(&self, id: DocumentId) -> Result<DocumentData> {
274        let resp = self
275            .request(Method::GET, &format!("/api/documents/{}/", id.0), None)
276            .await?;
277
278        let document_data: DocumentData = resp
279            .json()
280            .await
281            .map_err(|e| Error::Other(format!("Failed to parse document: {e}")))?;
282
283        Ok(document_data)
284    }
285
286    /// Get a document by its ID.
287    pub async fn get_document_by_id(&self, id: DocumentId) -> Result<Document> {
288        Ok(Document::new(
289            self.get_document_data_by_id(id).await?,
290            Arc::new(self.clone()),
291            false,
292        ))
293    }
294
295    pub(crate) async fn request(
296        &self,
297        method: Method,
298        endpoint: &str,
299        body: Option<&serde_json::Value>,
300    ) -> Result<reqwest::Response> {
301        let mut req = self
302            .client
303            .request(method, format!("{}{endpoint}", self.base_url))
304            .header(ACCEPT, "application/json");
305
306        // Set payload body if provided
307        if let Some(json_body) = body {
308            req = req.json(json_body);
309        }
310
311        let resp = req
312            .send()
313            .await
314            .map_err(|e| Error::Other(format!("Failed to send request: {e}")))?;
315
316        if resp.status() == StatusCode::NOT_FOUND {
317            return Err(Error::NotFound);
318        }
319
320        if !resp.status().is_success() {
321            return Err(Error::Response {
322                status_code: resp.status().as_u16(),
323                body: resp.text().await.unwrap_or_default(),
324            });
325        }
326
327        Ok(resp)
328    }
329
330    pub(crate) async fn fetch_all_pages<T: for<'de> Deserialize<'de>>(
331        &self,
332        endpoint: &str,
333    ) -> Result<Vec<T>> {
334        let mut results = Vec::new();
335        let mut current_url = Some(endpoint.to_string());
336
337        while let Some(url) = current_url {
338            let resp = self.request(Method::GET, &url, None).await?;
339
340            let page: PaginatedResponse<T> = resp.json().await.map_err(|e| {
341                Error::InvalidJson(format!(
342                    "Failed to parse paginated response for {endpoint}: {e}"
343                ))
344            })?;
345
346            results.extend(page.results);
347
348            current_url = page.next.and_then(|next_url| {
349                // Extract just the path from the full URL
350                next_url
351                    .trim_start_matches(&*self.base_url)
352                    .to_string()
353                    .into()
354            });
355        }
356
357        Ok(results)
358    }
359
360    /// Get all tasks with optional filtering by ID, name, or acknowledged status.
361    pub async fn get_task_status(
362        &self,
363        task_id: Option<&TaskId>,
364        task_name: Option<&str>,
365        acknowledged: Option<bool>,
366    ) -> Result<Vec<Task>> {
367        let mut query = Vec::new();
368
369        if let Some(id) = task_id {
370            query.push(("task_id", id.to_string()));
371        }
372
373        if let Some(name) = task_name {
374            query.push(("task_name", name.to_string()));
375        }
376
377        if let Some(ack) = acknowledged {
378            query.push(("acknowledged", ack.to_string()));
379        }
380
381        let resp = self
382            .request(
383                Method::GET,
384                &format!(
385                    "/api/tasks/?{}",
386                    serde_urlencoded::to_string(&query)
387                        .map_err(|e| Error::Other(format!("Failed to serialize query: {e}")))?
388                ),
389                None,
390            )
391            .await?;
392
393        trace!("get_task_status response: {:?}", resp);
394
395        let body = resp
396            .text()
397            .await
398            .map_err(|e| Error::Other(format!("Failed to read response body: {e}")))?;
399
400        let tasks: Vec<Task> = match serde_json::from_str(&body) {
401            Ok(t) => t,
402            Err(e) => {
403                return Err(Error::InvalidJson(format!(
404                    "Failed to parse response body: {e}"
405                )));
406            }
407        };
408
409        if tasks.is_empty() {
410            return Err(Error::NotFound);
411        }
412
413        Ok(tasks)
414    }
415
416    pub fn get_workflows(&self) -> impl Future<Output = Result<Vec<Workflow>>> {
417        self.fetch_all_pages("/api/workflows/")
418    }
419
420    /// Upload a document to Paperless.
421    ///
422    /// Returns the task ID on success.
423    pub async fn upload_document(&self, file_path: &Path, filename: &str) -> Result<TaskId> {
424        let file_bytes = std::fs::read(file_path)
425            .map_err(|e| Error::Other(format!("Failed to read file: {e}")))?;
426
427        let form = multipart::Form::new().part(
428            "document",
429            multipart::Part::bytes(file_bytes).file_name(filename.to_string()),
430        );
431
432        let url = format!("{}/api/documents/post_document/", self.base_url);
433
434        let resp = self
435            .client
436            .post(&url)
437            .multipart(form)
438            .send()
439            .await
440            .map_err(|e| Error::Other(format!("Failed to send request: {e}")))?;
441
442        let status = resp.status();
443        if !resp.status().is_success() {
444            return Err(Error::Response {
445                status_code: status.as_u16(),
446                body: resp.text().await.unwrap_or_default(),
447            });
448        }
449
450        let task_id: String = resp
451            .json()
452            .await
453            .map_err(|e| Error::Other(format!("Failed to parse task ID: {e}")))?;
454        Ok(TaskId(task_id))
455    }
456
457    #[inline]
458    #[must_use]
459    pub fn tags(&self) -> &HashMap<TagId, Tag> {
460        &self.cached_data.tags
461    }
462
463    #[inline]
464    #[must_use]
465    pub fn storage_paths(&self) -> &HashMap<StoragePathId, StoragePath> {
466        &self.cached_data.storage_paths
467    }
468
469    #[must_use]
470    pub fn find_tag_by_name(&self, name: &str) -> Option<&Tag> {
471        self.cached_data.tags.values().find(|tag| tag.name == name)
472    }
473
474    #[inline]
475    #[must_use]
476    pub fn document_types(&self) -> &HashMap<DocumentTypeId, DocumentType> {
477        &self.cached_data.document_types
478    }
479
480    #[must_use]
481    pub fn find_document_type_by_name(&self, name: &str) -> Option<&DocumentType> {
482        self.cached_data
483            .document_types
484            .values()
485            .find(|dt| dt.name == name)
486    }
487
488    #[inline]
489    #[must_use]
490    pub fn correspondents(&self) -> &HashMap<CorrespondentId, Correspondent> {
491        &self.cached_data.correspondents
492    }
493
494    #[inline]
495    #[must_use]
496    pub fn custom_fields(&self) -> &HashMap<CustomFieldId, CustomField> {
497        &self.cached_data.custom_fields
498    }
499
500    #[must_use]
501    pub fn find_custom_field_by_name(&self, name: &str) -> Option<&CustomField> {
502        self.cached_data
503            .custom_fields
504            .values()
505            .find(|field| field.name == name)
506    }
507
508    #[inline]
509    #[must_use]
510    pub fn users(&self) -> &HashMap<UserId, User> {
511        &self.cached_data.users
512    }
513}