1use std::{collections::HashMap, path::Path, str::FromStr, sync::Arc};
2
3use enum_iterator::Sequence;
4use reqwest::{
5 Method, StatusCode,
6 header::{ACCEPT, HeaderMap, HeaderName, InvalidHeaderValue},
7 multipart,
8};
9use serde::Deserialize;
10use tracing::{debug, trace};
11
12use crate::{
13 Error, Result, User,
14 correspondent::Correspondent,
15 custom_field::CustomField,
16 document::{Document, DocumentData},
17 document_type::DocumentType,
18 id::{
19 CorrespondentId, CustomFieldId, DocumentId, DocumentTypeId, StoragePathId, TagId, TaskId,
20 UserId,
21 },
22 storage_path::StoragePath,
23 tag::Tag,
24 task::Task,
25 workflow::Workflow,
26};
27
28#[derive(Copy, Clone, Debug, PartialEq, Eq, Sequence)]
30#[non_exhaustive]
31pub enum RefreshData {
32 Tags,
33 CustomFields,
34 Correspondents,
35 DocumentTypes,
36 Users,
37 StoragePaths,
38}
39
40#[derive(Debug, Clone)]
42pub struct PaperlessClient {
43 client: reqwest::Client,
44 pub(crate) base_url: Box<str>,
45 cached_data: Arc<CachedData>,
46}
47
48#[derive(Debug)]
49struct CachedData {
50 correspondents: HashMap<CorrespondentId, Correspondent>,
51 document_types: HashMap<DocumentTypeId, DocumentType>,
52 tags: HashMap<TagId, Tag>,
53 custom_fields: HashMap<CustomFieldId, CustomField>,
54 users: HashMap<UserId, User>,
55 storage_paths: HashMap<StoragePathId, StoragePath>,
56}
57
58#[derive(Debug, Deserialize)]
59struct PaginatedResponse<T> {
60 results: Vec<T>,
61 next: Option<String>,
62}
63
64impl PaperlessClient {
65 pub fn new(
67 base_url: &str,
68 token: &str,
69 headers: Option<&HashMap<String, String>>,
70 ) -> std::result::Result<Self, String> {
71 let mut headers_map = HeaderMap::new();
72
73 if let Some(headers) = headers {
75 for (key, value) in headers {
76 headers_map.insert(
77 HeaderName::from_str(key).map_err(|err| err.to_string())?,
78 value
79 .parse()
80 .map_err(|err: InvalidHeaderValue| err.to_string())?,
81 );
82 }
83 }
84
85 headers_map.insert(
87 HeaderName::from_str("Authorization").map_err(|err| err.to_string())?,
88 format!("Token {token}")
89 .parse()
90 .map_err(|err: InvalidHeaderValue| err.to_string())?,
91 );
92
93 Ok(Self {
94 base_url: base_url.into(),
95 client: reqwest::Client::builder()
96 .default_headers(headers_map)
97 .zstd(true)
98 .build()
99 .map_err(|err| err.to_string())?,
100 cached_data: Arc::new(CachedData {
101 tags: HashMap::new(),
102 custom_fields: HashMap::new(),
103 correspondents: HashMap::new(),
104 document_types: HashMap::new(),
105 users: HashMap::new(),
106 storage_paths: HashMap::new(),
107 }),
108 })
109 }
110
111 async fn load_tags(&self) -> Result<HashMap<TagId, Tag>> {
112 debug!("loading tags");
113 let tags: Vec<Tag> = self.fetch_all_pages("/api/tags/").await?;
114 Ok(tags.into_iter().map(|tag| (tag.id, tag)).collect())
115 }
116
117 async fn load_custom_fields(&self) -> Result<HashMap<CustomFieldId, CustomField>> {
118 debug!("loading custom fields");
119 let custom_fields: Vec<CustomField> = self.fetch_all_pages("/api/custom_fields/").await?;
120 Ok(custom_fields
121 .into_iter()
122 .map(|custom_field| (custom_field.id, custom_field))
123 .collect())
124 }
125
126 async fn load_correspondents(&self) -> Result<HashMap<CorrespondentId, Correspondent>> {
127 debug!("loading correspondents");
128 let correspondents: Vec<Correspondent> =
129 self.fetch_all_pages("/api/correspondents/").await?;
130 Ok(correspondents
131 .into_iter()
132 .map(|correspondent| (correspondent.id, correspondent))
133 .collect())
134 }
135
136 async fn load_document_types(&self) -> Result<HashMap<DocumentTypeId, DocumentType>> {
137 debug!("loading document types");
138 let document_types: Vec<DocumentType> =
139 self.fetch_all_pages("/api/document_types/").await?;
140 Ok(document_types
141 .into_iter()
142 .map(|document_type| (document_type.id, document_type))
143 .collect())
144 }
145
146 async fn load_users(&self) -> Result<HashMap<UserId, User>> {
147 debug!("loading users");
148 let users: Vec<User> = self.fetch_all_pages("/api/users/").await?;
149 Ok(users.into_iter().map(|user| (user.id, user)).collect())
150 }
151
152 async fn load_storage_paths(&self) -> Result<HashMap<StoragePathId, StoragePath>> {
153 debug!("loading storage paths");
154 let storage_paths: Vec<StoragePath> = self.fetch_all_pages("/api/storage_paths/").await?;
155 Ok(storage_paths
156 .into_iter()
157 .map(|storage_path| (storage_path.id, storage_path))
158 .collect())
159 }
160
161 pub async fn refresh_all(&mut self) -> Result<()> {
162 self.refresh(enum_iterator::all::<RefreshData>()).await
163 }
164
165 pub async fn refresh(&mut self, data: impl IntoIterator<Item = RefreshData>) -> Result<()> {
167 let mut refresh_tags = false;
168 let mut refresh_custom_fields = false;
169 let mut refresh_correspondents = false;
170 let mut refresh_document_types = false;
171 let mut refresh_users = false;
172 let mut refresh_storage_paths = false;
173
174 for item in data {
175 match item {
176 RefreshData::Tags => refresh_tags = true,
177 RefreshData::CustomFields => refresh_custom_fields = true,
178 RefreshData::Correspondents => refresh_correspondents = true,
179 RefreshData::DocumentTypes => refresh_document_types = true,
180 RefreshData::Users => refresh_users = true,
181 RefreshData::StoragePaths => refresh_storage_paths = true,
182 }
183 }
184
185 let (tags, custom_fields, correspondents, document_types, users, storage_paths) = futures_util::try_join!(
186 async {
187 if refresh_tags {
188 Ok::<Option<HashMap<TagId, Tag>>, Error>(Some(self.load_tags().await?))
189 } else {
190 Ok::<Option<HashMap<TagId, Tag>>, Error>(None)
191 }
192 },
193 async {
194 if refresh_custom_fields {
195 Ok::<Option<HashMap<CustomFieldId, CustomField>>, Error>(Some(
196 self.load_custom_fields().await?,
197 ))
198 } else {
199 Ok::<Option<HashMap<CustomFieldId, CustomField>>, Error>(None)
200 }
201 },
202 async {
203 if refresh_correspondents {
204 Ok::<Option<HashMap<CorrespondentId, Correspondent>>, Error>(Some(
205 self.load_correspondents().await?,
206 ))
207 } else {
208 Ok::<Option<HashMap<CorrespondentId, Correspondent>>, Error>(None)
209 }
210 },
211 async {
212 if refresh_document_types {
213 Ok::<Option<HashMap<DocumentTypeId, DocumentType>>, Error>(Some(
214 self.load_document_types().await?,
215 ))
216 } else {
217 Ok::<Option<HashMap<DocumentTypeId, DocumentType>>, Error>(None)
218 }
219 },
220 async {
221 if refresh_users {
222 Ok::<Option<HashMap<UserId, User>>, Error>(Some(self.load_users().await?))
223 } else {
224 Ok::<Option<HashMap<UserId, User>>, Error>(None)
225 }
226 },
227 async {
228 if refresh_storage_paths {
229 Ok::<Option<HashMap<StoragePathId, StoragePath>>, Error>(Some(
230 self.load_storage_paths().await?,
231 ))
232 } else {
233 Ok::<Option<HashMap<StoragePathId, StoragePath>>, Error>(None)
234 }
235 },
236 )?;
237
238 self.cached_data = Arc::new(CachedData {
239 correspondents: correspondents.unwrap_or_default(),
240 document_types: document_types.unwrap_or_default(),
241 tags: tags.unwrap_or_default(),
242 custom_fields: custom_fields.unwrap_or_default(),
243 users: users.unwrap_or_default(),
244 storage_paths: storage_paths.unwrap_or_default(),
245 });
246
247 Ok(())
248 }
249
250 pub async fn get_documents_by_tags(
252 &self,
253 tag_ids: &[TagId],
254 truncate_content: bool,
255 ) -> Result<Vec<Document>> {
256 let tag_id_str = tag_ids
257 .iter()
258 .map(|tag_id| tag_id.0.to_string())
259 .collect::<Vec<_>>()
260 .join(",");
261 let documents: Vec<_> = self
262 .fetch_all_pages::<DocumentData>(&format!(
263 "/api/documents/?truncate_content={truncate_content}&tags__id__in={tag_id_str}"
264 ))
265 .await?
266 .into_iter()
267 .map(|data| Document::new(data, Arc::new(self.clone()), truncate_content))
268 .collect();
269
270 Ok(documents)
271 }
272
273 pub(crate) async fn get_document_data_by_id(&self, id: DocumentId) -> Result<DocumentData> {
274 let resp = self
275 .request(Method::GET, &format!("/api/documents/{}/", id.0), None)
276 .await?;
277
278 let document_data: DocumentData = resp
279 .json()
280 .await
281 .map_err(|e| Error::Other(format!("Failed to parse document: {e}")))?;
282
283 Ok(document_data)
284 }
285
286 pub async fn get_document_by_id(&self, id: DocumentId) -> Result<Document> {
288 Ok(Document::new(
289 self.get_document_data_by_id(id).await?,
290 Arc::new(self.clone()),
291 false,
292 ))
293 }
294
295 pub(crate) async fn request(
296 &self,
297 method: Method,
298 endpoint: &str,
299 body: Option<&serde_json::Value>,
300 ) -> Result<reqwest::Response> {
301 let mut req = self
302 .client
303 .request(method, format!("{}{endpoint}", self.base_url))
304 .header(ACCEPT, "application/json");
305
306 if let Some(json_body) = body {
308 req = req.json(json_body);
309 }
310
311 let resp = req
312 .send()
313 .await
314 .map_err(|e| Error::Other(format!("Failed to send request: {e}")))?;
315
316 if resp.status() == StatusCode::NOT_FOUND {
317 return Err(Error::NotFound);
318 }
319
320 if !resp.status().is_success() {
321 return Err(Error::Response {
322 status_code: resp.status().as_u16(),
323 body: resp.text().await.unwrap_or_default(),
324 });
325 }
326
327 Ok(resp)
328 }
329
330 pub(crate) async fn fetch_all_pages<T: for<'de> Deserialize<'de>>(
331 &self,
332 endpoint: &str,
333 ) -> Result<Vec<T>> {
334 let mut results = Vec::new();
335 let mut current_url = Some(endpoint.to_string());
336
337 while let Some(url) = current_url {
338 let resp = self.request(Method::GET, &url, None).await?;
339
340 let page: PaginatedResponse<T> = resp.json().await.map_err(|e| {
341 Error::InvalidJson(format!(
342 "Failed to parse paginated response for {endpoint}: {e}"
343 ))
344 })?;
345
346 results.extend(page.results);
347
348 current_url = page.next.and_then(|next_url| {
349 next_url
351 .trim_start_matches(&*self.base_url)
352 .to_string()
353 .into()
354 });
355 }
356
357 Ok(results)
358 }
359
360 pub async fn get_task_status(
362 &self,
363 task_id: Option<&TaskId>,
364 task_name: Option<&str>,
365 acknowledged: Option<bool>,
366 ) -> Result<Vec<Task>> {
367 let mut query = Vec::new();
368
369 if let Some(id) = task_id {
370 query.push(("task_id", id.to_string()));
371 }
372
373 if let Some(name) = task_name {
374 query.push(("task_name", name.to_string()));
375 }
376
377 if let Some(ack) = acknowledged {
378 query.push(("acknowledged", ack.to_string()));
379 }
380
381 let resp = self
382 .request(
383 Method::GET,
384 &format!(
385 "/api/tasks/?{}",
386 serde_urlencoded::to_string(&query)
387 .map_err(|e| Error::Other(format!("Failed to serialize query: {e}")))?
388 ),
389 None,
390 )
391 .await?;
392
393 trace!("get_task_status response: {:?}", resp);
394
395 let body = resp
396 .text()
397 .await
398 .map_err(|e| Error::Other(format!("Failed to read response body: {e}")))?;
399
400 let tasks: Vec<Task> = match serde_json::from_str(&body) {
401 Ok(t) => t,
402 Err(e) => {
403 return Err(Error::InvalidJson(format!(
404 "Failed to parse response body: {e}"
405 )));
406 }
407 };
408
409 if tasks.is_empty() {
410 return Err(Error::NotFound);
411 }
412
413 Ok(tasks)
414 }
415
416 pub fn get_workflows(&self) -> impl Future<Output = Result<Vec<Workflow>>> {
417 self.fetch_all_pages("/api/workflows/")
418 }
419
420 pub async fn upload_document(&self, file_path: &Path, filename: &str) -> Result<TaskId> {
424 let file_bytes = std::fs::read(file_path)
425 .map_err(|e| Error::Other(format!("Failed to read file: {e}")))?;
426
427 let form = multipart::Form::new().part(
428 "document",
429 multipart::Part::bytes(file_bytes).file_name(filename.to_string()),
430 );
431
432 let url = format!("{}/api/documents/post_document/", self.base_url);
433
434 let resp = self
435 .client
436 .post(&url)
437 .multipart(form)
438 .send()
439 .await
440 .map_err(|e| Error::Other(format!("Failed to send request: {e}")))?;
441
442 let status = resp.status();
443 if !resp.status().is_success() {
444 return Err(Error::Response {
445 status_code: status.as_u16(),
446 body: resp.text().await.unwrap_or_default(),
447 });
448 }
449
450 let task_id: String = resp
451 .json()
452 .await
453 .map_err(|e| Error::Other(format!("Failed to parse task ID: {e}")))?;
454 Ok(TaskId(task_id))
455 }
456
457 #[inline]
458 #[must_use]
459 pub fn tags(&self) -> &HashMap<TagId, Tag> {
460 &self.cached_data.tags
461 }
462
463 #[inline]
464 #[must_use]
465 pub fn storage_paths(&self) -> &HashMap<StoragePathId, StoragePath> {
466 &self.cached_data.storage_paths
467 }
468
469 #[must_use]
470 pub fn find_tag_by_name(&self, name: &str) -> Option<&Tag> {
471 self.cached_data.tags.values().find(|tag| tag.name == name)
472 }
473
474 #[inline]
475 #[must_use]
476 pub fn document_types(&self) -> &HashMap<DocumentTypeId, DocumentType> {
477 &self.cached_data.document_types
478 }
479
480 #[must_use]
481 pub fn find_document_type_by_name(&self, name: &str) -> Option<&DocumentType> {
482 self.cached_data
483 .document_types
484 .values()
485 .find(|dt| dt.name == name)
486 }
487
488 #[inline]
489 #[must_use]
490 pub fn correspondents(&self) -> &HashMap<CorrespondentId, Correspondent> {
491 &self.cached_data.correspondents
492 }
493
494 #[inline]
495 #[must_use]
496 pub fn custom_fields(&self) -> &HashMap<CustomFieldId, CustomField> {
497 &self.cached_data.custom_fields
498 }
499
500 #[must_use]
501 pub fn find_custom_field_by_name(&self, name: &str) -> Option<&CustomField> {
502 self.cached_data
503 .custom_fields
504 .values()
505 .find(|field| field.name == name)
506 }
507
508 #[inline]
509 #[must_use]
510 pub fn users(&self) -> &HashMap<UserId, User> {
511 &self.cached_data.users
512 }
513}