Skip to main content

paperless_api/
client.rs

1use std::{collections::HashMap, path::Path, str::FromStr, sync::Arc};
2
3use reqwest::{
4    Method, StatusCode,
5    header::{ACCEPT, HeaderMap, HeaderName, InvalidHeaderValue},
6    multipart,
7};
8use serde::Deserialize;
9use tracing::{debug, trace};
10
11use crate::{
12    Error, Result, User,
13    correspondent::{Correspondent, CorrespondentId},
14    custom_field::{CustomField, CustomFieldId},
15    document::{Document, DocumentData, DocumentId},
16    document_type::{DocumentType, DocumentTypeId},
17    tag::{Tag, TagId},
18    task::{Task, TaskId},
19    user::UserId,
20};
21
22/// Selects which cached metadata to refresh.
23#[derive(Copy, Clone, Debug, PartialEq, Eq)]
24pub enum RefreshData {
25    Tags,
26    CustomFields,
27    Correspondents,
28    DocumentTypes,
29    Users,
30}
31
32/// Client to interact with Paperless.
33#[derive(Debug, Clone)]
34pub struct PaperlessClient {
35    client: reqwest::Client,
36    base_url: String,
37
38    correspondents: HashMap<CorrespondentId, Correspondent>,
39    document_types: HashMap<DocumentTypeId, DocumentType>,
40    tags: HashMap<TagId, Tag>,
41    custom_fields: HashMap<CustomFieldId, CustomField>,
42    users: HashMap<UserId, User>,
43}
44
45#[derive(Debug, Deserialize)]
46struct PaginatedResponse<T> {
47    results: Vec<T>,
48    next: Option<String>,
49}
50
51impl PaperlessClient {
52    /// Create a new Paperless client.
53    pub fn new(
54        base_url: &str,
55        token: &str,
56        headers: Option<&HashMap<String, String>>,
57    ) -> std::result::Result<Self, String> {
58        let mut headers_map = HeaderMap::new();
59
60        // Add additional headers if provided
61        if let Some(headers) = headers {
62            for (key, value) in headers {
63                headers_map.insert(
64                    HeaderName::from_str(key).map_err(|err| err.to_string())?,
65                    value
66                        .parse()
67                        .map_err(|err: InvalidHeaderValue| err.to_string())?,
68                );
69            }
70        }
71
72        // Add the Paperless token header
73        headers_map.insert(
74            HeaderName::from_str("Authorization").map_err(|err| err.to_string())?,
75            format!("Token {token}")
76                .parse()
77                .map_err(|err: InvalidHeaderValue| err.to_string())?,
78        );
79
80        Ok(Self {
81            base_url: base_url.to_string(),
82            client: reqwest::Client::builder()
83                .default_headers(headers_map)
84                .zstd(true)
85                .build()
86                .map_err(|err| err.to_string())?,
87            tags: HashMap::new(),
88            custom_fields: HashMap::new(),
89            correspondents: HashMap::new(),
90            document_types: HashMap::new(),
91            users: HashMap::new(),
92        })
93    }
94
95    async fn load_tags(&self) -> Result<HashMap<TagId, Tag>> {
96        debug!("loading tags");
97        let tags: Vec<Tag> = self.fetch_all_pages("/api/tags/").await?;
98        Ok(tags.into_iter().map(|tag| (tag.id, tag)).collect())
99    }
100
101    async fn load_custom_fields(&self) -> Result<HashMap<CustomFieldId, CustomField>> {
102        debug!("loading custom fields");
103        let custom_fields: Vec<CustomField> = self.fetch_all_pages("/api/custom_fields/").await?;
104        Ok(custom_fields
105            .into_iter()
106            .map(|custom_field| (custom_field.id, custom_field))
107            .collect())
108    }
109
110    async fn load_correspondents(&self) -> Result<HashMap<CorrespondentId, Correspondent>> {
111        debug!("loading correspondents");
112        let correspondents: Vec<Correspondent> =
113            self.fetch_all_pages("/api/correspondents/").await?;
114        Ok(correspondents
115            .into_iter()
116            .map(|correspondent| (correspondent.id, correspondent))
117            .collect())
118    }
119
120    async fn load_document_types(&self) -> Result<HashMap<DocumentTypeId, DocumentType>> {
121        debug!("loading document types");
122        let document_types: Vec<DocumentType> =
123            self.fetch_all_pages("/api/document_types/").await?;
124        Ok(document_types
125            .into_iter()
126            .map(|document_type| (document_type.id, document_type))
127            .collect())
128    }
129
130    async fn load_users(&self) -> Result<HashMap<UserId, User>> {
131        debug!("loading users");
132        let users: Vec<User> = self.fetch_all_pages("/api/users/").await?;
133        Ok(users.into_iter().map(|user| (user.id, user)).collect())
134    }
135
136    pub async fn refresh_all(&mut self) -> Result<()> {
137        self.refresh([
138            RefreshData::Tags,
139            RefreshData::CustomFields,
140            RefreshData::Correspondents,
141            RefreshData::DocumentTypes,
142            RefreshData::Users,
143        ])
144        .await
145    }
146
147    /// Refresh selected cached metadata concurrently.
148    pub async fn refresh(&mut self, data: impl IntoIterator<Item = RefreshData>) -> Result<()> {
149        let mut refresh_tags = false;
150        let mut refresh_custom_fields = false;
151        let mut refresh_correspondents = false;
152        let mut refresh_document_types = false;
153        let mut refresh_users = false;
154
155        for item in data {
156            match item {
157                RefreshData::Tags => refresh_tags = true,
158                RefreshData::CustomFields => refresh_custom_fields = true,
159                RefreshData::Correspondents => refresh_correspondents = true,
160                RefreshData::DocumentTypes => refresh_document_types = true,
161                RefreshData::Users => refresh_users = true,
162            }
163        }
164
165        let (tags, custom_fields, correspondents, document_types, users) = futures_util::try_join!(
166            async {
167                if refresh_tags {
168                    Ok::<Option<HashMap<TagId, Tag>>, Error>(Some(self.load_tags().await?))
169                } else {
170                    Ok::<Option<HashMap<TagId, Tag>>, Error>(None)
171                }
172            },
173            async {
174                if refresh_custom_fields {
175                    Ok::<Option<HashMap<CustomFieldId, CustomField>>, Error>(Some(
176                        self.load_custom_fields().await?,
177                    ))
178                } else {
179                    Ok::<Option<HashMap<CustomFieldId, CustomField>>, Error>(None)
180                }
181            },
182            async {
183                if refresh_correspondents {
184                    Ok::<Option<HashMap<CorrespondentId, Correspondent>>, Error>(Some(
185                        self.load_correspondents().await?,
186                    ))
187                } else {
188                    Ok::<Option<HashMap<CorrespondentId, Correspondent>>, Error>(None)
189                }
190            },
191            async {
192                if refresh_document_types {
193                    Ok::<Option<HashMap<DocumentTypeId, DocumentType>>, Error>(Some(
194                        self.load_document_types().await?,
195                    ))
196                } else {
197                    Ok::<Option<HashMap<DocumentTypeId, DocumentType>>, Error>(None)
198                }
199            },
200            async {
201                if refresh_users {
202                    Ok::<Option<HashMap<UserId, User>>, Error>(Some(self.load_users().await?))
203                } else {
204                    Ok::<Option<HashMap<UserId, User>>, Error>(None)
205                }
206            },
207        )?;
208
209        if let Some(tags) = tags {
210            self.tags = tags;
211        }
212
213        if let Some(custom_fields) = custom_fields {
214            self.custom_fields = custom_fields;
215        }
216
217        if let Some(correspondents) = correspondents {
218            self.correspondents = correspondents;
219        }
220
221        if let Some(document_types) = document_types {
222            self.document_types = document_types;
223        }
224
225        if let Some(users) = users {
226            self.users = users;
227        }
228
229        Ok(())
230    }
231
232    /// Refresh tags.
233    #[inline]
234    pub async fn refresh_tags(&mut self) -> Result<()> {
235        self.refresh([RefreshData::Tags]).await
236    }
237
238    /// Refresh custom fields.
239    #[inline]
240    pub async fn refresh_custom_fields(&mut self) -> Result<()> {
241        self.refresh([RefreshData::CustomFields]).await
242    }
243
244    /// Refresh correspondents.
245    #[inline]
246    pub async fn refresh_correspondents(&mut self) -> Result<()> {
247        self.refresh([RefreshData::Correspondents]).await
248    }
249
250    /// Refresh document types.
251    #[inline]
252    pub async fn refresh_document_types(&mut self) -> Result<()> {
253        self.refresh([RefreshData::DocumentTypes]).await
254    }
255
256    /// Refresh users.
257    #[inline]
258    pub async fn refresh_users(&mut self) -> Result<()> {
259        self.refresh([RefreshData::Users]).await
260    }
261
262    /// Get all documents with any of the given tags.
263    pub async fn get_documents_by_tags(
264        &self,
265        tag_ids: &[TagId],
266        truncate_content: bool,
267    ) -> Result<Vec<Document>> {
268        let tag_id_str = tag_ids
269            .iter()
270            .map(|tag_id| tag_id.0.to_string())
271            .collect::<Vec<_>>()
272            .join(",");
273        let documents: Vec<_> = self
274            .fetch_all_pages::<DocumentData>(&format!(
275                "/api/documents/?truncate_content={truncate_content}&tags__id__in={tag_id_str}"
276            ))
277            .await?
278            .into_iter()
279            .map(|data| Document::new(data, Arc::new(self.clone()), truncate_content))
280            .collect();
281
282        Ok(documents)
283    }
284
285    pub(crate) async fn get_document_data_by_id(&self, id: DocumentId) -> Result<DocumentData> {
286        let resp = self
287            .request(Method::GET, &format!("/api/documents/{}/", id.0), None)
288            .await?;
289
290        let document_data: DocumentData = resp
291            .json()
292            .await
293            .map_err(|e| Error::Other(format!("Failed to parse document: {e}")))?;
294
295        Ok(document_data)
296    }
297
298    /// Get a document by its ID.
299    pub async fn get_document_by_id(&self, id: DocumentId) -> Result<Document> {
300        Ok(Document::new(
301            self.get_document_data_by_id(id).await?,
302            Arc::new(self.clone()),
303            false,
304        ))
305    }
306
307    pub(crate) async fn request(
308        &self,
309        method: Method,
310        endpoint: &str,
311        body: Option<&serde_json::Value>,
312    ) -> Result<reqwest::Response> {
313        let mut req = self
314            .client
315            .request(method, format!("{}{endpoint}", self.base_url))
316            .header(ACCEPT, "application/json");
317
318        // Set payload body if provided
319        if let Some(json_body) = body {
320            req = req.json(json_body);
321        }
322
323        let resp = req
324            .send()
325            .await
326            .map_err(|e| Error::Other(format!("Failed to send request: {e}")))?;
327
328        if resp.status() == StatusCode::NOT_FOUND {
329            return Err(Error::NotFound);
330        }
331
332        if !resp.status().is_success() {
333            return Err(Error::Response {
334                status_code: resp.status().as_u16(),
335                body: resp.text().await.unwrap_or_default(),
336            });
337        }
338
339        Ok(resp)
340    }
341
342    pub(crate) async fn fetch_all_pages<T: for<'de> Deserialize<'de>>(
343        &self,
344        endpoint: &str,
345    ) -> Result<Vec<T>> {
346        let mut results = Vec::new();
347        let mut current_url = Some(endpoint.to_string());
348
349        while let Some(url) = current_url {
350            let resp = self.request(Method::GET, &url, None).await?;
351
352            let page: PaginatedResponse<T> = resp.json().await.map_err(|e| {
353                Error::Other(format!(
354                    "Failed to parse paginated response for {endpoint}: {e}"
355                ))
356            })?;
357
358            results.extend(page.results);
359
360            current_url = page.next.and_then(|next_url| {
361                // Extract just the path from the full URL
362                next_url
363                    .trim_start_matches(&self.base_url)
364                    .to_string()
365                    .into()
366            });
367        }
368
369        Ok(results)
370    }
371
372    /// Get all tasks with optional filtering by ID, name, or acknowledged status.
373    pub async fn get_task_status(
374        &self,
375        task_id: Option<&TaskId>,
376        task_name: Option<&str>,
377        acknowledged: Option<bool>,
378    ) -> Result<Vec<Task>> {
379        let mut query = Vec::new();
380
381        if let Some(id) = task_id {
382            query.push(("task_id", id.to_string()));
383        }
384
385        if let Some(name) = task_name {
386            query.push(("task_name", name.to_string()));
387        }
388
389        if let Some(ack) = acknowledged {
390            query.push(("acknowledged", ack.to_string()));
391        }
392
393        let resp = self
394            .request(
395                Method::GET,
396                &format!(
397                    "/api/tasks/?{}",
398                    serde_urlencoded::to_string(&query)
399                        .map_err(|e| Error::Other(format!("Failed to serialize query: {e}")))?
400                ),
401                None,
402            )
403            .await?;
404
405        trace!("get_task_status response: {:?}", resp);
406
407        let body = resp
408            .text()
409            .await
410            .map_err(|e| Error::Other(format!("Failed to read response body: {e}")))?;
411
412        let tasks: Vec<Task> = match serde_json::from_str(&body) {
413            Ok(t) => t,
414            Err(e) => {
415                return Err(Error::InvalidJson(format!(
416                    "Failed to parse response body: {e}"
417                )));
418            }
419        };
420
421        if tasks.is_empty() {
422            return Err(Error::NotFound);
423        }
424
425        Ok(tasks)
426    }
427
428    /// Upload a document to Paperless.
429    ///
430    /// Returns the task ID on success.
431    pub async fn upload_document(&self, file_path: &Path, filename: &str) -> Result<TaskId> {
432        let file_bytes = std::fs::read(file_path)
433            .map_err(|e| Error::Other(format!("Failed to read file: {e}")))?;
434
435        let form = multipart::Form::new().part(
436            "document",
437            multipart::Part::bytes(file_bytes).file_name(filename.to_string()),
438        );
439
440        let url = format!("{}/api/documents/post_document/", self.base_url);
441
442        let resp = self
443            .client
444            .post(&url)
445            .multipart(form)
446            .send()
447            .await
448            .map_err(|e| Error::Other(format!("Failed to send request: {e}")))?;
449
450        let status = resp.status();
451        if !resp.status().is_success() {
452            return Err(Error::Response {
453                status_code: status.as_u16(),
454                body: resp.text().await.unwrap_or_default(),
455            });
456        }
457
458        let task_id: String = resp
459            .json()
460            .await
461            .map_err(|e| Error::Other(format!("Failed to parse task ID: {e}")))?;
462        Ok(TaskId(task_id))
463    }
464
465    #[inline]
466    #[must_use]
467    pub fn tags(&self) -> &HashMap<TagId, Tag> {
468        &self.tags
469    }
470
471    #[must_use]
472    pub fn find_tag_by_name(&self, name: &str) -> Option<&Tag> {
473        self.tags.values().find(|tag| tag.name == name)
474    }
475
476    #[inline]
477    #[must_use]
478    pub fn document_types(&self) -> &HashMap<DocumentTypeId, DocumentType> {
479        &self.document_types
480    }
481
482    #[must_use]
483    pub fn find_document_type_by_name(&self, name: &str) -> Option<&DocumentType> {
484        self.document_types.values().find(|dt| dt.name == name)
485    }
486
487    #[inline]
488    #[must_use]
489    pub fn correspondents(&self) -> &HashMap<CorrespondentId, Correspondent> {
490        &self.correspondents
491    }
492
493    #[inline]
494    #[must_use]
495    pub fn custom_fields(&self) -> &HashMap<CustomFieldId, CustomField> {
496        &self.custom_fields
497    }
498
499    #[must_use]
500    pub fn find_custom_field_by_name(&self, name: &str) -> Option<&CustomField> {
501        self.custom_fields.values().find(|field| field.name == name)
502    }
503
504    #[inline]
505    #[must_use]
506    pub fn users(&self) -> &HashMap<UserId, User> {
507        &self.users
508    }
509}