Skip to main content

paperless_api/
client.rs

1//! The central client for interacting with Paperless.
2
3use std::{collections::HashMap, path::Path, str::FromStr, sync::Arc};
4
5use enum_iterator::Sequence;
6use reqwest::{
7    Method, StatusCode,
8    header::{ACCEPT, HeaderMap, HeaderName, InvalidHeaderValue},
9    multipart,
10};
11use serde::{Deserialize, de::DeserializeOwned};
12use tracing::{debug, trace};
13
14use crate::{
15    Error, Group, Result, SavedView, User,
16    document::{Document, DocumentData},
17    document_query::DocumentQueryBuilder,
18    dto::Item,
19    id::{
20        CorrespondentId, CustomFieldId, DocumentId, DocumentTypeId, GroupId, StoragePathId, TagId,
21        TaskId, UserId,
22    },
23    metadata::{
24        correspondent::Correspondent, custom_field::CustomField, document_type::DocumentType,
25        storage_path::StoragePath, tag::Tag,
26    },
27    task::Task,
28    util,
29    workflow::Workflow,
30};
31
32/// Selects which cached metadata to refresh.
33///
34/// Cached data is data which is rarely updated;
35/// refreshing it is normally not necessary on every request.
36#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Sequence)]
37#[non_exhaustive]
38pub enum RefreshMetaData {
39    Tags,
40    CustomFields,
41    Correspondents,
42    DocumentTypes,
43    Groups,
44    Users,
45    StoragePaths,
46}
47
48/// Client to interact with Paperless.
49#[derive(Debug, Clone)]
50pub struct PaperlessClient {
51    /// Whether to request full permissions data for items.
52    pub request_full_permissions: bool,
53
54    /// Whether to always request the full document content.
55    pub request_full_content: bool,
56
57    pub(crate) base_url: Arc<str>,
58
59    client: reqwest::Client,
60    cached_data: Arc<CachedData>,
61}
62
63#[derive(Debug, Clone)]
64struct CachedData {
65    correspondents: HashMap<CorrespondentId, Correspondent>,
66    custom_fields: HashMap<CustomFieldId, CustomField>,
67    document_types: HashMap<DocumentTypeId, DocumentType>,
68    groups: HashMap<GroupId, Group>,
69    storage_paths: HashMap<StoragePathId, StoragePath>,
70    tags: HashMap<TagId, Tag>,
71    users: HashMap<UserId, User>,
72}
73
74#[derive(Debug, Deserialize)]
75struct PaginatedResponse<T> {
76    results: Vec<T>,
77    next: Option<String>,
78}
79
80impl PaperlessClient {
81    /// Create a new Paperless client.
82    ///
83    /// # Arguments
84    ///
85    /// * `base_url` - The base URL of the Paperless API.
86    /// * `token` - The authentication token for the Paperless API.
87    /// * `headers` - Optional additional headers to include in requests.
88    pub fn new(
89        base_url: &str,
90        token: &str,
91        headers: Option<&HashMap<String, String>>,
92    ) -> std::result::Result<Self, String> {
93        Self::new_with_client(
94            base_url,
95            token,
96            headers,
97            reqwest::Client::builder().zstd(true),
98        )
99    }
100
101    /// Create a new Paperless client.
102    ///
103    /// Provide a [`reqwest::ClientBuilder`] to customize the HTTP client,
104    /// such as adding custom headers or disabling compression.
105    ///
106    /// # Arguments
107    ///
108    /// * `base_url` - The base URL of the Paperless API.
109    /// * `token` - The authentication token for the Paperless API.
110    /// * `headers` - Optional additional headers to include in requests.
111    /// * `client_builder` - [`reqwest::ClientBuilder`] to use for creating the HTTP client.
112    pub fn new_with_client(
113        base_url: &str,
114        token: &str,
115        headers: Option<&HashMap<String, String>>,
116        client_builder: reqwest::ClientBuilder,
117    ) -> std::result::Result<Self, String> {
118        let mut headers_map = HeaderMap::new();
119
120        // Add additional headers if provided
121        if let Some(headers) = headers {
122            for (key, value) in headers {
123                headers_map.insert(
124                    HeaderName::from_str(key).map_err(|err| err.to_string())?,
125                    value
126                        .parse()
127                        .map_err(|err: InvalidHeaderValue| err.to_string())?,
128                );
129            }
130        }
131
132        // Add the Paperless token header
133        headers_map.insert(
134            HeaderName::from_str("Authorization").map_err(|err| err.to_string())?,
135            format!("Token {token}")
136                .parse()
137                .map_err(|err: InvalidHeaderValue| err.to_string())?,
138        );
139
140        Ok(Self {
141            request_full_permissions: false,
142            request_full_content: false,
143            base_url: base_url.into(),
144            client: client_builder
145                .default_headers(headers_map)
146                .build()
147                .map_err(|err| err.to_string())?,
148            cached_data: Arc::new(CachedData {
149                custom_fields: HashMap::new(),
150                correspondents: HashMap::new(),
151                document_types: HashMap::new(),
152                groups: HashMap::new(),
153                storage_paths: HashMap::new(),
154                tags: HashMap::new(),
155                users: HashMap::new(),
156            }),
157        })
158    }
159
160    /// Sets whether to request full permissions data for items during refresh.
161    ///
162    /// If not enabled only simple permission data is loaded.
163    /// See [`ItemPermissions`](crate::metadata::permission::ItemPermissions) for more details.
164    #[must_use]
165    pub fn with_full_permissions(mut self, req: bool) -> Self {
166        self.request_full_permissions = req;
167        self
168    }
169
170    #[must_use]
171    pub fn with_full_content(mut self, full_content: bool) -> Self {
172        self.request_full_content = full_content;
173        self
174    }
175
176    /// Loads all items of the given type from the API.
177    async fn load_items<T: Item + DeserializeOwned>(&self) -> Result<HashMap<T::Id, T>> {
178        debug!("Loading {}", T::endpoint());
179        let endpoint = format!("/api/{}/", T::endpoint());
180
181        let items: Vec<T> = self
182            .fetch_all_pages(&endpoint, self.default_query_params().as_deref())
183            .await?;
184        Ok(items.into_iter().map(|item| (item.id(), item)).collect())
185    }
186
187    fn default_query_params(&self) -> Option<Vec<(&'static str, &'static str)>> {
188        let mut params = Vec::with_capacity(2);
189
190        if self.request_full_permissions {
191            params.push((crate::document_query::QUERY_PARAM_FULL_PERMISSIONS, "true"));
192        }
193        if !self.request_full_content {
194            params.push((crate::document_query::QUERY_PARAM_TRUNCATE_CONTENT, "true"));
195        }
196
197        if params.is_empty() {
198            None
199        } else {
200            Some(params)
201        }
202    }
203
204    /// Refresh and cache all metadata.
205    ///
206    /// Only updates the cache for this instance, cloned instances will not see the changes.
207    pub async fn refresh_all(&mut self) -> Result<()> {
208        self.refresh(enum_iterator::all::<RefreshMetaData>()).await
209    }
210
211    /// Refresh and cache the selected metadata.
212    ///
213    /// Only updates the cache for this instance, cloned instances will not see the changes.
214    ///
215    /// # Arguments
216    ///
217    /// * `data` - The metadata to refresh.
218    /// * `full_permissions` - Whether to use request full permissions data for the items being refreshed.
219    pub async fn refresh(&mut self, data: impl IntoIterator<Item = RefreshMetaData>) -> Result<()> {
220        #[rustfmt::skip]
221        async fn inner(
222            client: &mut PaperlessClient,
223            data: &mut dyn Iterator<Item = RefreshMetaData>,
224        ) -> Result<()> {
225            let selected: std::collections::HashSet<_> = data.into_iter().collect();
226
227            if selected.is_empty() {
228                return Ok(());
229            }
230
231            let (tags, custom_fields, correspondents, document_types, groups, users, storage_paths) =
232                futures_util::try_join!(
233                    async {
234                        if selected.contains(&RefreshMetaData::Tags) {
235                            Ok(Some(client.load_items::<Tag>().await?))
236                        } else {
237                            Ok::<Option<_>, Error>(None)
238                        }
239                    },
240                    async {
241                        if selected.contains(&RefreshMetaData::CustomFields) {
242                            Ok(Some(client.load_items::<CustomField>().await?))
243                        } else {
244                            Ok(None)
245                        }
246                    },
247                    async {
248                        if selected.contains(&RefreshMetaData::Correspondents) {
249                            Ok(Some(client.load_items::<Correspondent>().await?))
250                        } else {
251                            Ok(None)
252                        }
253                    },
254                    async {
255                        if selected.contains(&RefreshMetaData::DocumentTypes) {
256                            Ok(Some(client.load_items::<DocumentType>().await?))
257                        } else {
258                            Ok(None)
259                        }
260                    },
261                    async {
262                        if selected.contains(&RefreshMetaData::Groups) {
263                            Ok(Some(client.load_items::<Group>().await?))
264                        } else {
265                            Ok(None)
266                        }
267                    },
268                    async {
269                        if selected.contains(&RefreshMetaData::Users) {
270                            Ok(Some(client.load_items::<User>().await?))
271                        } else {
272                            Ok(None)
273                        }
274                    },
275                    async {
276                        if selected.contains(&RefreshMetaData::StoragePaths) {
277                            Ok(Some(client.load_items::<StoragePath>().await?))
278                        } else {
279                            Ok(None)
280                        }
281                    },
282                )?;
283
284            let cached_data = Arc::make_mut(&mut client.cached_data);
285
286            if let Some(value) = custom_fields { cached_data.custom_fields = value; }
287            if let Some(value) = correspondents { cached_data.correspondents = value; }
288            if let Some(value) = document_types { cached_data.document_types = value; }
289            if let Some(value) = groups { cached_data.groups = value; }
290            if let Some(value) = storage_paths { cached_data.storage_paths = value; }
291            if let Some(value) = tags { cached_data.tags = value; }
292            if let Some(value) = users { cached_data.users = value; }
293
294            Ok(())
295        }
296
297        inner(self, &mut data.into_iter()).await
298    }
299
300    /// Query documents using the given [`DocumentQueryBuilder`].
301    pub async fn query_documents(&self, query: DocumentQueryBuilder) -> Result<Vec<Document>> {
302        let full_content = query.full_content;
303        let query_params = query.build();
304        let query_vec: Vec<_> = query_params
305            .query
306            .iter()
307            .map(|(k, v)| (*k, v.as_str()))
308            .collect();
309        let query_slice = query_vec.as_slice();
310
311        let documents: Vec<_> = self
312            .fetch_all_pages::<DocumentData>("/api/documents/", Some(query_slice))
313            .await?
314            .into_iter()
315            .map(|data| Document::new(data, Arc::new(self.clone()), !full_content))
316            .collect();
317
318        Ok(documents)
319    }
320
321    /// Get all documents with any of the given tags.
322    pub fn get_documents_by_tags(
323        &self,
324        tag_ids: &[TagId],
325    ) -> impl Future<Output = Result<Vec<Document>>> {
326        let query = DocumentQueryBuilder::default()
327            .full_content(self.request_full_content)
328            .full_permissions(self.request_full_permissions)
329            .tags_id_in(tag_ids.to_vec());
330
331        self.query_documents(query)
332    }
333
334    pub(crate) async fn get_document_data_by_id(&self, id: DocumentId) -> Result<DocumentData> {
335        self.request_json(
336            Method::GET,
337            &format!("/api/documents/{}/", id.0),
338            None,
339            self.default_query_params().as_deref(),
340        )
341        .await
342    }
343
344    /// Get a document by its ID.
345    pub async fn get_document_by_id(&self, id: DocumentId) -> Result<Document> {
346        Ok(Document::new(
347            self.get_document_data_by_id(id).await?,
348            Arc::new(self.clone()),
349            false,
350        ))
351    }
352
353    /// Make a request and parse the response as JSON.
354    pub(crate) async fn request_json<T: serde::de::DeserializeOwned>(
355        &self,
356        method: Method,
357        endpoint: &str,
358        body: Option<&serde_json::Value>,
359        query_params: Option<&[(&str, &str)]>,
360    ) -> Result<T> {
361        let resp = self.request(method, endpoint, body, query_params).await?;
362
363        if tracing::enabled!(tracing::Level::TRACE) {
364            // Only log the response body if trace logging is enabled to avoid unnecessary overhead
365            let response_text = resp.text().await.unwrap_or_default();
366            trace!(body = %response_text, "Response");
367
368            Ok(serde_json::from_str(&response_text)
369                .map_err(|e| Error::InvalidJson(format!("Failed to parse response body: {e:?}")))?)
370        } else {
371            Ok(resp
372                .json()
373                .await
374                .map_err(|e| Error::InvalidJson(format!("Failed to parse response body: {e:?}")))?)
375        }
376    }
377
378    /// Make a request and return the raw [`reqwest::Response`].
379    pub(crate) async fn request(
380        &self,
381        method: Method,
382        endpoint: &str,
383        body: Option<&serde_json::Value>,
384        query_params: Option<&[(&str, &str)]>,
385    ) -> Result<reqwest::Response> {
386        let mut req = self
387            .client
388            .request(method, format!("{}{endpoint}", self.base_url))
389            .header(ACCEPT, "application/json");
390
391        if let Some(params) = query_params {
392            req = req.query(params);
393        }
394
395        // Set payload body if provided
396        if let Some(json_body) = body {
397            req = req.json(json_body);
398        }
399
400        let req = req.build().map_err(|e| Error::Request(e.into()))?;
401
402        if tracing::enabled!(tracing::Level::TRACE)
403            && let Some(body) = req.body().map(|b| b.as_bytes()).flatten()
404        {
405            trace!(
406                method = ?req.method(),
407                url = ?req.url(),
408                body = %String::from_utf8_lossy(body),
409                "Sending request to Paperless API");
410        } else {
411            debug!(
412                method = ?req.method(),
413                url = ?req.url(),
414                "Sending request to Paperless API");
415        }
416
417        let resp = self
418            .client
419            .execute(req)
420            .await
421            .map_err(|e| Error::Other(format!("Failed to send request: {e}")))?;
422
423        // Log the response body for debugging
424        debug!(status = ?resp.status(), "Response");
425
426        if resp.status() == StatusCode::NOT_FOUND {
427            return Err(Error::NotFound);
428        }
429
430        if !resp.status().is_success() {
431            return Err(Error::Response {
432                status_code: resp.status().as_u16(),
433                body: resp.text().await.unwrap_or_default(),
434            });
435        }
436
437        Ok(resp)
438    }
439
440    pub(crate) async fn fetch_all_pages<T: for<'de> Deserialize<'de>>(
441        &self,
442        endpoint: &str,
443        query_params: Option<&[(&str, &str)]>,
444    ) -> Result<Vec<T>> {
445        let mut results = vec![];
446        let mut all_query_params = self.default_query_params().unwrap_or_default();
447        all_query_params.extend(query_params.unwrap_or_default());
448        let mut all_query_params = Some(all_query_params);
449
450        let mut current_url = Some(endpoint.to_string());
451
452        while let Some(url) = current_url {
453            debug!("Fetching page: {url}");
454
455            let page: PaginatedResponse<T> = self
456                .request_json(Method::GET, &url, None, all_query_params.as_deref())
457                .await?;
458
459            results.extend(page.results);
460
461            current_url = page.next.and_then(|next_url| {
462                // Extract just the path from the full URL
463                next_url
464                    .strip_prefix(&*self.base_url)
465                    .unwrap_or(&next_url)
466                    .to_string()
467                    .into()
468            });
469            all_query_params = None;
470        }
471
472        Ok(results)
473    }
474
475    /// Get all tasks with optional filtering by ID, name, or acknowledged status.
476    pub async fn get_task_status(
477        &self,
478        task_id: Option<&TaskId>,
479        task_name: Option<&str>,
480        acknowledged: Option<bool>,
481    ) -> Result<Vec<Task>> {
482        let mut query = Vec::new();
483
484        if let Some(id) = task_id {
485            query.push(("task_id", id.to_string()));
486        }
487
488        if let Some(name) = task_name {
489            query.push(("task_name", name.to_string()));
490        }
491
492        if let Some(ack) = acknowledged {
493            query.push(("acknowledged", ack.to_string()));
494        }
495
496        let resp = self
497            .request(
498                Method::GET,
499                &format!(
500                    "/api/tasks/?{}",
501                    serde_urlencoded::to_string(&query)
502                        .map_err(|e| Error::Other(format!("Failed to serialize query: {e}")))?
503                ),
504                None::<&serde_json::Value>,
505                None,
506            )
507            .await?;
508
509        let body = resp
510            .text()
511            .await
512            .map_err(|e| Error::Other(format!("Failed to read response body: {e:?}")))?;
513
514        trace!("get_task_status response: {:?}", body);
515
516        let tasks: Vec<Task> = match serde_json::from_str(&body) {
517            Ok(t) => t,
518            Err(e) => {
519                return Err(Error::InvalidJson(format!(
520                    "Failed to parse response body: {e:?}"
521                )));
522            }
523        };
524
525        if tasks.is_empty() {
526            return Err(Error::NotFound);
527        }
528
529        Ok(tasks)
530    }
531
532    /// Get all workflows.
533    pub fn get_workflows(&self) -> impl Future<Output = Result<Vec<Workflow>>> {
534        self.fetch_all_pages("/api/workflows/", None)
535    }
536
537    /// Get all saved views.
538    pub fn get_saved_views(&self) -> impl Future<Output = Result<Vec<SavedView>>> {
539        self.fetch_all_pages("/api/saved_views/", None)
540    }
541
542    /// Get server statistics.
543    pub fn get_statistics(&self) -> impl Future<Output = Result<util::Statistics>> {
544        self.request_json(Method::GET, "/api/statistics/", None, None)
545    }
546
547    /// Get server status.
548    pub fn get_status(&self) -> impl Future<Output = Result<util::ServerStatus>> {
549        self.request_json(Method::GET, "/api/status/", None, None)
550    }
551
552    /// Create a new item in Paperless.
553    ///
554    /// All structs which implement [`CreateDtoObject`](crate::dto::CreateDtoObject) can be used as `new_item`.
555    ///
556    /// Returns the created item.
557    pub async fn create<T: Item>(&self, new_item: &T::CreateDto) -> Result<T::BaseType> {
558        let url = format!("/api/{}/", T::endpoint());
559        self.request_json(
560            Method::POST,
561            &url,
562            Some(&serde_json::to_value(&new_item).map_err(|e| Error::Other(e.to_string()))?),
563            None,
564        )
565        .await
566    }
567
568    /// Updates an existing item in Paperless.
569    ///
570    /// All structs which implement [`UpdateDtoObject`](crate::dto::UpdateDtoObject) can be used as `item`.
571    pub async fn update<T: Item>(&self, id: T::Id, update: &T::UpdateDto) -> Result<T::BaseType> {
572        let url = format!("/api/{}/{}/", T::endpoint(), id);
573        self.request_json::<T::BaseType>(
574            Method::PATCH,
575            &url,
576            Some(&serde_json::to_value(&update).map_err(|e| Error::Other(e.to_string()))?),
577            None,
578        )
579        .await
580    }
581
582    /// Deletes an existing item in Paperless.
583    pub async fn delete<T: Item>(&self, id: T::Id) -> Result<()> {
584        let url = format!("/api/{}/{}/", T::endpoint(), id);
585        self.request(Method::DELETE, &url, None, None).await?;
586        Ok(())
587    }
588
589    /// Upload a document to Paperless.
590    ///
591    /// Returns the task ID on success.
592    pub async fn upload_document(&self, file_path: &Path, filename: &str) -> Result<TaskId> {
593        let stream = tokio::fs::File::open(file_path)
594            .await
595            .map_err(|e| Error::Other(format!("Failed to open file: {e}")))?;
596
597        let form = multipart::Form::new().part(
598            "document",
599            multipart::Part::stream(stream).file_name(filename.to_string()),
600        );
601
602        let url = format!("{}/api/documents/post_document/", self.base_url);
603
604        let resp = self
605            .client
606            .post(&url)
607            .multipart(form)
608            .send()
609            .await
610            .map_err(|e| Error::Other(format!("Failed to send request: {e}")))?;
611
612        let status = resp.status();
613        if !resp.status().is_success() {
614            return Err(Error::Response {
615                status_code: status.as_u16(),
616                body: resp.text().await.unwrap_or_default(),
617            });
618        }
619
620        let task_id: String = resp
621            .json()
622            .await
623            .map_err(|e| Error::Other(format!("Failed to parse task ID: {e:?}")))?;
624        Ok(TaskId(task_id))
625    }
626
627    /// Get the tags cache.
628    #[inline]
629    #[must_use]
630    pub fn tags(&self) -> &HashMap<TagId, Tag> {
631        &self.cached_data.tags
632    }
633
634    /// Get the storage paths cache.
635    #[inline]
636    #[must_use]
637    pub fn storage_paths(&self) -> &HashMap<StoragePathId, StoragePath> {
638        &self.cached_data.storage_paths
639    }
640
641    /// Find a tag by its name.
642    #[must_use]
643    pub fn find_tag_by_name(&self, name: &str) -> Option<&Tag> {
644        self.cached_data.tags.values().find(|tag| tag.name == name)
645    }
646
647    /// Get the document types cache.
648    #[inline]
649    #[must_use]
650    pub fn document_types(&self) -> &HashMap<DocumentTypeId, DocumentType> {
651        &self.cached_data.document_types
652    }
653
654    /// Find a document type by its name.
655    #[must_use]
656    pub fn find_document_type_by_name(&self, name: &str) -> Option<&DocumentType> {
657        self.cached_data
658            .document_types
659            .values()
660            .find(|dt| dt.name == name)
661    }
662
663    /// Get the correspondents cache.
664    #[inline]
665    #[must_use]
666    pub fn correspondents(&self) -> &HashMap<CorrespondentId, Correspondent> {
667        &self.cached_data.correspondents
668    }
669
670    /// Get the custom fields cache.
671    #[inline]
672    #[must_use]
673    pub fn custom_fields(&self) -> &HashMap<CustomFieldId, CustomField> {
674        &self.cached_data.custom_fields
675    }
676
677    /// Find a custom field by its name.
678    #[must_use]
679    pub fn find_custom_field_by_name(&self, name: &str) -> Option<&CustomField> {
680        self.cached_data
681            .custom_fields
682            .values()
683            .find(|field| field.name == name)
684    }
685
686    /// Get the users cache.
687    #[inline]
688    #[must_use]
689    pub fn users(&self) -> &HashMap<UserId, User> {
690        &self.cached_data.users
691    }
692
693    /// Get the groups cache.
694    #[inline]
695    #[must_use]
696    pub fn groups(&self) -> &HashMap<GroupId, Group> {
697        &self.cached_data.groups
698    }
699}