Skip to main content

paperless_api/
client.rs

1use std::{collections::HashMap, path::Path, str::FromStr, sync::Arc};
2
3use reqwest::{
4    Method, StatusCode,
5    header::{ACCEPT, HeaderMap, HeaderName, InvalidHeaderValue},
6    multipart,
7};
8use serde::Deserialize;
9use tracing::{debug, trace};
10
11use crate::{
12    Error, Result,
13    correspondent::{Correspondent, CorrespondentId},
14    custom_field::{CustomField, CustomFieldId},
15    document::{Document, DocumentId},
16    document_type::{DocumentType, DocumentTypeId},
17    tag::{Tag, TagId},
18    task::{Task, TaskId},
19};
20
21/// Selects which cached metadata to refresh.
22#[derive(Copy, Clone, Debug, PartialEq, Eq)]
23pub enum RefreshData {
24    Tags,
25    CustomFields,
26    Correspondents,
27    DocumentTypes,
28}
29
30/// Client to interact with Paperless.
31#[derive(Debug, Clone)]
32pub struct PaperlessClient {
33    client: reqwest::Client,
34    base_url: String,
35
36    correspondents: HashMap<CorrespondentId, Correspondent>,
37    document_types: HashMap<DocumentTypeId, DocumentType>,
38    tags: HashMap<TagId, Tag>,
39    custom_fields: HashMap<CustomFieldId, CustomField>,
40}
41
42#[derive(Debug, Deserialize)]
43struct PaginatedResponse<T> {
44    results: Vec<T>,
45    next: Option<String>,
46}
47
48impl PaperlessClient {
49    /// Create a new Paperless client.
50    pub fn new(
51        base_url: &str,
52        token: &str,
53        headers: Option<&HashMap<String, String>>,
54    ) -> std::result::Result<Self, String> {
55        let mut headers_map = HeaderMap::new();
56
57        // Add additional headers if provided
58        if let Some(headers) = headers {
59            for (key, value) in headers {
60                headers_map.insert(
61                    HeaderName::from_str(key).map_err(|err| err.to_string())?,
62                    value
63                        .parse()
64                        .map_err(|err: InvalidHeaderValue| err.to_string())?,
65                );
66            }
67        }
68
69        // Add the Paperless token header
70        headers_map.insert(
71            HeaderName::from_str("Authorization").map_err(|err| err.to_string())?,
72            format!("Token {token}")
73                .parse()
74                .map_err(|err: InvalidHeaderValue| err.to_string())?,
75        );
76
77        Ok(Self {
78            base_url: base_url.to_string(),
79            client: reqwest::Client::builder()
80                .default_headers(headers_map)
81                .build()
82                .map_err(|err| err.to_string())?,
83            tags: HashMap::new(),
84            custom_fields: HashMap::new(),
85            correspondents: HashMap::new(),
86            document_types: HashMap::new(),
87        })
88    }
89
90    async fn load_tags(&self) -> Result<HashMap<TagId, Tag>> {
91        debug!("loading tags");
92        let tags: Vec<Tag> = self.fetch_all_pages("/api/tags/").await?;
93        Ok(tags.into_iter().map(|tag| (tag.id, tag)).collect())
94    }
95
96    async fn load_custom_fields(&self) -> Result<HashMap<CustomFieldId, CustomField>> {
97        debug!("loading custom fields");
98        let custom_fields: Vec<CustomField> = self.fetch_all_pages("/api/custom_fields/").await?;
99        Ok(custom_fields
100            .into_iter()
101            .map(|custom_field| (custom_field.id, custom_field))
102            .collect())
103    }
104
105    async fn load_correspondents(&self) -> Result<HashMap<CorrespondentId, Correspondent>> {
106        debug!("loading correspondents");
107        let correspondents: Vec<Correspondent> =
108            self.fetch_all_pages("/api/correspondents/").await?;
109        Ok(correspondents
110            .into_iter()
111            .map(|correspondent| (correspondent.id, correspondent))
112            .collect())
113    }
114
115    async fn load_document_types(&self) -> Result<HashMap<DocumentTypeId, DocumentType>> {
116        debug!("loading document types");
117        let document_types: Vec<DocumentType> =
118            self.fetch_all_pages("/api/document_types/").await?;
119        Ok(document_types
120            .into_iter()
121            .map(|document_type| (document_type.id, document_type))
122            .collect())
123    }
124
125    /// Refresh selected cached metadata concurrently.
126    pub async fn refresh(&mut self, data: impl IntoIterator<Item = RefreshData>) -> Result<()> {
127        let mut refresh_tags = false;
128        let mut refresh_custom_fields = false;
129        let mut refresh_correspondents = false;
130        let mut refresh_document_types = false;
131
132        for item in data {
133            match item {
134                RefreshData::Tags => refresh_tags = true,
135                RefreshData::CustomFields => refresh_custom_fields = true,
136                RefreshData::Correspondents => refresh_correspondents = true,
137                RefreshData::DocumentTypes => refresh_document_types = true,
138            }
139        }
140
141        let (tags, custom_fields, correspondents, document_types) = futures_util::try_join!(
142            async {
143                if refresh_tags {
144                    Ok::<Option<HashMap<TagId, Tag>>, Error>(Some(self.load_tags().await?))
145                } else {
146                    Ok::<Option<HashMap<TagId, Tag>>, Error>(None)
147                }
148            },
149            async {
150                if refresh_custom_fields {
151                    Ok::<Option<HashMap<CustomFieldId, CustomField>>, Error>(Some(
152                        self.load_custom_fields().await?,
153                    ))
154                } else {
155                    Ok::<Option<HashMap<CustomFieldId, CustomField>>, Error>(None)
156                }
157            },
158            async {
159                if refresh_correspondents {
160                    Ok::<Option<HashMap<CorrespondentId, Correspondent>>, Error>(Some(
161                        self.load_correspondents().await?,
162                    ))
163                } else {
164                    Ok::<Option<HashMap<CorrespondentId, Correspondent>>, Error>(None)
165                }
166            },
167            async {
168                if refresh_document_types {
169                    Ok::<Option<HashMap<DocumentTypeId, DocumentType>>, Error>(Some(
170                        self.load_document_types().await?,
171                    ))
172                } else {
173                    Ok::<Option<HashMap<DocumentTypeId, DocumentType>>, Error>(None)
174                }
175            },
176        )?;
177
178        if let Some(tags) = tags {
179            self.tags = tags;
180        }
181
182        if let Some(custom_fields) = custom_fields {
183            self.custom_fields = custom_fields;
184        }
185
186        if let Some(correspondents) = correspondents {
187            self.correspondents = correspondents;
188        }
189
190        if let Some(document_types) = document_types {
191            self.document_types = document_types;
192        }
193
194        Ok(())
195    }
196
197    /// List all tags.
198    #[inline]
199    pub async fn refresh_tags(&mut self) -> Result<()> {
200        self.refresh([RefreshData::Tags]).await
201    }
202
203    /// List all custom fields.
204    #[inline]
205    pub async fn refresh_custom_fields(&mut self) -> Result<()> {
206        self.refresh([RefreshData::CustomFields]).await
207    }
208
209    /// List all correspondents.
210    #[inline]
211    pub async fn refresh_correspondents(&mut self) -> Result<()> {
212        self.refresh([RefreshData::Correspondents]).await
213    }
214
215    /// List all document types.
216    #[inline]
217    pub async fn refresh_document_types(&mut self) -> Result<()> {
218        self.refresh([RefreshData::DocumentTypes]).await
219    }
220
221    /// Get all documents with any of the given tags.
222    pub async fn get_documents_by_tags(
223        &self,
224        tag_ids: &[TagId],
225        truncate_content: bool,
226    ) -> Result<Vec<Document>> {
227        let tag_id_str = tag_ids
228            .iter()
229            .map(|tag_id| tag_id.0.to_string())
230            .collect::<Vec<_>>()
231            .join(",");
232        let mut documents: Vec<Document> = self
233            .fetch_all_pages(&format!(
234                "/api/documents/?truncate_content={truncate_content}&tags__id__in={tag_id_str}"
235            ))
236            .await?;
237
238        for doc in &mut documents {
239            doc.client = Some(Arc::new(self.clone()));
240            doc.content_is_truncated = truncate_content;
241        }
242
243        Ok(documents)
244    }
245
246    /// Get a document by its ID.
247    pub async fn get_document_by_id(&self, id: DocumentId) -> Result<Document> {
248        let resp = self
249            .request(Method::GET, &format!("/api/documents/{}/", id.0), None)
250            .await?;
251
252        let mut document: Document = resp
253            .json()
254            .await
255            .map_err(|e| Error::Other(format!("Failed to parse document: {e}")))?;
256
257        document.client = Some(Arc::new(self.clone()));
258        Ok(document)
259    }
260
261    pub(crate) async fn request(
262        &self,
263        method: Method,
264        endpoint: &str,
265        body: Option<&serde_json::Value>,
266    ) -> Result<reqwest::Response> {
267        let mut req = self
268            .client
269            .request(method, format!("{}{endpoint}", self.base_url))
270            .header(ACCEPT, "application/json");
271
272        if let Some(json_body) = body {
273            req = req.json(json_body);
274        }
275
276        let resp = req
277            .send()
278            .await
279            .map_err(|e| Error::Other(format!("Failed to send request: {e}")))?;
280
281        if resp.status() == StatusCode::NOT_FOUND {
282            return Err(Error::NotFound);
283        }
284
285        if !resp.status().is_success() {
286            return Err(Error::Response {
287                status_code: resp.status().as_u16(),
288                body: resp.text().await.unwrap_or_default(),
289            });
290        }
291
292        Ok(resp)
293    }
294
295    pub(crate) async fn fetch_all_pages<T: for<'de> Deserialize<'de>>(
296        &self,
297        endpoint: &str,
298    ) -> Result<Vec<T>> {
299        let mut results = Vec::new();
300        let mut current_url = Some(endpoint.to_string());
301
302        while let Some(url) = current_url {
303            let resp = self.request(Method::GET, &url, None).await?;
304
305            let page: PaginatedResponse<T> = resp.json().await.map_err(|e| {
306                Error::Other(format!(
307                    "Failed to parse paginated response for {endpoint}: {e}"
308                ))
309            })?;
310
311            results.extend(page.results);
312
313            current_url = page.next.and_then(|next_url| {
314                // Extract just the path from the full URL
315                next_url
316                    .trim_start_matches(&self.base_url)
317                    .to_string()
318                    .into()
319            });
320        }
321
322        Ok(results)
323    }
324
325    /// Get all tasks with optional filtering by ID, name, or acknowledged status.
326    pub async fn get_task_status(
327        &self,
328        task_id: Option<&TaskId>,
329        task_name: Option<&str>,
330        acknowledged: Option<bool>,
331    ) -> Result<Vec<Task>> {
332        let mut query = Vec::new();
333
334        if let Some(id) = task_id {
335            query.push(("task_id", id.to_string()));
336        }
337
338        if let Some(name) = task_name {
339            query.push(("task_name", name.to_string()));
340        }
341
342        if let Some(ack) = acknowledged {
343            query.push(("acknowledged", ack.to_string()));
344        }
345
346        let resp = self
347            .request(
348                Method::GET,
349                &format!(
350                    "/api/tasks/?{}",
351                    serde_urlencoded::to_string(&query)
352                        .map_err(|e| Error::Other(format!("Failed to serialize query: {e}")))?
353                ),
354                None,
355            )
356            .await?;
357
358        trace!("get_task_status response: {:?}", resp);
359
360        let body = resp
361            .text()
362            .await
363            .map_err(|e| Error::Other(format!("Failed to read response body: {e}")))?;
364
365        let tasks: Vec<Task> = match serde_json::from_str(&body) {
366            Ok(t) => t,
367            Err(e) => {
368                return Err(Error::InvalidJson(format!(
369                    "Failed to parse response body: {e}"
370                )));
371            }
372        };
373
374        if tasks.is_empty() {
375            return Err(Error::NotFound);
376        }
377
378        Ok(tasks)
379    }
380
381    /// Upload a document to Paperless.
382    ///
383    /// Returns the task ID on success.
384    pub async fn upload_document(&self, file_path: &Path, filename: &str) -> Result<TaskId> {
385        let file_bytes = std::fs::read(file_path)
386            .map_err(|e| Error::Other(format!("Failed to read file: {e}")))?;
387
388        let form = multipart::Form::new().part(
389            "document",
390            multipart::Part::bytes(file_bytes).file_name(filename.to_string()),
391        );
392
393        let url = format!("{}/api/documents/post_document/", self.base_url);
394
395        let resp = self
396            .client
397            .post(&url)
398            .multipart(form)
399            .send()
400            .await
401            .map_err(|e| Error::Other(format!("Failed to send request: {e}")))?;
402
403        let status = resp.status();
404        if !resp.status().is_success() {
405            return Err(Error::Response {
406                status_code: status.as_u16(),
407                body: resp.text().await.unwrap_or_default(),
408            });
409        }
410
411        let task_id: String = resp
412            .json()
413            .await
414            .map_err(|e| Error::Other(format!("Failed to parse task ID: {e}")))?;
415        Ok(TaskId(task_id))
416    }
417
418    pub fn tags(&self) -> &HashMap<TagId, Tag> {
419        &self.tags
420    }
421
422    pub fn find_tag_by_name(&self, name: &str) -> Option<&Tag> {
423        self.tags.values().find(|tag| tag.name == name)
424    }
425
426    pub fn document_types(&self) -> &HashMap<DocumentTypeId, DocumentType> {
427        &self.document_types
428    }
429
430    pub fn correspondents(&self) -> &HashMap<CorrespondentId, Correspondent> {
431        &self.correspondents
432    }
433
434    pub fn custom_fields(&self) -> &HashMap<CustomFieldId, CustomField> {
435        &self.custom_fields
436    }
437
438    pub fn find_custom_field_by_name(&self, name: &str) -> Option<&CustomField> {
439        self.custom_fields.values().find(|field| field.name == name)
440    }
441}