1use std::{collections::HashMap, path::Path, str::FromStr, sync::Arc};
4
5use enum_iterator::Sequence;
6use reqwest::{
7 Method, StatusCode,
8 header::{ACCEPT, HeaderMap, HeaderName, InvalidHeaderValue},
9 multipart,
10};
11use serde::{Deserialize, Serialize, de::DeserializeOwned};
12use tracing::{debug, trace};
13
14use crate::{
15 Error, Group, Result, SavedView, User,
16 document::{Document, DocumentData},
17 document_query::DocumentQueryBuilder,
18 dto::{CreateDto, Item, UpdateDto},
19 id::{
20 CorrespondentId, CustomFieldId, DocumentId, DocumentTypeId, GroupId, ItemId, StoragePathId,
21 TagId, TaskId, UserId,
22 },
23 metadata::{
24 correspondent::Correspondent, custom_field::CustomField, document_type::DocumentType,
25 storage_path::StoragePath, tag::Tag,
26 },
27 task::Task,
28 util,
29 workflow::Workflow,
30};
31
32#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Sequence)]
37#[non_exhaustive]
38pub enum RefreshMetaData {
39 Tags,
40 CustomFields,
41 Correspondents,
42 DocumentTypes,
43 Groups,
44 Users,
45 StoragePaths,
46}
47
48#[derive(Debug, Clone)]
50pub struct PaperlessClient {
51 pub request_full_permissions: bool,
53
54 pub request_full_content: bool,
56
57 pub(crate) base_url: Arc<str>,
58
59 client: reqwest::Client,
60 cached_data: Arc<CachedData>,
61}
62
63#[derive(Debug, Clone)]
64struct CachedData {
65 correspondents: HashMap<CorrespondentId, Correspondent>,
66 custom_fields: HashMap<CustomFieldId, CustomField>,
67 document_types: HashMap<DocumentTypeId, DocumentType>,
68 groups: HashMap<GroupId, Group>,
69 storage_paths: HashMap<StoragePathId, StoragePath>,
70 tags: HashMap<TagId, Tag>,
71 users: HashMap<UserId, User>,
72}
73
74#[derive(Debug, Deserialize)]
75struct PaginatedResponse<T> {
76 results: Vec<T>,
77 next: Option<String>,
78}
79
80impl PaperlessClient {
81 pub fn new(
89 base_url: &str,
90 token: &str,
91 headers: Option<&HashMap<String, String>>,
92 ) -> std::result::Result<Self, String> {
93 Self::new_with_client(
94 base_url,
95 token,
96 headers,
97 reqwest::Client::builder().zstd(true),
98 )
99 }
100
101 pub fn new_with_client(
113 base_url: &str,
114 token: &str,
115 headers: Option<&HashMap<String, String>>,
116 client_builder: reqwest::ClientBuilder,
117 ) -> std::result::Result<Self, String> {
118 let mut headers_map = HeaderMap::new();
119
120 if let Some(headers) = headers {
122 for (key, value) in headers {
123 headers_map.insert(
124 HeaderName::from_str(key).map_err(|err| err.to_string())?,
125 value
126 .parse()
127 .map_err(|err: InvalidHeaderValue| err.to_string())?,
128 );
129 }
130 }
131
132 headers_map.insert(
134 HeaderName::from_str("Authorization").map_err(|err| err.to_string())?,
135 format!("Token {token}")
136 .parse()
137 .map_err(|err: InvalidHeaderValue| err.to_string())?,
138 );
139
140 Ok(Self {
141 request_full_permissions: false,
142 request_full_content: false,
143 base_url: base_url.into(),
144 client: client_builder
145 .default_headers(headers_map)
146 .build()
147 .map_err(|err| err.to_string())?,
148 cached_data: Arc::new(CachedData {
149 custom_fields: HashMap::new(),
150 correspondents: HashMap::new(),
151 document_types: HashMap::new(),
152 groups: HashMap::new(),
153 storage_paths: HashMap::new(),
154 tags: HashMap::new(),
155 users: HashMap::new(),
156 }),
157 })
158 }
159
160 #[must_use]
165 pub fn with_full_permissions(mut self, req: bool) -> Self {
166 self.request_full_permissions = req;
167 self
168 }
169
170 #[must_use]
171 pub fn with_full_content(mut self, full_content: bool) -> Self {
172 self.request_full_content = full_content;
173 self
174 }
175
176 pub async fn load_items<T: Item + DeserializeOwned>(&self) -> Result<HashMap<T::Id, T>> {
178 let endpoint = format!("/api/{}/", T::endpoint());
179 debug!(endpoint, "Loading");
180
181 let items: Vec<T> = self.fetch_all_pages(&endpoint, None).await?;
182 Ok(items.into_iter().map(|item| (item.id(), item)).collect())
183 }
184
185 fn default_query_params(&self) -> Option<Vec<(&'static str, &'static str)>> {
186 let mut params = Vec::with_capacity(2);
187
188 if self.request_full_permissions {
189 params.push((crate::document_query::QUERY_PARAM_FULL_PERMISSIONS, "true"));
190 }
191 if !self.request_full_content {
192 params.push((crate::document_query::QUERY_PARAM_TRUNCATE_CONTENT, "true"));
193 }
194
195 if params.is_empty() {
196 None
197 } else {
198 Some(params)
199 }
200 }
201
202 pub async fn refresh_all(&mut self) -> Result<()> {
206 self.refresh(enum_iterator::all::<RefreshMetaData>()).await
207 }
208
209 pub async fn refresh(&mut self, data: impl IntoIterator<Item = RefreshMetaData>) -> Result<()> {
218 #[rustfmt::skip]
219 async fn inner(
220 client: &mut PaperlessClient,
221 data: &mut dyn Iterator<Item = RefreshMetaData>,
222 ) -> Result<()> {
223 let selected: std::collections::HashSet<_> = data.into_iter().collect();
224
225 if selected.is_empty() {
226 return Ok(());
227 }
228
229 let (tags, custom_fields, correspondents, document_types, groups, users, storage_paths) =
230 futures_util::try_join!(
231 async {
232 if selected.contains(&RefreshMetaData::Tags) {
233 Ok(Some(client.load_items::<Tag>().await?))
234 } else {
235 Ok::<Option<_>, Error>(None)
236 }
237 },
238 async {
239 if selected.contains(&RefreshMetaData::CustomFields) {
240 Ok(Some(client.load_items::<CustomField>().await?))
241 } else {
242 Ok(None)
243 }
244 },
245 async {
246 if selected.contains(&RefreshMetaData::Correspondents) {
247 Ok(Some(client.load_items::<Correspondent>().await?))
248 } else {
249 Ok(None)
250 }
251 },
252 async {
253 if selected.contains(&RefreshMetaData::DocumentTypes) {
254 Ok(Some(client.load_items::<DocumentType>().await?))
255 } else {
256 Ok(None)
257 }
258 },
259 async {
260 if selected.contains(&RefreshMetaData::Groups) {
261 Ok(Some(client.load_items::<Group>().await?))
262 } else {
263 Ok(None)
264 }
265 },
266 async {
267 if selected.contains(&RefreshMetaData::Users) {
268 Ok(Some(client.load_items::<User>().await?))
269 } else {
270 Ok(None)
271 }
272 },
273 async {
274 if selected.contains(&RefreshMetaData::StoragePaths) {
275 Ok(Some(client.load_items::<StoragePath>().await?))
276 } else {
277 Ok(None)
278 }
279 },
280 )?;
281
282 let cached_data = Arc::make_mut(&mut client.cached_data);
283
284 if let Some(value) = custom_fields { cached_data.custom_fields = value; }
285 if let Some(value) = correspondents { cached_data.correspondents = value; }
286 if let Some(value) = document_types { cached_data.document_types = value; }
287 if let Some(value) = groups { cached_data.groups = value; }
288 if let Some(value) = storage_paths { cached_data.storage_paths = value; }
289 if let Some(value) = tags { cached_data.tags = value; }
290 if let Some(value) = users { cached_data.users = value; }
291
292 Ok(())
293 }
294
295 inner(self, &mut data.into_iter()).await
296 }
297
298 pub async fn query_documents(&self, query: DocumentQueryBuilder) -> Result<Vec<Document>> {
300 let full_content = query.full_content;
301 let query_params = query.build();
302 let query_vec: Vec<_> = query_params
303 .query
304 .iter()
305 .map(|(k, v)| (*k, v.as_str()))
306 .collect();
307 let query_slice = query_vec.as_slice();
308
309 let documents: Vec<_> = self
310 .fetch_all_pages::<DocumentData>("/api/documents/", Some(query_slice))
311 .await?
312 .into_iter()
313 .map(|data| Document::new(data, Arc::new(self.clone()), !full_content))
314 .collect();
315
316 Ok(documents)
317 }
318
319 pub fn get_documents_by_tags(
321 &self,
322 tag_ids: &[TagId],
323 ) -> impl Future<Output = Result<Vec<Document>>> {
324 let query = DocumentQueryBuilder::default()
325 .full_content(self.request_full_content)
326 .full_permissions(self.request_full_permissions)
327 .tags_id_in(tag_ids.to_vec());
328
329 self.query_documents(query)
330 }
331
332 pub(crate) async fn get_document_data_by_id(&self, id: DocumentId) -> Result<DocumentData> {
333 self.request_json_no_body(
334 Method::GET,
335 &format!("/api/documents/{}/", id.0),
336 self.default_query_params().as_deref(),
337 )
338 .await
339 }
340
341 pub async fn get_document_by_id(&self, id: DocumentId) -> Result<Document> {
343 Ok(Document::new(
344 self.get_document_data_by_id(id).await?,
345 Arc::new(self.clone()),
346 false,
347 ))
348 }
349
350 pub(crate) fn request_json_no_body<T: serde::de::DeserializeOwned>(
352 &self,
353 method: Method,
354 endpoint: &str,
355 query_params: Option<&[(&str, &str)]>,
356 ) -> impl Future<Output = Result<T>> {
357 self.request_json(method, endpoint, None::<&()>, query_params)
358 }
359
360 pub(crate) async fn request_json<T: serde::de::DeserializeOwned>(
362 &self,
363 method: Method,
364 endpoint: &str,
365 body: Option<&impl Serialize>,
366 query_params: Option<&[(&str, &str)]>,
367 ) -> Result<T> {
368 let resp = self.request(method, endpoint, body, query_params).await?;
369
370 if tracing::enabled!(tracing::Level::TRACE) {
371 let response_text = resp.text().await.unwrap_or_default();
373 trace!(body = %response_text, "Response");
374
375 Ok(serde_json::from_str(&response_text)
376 .map_err(|e| Error::InvalidJson(format!("Failed to parse response body: {e:?}")))?)
377 } else {
378 Ok(resp
379 .json()
380 .await
381 .map_err(|e| Error::InvalidJson(format!("Failed to parse response body: {e:?}")))?)
382 }
383 }
384
385 pub(crate) fn request_no_body(
387 &self,
388 method: Method,
389 endpoint: &str,
390 query_params: Option<&[(&str, &str)]>,
391 ) -> impl Future<Output = Result<reqwest::Response>> {
392 self.request(method, endpoint, None::<&()>, query_params)
393 }
394
395 pub(crate) async fn request(
397 &self,
398 method: Method,
399 endpoint: &str,
400 body: Option<&impl Serialize>,
401 query_params: Option<&[(&str, &str)]>,
402 ) -> Result<reqwest::Response> {
403 let mut req = self
404 .client
405 .request(method, format!("{}{endpoint}", self.base_url))
406 .header(ACCEPT, "application/json");
407
408 if let Some(params) = query_params {
409 req = req.query(params);
410 }
411
412 if let Some(json_body) = body {
414 req = req.json(json_body);
415 }
416
417 let req = req.build().map_err(|e| Error::Request(e.into()))?;
418
419 if tracing::enabled!(tracing::Level::TRACE)
420 && let Some(body) = req.body().and_then(|b| b.as_bytes())
421 {
422 trace!(
423 method = ?req.method(),
424 url = ?req.url(),
425 body = %String::from_utf8_lossy(body),
426 "Sending request to Paperless API");
427 } else {
428 debug!(
429 method = ?req.method(),
430 url = ?req.url(),
431 "Sending request to Paperless API");
432 }
433
434 let resp = self
435 .client
436 .execute(req)
437 .await
438 .map_err(|e| Error::Other(format!("Failed to send request: {e}")))?;
439
440 debug!(status = ?resp.status(), "Response");
442
443 if resp.status() == StatusCode::NOT_FOUND {
444 return Err(Error::NotFound);
445 }
446
447 if !resp.status().is_success() {
448 return Err(Error::Response {
449 status_code: resp.status().as_u16(),
450 body: resp.text().await.unwrap_or_default(),
451 });
452 }
453
454 Ok(resp)
455 }
456
457 pub(crate) async fn fetch_all_pages<T: for<'de> Deserialize<'de>>(
458 &self,
459 endpoint: &str,
460 query_params: Option<&[(&str, &str)]>,
461 ) -> Result<Vec<T>> {
462 let mut results = vec![];
463 let mut all_query_params = self.default_query_params().unwrap_or_default();
464 all_query_params.extend(query_params.unwrap_or_default());
465 let mut all_query_params = Some(all_query_params);
466
467 let mut current_url = Some(endpoint.to_string());
468
469 while let Some(url) = current_url {
470 debug!("Fetching page: {url}");
471
472 let page: PaginatedResponse<T> = self
473 .request_json_no_body(Method::GET, &url, all_query_params.as_deref())
474 .await?;
475
476 results.extend(page.results);
477
478 current_url = page.next.and_then(|next_url| {
479 next_url
481 .strip_prefix(&*self.base_url)
482 .unwrap_or(&next_url)
483 .to_string()
484 .into()
485 });
486 all_query_params = None;
487 }
488
489 Ok(results)
490 }
491
492 pub async fn get_task_status(
494 &self,
495 task_id: Option<&TaskId>,
496 task_name: Option<&str>,
497 acknowledged: Option<bool>,
498 ) -> Result<Vec<Task>> {
499 let mut query = Vec::new();
500
501 if let Some(id) = task_id {
502 query.push(("task_id", id.to_string()));
503 }
504
505 if let Some(name) = task_name {
506 query.push(("task_name", name.to_string()));
507 }
508
509 if let Some(ack) = acknowledged {
510 query.push(("acknowledged", ack.to_string()));
511 }
512
513 let resp = self
514 .request_no_body(
515 Method::GET,
516 &format!(
517 "/api/tasks/?{}",
518 serde_urlencoded::to_string(&query)
519 .map_err(|e| Error::Other(format!("Failed to serialize query: {e}")))?
520 ),
521 None,
522 )
523 .await?;
524
525 let body = resp
526 .text()
527 .await
528 .map_err(|e| Error::Other(format!("Failed to read response body: {e:?}")))?;
529
530 trace!("get_task_status response: {:?}", body);
531
532 let tasks: Vec<Task> = match serde_json::from_str(&body) {
533 Ok(t) => t,
534 Err(e) => {
535 return Err(Error::InvalidJson(format!(
536 "Failed to parse response body: {e:?}"
537 )));
538 }
539 };
540
541 if tasks.is_empty() {
542 return Err(Error::NotFound);
543 }
544
545 Ok(tasks)
546 }
547
548 pub fn get_workflows(&self) -> impl Future<Output = Result<Vec<Workflow>>> {
550 self.fetch_all_pages("/api/workflows/", None)
551 }
552
553 pub fn get_saved_views(&self) -> impl Future<Output = Result<Vec<SavedView>>> {
555 self.fetch_all_pages("/api/saved_views/", None)
556 }
557
558 pub fn get_statistics(&self) -> impl Future<Output = Result<util::Statistics>> {
560 self.request_json_no_body(Method::GET, "/api/statistics/", None)
561 }
562
563 pub fn get_status(&self) -> impl Future<Output = Result<util::ServerStatus>> {
565 self.request_json_no_body(Method::GET, "/api/status/", None)
566 }
567
568 pub async fn create<T>(&self, new_item: &T) -> Result<T::BaseType>
574 where
575 T: CreateDto,
576 T::BaseType: Item,
577 {
578 let url = format!("/api/{}/", <T::BaseType as Item>::endpoint());
579 self.request_json(Method::POST, &url, Some(&new_item), None)
580 .await
581 }
582
583 pub async fn update<T>(&self, id: T::Id, update: &T) -> Result<T::BaseType>
589 where
590 T: UpdateDto,
591 T::BaseType: Item,
592 {
593 let url = format!("/api/{}/{}/", <T::BaseType as Item>::endpoint(), id);
594 self.request_json::<T::BaseType>(Method::PATCH, &url, Some(&update), None)
595 .await
596 }
597
598 pub async fn delete<T: ItemId>(&self, id: T) -> Result<()> {
602 let url = format!("/api/{}/{}/", T::endpoint(), id);
603 self.request_no_body(Method::DELETE, &url, None).await?;
604 Ok(())
605 }
606
607 pub async fn load_by_id<T: Item>(&self, id: T::Id) -> Result<Option<T>> {
611 let url = format!("/api/{}/{}/", T::endpoint(), id);
612 match self.request_json_no_body(Method::GET, &url, None).await {
613 found_item @ Ok(_) => found_item,
614 Err(Error::NotFound) => Ok(None),
615 err @ Err(_) => err,
616 }
617 }
618
619 pub async fn upload_document(&self, file_path: &Path, filename: &str) -> Result<TaskId> {
623 let stream = tokio::fs::File::open(file_path)
624 .await
625 .map_err(|e| Error::Other(format!("Failed to open file: {e}")))?;
626
627 let form = multipart::Form::new().part(
628 "document",
629 multipart::Part::stream(stream).file_name(filename.to_string()),
630 );
631
632 let url = format!("{}/api/documents/post_document/", self.base_url);
633
634 let resp = self
635 .client
636 .post(&url)
637 .multipart(form)
638 .send()
639 .await
640 .map_err(|e| Error::Other(format!("Failed to send request: {e}")))?;
641
642 let status = resp.status();
643 if !resp.status().is_success() {
644 return Err(Error::Response {
645 status_code: status.as_u16(),
646 body: resp.text().await.unwrap_or_default(),
647 });
648 }
649
650 let task_id: String = resp
651 .json()
652 .await
653 .map_err(|e| Error::Other(format!("Failed to parse task ID: {e:?}")))?;
654 Ok(TaskId(task_id))
655 }
656
657 #[inline]
659 #[must_use]
660 pub fn tags(&self) -> &HashMap<TagId, Tag> {
661 &self.cached_data.tags
662 }
663
664 #[inline]
666 #[must_use]
667 pub fn storage_paths(&self) -> &HashMap<StoragePathId, StoragePath> {
668 &self.cached_data.storage_paths
669 }
670
671 #[must_use]
673 pub fn find_tag_by_name(&self, name: &str) -> Option<&Tag> {
674 self.cached_data.tags.values().find(|tag| tag.name == name)
675 }
676
677 #[inline]
679 #[must_use]
680 pub fn document_types(&self) -> &HashMap<DocumentTypeId, DocumentType> {
681 &self.cached_data.document_types
682 }
683
684 #[must_use]
686 pub fn find_document_type_by_name(&self, name: &str) -> Option<&DocumentType> {
687 self.cached_data
688 .document_types
689 .values()
690 .find(|dt| dt.name == name)
691 }
692
693 #[inline]
695 #[must_use]
696 pub fn correspondents(&self) -> &HashMap<CorrespondentId, Correspondent> {
697 &self.cached_data.correspondents
698 }
699
700 #[inline]
702 #[must_use]
703 pub fn custom_fields(&self) -> &HashMap<CustomFieldId, CustomField> {
704 &self.cached_data.custom_fields
705 }
706
707 #[must_use]
709 pub fn find_custom_field_by_name(&self, name: &str) -> Option<&CustomField> {
710 self.cached_data
711 .custom_fields
712 .values()
713 .find(|field| field.name == name)
714 }
715
716 #[inline]
718 #[must_use]
719 pub fn users(&self) -> &HashMap<UserId, User> {
720 &self.cached_data.users
721 }
722
723 #[inline]
725 #[must_use]
726 pub fn groups(&self) -> &HashMap<GroupId, Group> {
727 &self.cached_data.groups
728 }
729}