Skip to main content

paperless_api/
client.rs

1use std::{collections::HashMap, path::Path, str::FromStr, sync::Arc};
2
3use reqwest::{
4    Method, StatusCode,
5    header::{ACCEPT, HeaderMap, HeaderName, InvalidHeaderValue},
6    multipart,
7};
8use serde::Deserialize;
9use tracing::{debug, trace};
10
11use crate::{
12    Error, Result, User,
13    correspondent::{Correspondent, CorrespondentId},
14    custom_field::{CustomField, CustomFieldId},
15    document::{Document, DocumentData, DocumentId},
16    document_type::{DocumentType, DocumentTypeId},
17    tag::{Tag, TagId},
18    task::{Task, TaskId},
19    user::UserId,
20};
21
22/// Selects which cached metadata to refresh.
23#[derive(Copy, Clone, Debug, PartialEq, Eq)]
24pub enum RefreshData {
25    Tags,
26    CustomFields,
27    Correspondents,
28    DocumentTypes,
29    Users,
30}
31
32/// Client to interact with Paperless.
33#[derive(Debug, Clone)]
34pub struct PaperlessClient {
35    client: reqwest::Client,
36    base_url: String,
37
38    correspondents: HashMap<CorrespondentId, Correspondent>,
39    document_types: HashMap<DocumentTypeId, DocumentType>,
40    tags: HashMap<TagId, Tag>,
41    custom_fields: HashMap<CustomFieldId, CustomField>,
42    users: HashMap<UserId, User>,
43}
44
45#[derive(Debug, Deserialize)]
46struct PaginatedResponse<T> {
47    results: Vec<T>,
48    next: Option<String>,
49}
50
51impl PaperlessClient {
52    /// Create a new Paperless client.
53    pub fn new(
54        base_url: &str,
55        token: &str,
56        headers: Option<&HashMap<String, String>>,
57    ) -> std::result::Result<Self, String> {
58        let mut headers_map = HeaderMap::new();
59
60        // Add additional headers if provided
61        if let Some(headers) = headers {
62            for (key, value) in headers {
63                headers_map.insert(
64                    HeaderName::from_str(key).map_err(|err| err.to_string())?,
65                    value
66                        .parse()
67                        .map_err(|err: InvalidHeaderValue| err.to_string())?,
68                );
69            }
70        }
71
72        // Add the Paperless token header
73        headers_map.insert(
74            HeaderName::from_str("Authorization").map_err(|err| err.to_string())?,
75            format!("Token {token}")
76                .parse()
77                .map_err(|err: InvalidHeaderValue| err.to_string())?,
78        );
79
80        Ok(Self {
81            base_url: base_url.to_string(),
82            client: reqwest::Client::builder()
83                .default_headers(headers_map)
84                .build()
85                .map_err(|err| err.to_string())?,
86            tags: HashMap::new(),
87            custom_fields: HashMap::new(),
88            correspondents: HashMap::new(),
89            document_types: HashMap::new(),
90            users: HashMap::new(),
91        })
92    }
93
94    async fn load_tags(&self) -> Result<HashMap<TagId, Tag>> {
95        debug!("loading tags");
96        let tags: Vec<Tag> = self.fetch_all_pages("/api/tags/").await?;
97        Ok(tags.into_iter().map(|tag| (tag.id, tag)).collect())
98    }
99
100    async fn load_custom_fields(&self) -> Result<HashMap<CustomFieldId, CustomField>> {
101        debug!("loading custom fields");
102        let custom_fields: Vec<CustomField> = self.fetch_all_pages("/api/custom_fields/").await?;
103        Ok(custom_fields
104            .into_iter()
105            .map(|custom_field| (custom_field.id, custom_field))
106            .collect())
107    }
108
109    async fn load_correspondents(&self) -> Result<HashMap<CorrespondentId, Correspondent>> {
110        debug!("loading correspondents");
111        let correspondents: Vec<Correspondent> =
112            self.fetch_all_pages("/api/correspondents/").await?;
113        Ok(correspondents
114            .into_iter()
115            .map(|correspondent| (correspondent.id, correspondent))
116            .collect())
117    }
118
119    async fn load_document_types(&self) -> Result<HashMap<DocumentTypeId, DocumentType>> {
120        debug!("loading document types");
121        let document_types: Vec<DocumentType> =
122            self.fetch_all_pages("/api/document_types/").await?;
123        Ok(document_types
124            .into_iter()
125            .map(|document_type| (document_type.id, document_type))
126            .collect())
127    }
128
129    async fn load_users(&self) -> Result<HashMap<UserId, User>> {
130        debug!("loading users");
131        let users: Vec<User> = self.fetch_all_pages("/api/users/").await?;
132        Ok(users.into_iter().map(|user| (user.id, user)).collect())
133    }
134
135    pub async fn refresh_all(&mut self) -> Result<()> {
136        self.refresh([
137            RefreshData::Tags,
138            RefreshData::CustomFields,
139            RefreshData::Correspondents,
140            RefreshData::DocumentTypes,
141            RefreshData::Users,
142        ])
143        .await
144    }
145
146    /// Refresh selected cached metadata concurrently.
147    pub async fn refresh(&mut self, data: impl IntoIterator<Item = RefreshData>) -> Result<()> {
148        let mut refresh_tags = false;
149        let mut refresh_custom_fields = false;
150        let mut refresh_correspondents = false;
151        let mut refresh_document_types = false;
152        let mut refresh_users = false;
153
154        for item in data {
155            match item {
156                RefreshData::Tags => refresh_tags = true,
157                RefreshData::CustomFields => refresh_custom_fields = true,
158                RefreshData::Correspondents => refresh_correspondents = true,
159                RefreshData::DocumentTypes => refresh_document_types = true,
160                RefreshData::Users => refresh_users = true,
161            }
162        }
163
164        let (tags, custom_fields, correspondents, document_types, users) = futures_util::try_join!(
165            async {
166                if refresh_tags {
167                    Ok::<Option<HashMap<TagId, Tag>>, Error>(Some(self.load_tags().await?))
168                } else {
169                    Ok::<Option<HashMap<TagId, Tag>>, Error>(None)
170                }
171            },
172            async {
173                if refresh_custom_fields {
174                    Ok::<Option<HashMap<CustomFieldId, CustomField>>, Error>(Some(
175                        self.load_custom_fields().await?,
176                    ))
177                } else {
178                    Ok::<Option<HashMap<CustomFieldId, CustomField>>, Error>(None)
179                }
180            },
181            async {
182                if refresh_correspondents {
183                    Ok::<Option<HashMap<CorrespondentId, Correspondent>>, Error>(Some(
184                        self.load_correspondents().await?,
185                    ))
186                } else {
187                    Ok::<Option<HashMap<CorrespondentId, Correspondent>>, Error>(None)
188                }
189            },
190            async {
191                if refresh_document_types {
192                    Ok::<Option<HashMap<DocumentTypeId, DocumentType>>, Error>(Some(
193                        self.load_document_types().await?,
194                    ))
195                } else {
196                    Ok::<Option<HashMap<DocumentTypeId, DocumentType>>, Error>(None)
197                }
198            },
199            async {
200                if refresh_users {
201                    Ok::<Option<HashMap<UserId, User>>, Error>(Some(self.load_users().await?))
202                } else {
203                    Ok::<Option<HashMap<UserId, User>>, Error>(None)
204                }
205            },
206        )?;
207
208        if let Some(tags) = tags {
209            self.tags = tags;
210        }
211
212        if let Some(custom_fields) = custom_fields {
213            self.custom_fields = custom_fields;
214        }
215
216        if let Some(correspondents) = correspondents {
217            self.correspondents = correspondents;
218        }
219
220        if let Some(document_types) = document_types {
221            self.document_types = document_types;
222        }
223
224        if let Some(users) = users {
225            self.users = users;
226        }
227
228        Ok(())
229    }
230
231    /// Refresh tags.
232    #[inline]
233    pub async fn refresh_tags(&mut self) -> Result<()> {
234        self.refresh([RefreshData::Tags]).await
235    }
236
237    /// Refresh custom fields.
238    #[inline]
239    pub async fn refresh_custom_fields(&mut self) -> Result<()> {
240        self.refresh([RefreshData::CustomFields]).await
241    }
242
243    /// Refresh correspondents.
244    #[inline]
245    pub async fn refresh_correspondents(&mut self) -> Result<()> {
246        self.refresh([RefreshData::Correspondents]).await
247    }
248
249    /// Refresh document types.
250    #[inline]
251    pub async fn refresh_document_types(&mut self) -> Result<()> {
252        self.refresh([RefreshData::DocumentTypes]).await
253    }
254
255    /// Refresh users.
256    #[inline]
257    pub async fn refresh_users(&mut self) -> Result<()> {
258        self.refresh([RefreshData::Users]).await
259    }
260
261    /// Get all documents with any of the given tags.
262    pub async fn get_documents_by_tags(
263        &self,
264        tag_ids: &[TagId],
265        truncate_content: bool,
266    ) -> Result<Vec<Document>> {
267        let tag_id_str = tag_ids
268            .iter()
269            .map(|tag_id| tag_id.0.to_string())
270            .collect::<Vec<_>>()
271            .join(",");
272        let documents: Vec<_> = self
273            .fetch_all_pages::<DocumentData>(&format!(
274                "/api/documents/?truncate_content={truncate_content}&tags__id__in={tag_id_str}"
275            ))
276            .await?
277            .into_iter()
278            .map(|data| Document::new(data, Arc::new(self.clone()), truncate_content))
279            .collect();
280
281        Ok(documents)
282    }
283
284    pub(crate) async fn get_document_data_by_id(&self, id: DocumentId) -> Result<DocumentData> {
285        let resp = self
286            .request(Method::GET, &format!("/api/documents/{}/", id.0), None)
287            .await?;
288
289        let document_data: DocumentData = resp
290            .json()
291            .await
292            .map_err(|e| Error::Other(format!("Failed to parse document: {e}")))?;
293
294        Ok(document_data)
295    }
296
297    /// Get a document by its ID.
298    pub async fn get_document_by_id(&self, id: DocumentId) -> Result<Document> {
299        Ok(Document::new(
300            self.get_document_data_by_id(id).await?,
301            Arc::new(self.clone()),
302            false,
303        ))
304    }
305
306    pub(crate) async fn request(
307        &self,
308        method: Method,
309        endpoint: &str,
310        body: Option<&serde_json::Value>,
311    ) -> Result<reqwest::Response> {
312        let mut req = self
313            .client
314            .request(method, format!("{}{endpoint}", self.base_url))
315            .header(ACCEPT, "application/json");
316
317        if let Some(json_body) = body {
318            req = req.json(json_body);
319        }
320
321        let resp = req
322            .send()
323            .await
324            .map_err(|e| Error::Other(format!("Failed to send request: {e}")))?;
325
326        if resp.status() == StatusCode::NOT_FOUND {
327            return Err(Error::NotFound);
328        }
329
330        if !resp.status().is_success() {
331            return Err(Error::Response {
332                status_code: resp.status().as_u16(),
333                body: resp.text().await.unwrap_or_default(),
334            });
335        }
336
337        Ok(resp)
338    }
339
340    pub(crate) async fn fetch_all_pages<T: for<'de> Deserialize<'de>>(
341        &self,
342        endpoint: &str,
343    ) -> Result<Vec<T>> {
344        let mut results = Vec::new();
345        let mut current_url = Some(endpoint.to_string());
346
347        while let Some(url) = current_url {
348            let resp = self.request(Method::GET, &url, None).await?;
349
350            let page: PaginatedResponse<T> = resp.json().await.map_err(|e| {
351                Error::Other(format!(
352                    "Failed to parse paginated response for {endpoint}: {e}"
353                ))
354            })?;
355
356            results.extend(page.results);
357
358            current_url = page.next.and_then(|next_url| {
359                // Extract just the path from the full URL
360                next_url
361                    .trim_start_matches(&self.base_url)
362                    .to_string()
363                    .into()
364            });
365        }
366
367        Ok(results)
368    }
369
370    /// Get all tasks with optional filtering by ID, name, or acknowledged status.
371    pub async fn get_task_status(
372        &self,
373        task_id: Option<&TaskId>,
374        task_name: Option<&str>,
375        acknowledged: Option<bool>,
376    ) -> Result<Vec<Task>> {
377        let mut query = Vec::new();
378
379        if let Some(id) = task_id {
380            query.push(("task_id", id.to_string()));
381        }
382
383        if let Some(name) = task_name {
384            query.push(("task_name", name.to_string()));
385        }
386
387        if let Some(ack) = acknowledged {
388            query.push(("acknowledged", ack.to_string()));
389        }
390
391        let resp = self
392            .request(
393                Method::GET,
394                &format!(
395                    "/api/tasks/?{}",
396                    serde_urlencoded::to_string(&query)
397                        .map_err(|e| Error::Other(format!("Failed to serialize query: {e}")))?
398                ),
399                None,
400            )
401            .await?;
402
403        trace!("get_task_status response: {:?}", resp);
404
405        let body = resp
406            .text()
407            .await
408            .map_err(|e| Error::Other(format!("Failed to read response body: {e}")))?;
409
410        let tasks: Vec<Task> = match serde_json::from_str(&body) {
411            Ok(t) => t,
412            Err(e) => {
413                return Err(Error::InvalidJson(format!(
414                    "Failed to parse response body: {e}"
415                )));
416            }
417        };
418
419        if tasks.is_empty() {
420            return Err(Error::NotFound);
421        }
422
423        Ok(tasks)
424    }
425
426    /// Upload a document to Paperless.
427    ///
428    /// Returns the task ID on success.
429    pub async fn upload_document(&self, file_path: &Path, filename: &str) -> Result<TaskId> {
430        let file_bytes = std::fs::read(file_path)
431            .map_err(|e| Error::Other(format!("Failed to read file: {e}")))?;
432
433        let form = multipart::Form::new().part(
434            "document",
435            multipart::Part::bytes(file_bytes).file_name(filename.to_string()),
436        );
437
438        let url = format!("{}/api/documents/post_document/", self.base_url);
439
440        let resp = self
441            .client
442            .post(&url)
443            .multipart(form)
444            .send()
445            .await
446            .map_err(|e| Error::Other(format!("Failed to send request: {e}")))?;
447
448        let status = resp.status();
449        if !resp.status().is_success() {
450            return Err(Error::Response {
451                status_code: status.as_u16(),
452                body: resp.text().await.unwrap_or_default(),
453            });
454        }
455
456        let task_id: String = resp
457            .json()
458            .await
459            .map_err(|e| Error::Other(format!("Failed to parse task ID: {e}")))?;
460        Ok(TaskId(task_id))
461    }
462
463    #[inline]
464    #[must_use]
465    pub fn tags(&self) -> &HashMap<TagId, Tag> {
466        &self.tags
467    }
468
469    #[must_use]
470    pub fn find_tag_by_name(&self, name: &str) -> Option<&Tag> {
471        self.tags.values().find(|tag| tag.name == name)
472    }
473
474    #[inline]
475    #[must_use]
476    pub fn document_types(&self) -> &HashMap<DocumentTypeId, DocumentType> {
477        &self.document_types
478    }
479
480    #[must_use]
481    pub fn find_document_type_by_name(&self, name: &str) -> Option<&DocumentType> {
482        self.document_types.values().find(|dt| dt.name == name)
483    }
484
485    #[inline]
486    #[must_use]
487    pub fn correspondents(&self) -> &HashMap<CorrespondentId, Correspondent> {
488        &self.correspondents
489    }
490
491    #[inline]
492    #[must_use]
493    pub fn custom_fields(&self) -> &HashMap<CustomFieldId, CustomField> {
494        &self.custom_fields
495    }
496
497    #[must_use]
498    pub fn find_custom_field_by_name(&self, name: &str) -> Option<&CustomField> {
499        self.custom_fields.values().find(|field| field.name == name)
500    }
501
502    #[inline]
503    #[must_use]
504    pub fn users(&self) -> &HashMap<UserId, User> {
505        &self.users
506    }
507}