1use std::{collections::HashMap, fmt::Write, path::Path, str::FromStr, sync::Arc};
4
5use enum_iterator::Sequence;
6use reqwest::{
7 Method, StatusCode,
8 header::{ACCEPT, HeaderMap, HeaderName, InvalidHeaderValue},
9 multipart,
10};
11use serde::Deserialize;
12use tracing::{debug, trace};
13
14use crate::{
15 Error, Group, Result, SavedView, User,
16 document::{Document, DocumentData},
17 dto::CreateDtoObject,
18 id::{
19 CorrespondentId, CustomFieldId, DocumentId, DocumentTypeId, GroupId, StoragePathId, TagId,
20 TaskId, UserId,
21 },
22 metadata::{
23 correspondent::Correspondent, custom_field::CustomField, document_type::DocumentType,
24 storage_path::StoragePath, tag::Tag,
25 },
26 task::Task,
27 util,
28 workflow::Workflow,
29};
30
31const QUERY_PARAM_FULL_PERMISSIONS: &str = "full_perms";
32const QUERY_PARAM_TRUNCATE_CONTENT: &str = "truncate_content";
33const QUERY_PARAM_TAGS_ID_IN: &str = "tags__id__in";
34
35#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Sequence)]
40#[non_exhaustive]
41pub enum RefreshMetaData {
42 Tags,
43 CustomFields,
44 Correspondents,
45 DocumentTypes,
46 Groups,
47 Users,
48 StoragePaths,
49}
50
51#[derive(Debug, Clone)]
53pub struct PaperlessClient {
54 pub request_full_permissions: bool,
56
57 pub(crate) base_url: Arc<str>,
58
59 client: reqwest::Client,
60 cached_data: Arc<CachedData>,
61}
62
63#[derive(Debug, Clone)]
64struct CachedData {
65 correspondents: HashMap<CorrespondentId, Correspondent>,
66 custom_fields: HashMap<CustomFieldId, CustomField>,
67 document_types: HashMap<DocumentTypeId, DocumentType>,
68 groups: HashMap<GroupId, Group>,
69 storage_paths: HashMap<StoragePathId, StoragePath>,
70 tags: HashMap<TagId, Tag>,
71 users: HashMap<UserId, User>,
72}
73
74#[derive(Debug, Deserialize)]
75struct PaginatedResponse<T> {
76 results: Vec<T>,
77 next: Option<String>,
78}
79
80impl PaperlessClient {
81 pub fn new(
89 base_url: &str,
90 token: &str,
91 headers: Option<&HashMap<String, String>>,
92 ) -> std::result::Result<Self, String> {
93 Self::new_with_client(
94 base_url,
95 token,
96 headers,
97 reqwest::Client::builder().zstd(true),
98 )
99 }
100
101 pub fn new_with_client(
113 base_url: &str,
114 token: &str,
115 headers: Option<&HashMap<String, String>>,
116 client_builder: reqwest::ClientBuilder,
117 ) -> std::result::Result<Self, String> {
118 let mut headers_map = HeaderMap::new();
119
120 if let Some(headers) = headers {
122 for (key, value) in headers {
123 headers_map.insert(
124 HeaderName::from_str(key).map_err(|err| err.to_string())?,
125 value
126 .parse()
127 .map_err(|err: InvalidHeaderValue| err.to_string())?,
128 );
129 }
130 }
131
132 headers_map.insert(
134 HeaderName::from_str("Authorization").map_err(|err| err.to_string())?,
135 format!("Token {token}")
136 .parse()
137 .map_err(|err: InvalidHeaderValue| err.to_string())?,
138 );
139
140 Ok(Self {
141 request_full_permissions: false,
142 base_url: base_url.into(),
143 client: client_builder
144 .default_headers(headers_map)
145 .build()
146 .map_err(|err| err.to_string())?,
147 cached_data: Arc::new(CachedData {
148 custom_fields: HashMap::new(),
149 correspondents: HashMap::new(),
150 document_types: HashMap::new(),
151 groups: HashMap::new(),
152 storage_paths: HashMap::new(),
153 tags: HashMap::new(),
154 users: HashMap::new(),
155 }),
156 })
157 }
158
159 #[must_use]
161 pub fn request_full_permissions(mut self, req: bool) -> Self {
162 self.request_full_permissions = req;
163 self
164 }
165
166 async fn load_tags(&self) -> Result<HashMap<TagId, Tag>> {
167 debug!("loading tags");
168 let tags: Vec<Tag> = self
169 .fetch_all_pages("/api/tags/", self.permissions_query_param())
170 .await?;
171 Ok(tags.into_iter().map(|tag| (tag.id, tag)).collect())
172 }
173
174 async fn load_custom_fields(&self) -> Result<HashMap<CustomFieldId, CustomField>> {
175 debug!("loading custom fields");
176 let custom_fields: Vec<CustomField> =
177 self.fetch_all_pages("/api/custom_fields/", None).await?;
178 Ok(custom_fields
179 .into_iter()
180 .map(|custom_field| (custom_field.id, custom_field))
181 .collect())
182 }
183
184 async fn load_correspondents(&self) -> Result<HashMap<CorrespondentId, Correspondent>> {
185 debug!("loading correspondents");
186 let correspondents: Vec<Correspondent> = self
187 .fetch_all_pages("/api/correspondents/", self.permissions_query_param())
188 .await?;
189 Ok(correspondents
190 .into_iter()
191 .map(|correspondent| (correspondent.id, correspondent))
192 .collect())
193 }
194
195 async fn load_document_types(&self) -> Result<HashMap<DocumentTypeId, DocumentType>> {
196 debug!("loading document types");
197 let document_types: Vec<DocumentType> = self
198 .fetch_all_pages("/api/document_types/", self.permissions_query_param())
199 .await?;
200 Ok(document_types
201 .into_iter()
202 .map(|document_type| (document_type.id, document_type))
203 .collect())
204 }
205
206 async fn load_groups(&self) -> Result<HashMap<GroupId, Group>> {
207 debug!("loading groups");
208 let groups: Vec<Group> = self.fetch_all_pages("/api/groups/", None).await?;
209 Ok(groups.into_iter().map(|group| (group.id, group)).collect())
210 }
211
212 async fn load_users(&self) -> Result<HashMap<UserId, User>> {
213 debug!("loading users");
214 let users: Vec<User> = self.fetch_all_pages("/api/users/", None).await?;
215 Ok(users.into_iter().map(|user| (user.id, user)).collect())
216 }
217
218 async fn load_storage_paths(&self) -> Result<HashMap<StoragePathId, StoragePath>> {
219 debug!("loading storage paths");
220 let storage_paths: Vec<StoragePath> = self
221 .fetch_all_pages("/api/storage_paths/", self.permissions_query_param())
222 .await?;
223 Ok(storage_paths
224 .into_iter()
225 .map(|storage_path| (storage_path.id, storage_path))
226 .collect())
227 }
228
229 fn permissions_query_param(&self) -> Option<&'static [(&'static str, &'static str)]> {
230 if self.request_full_permissions {
231 Some(&[(QUERY_PARAM_FULL_PERMISSIONS, "true")])
232 } else {
233 None
234 }
235 }
236
237 pub async fn refresh_all(&mut self) -> Result<()> {
245 self.refresh(enum_iterator::all::<RefreshMetaData>()).await
246 }
247
248 pub async fn refresh(&mut self, data: impl IntoIterator<Item = RefreshMetaData>) -> Result<()> {
257 #[rustfmt::skip]
258 async fn inner(
259 client: &mut PaperlessClient,
260 data: &mut dyn Iterator<Item = RefreshMetaData>,
261 ) -> Result<()> {
262 let selected: std::collections::HashSet<_> = data.into_iter().collect();
263
264 if selected.is_empty() {
265 return Ok(());
266 }
267
268 let (tags, custom_fields, correspondents, document_types, groups, users, storage_paths) =
269 futures_util::try_join!(
270 async {
271 if selected.contains(&RefreshMetaData::Tags) {
272 Ok(Some(client.load_tags().await?))
273 } else {
274 Ok::<Option<_>, Error>(None)
275 }
276 },
277 async {
278 if selected.contains(&RefreshMetaData::CustomFields) {
279 Ok(Some(client.load_custom_fields().await?))
280 } else {
281 Ok(None)
282 }
283 },
284 async {
285 if selected.contains(&RefreshMetaData::Correspondents) {
286 Ok(Some(client.load_correspondents().await?))
287 } else {
288 Ok(None)
289 }
290 },
291 async {
292 if selected.contains(&RefreshMetaData::DocumentTypes) {
293 Ok(Some(client.load_document_types().await?))
294 } else {
295 Ok(None)
296 }
297 },
298 async {
299 if selected.contains(&RefreshMetaData::Groups) {
300 Ok(Some(client.load_groups().await?))
301 } else {
302 Ok(None)
303 }
304 },
305 async {
306 if selected.contains(&RefreshMetaData::Users) {
307 Ok(Some(client.load_users().await?))
308 } else {
309 Ok(None)
310 }
311 },
312 async {
313 if selected.contains(&RefreshMetaData::StoragePaths) {
314 Ok(Some(client.load_storage_paths().await?))
315 } else {
316 Ok(None)
317 }
318 },
319 )?;
320
321 let cached_data = Arc::make_mut(&mut client.cached_data);
322
323 if let Some(value) = correspondents { cached_data.correspondents = value; }
324 if let Some(value) = document_types { cached_data.document_types = value; }
325 if let Some(value) = groups { cached_data.groups = value; }
326 if let Some(value) = tags { cached_data.tags = value; }
327 if let Some(value) = custom_fields { cached_data.custom_fields = value; }
328 if let Some(value) = users { cached_data.users = value; }
329 if let Some(value) = storage_paths { cached_data.storage_paths = value; }
330
331 Ok(())
332 }
333
334 inner(self, &mut data.into_iter()).await
335 }
336
337 pub async fn get_documents_by_tags(
339 &self,
340 tag_ids: &[TagId],
341 truncate_content: bool,
342 ) -> Result<Vec<Document>> {
343 let tag_id_str = tag_ids
344 .iter()
345 .map(|tag_id| tag_id.0.to_string())
346 .collect::<Vec<_>>()
347 .join(",");
348
349 let documents: Vec<_> = self
350 .fetch_all_pages::<DocumentData>(
351 "/api/documents/",
352 Some(&[
353 (QUERY_PARAM_TAGS_ID_IN, &tag_id_str),
354 (QUERY_PARAM_TRUNCATE_CONTENT, &format!("{truncate_content}")),
355 ]),
356 )
357 .await?
358 .into_iter()
359 .map(|data| Document::new(data, Arc::new(self.clone()), truncate_content))
360 .collect();
361
362 Ok(documents)
363 }
364
365 pub(crate) async fn get_document_data_by_id(&self, id: DocumentId) -> Result<DocumentData> {
366 let resp = self
367 .request(Method::GET, &format!("/api/documents/{}/", id.0), None)
368 .await?;
369
370 let document_data: DocumentData = resp
371 .json()
372 .await
373 .map_err(|e| Error::Other(format!("Failed to parse document: {e}")))?;
374
375 Ok(document_data)
376 }
377
378 pub async fn get_document_by_id(&self, id: DocumentId) -> Result<Document> {
380 Ok(Document::new(
381 self.get_document_data_by_id(id).await?,
382 Arc::new(self.clone()),
383 false,
384 ))
385 }
386
387 pub(crate) async fn request(
388 &self,
389 method: Method,
390 endpoint: &str,
391 body: Option<&serde_json::Value>,
392 ) -> Result<reqwest::Response> {
393 let mut req = self
394 .client
395 .request(method, format!("{}{endpoint}", self.base_url))
396 .header(ACCEPT, "application/json");
397
398 if let Some(json_body) = body {
400 req = req.json(json_body);
401 }
402
403 let resp = req
404 .send()
405 .await
406 .map_err(|e| Error::Other(format!("Failed to send request: {e}")))?;
407
408 if resp.status() == StatusCode::NOT_FOUND {
409 return Err(Error::NotFound);
410 }
411
412 if !resp.status().is_success() {
413 return Err(Error::Response {
414 status_code: resp.status().as_u16(),
415 body: resp.text().await.unwrap_or_default(),
416 });
417 }
418
419 Ok(resp)
420 }
421
422 pub(crate) async fn request_with_body(
423 &self,
424 method: Method,
425 endpoint: &str,
426 body: &impl serde::Serialize,
427 ) -> Result<reqwest::Response> {
428 let body = serde_json::to_value(body).map_err(|e| Error::Other(e.to_string()))?;
429 self.request(method, endpoint, Some(&body)).await
430 }
431
432 pub(crate) async fn fetch_all_pages<T: for<'de> Deserialize<'de>>(
433 &self,
434 endpoint: &str,
435 query_params: Option<&[(&str, &str)]>,
436 ) -> Result<Vec<T>> {
437 let mut results = Vec::new();
438 let mut current_url = endpoint.to_string();
439 let mut first_param = true;
440
441 if let Some(params) = query_params {
442 for param in params {
443 if first_param {
444 current_url.push('?');
445 first_param = false;
446 } else {
447 current_url.push('&');
448 }
449 let _ = write!(current_url, "{}={}", param.0, param.1);
450 }
451 }
452
453 let mut current_url = Some(current_url);
454
455 while let Some(url) = current_url {
456 let resp = self.request(Method::GET, &url, None).await?;
457
458 let page: PaginatedResponse<T> = resp.json().await.map_err(|e| {
459 Error::InvalidJson(format!(
460 "Failed to parse paginated response for {endpoint}: {e:?}"
461 ))
462 })?;
463
464 results.extend(page.results);
465
466 current_url = page.next.and_then(|next_url| {
467 next_url
469 .trim_start_matches(&*self.base_url)
470 .to_string()
471 .into()
472 });
473 }
474
475 Ok(results)
476 }
477
478 pub async fn get_task_status(
480 &self,
481 task_id: Option<&TaskId>,
482 task_name: Option<&str>,
483 acknowledged: Option<bool>,
484 ) -> Result<Vec<Task>> {
485 let mut query = Vec::new();
486
487 if let Some(id) = task_id {
488 query.push(("task_id", id.to_string()));
489 }
490
491 if let Some(name) = task_name {
492 query.push(("task_name", name.to_string()));
493 }
494
495 if let Some(ack) = acknowledged {
496 query.push(("acknowledged", ack.to_string()));
497 }
498
499 let resp = self
500 .request(
501 Method::GET,
502 &format!(
503 "/api/tasks/?{}",
504 serde_urlencoded::to_string(&query)
505 .map_err(|e| Error::Other(format!("Failed to serialize query: {e}")))?
506 ),
507 None::<&serde_json::Value>,
508 )
509 .await?;
510
511 trace!("get_task_status response: {:?}", resp);
512
513 let body = resp
514 .text()
515 .await
516 .map_err(|e| Error::Other(format!("Failed to read response body: {e}")))?;
517
518 let tasks: Vec<Task> = match serde_json::from_str(&body) {
519 Ok(t) => t,
520 Err(e) => {
521 return Err(Error::InvalidJson(format!(
522 "Failed to parse response body: {e}"
523 )));
524 }
525 };
526
527 if tasks.is_empty() {
528 return Err(Error::NotFound);
529 }
530
531 Ok(tasks)
532 }
533
534 pub fn get_workflows(&self) -> impl Future<Output = Result<Vec<Workflow>>> {
535 self.fetch_all_pages("/api/workflows/", None)
536 }
537
538 pub fn get_saved_views(&self) -> impl Future<Output = Result<Vec<SavedView>>> {
539 self.fetch_all_pages("/api/saved_views/", None)
540 }
541
542 pub async fn get_statistics(&self) -> Result<util::Statistics> {
543 self.request(Method::GET, "/api/statistics/", None)
544 .await
545 .map_err(|e| Error::Other(format!("Failed to send request: {e}")))?
546 .json()
547 .await
548 .map_err(|e| Error::Other(format!("Failed to parse response body: {e:?}")))
549 }
550
551 pub async fn get_status(&self) -> Result<util::ServerStatus> {
552 self.request(Method::GET, "/api/status/", None)
553 .await
554 .map_err(|e| Error::Other(format!("Failed to send request: {e}")))?
555 .json()
556 .await
557 .map_err(|e| Error::Other(format!("Failed to parse response body: {e:?}")))
558 }
559
560 pub async fn create<T>(&self, new_item: T) -> Result<T::BaseType>
566 where
567 T: CreateDtoObject,
568 {
569 let url = format!("/api/{}/", T::endpoint());
570 let resp = self
571 .request_with_body(Method::POST, &url, &new_item)
572 .await?;
573
574 resp.json::<T::BaseType>()
575 .await
576 .map_err(|e| Error::Other(format!("Failed to parse response body: {e}")))
577 }
578
579 pub async fn upload_document(&self, file_path: &Path, filename: &str) -> Result<TaskId> {
583 let file_bytes = std::fs::read(file_path)
584 .map_err(|e| Error::Other(format!("Failed to read file: {e}")))?;
585
586 let form = multipart::Form::new().part(
587 "document",
588 multipart::Part::bytes(file_bytes).file_name(filename.to_string()),
589 );
590
591 let url = format!("{}/api/documents/post_document/", self.base_url);
592
593 let resp = self
594 .client
595 .post(&url)
596 .multipart(form)
597 .send()
598 .await
599 .map_err(|e| Error::Other(format!("Failed to send request: {e}")))?;
600
601 let status = resp.status();
602 if !resp.status().is_success() {
603 return Err(Error::Response {
604 status_code: status.as_u16(),
605 body: resp.text().await.unwrap_or_default(),
606 });
607 }
608
609 let task_id: String = resp
610 .json()
611 .await
612 .map_err(|e| Error::Other(format!("Failed to parse task ID: {e}")))?;
613 Ok(TaskId(task_id))
614 }
615
616 #[inline]
617 #[must_use]
618 pub fn tags(&self) -> &HashMap<TagId, Tag> {
619 &self.cached_data.tags
620 }
621
622 #[inline]
623 #[must_use]
624 pub fn storage_paths(&self) -> &HashMap<StoragePathId, StoragePath> {
625 &self.cached_data.storage_paths
626 }
627
628 #[must_use]
629 pub fn find_tag_by_name(&self, name: &str) -> Option<&Tag> {
630 self.cached_data.tags.values().find(|tag| tag.name == name)
631 }
632
633 #[inline]
634 #[must_use]
635 pub fn document_types(&self) -> &HashMap<DocumentTypeId, DocumentType> {
636 &self.cached_data.document_types
637 }
638
639 #[must_use]
640 pub fn find_document_type_by_name(&self, name: &str) -> Option<&DocumentType> {
641 self.cached_data
642 .document_types
643 .values()
644 .find(|dt| dt.name == name)
645 }
646
647 #[inline]
648 #[must_use]
649 pub fn correspondents(&self) -> &HashMap<CorrespondentId, Correspondent> {
650 &self.cached_data.correspondents
651 }
652
653 #[inline]
654 #[must_use]
655 pub fn custom_fields(&self) -> &HashMap<CustomFieldId, CustomField> {
656 &self.cached_data.custom_fields
657 }
658
659 #[must_use]
660 pub fn find_custom_field_by_name(&self, name: &str) -> Option<&CustomField> {
661 self.cached_data
662 .custom_fields
663 .values()
664 .find(|field| field.name == name)
665 }
666
667 #[inline]
668 #[must_use]
669 pub fn users(&self) -> &HashMap<UserId, User> {
670 &self.cached_data.users
671 }
672
673 #[inline]
674 #[must_use]
675 pub fn groups(&self) -> &HashMap<GroupId, Group> {
676 &self.cached_data.groups
677 }
678}