1use std::{collections::HashMap, path::Path, str::FromStr, sync::Arc};
2
3use enum_iterator::Sequence;
4use reqwest::{
5 Method, StatusCode,
6 header::{ACCEPT, HeaderMap, HeaderName, InvalidHeaderValue},
7 multipart,
8};
9use serde::Deserialize;
10use tracing::{debug, trace};
11
12use crate::{
13 Error, Result, User,
14 correspondent::Correspondent,
15 custom_field::CustomField,
16 document::{Document, DocumentData},
17 document_type::DocumentType,
18 id::{
19 CorrespondentId, CustomFieldId, DocumentId, DocumentTypeId, StoragePathId, TagId, TaskId,
20 UserId,
21 },
22 storage_path::StoragePath,
23 tag::Tag,
24 task::Task,
25 workflow::Workflow,
26};
27
28#[derive(Copy, Clone, Debug, PartialEq, Eq, Sequence)]
30#[non_exhaustive]
31pub enum RefreshData {
32 Tags,
33 CustomFields,
34 Correspondents,
35 DocumentTypes,
36 Users,
37 StoragePaths,
38}
39
40#[derive(Debug, Clone)]
42pub struct PaperlessClient {
43 client: reqwest::Client,
44 pub(crate) base_url: Box<str>,
45 cached_data: Arc<CachedData>,
46}
47
48#[derive(Debug, Clone)]
49struct CachedData {
50 correspondents: HashMap<CorrespondentId, Correspondent>,
51 document_types: HashMap<DocumentTypeId, DocumentType>,
52 tags: HashMap<TagId, Tag>,
53 custom_fields: HashMap<CustomFieldId, CustomField>,
54 users: HashMap<UserId, User>,
55 storage_paths: HashMap<StoragePathId, StoragePath>,
56}
57
58#[derive(Debug, Deserialize)]
59struct PaginatedResponse<T> {
60 results: Vec<T>,
61 next: Option<String>,
62}
63
64impl PaperlessClient {
65 pub fn new(
67 base_url: &str,
68 token: &str,
69 headers: Option<&HashMap<String, String>>,
70 ) -> std::result::Result<Self, String> {
71 let mut headers_map = HeaderMap::new();
72
73 if let Some(headers) = headers {
75 for (key, value) in headers {
76 headers_map.insert(
77 HeaderName::from_str(key).map_err(|err| err.to_string())?,
78 value
79 .parse()
80 .map_err(|err: InvalidHeaderValue| err.to_string())?,
81 );
82 }
83 }
84
85 headers_map.insert(
87 HeaderName::from_str("Authorization").map_err(|err| err.to_string())?,
88 format!("Token {token}")
89 .parse()
90 .map_err(|err: InvalidHeaderValue| err.to_string())?,
91 );
92
93 Ok(Self {
94 base_url: base_url.into(),
95 client: reqwest::Client::builder()
96 .default_headers(headers_map)
97 .zstd(true)
98 .build()
99 .map_err(|err| err.to_string())?,
100 cached_data: Arc::new(CachedData {
101 tags: HashMap::new(),
102 custom_fields: HashMap::new(),
103 correspondents: HashMap::new(),
104 document_types: HashMap::new(),
105 users: HashMap::new(),
106 storage_paths: HashMap::new(),
107 }),
108 })
109 }
110
111 async fn load_tags(&self) -> Result<HashMap<TagId, Tag>> {
112 debug!("loading tags");
113 let tags: Vec<Tag> = self.fetch_all_pages("/api/tags/").await?;
114 Ok(tags.into_iter().map(|tag| (tag.id, tag)).collect())
115 }
116
117 async fn load_custom_fields(&self) -> Result<HashMap<CustomFieldId, CustomField>> {
118 debug!("loading custom fields");
119 let custom_fields: Vec<CustomField> = self.fetch_all_pages("/api/custom_fields/").await?;
120 Ok(custom_fields
121 .into_iter()
122 .map(|custom_field| (custom_field.id, custom_field))
123 .collect())
124 }
125
126 async fn load_correspondents(&self) -> Result<HashMap<CorrespondentId, Correspondent>> {
127 debug!("loading correspondents");
128 let correspondents: Vec<Correspondent> =
129 self.fetch_all_pages("/api/correspondents/").await?;
130 Ok(correspondents
131 .into_iter()
132 .map(|correspondent| (correspondent.id, correspondent))
133 .collect())
134 }
135
136 async fn load_document_types(&self) -> Result<HashMap<DocumentTypeId, DocumentType>> {
137 debug!("loading document types");
138 let document_types: Vec<DocumentType> =
139 self.fetch_all_pages("/api/document_types/").await?;
140 Ok(document_types
141 .into_iter()
142 .map(|document_type| (document_type.id, document_type))
143 .collect())
144 }
145
146 async fn load_users(&self) -> Result<HashMap<UserId, User>> {
147 debug!("loading users");
148 let users: Vec<User> = self.fetch_all_pages("/api/users/").await?;
149 Ok(users.into_iter().map(|user| (user.id, user)).collect())
150 }
151
152 async fn load_storage_paths(&self) -> Result<HashMap<StoragePathId, StoragePath>> {
153 debug!("loading storage paths");
154 let storage_paths: Vec<StoragePath> = self.fetch_all_pages("/api/storage_paths/").await?;
155 Ok(storage_paths
156 .into_iter()
157 .map(|storage_path| (storage_path.id, storage_path))
158 .collect())
159 }
160
161 pub async fn refresh_all(&mut self) -> Result<()> {
162 self.refresh(enum_iterator::all::<RefreshData>()).await
163 }
164
165 pub async fn refresh(&mut self, data: impl IntoIterator<Item = RefreshData>) -> Result<()> {
167 let mut refresh_tags = false;
168 let mut refresh_custom_fields = false;
169 let mut refresh_correspondents = false;
170 let mut refresh_document_types = false;
171 let mut refresh_users = false;
172 let mut refresh_storage_paths = false;
173
174 for item in data {
175 match item {
176 RefreshData::Tags => refresh_tags = true,
177 RefreshData::CustomFields => refresh_custom_fields = true,
178 RefreshData::Correspondents => refresh_correspondents = true,
179 RefreshData::DocumentTypes => refresh_document_types = true,
180 RefreshData::Users => refresh_users = true,
181 RefreshData::StoragePaths => refresh_storage_paths = true,
182 }
183 }
184
185 let (tags, custom_fields, correspondents, document_types, users, storage_paths) = futures_util::try_join!(
186 async {
187 if refresh_tags {
188 Ok::<Option<HashMap<TagId, Tag>>, Error>(Some(self.load_tags().await?))
189 } else {
190 Ok::<Option<HashMap<TagId, Tag>>, Error>(None)
191 }
192 },
193 async {
194 if refresh_custom_fields {
195 Ok::<Option<HashMap<CustomFieldId, CustomField>>, Error>(Some(
196 self.load_custom_fields().await?,
197 ))
198 } else {
199 Ok::<Option<HashMap<CustomFieldId, CustomField>>, Error>(None)
200 }
201 },
202 async {
203 if refresh_correspondents {
204 Ok::<Option<HashMap<CorrespondentId, Correspondent>>, Error>(Some(
205 self.load_correspondents().await?,
206 ))
207 } else {
208 Ok::<Option<HashMap<CorrespondentId, Correspondent>>, Error>(None)
209 }
210 },
211 async {
212 if refresh_document_types {
213 Ok::<Option<HashMap<DocumentTypeId, DocumentType>>, Error>(Some(
214 self.load_document_types().await?,
215 ))
216 } else {
217 Ok::<Option<HashMap<DocumentTypeId, DocumentType>>, Error>(None)
218 }
219 },
220 async {
221 if refresh_users {
222 Ok::<Option<HashMap<UserId, User>>, Error>(Some(self.load_users().await?))
223 } else {
224 Ok::<Option<HashMap<UserId, User>>, Error>(None)
225 }
226 },
227 async {
228 if refresh_storage_paths {
229 Ok::<Option<HashMap<StoragePathId, StoragePath>>, Error>(Some(
230 self.load_storage_paths().await?,
231 ))
232 } else {
233 Ok::<Option<HashMap<StoragePathId, StoragePath>>, Error>(None)
234 }
235 },
236 )?;
237
238 let cached_data = Arc::make_mut(&mut self.cached_data);
241
242 if let Some(correspondents) = correspondents {
243 cached_data.correspondents = correspondents;
244 }
245 if let Some(document_types) = document_types {
246 cached_data.document_types = document_types;
247 }
248 if let Some(tags) = tags {
249 cached_data.tags = tags;
250 }
251 if let Some(custom_fields) = custom_fields {
252 cached_data.custom_fields = custom_fields;
253 }
254 if let Some(users) = users {
255 cached_data.users = users;
256 }
257 if let Some(storage_paths) = storage_paths {
258 cached_data.storage_paths = storage_paths;
259 }
260
261 Ok(())
262 }
263
264 pub async fn get_documents_by_tags(
266 &self,
267 tag_ids: &[TagId],
268 truncate_content: bool,
269 ) -> Result<Vec<Document>> {
270 let tag_id_str = tag_ids
271 .iter()
272 .map(|tag_id| tag_id.0.to_string())
273 .collect::<Vec<_>>()
274 .join(",");
275 let documents: Vec<_> = self
276 .fetch_all_pages::<DocumentData>(&format!(
277 "/api/documents/?truncate_content={truncate_content}&tags__id__in={tag_id_str}"
278 ))
279 .await?
280 .into_iter()
281 .map(|data| Document::new(data, Arc::new(self.clone()), truncate_content))
282 .collect();
283
284 Ok(documents)
285 }
286
287 pub(crate) async fn get_document_data_by_id(&self, id: DocumentId) -> Result<DocumentData> {
288 let resp = self
289 .request(Method::GET, &format!("/api/documents/{}/", id.0), None)
290 .await?;
291
292 let document_data: DocumentData = resp
293 .json()
294 .await
295 .map_err(|e| Error::Other(format!("Failed to parse document: {e}")))?;
296
297 Ok(document_data)
298 }
299
300 pub async fn get_document_by_id(&self, id: DocumentId) -> Result<Document> {
302 Ok(Document::new(
303 self.get_document_data_by_id(id).await?,
304 Arc::new(self.clone()),
305 false,
306 ))
307 }
308
309 pub(crate) async fn request(
310 &self,
311 method: Method,
312 endpoint: &str,
313 body: Option<&serde_json::Value>,
314 ) -> Result<reqwest::Response> {
315 let mut req = self
316 .client
317 .request(method, format!("{}{endpoint}", self.base_url))
318 .header(ACCEPT, "application/json");
319
320 if let Some(json_body) = body {
322 req = req.json(json_body);
323 }
324
325 let resp = req
326 .send()
327 .await
328 .map_err(|e| Error::Other(format!("Failed to send request: {e}")))?;
329
330 if resp.status() == StatusCode::NOT_FOUND {
331 return Err(Error::NotFound);
332 }
333
334 if !resp.status().is_success() {
335 return Err(Error::Response {
336 status_code: resp.status().as_u16(),
337 body: resp.text().await.unwrap_or_default(),
338 });
339 }
340
341 Ok(resp)
342 }
343
344 pub(crate) async fn fetch_all_pages<T: for<'de> Deserialize<'de>>(
345 &self,
346 endpoint: &str,
347 ) -> Result<Vec<T>> {
348 let mut results = Vec::new();
349 let mut current_url = Some(endpoint.to_string());
350
351 while let Some(url) = current_url {
352 let resp = self.request(Method::GET, &url, None).await?;
353
354 let page: PaginatedResponse<T> = resp.json().await.map_err(|e| {
355 Error::InvalidJson(format!(
356 "Failed to parse paginated response for {endpoint}: {e}"
357 ))
358 })?;
359
360 results.extend(page.results);
361
362 current_url = page.next.and_then(|next_url| {
363 next_url
365 .trim_start_matches(&*self.base_url)
366 .to_string()
367 .into()
368 });
369 }
370
371 Ok(results)
372 }
373
374 pub async fn get_task_status(
376 &self,
377 task_id: Option<&TaskId>,
378 task_name: Option<&str>,
379 acknowledged: Option<bool>,
380 ) -> Result<Vec<Task>> {
381 let mut query = Vec::new();
382
383 if let Some(id) = task_id {
384 query.push(("task_id", id.to_string()));
385 }
386
387 if let Some(name) = task_name {
388 query.push(("task_name", name.to_string()));
389 }
390
391 if let Some(ack) = acknowledged {
392 query.push(("acknowledged", ack.to_string()));
393 }
394
395 let resp = self
396 .request(
397 Method::GET,
398 &format!(
399 "/api/tasks/?{}",
400 serde_urlencoded::to_string(&query)
401 .map_err(|e| Error::Other(format!("Failed to serialize query: {e}")))?
402 ),
403 None,
404 )
405 .await?;
406
407 trace!("get_task_status response: {:?}", resp);
408
409 let body = resp
410 .text()
411 .await
412 .map_err(|e| Error::Other(format!("Failed to read response body: {e}")))?;
413
414 let tasks: Vec<Task> = match serde_json::from_str(&body) {
415 Ok(t) => t,
416 Err(e) => {
417 return Err(Error::InvalidJson(format!(
418 "Failed to parse response body: {e}"
419 )));
420 }
421 };
422
423 if tasks.is_empty() {
424 return Err(Error::NotFound);
425 }
426
427 Ok(tasks)
428 }
429
430 pub fn get_workflows(&self) -> impl Future<Output = Result<Vec<Workflow>>> {
431 self.fetch_all_pages("/api/workflows/")
432 }
433
434 pub async fn upload_document(&self, file_path: &Path, filename: &str) -> Result<TaskId> {
438 let file_bytes = std::fs::read(file_path)
439 .map_err(|e| Error::Other(format!("Failed to read file: {e}")))?;
440
441 let form = multipart::Form::new().part(
442 "document",
443 multipart::Part::bytes(file_bytes).file_name(filename.to_string()),
444 );
445
446 let url = format!("{}/api/documents/post_document/", self.base_url);
447
448 let resp = self
449 .client
450 .post(&url)
451 .multipart(form)
452 .send()
453 .await
454 .map_err(|e| Error::Other(format!("Failed to send request: {e}")))?;
455
456 let status = resp.status();
457 if !resp.status().is_success() {
458 return Err(Error::Response {
459 status_code: status.as_u16(),
460 body: resp.text().await.unwrap_or_default(),
461 });
462 }
463
464 let task_id: String = resp
465 .json()
466 .await
467 .map_err(|e| Error::Other(format!("Failed to parse task ID: {e}")))?;
468 Ok(TaskId(task_id))
469 }
470
471 #[inline]
472 #[must_use]
473 pub fn tags(&self) -> &HashMap<TagId, Tag> {
474 &self.cached_data.tags
475 }
476
477 #[inline]
478 #[must_use]
479 pub fn storage_paths(&self) -> &HashMap<StoragePathId, StoragePath> {
480 &self.cached_data.storage_paths
481 }
482
483 #[must_use]
484 pub fn find_tag_by_name(&self, name: &str) -> Option<&Tag> {
485 self.cached_data.tags.values().find(|tag| tag.name == name)
486 }
487
488 #[inline]
489 #[must_use]
490 pub fn document_types(&self) -> &HashMap<DocumentTypeId, DocumentType> {
491 &self.cached_data.document_types
492 }
493
494 #[must_use]
495 pub fn find_document_type_by_name(&self, name: &str) -> Option<&DocumentType> {
496 self.cached_data
497 .document_types
498 .values()
499 .find(|dt| dt.name == name)
500 }
501
502 #[inline]
503 #[must_use]
504 pub fn correspondents(&self) -> &HashMap<CorrespondentId, Correspondent> {
505 &self.cached_data.correspondents
506 }
507
508 #[inline]
509 #[must_use]
510 pub fn custom_fields(&self) -> &HashMap<CustomFieldId, CustomField> {
511 &self.cached_data.custom_fields
512 }
513
514 #[must_use]
515 pub fn find_custom_field_by_name(&self, name: &str) -> Option<&CustomField> {
516 self.cached_data
517 .custom_fields
518 .values()
519 .find(|field| field.name == name)
520 }
521
522 #[inline]
523 #[must_use]
524 pub fn users(&self) -> &HashMap<UserId, User> {
525 &self.cached_data.users
526 }
527}