1use std::{collections::HashMap, path::Path, str::FromStr, sync::Arc};
4
5use enum_iterator::Sequence;
6use reqwest::{
7 Method, StatusCode,
8 header::{ACCEPT, HeaderMap, HeaderName, InvalidHeaderValue},
9 multipart,
10};
11use serde::{Deserialize, de::DeserializeOwned};
12use tracing::{debug, trace};
13
14use crate::{
15 Error, Group, Result, SavedView, User,
16 document::{Document, DocumentData},
17 document_query::DocumentQueryBuilder,
18 dto::Item,
19 id::{
20 CorrespondentId, CustomFieldId, DocumentId, DocumentTypeId, GroupId, StoragePathId, TagId,
21 TaskId, UserId,
22 },
23 metadata::{
24 correspondent::Correspondent, custom_field::CustomField, document_type::DocumentType,
25 storage_path::StoragePath, tag::Tag,
26 },
27 task::Task,
28 util,
29 workflow::Workflow,
30};
31
32#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Sequence)]
37#[non_exhaustive]
38pub enum RefreshMetaData {
39 Tags,
40 CustomFields,
41 Correspondents,
42 DocumentTypes,
43 Groups,
44 Users,
45 StoragePaths,
46}
47
48#[derive(Debug, Clone)]
50pub struct PaperlessClient {
51 pub request_full_permissions: bool,
53
54 pub request_full_content: bool,
56
57 pub(crate) base_url: Arc<str>,
58
59 client: reqwest::Client,
60 cached_data: Arc<CachedData>,
61}
62
63#[derive(Debug, Clone)]
64struct CachedData {
65 correspondents: HashMap<CorrespondentId, Correspondent>,
66 custom_fields: HashMap<CustomFieldId, CustomField>,
67 document_types: HashMap<DocumentTypeId, DocumentType>,
68 groups: HashMap<GroupId, Group>,
69 storage_paths: HashMap<StoragePathId, StoragePath>,
70 tags: HashMap<TagId, Tag>,
71 users: HashMap<UserId, User>,
72}
73
74#[derive(Debug, Deserialize)]
75struct PaginatedResponse<T> {
76 results: Vec<T>,
77 next: Option<String>,
78}
79
80impl PaperlessClient {
81 pub fn new(
89 base_url: &str,
90 token: &str,
91 headers: Option<&HashMap<String, String>>,
92 ) -> std::result::Result<Self, String> {
93 Self::new_with_client(
94 base_url,
95 token,
96 headers,
97 reqwest::Client::builder().zstd(true),
98 )
99 }
100
101 pub fn new_with_client(
113 base_url: &str,
114 token: &str,
115 headers: Option<&HashMap<String, String>>,
116 client_builder: reqwest::ClientBuilder,
117 ) -> std::result::Result<Self, String> {
118 let mut headers_map = HeaderMap::new();
119
120 if let Some(headers) = headers {
122 for (key, value) in headers {
123 headers_map.insert(
124 HeaderName::from_str(key).map_err(|err| err.to_string())?,
125 value
126 .parse()
127 .map_err(|err: InvalidHeaderValue| err.to_string())?,
128 );
129 }
130 }
131
132 headers_map.insert(
134 HeaderName::from_str("Authorization").map_err(|err| err.to_string())?,
135 format!("Token {token}")
136 .parse()
137 .map_err(|err: InvalidHeaderValue| err.to_string())?,
138 );
139
140 Ok(Self {
141 request_full_permissions: false,
142 request_full_content: false,
143 base_url: base_url.into(),
144 client: client_builder
145 .default_headers(headers_map)
146 .build()
147 .map_err(|err| err.to_string())?,
148 cached_data: Arc::new(CachedData {
149 custom_fields: HashMap::new(),
150 correspondents: HashMap::new(),
151 document_types: HashMap::new(),
152 groups: HashMap::new(),
153 storage_paths: HashMap::new(),
154 tags: HashMap::new(),
155 users: HashMap::new(),
156 }),
157 })
158 }
159
160 #[must_use]
165 pub fn with_full_permissions(mut self, req: bool) -> Self {
166 self.request_full_permissions = req;
167 self
168 }
169
170 #[must_use]
171 pub fn with_full_content(mut self, full_content: bool) -> Self {
172 self.request_full_content = full_content;
173 self
174 }
175
176 async fn load_items<T: Item + DeserializeOwned>(&self) -> Result<HashMap<T::Id, T>> {
178 debug!("Loading {}", T::endpoint());
179 let endpoint = format!("/api/{}/", T::endpoint());
180
181 let items: Vec<T> = self
182 .fetch_all_pages(&endpoint, self.default_query_params().as_deref())
183 .await?;
184 Ok(items.into_iter().map(|item| (item.id(), item)).collect())
185 }
186
187 fn default_query_params(&self) -> Option<Vec<(&'static str, &'static str)>> {
188 let mut params = Vec::with_capacity(2);
189
190 if self.request_full_permissions {
191 params.push((crate::document_query::QUERY_PARAM_FULL_PERMISSIONS, "true"));
192 }
193 if !self.request_full_content {
194 params.push((crate::document_query::QUERY_PARAM_TRUNCATE_CONTENT, "true"));
195 }
196
197 if params.is_empty() {
198 None
199 } else {
200 Some(params)
201 }
202 }
203
204 pub async fn refresh_all(&mut self) -> Result<()> {
212 self.refresh(enum_iterator::all::<RefreshMetaData>()).await
213 }
214
215 pub async fn refresh(&mut self, data: impl IntoIterator<Item = RefreshMetaData>) -> Result<()> {
224 #[rustfmt::skip]
225 async fn inner(
226 client: &mut PaperlessClient,
227 data: &mut dyn Iterator<Item = RefreshMetaData>,
228 ) -> Result<()> {
229 let selected: std::collections::HashSet<_> = data.into_iter().collect();
230
231 if selected.is_empty() {
232 return Ok(());
233 }
234
235 let (tags, custom_fields, correspondents, document_types, groups, users, storage_paths) =
236 futures_util::try_join!(
237 async {
238 if selected.contains(&RefreshMetaData::Tags) {
239 Ok(Some(client.load_items::<Tag>().await?))
240 } else {
241 Ok::<Option<_>, Error>(None)
242 }
243 },
244 async {
245 if selected.contains(&RefreshMetaData::CustomFields) {
246 Ok(Some(client.load_items::<CustomField>().await?))
247 } else {
248 Ok(None)
249 }
250 },
251 async {
252 if selected.contains(&RefreshMetaData::Correspondents) {
253 Ok(Some(client.load_items::<Correspondent>().await?))
254 } else {
255 Ok(None)
256 }
257 },
258 async {
259 if selected.contains(&RefreshMetaData::DocumentTypes) {
260 Ok(Some(client.load_items::<DocumentType>().await?))
261 } else {
262 Ok(None)
263 }
264 },
265 async {
266 if selected.contains(&RefreshMetaData::Groups) {
267 Ok(Some(client.load_items::<Group>().await?))
268 } else {
269 Ok(None)
270 }
271 },
272 async {
273 if selected.contains(&RefreshMetaData::Users) {
274 Ok(Some(client.load_items::<User>().await?))
275 } else {
276 Ok(None)
277 }
278 },
279 async {
280 if selected.contains(&RefreshMetaData::StoragePaths) {
281 Ok(Some(client.load_items::<StoragePath>().await?))
282 } else {
283 Ok(None)
284 }
285 },
286 )?;
287
288 let cached_data = Arc::make_mut(&mut client.cached_data);
289
290 if let Some(value) = custom_fields { cached_data.custom_fields = value; }
291 if let Some(value) = correspondents { cached_data.correspondents = value; }
292 if let Some(value) = document_types { cached_data.document_types = value; }
293 if let Some(value) = groups { cached_data.groups = value; }
294 if let Some(value) = storage_paths { cached_data.storage_paths = value; }
295 if let Some(value) = tags { cached_data.tags = value; }
296 if let Some(value) = users { cached_data.users = value; }
297
298 Ok(())
299 }
300
301 inner(self, &mut data.into_iter()).await
302 }
303
304 pub async fn query_documents(&self, query: DocumentQueryBuilder) -> Result<Vec<Document>> {
306 let full_content = query.full_content;
307 let query_params = query.build();
308 let query_vec: Vec<_> = query_params
309 .query
310 .iter()
311 .map(|(k, v)| (*k, v.as_str()))
312 .collect();
313 let query_slice = query_vec.as_slice();
314
315 let documents: Vec<_> = self
316 .fetch_all_pages::<DocumentData>("/api/documents/", Some(query_slice))
317 .await?
318 .into_iter()
319 .map(|data| Document::new(data, Arc::new(self.clone()), !full_content))
320 .collect();
321
322 Ok(documents)
323 }
324
325 pub fn get_documents_by_tags(
327 &self,
328 tag_ids: &[TagId],
329 ) -> impl Future<Output = Result<Vec<Document>>> {
330 let query = DocumentQueryBuilder::default()
331 .full_content(self.request_full_content)
332 .full_permissions(self.request_full_permissions)
333 .tags_id_in(tag_ids.to_vec());
334
335 self.query_documents(query)
336 }
337
338 pub(crate) async fn get_document_data_by_id(&self, id: DocumentId) -> Result<DocumentData> {
339 self.request_json(
340 Method::GET,
341 &format!("/api/documents/{}/", id.0),
342 None,
343 self.default_query_params().as_deref(),
344 )
345 .await
346 }
347
348 pub async fn get_document_by_id(&self, id: DocumentId) -> Result<Document> {
350 Ok(Document::new(
351 self.get_document_data_by_id(id).await?,
352 Arc::new(self.clone()),
353 false,
354 ))
355 }
356
357 pub(crate) async fn request_json<T: serde::de::DeserializeOwned>(
359 &self,
360 method: Method,
361 endpoint: &str,
362 body: Option<&serde_json::Value>,
363 query_params: Option<&[(&str, &str)]>,
364 ) -> Result<T> {
365 let resp = self.request(method, endpoint, body, query_params).await?;
366
367 if tracing::enabled!(tracing::Level::TRACE) {
368 let response_text = resp.text().await.unwrap_or_default();
370 trace!(body = %response_text, "Response");
371
372 Ok(serde_json::from_str(&response_text)
373 .map_err(|e| Error::InvalidJson(format!("Failed to parse response body: {e:?}")))?)
374 } else {
375 Ok(resp
376 .json()
377 .await
378 .map_err(|e| Error::InvalidJson(format!("Failed to parse response body: {e:?}")))?)
379 }
380 }
381
382 pub(crate) async fn request(
384 &self,
385 method: Method,
386 endpoint: &str,
387 body: Option<&serde_json::Value>,
388 query_params: Option<&[(&str, &str)]>,
389 ) -> Result<reqwest::Response> {
390 let mut req = self
391 .client
392 .request(method, format!("{}{endpoint}", self.base_url))
393 .header(ACCEPT, "application/json");
394
395 if let Some(params) = query_params {
396 req = req.query(params);
397 }
398
399 if let Some(json_body) = body {
401 req = req.json(json_body);
402 }
403
404 let req = req.build().map_err(|e| Error::Request(e.into()))?;
405 debug!(
406 method = ?req.method(),
407 url = ?req.url(),
408 body = ?req.body(),
409 "Sending request to Paperless API");
410
411 let resp = self
412 .client
413 .execute(req)
414 .await
415 .map_err(|e| Error::Other(format!("Failed to send request: {e}")))?;
416
417 debug!(status = ?resp.status(), "Response");
419
420 if resp.status() == StatusCode::NOT_FOUND {
421 return Err(Error::NotFound);
422 }
423
424 if !resp.status().is_success() {
425 return Err(Error::Response {
426 status_code: resp.status().as_u16(),
427 body: resp.text().await.unwrap_or_default(),
428 });
429 }
430
431 Ok(resp)
432 }
433
434 pub(crate) async fn fetch_all_pages<T: for<'de> Deserialize<'de>>(
435 &self,
436 endpoint: &str,
437 query_params: Option<&[(&str, &str)]>,
438 ) -> Result<Vec<T>> {
439 let mut results = vec![];
440 let mut all_query_params = self.default_query_params().unwrap_or_default();
441 all_query_params.extend(query_params.unwrap_or_default());
442 let mut all_query_params = Some(all_query_params);
443
444 let mut current_url = Some(endpoint.to_string());
445
446 while let Some(url) = current_url {
447 debug!("Fetching page: {url}");
448
449 let page: PaginatedResponse<T> = self
450 .request_json(Method::GET, &url, None, all_query_params.as_deref())
451 .await?;
452
453 results.extend(page.results);
454
455 current_url = page.next.and_then(|next_url| {
456 next_url
458 .strip_prefix(&*self.base_url)
459 .unwrap_or(&next_url)
460 .to_string()
461 .into()
462 });
463 all_query_params = None;
464 }
465
466 Ok(results)
467 }
468
469 pub async fn get_task_status(
471 &self,
472 task_id: Option<&TaskId>,
473 task_name: Option<&str>,
474 acknowledged: Option<bool>,
475 ) -> Result<Vec<Task>> {
476 let mut query = Vec::new();
477
478 if let Some(id) = task_id {
479 query.push(("task_id", id.to_string()));
480 }
481
482 if let Some(name) = task_name {
483 query.push(("task_name", name.to_string()));
484 }
485
486 if let Some(ack) = acknowledged {
487 query.push(("acknowledged", ack.to_string()));
488 }
489
490 let resp = self
491 .request(
492 Method::GET,
493 &format!(
494 "/api/tasks/?{}",
495 serde_urlencoded::to_string(&query)
496 .map_err(|e| Error::Other(format!("Failed to serialize query: {e}")))?
497 ),
498 None::<&serde_json::Value>,
499 None,
500 )
501 .await?;
502
503 let body = resp
504 .text()
505 .await
506 .map_err(|e| Error::Other(format!("Failed to read response body: {e:?}")))?;
507
508 trace!("get_task_status response: {:?}", body);
509
510 let tasks: Vec<Task> = match serde_json::from_str(&body) {
511 Ok(t) => t,
512 Err(e) => {
513 return Err(Error::InvalidJson(format!(
514 "Failed to parse response body: {e:?}"
515 )));
516 }
517 };
518
519 if tasks.is_empty() {
520 return Err(Error::NotFound);
521 }
522
523 Ok(tasks)
524 }
525
526 pub fn get_workflows(&self) -> impl Future<Output = Result<Vec<Workflow>>> {
527 self.fetch_all_pages("/api/workflows/", None)
528 }
529
530 pub fn get_saved_views(&self) -> impl Future<Output = Result<Vec<SavedView>>> {
531 self.fetch_all_pages("/api/saved_views/", None)
532 }
533
534 pub fn get_statistics(&self) -> impl Future<Output = Result<util::Statistics>> {
535 self.request_json(Method::GET, "/api/statistics/", None, None)
536 }
537
538 pub fn get_status(&self) -> impl Future<Output = Result<util::Statistics>> {
539 self.request_json(Method::GET, "/api/status/", None, None)
540 }
541
542 pub async fn create<T: Item>(&self, new_item: T::CreateDto) -> Result<T::BaseType> {
548 let url = format!("/api/{}/", T::endpoint());
549 self.request_json(
550 Method::POST,
551 &url,
552 Some(&serde_json::to_value(&new_item).map_err(|e| Error::Other(e.to_string()))?),
553 None,
554 )
555 .await
556 }
557
558 pub async fn update<T: Item>(&self, id: T::Id, update: T::UpdateDto) -> Result<()> {
562 let url = format!("/api/{}/{}/", T::endpoint(), id);
563 self.request_json(
564 Method::PATCH,
565 &url,
566 Some(&serde_json::to_value(&update).map_err(|e| Error::Other(e.to_string()))?),
567 None,
568 )
569 .await
570 }
571
572 pub async fn delete<T: Item>(&self, id: T::Id) -> Result<()> {
574 let url = format!("/api/{}/{}/", T::endpoint(), id);
575 self.request(Method::DELETE, &url, None, None).await?;
576 Ok(())
577 }
578
579 pub async fn upload_document(&self, file_path: &Path, filename: &str) -> Result<TaskId> {
583 let stream = tokio::fs::File::open(file_path)
584 .await
585 .map_err(|e| Error::Other(format!("Failed to open file: {e}")))?;
586
587 let form = multipart::Form::new().part(
588 "document",
589 multipart::Part::stream(stream).file_name(filename.to_string()),
590 );
591
592 let url = format!("{}/api/documents/post_document/", self.base_url);
593
594 let resp = self
595 .client
596 .post(&url)
597 .multipart(form)
598 .send()
599 .await
600 .map_err(|e| Error::Other(format!("Failed to send request: {e}")))?;
601
602 let status = resp.status();
603 if !resp.status().is_success() {
604 return Err(Error::Response {
605 status_code: status.as_u16(),
606 body: resp.text().await.unwrap_or_default(),
607 });
608 }
609
610 let task_id: String = resp
611 .json()
612 .await
613 .map_err(|e| Error::Other(format!("Failed to parse task ID: {e:?}")))?;
614 Ok(TaskId(task_id))
615 }
616
617 #[inline]
619 #[must_use]
620 pub fn tags(&self) -> &HashMap<TagId, Tag> {
621 &self.cached_data.tags
622 }
623
624 #[inline]
626 #[must_use]
627 pub fn storage_paths(&self) -> &HashMap<StoragePathId, StoragePath> {
628 &self.cached_data.storage_paths
629 }
630
631 #[must_use]
632 pub fn find_tag_by_name(&self, name: &str) -> Option<&Tag> {
633 self.cached_data.tags.values().find(|tag| tag.name == name)
634 }
635
636 #[inline]
638 #[must_use]
639 pub fn document_types(&self) -> &HashMap<DocumentTypeId, DocumentType> {
640 &self.cached_data.document_types
641 }
642
643 #[must_use]
644 pub fn find_document_type_by_name(&self, name: &str) -> Option<&DocumentType> {
645 self.cached_data
646 .document_types
647 .values()
648 .find(|dt| dt.name == name)
649 }
650
651 #[inline]
653 #[must_use]
654 pub fn correspondents(&self) -> &HashMap<CorrespondentId, Correspondent> {
655 &self.cached_data.correspondents
656 }
657
658 #[inline]
660 #[must_use]
661 pub fn custom_fields(&self) -> &HashMap<CustomFieldId, CustomField> {
662 &self.cached_data.custom_fields
663 }
664
665 #[must_use]
666 pub fn find_custom_field_by_name(&self, name: &str) -> Option<&CustomField> {
667 self.cached_data
668 .custom_fields
669 .values()
670 .find(|field| field.name == name)
671 }
672
673 #[inline]
675 #[must_use]
676 pub fn users(&self) -> &HashMap<UserId, User> {
677 &self.cached_data.users
678 }
679
680 #[inline]
682 #[must_use]
683 pub fn groups(&self) -> &HashMap<GroupId, Group> {
684 &self.cached_data.groups
685 }
686}