1use std::{collections::HashMap, path::Path, str::FromStr, sync::Arc};
2
3use reqwest::{
4 Method, StatusCode,
5 header::{ACCEPT, HeaderMap, HeaderName, InvalidHeaderValue},
6 multipart,
7};
8use serde::Deserialize;
9use tracing::{debug, trace};
10
11use crate::{
12 Error, Result,
13 correspondent::{Correspondent, CorrespondentId},
14 custom_field::{CustomField, CustomFieldId},
15 document::{Document, DocumentData, DocumentId},
16 document_type::{DocumentType, DocumentTypeId},
17 tag::{Tag, TagId},
18 task::{Task, TaskId},
19};
20
21#[derive(Copy, Clone, Debug, PartialEq, Eq)]
23pub enum RefreshData {
24 Tags,
25 CustomFields,
26 Correspondents,
27 DocumentTypes,
28}
29
30#[derive(Debug, Clone)]
32pub struct PaperlessClient {
33 client: reqwest::Client,
34 base_url: String,
35
36 correspondents: HashMap<CorrespondentId, Correspondent>,
37 document_types: HashMap<DocumentTypeId, DocumentType>,
38 tags: HashMap<TagId, Tag>,
39 custom_fields: HashMap<CustomFieldId, CustomField>,
40}
41
42#[derive(Debug, Deserialize)]
43struct PaginatedResponse<T> {
44 results: Vec<T>,
45 next: Option<String>,
46}
47
48impl PaperlessClient {
49 pub fn new(
51 base_url: &str,
52 token: &str,
53 headers: Option<&HashMap<String, String>>,
54 ) -> std::result::Result<Self, String> {
55 let mut headers_map = HeaderMap::new();
56
57 if let Some(headers) = headers {
59 for (key, value) in headers {
60 headers_map.insert(
61 HeaderName::from_str(key).map_err(|err| err.to_string())?,
62 value
63 .parse()
64 .map_err(|err: InvalidHeaderValue| err.to_string())?,
65 );
66 }
67 }
68
69 headers_map.insert(
71 HeaderName::from_str("Authorization").map_err(|err| err.to_string())?,
72 format!("Token {token}")
73 .parse()
74 .map_err(|err: InvalidHeaderValue| err.to_string())?,
75 );
76
77 Ok(Self {
78 base_url: base_url.to_string(),
79 client: reqwest::Client::builder()
80 .default_headers(headers_map)
81 .build()
82 .map_err(|err| err.to_string())?,
83 tags: HashMap::new(),
84 custom_fields: HashMap::new(),
85 correspondents: HashMap::new(),
86 document_types: HashMap::new(),
87 })
88 }
89
90 async fn load_tags(&self) -> Result<HashMap<TagId, Tag>> {
91 debug!("loading tags");
92 let tags: Vec<Tag> = self.fetch_all_pages("/api/tags/").await?;
93 Ok(tags.into_iter().map(|tag| (tag.id, tag)).collect())
94 }
95
96 async fn load_custom_fields(&self) -> Result<HashMap<CustomFieldId, CustomField>> {
97 debug!("loading custom fields");
98 let custom_fields: Vec<CustomField> = self.fetch_all_pages("/api/custom_fields/").await?;
99 Ok(custom_fields
100 .into_iter()
101 .map(|custom_field| (custom_field.id, custom_field))
102 .collect())
103 }
104
105 async fn load_correspondents(&self) -> Result<HashMap<CorrespondentId, Correspondent>> {
106 debug!("loading correspondents");
107 let correspondents: Vec<Correspondent> =
108 self.fetch_all_pages("/api/correspondents/").await?;
109 Ok(correspondents
110 .into_iter()
111 .map(|correspondent| (correspondent.id, correspondent))
112 .collect())
113 }
114
115 async fn load_document_types(&self) -> Result<HashMap<DocumentTypeId, DocumentType>> {
116 debug!("loading document types");
117 let document_types: Vec<DocumentType> =
118 self.fetch_all_pages("/api/document_types/").await?;
119 Ok(document_types
120 .into_iter()
121 .map(|document_type| (document_type.id, document_type))
122 .collect())
123 }
124
125 pub async fn refresh(&mut self, data: impl IntoIterator<Item = RefreshData>) -> Result<()> {
127 let mut refresh_tags = false;
128 let mut refresh_custom_fields = false;
129 let mut refresh_correspondents = false;
130 let mut refresh_document_types = false;
131
132 for item in data {
133 match item {
134 RefreshData::Tags => refresh_tags = true,
135 RefreshData::CustomFields => refresh_custom_fields = true,
136 RefreshData::Correspondents => refresh_correspondents = true,
137 RefreshData::DocumentTypes => refresh_document_types = true,
138 }
139 }
140
141 let (tags, custom_fields, correspondents, document_types) = futures_util::try_join!(
142 async {
143 if refresh_tags {
144 Ok::<Option<HashMap<TagId, Tag>>, Error>(Some(self.load_tags().await?))
145 } else {
146 Ok::<Option<HashMap<TagId, Tag>>, Error>(None)
147 }
148 },
149 async {
150 if refresh_custom_fields {
151 Ok::<Option<HashMap<CustomFieldId, CustomField>>, Error>(Some(
152 self.load_custom_fields().await?,
153 ))
154 } else {
155 Ok::<Option<HashMap<CustomFieldId, CustomField>>, Error>(None)
156 }
157 },
158 async {
159 if refresh_correspondents {
160 Ok::<Option<HashMap<CorrespondentId, Correspondent>>, Error>(Some(
161 self.load_correspondents().await?,
162 ))
163 } else {
164 Ok::<Option<HashMap<CorrespondentId, Correspondent>>, Error>(None)
165 }
166 },
167 async {
168 if refresh_document_types {
169 Ok::<Option<HashMap<DocumentTypeId, DocumentType>>, Error>(Some(
170 self.load_document_types().await?,
171 ))
172 } else {
173 Ok::<Option<HashMap<DocumentTypeId, DocumentType>>, Error>(None)
174 }
175 },
176 )?;
177
178 if let Some(tags) = tags {
179 self.tags = tags;
180 }
181
182 if let Some(custom_fields) = custom_fields {
183 self.custom_fields = custom_fields;
184 }
185
186 if let Some(correspondents) = correspondents {
187 self.correspondents = correspondents;
188 }
189
190 if let Some(document_types) = document_types {
191 self.document_types = document_types;
192 }
193
194 Ok(())
195 }
196
197 #[inline]
199 pub async fn refresh_tags(&mut self) -> Result<()> {
200 self.refresh([RefreshData::Tags]).await
201 }
202
203 #[inline]
205 pub async fn refresh_custom_fields(&mut self) -> Result<()> {
206 self.refresh([RefreshData::CustomFields]).await
207 }
208
209 #[inline]
211 pub async fn refresh_correspondents(&mut self) -> Result<()> {
212 self.refresh([RefreshData::Correspondents]).await
213 }
214
215 #[inline]
217 pub async fn refresh_document_types(&mut self) -> Result<()> {
218 self.refresh([RefreshData::DocumentTypes]).await
219 }
220
221 pub async fn get_documents_by_tags(
223 &self,
224 tag_ids: &[TagId],
225 truncate_content: bool,
226 ) -> Result<Vec<Document>> {
227 let tag_id_str = tag_ids
228 .iter()
229 .map(|tag_id| tag_id.0.to_string())
230 .collect::<Vec<_>>()
231 .join(",");
232 let documents: Vec<_> = self
233 .fetch_all_pages::<DocumentData>(&format!(
234 "/api/documents/?truncate_content={truncate_content}&tags__id__in={tag_id_str}"
235 ))
236 .await?
237 .into_iter()
238 .map(|data| Document::new(data, Arc::new(self.clone()), truncate_content))
239 .collect();
240
241 Ok(documents)
242 }
243
244 pub(crate) async fn get_document_data_by_id(&self, id: DocumentId) -> Result<DocumentData> {
245 let resp = self
246 .request(Method::GET, &format!("/api/documents/{}/", id.0), None)
247 .await?;
248
249 let document_data: DocumentData = resp
250 .json()
251 .await
252 .map_err(|e| Error::Other(format!("Failed to parse document: {e}")))?;
253
254 Ok(document_data)
255 }
256
257 pub async fn get_document_by_id(&self, id: DocumentId) -> Result<Document> {
259 Ok(Document::new(
260 self.get_document_data_by_id(id).await?,
261 Arc::new(self.clone()),
262 false,
263 ))
264 }
265
266 pub(crate) async fn request(
267 &self,
268 method: Method,
269 endpoint: &str,
270 body: Option<&serde_json::Value>,
271 ) -> Result<reqwest::Response> {
272 let mut req = self
273 .client
274 .request(method, format!("{}{endpoint}", self.base_url))
275 .header(ACCEPT, "application/json");
276
277 if let Some(json_body) = body {
278 req = req.json(json_body);
279 }
280
281 let resp = req
282 .send()
283 .await
284 .map_err(|e| Error::Other(format!("Failed to send request: {e}")))?;
285
286 if resp.status() == StatusCode::NOT_FOUND {
287 return Err(Error::NotFound);
288 }
289
290 if !resp.status().is_success() {
291 return Err(Error::Response {
292 status_code: resp.status().as_u16(),
293 body: resp.text().await.unwrap_or_default(),
294 });
295 }
296
297 Ok(resp)
298 }
299
300 pub(crate) async fn fetch_all_pages<T: for<'de> Deserialize<'de>>(
301 &self,
302 endpoint: &str,
303 ) -> Result<Vec<T>> {
304 let mut results = Vec::new();
305 let mut current_url = Some(endpoint.to_string());
306
307 while let Some(url) = current_url {
308 let resp = self.request(Method::GET, &url, None).await?;
309
310 let page: PaginatedResponse<T> = resp.json().await.map_err(|e| {
311 Error::Other(format!(
312 "Failed to parse paginated response for {endpoint}: {e}"
313 ))
314 })?;
315
316 results.extend(page.results);
317
318 current_url = page.next.and_then(|next_url| {
319 next_url
321 .trim_start_matches(&self.base_url)
322 .to_string()
323 .into()
324 });
325 }
326
327 Ok(results)
328 }
329
330 pub async fn get_task_status(
332 &self,
333 task_id: Option<&TaskId>,
334 task_name: Option<&str>,
335 acknowledged: Option<bool>,
336 ) -> Result<Vec<Task>> {
337 let mut query = Vec::new();
338
339 if let Some(id) = task_id {
340 query.push(("task_id", id.to_string()));
341 }
342
343 if let Some(name) = task_name {
344 query.push(("task_name", name.to_string()));
345 }
346
347 if let Some(ack) = acknowledged {
348 query.push(("acknowledged", ack.to_string()));
349 }
350
351 let resp = self
352 .request(
353 Method::GET,
354 &format!(
355 "/api/tasks/?{}",
356 serde_urlencoded::to_string(&query)
357 .map_err(|e| Error::Other(format!("Failed to serialize query: {e}")))?
358 ),
359 None,
360 )
361 .await?;
362
363 trace!("get_task_status response: {:?}", resp);
364
365 let body = resp
366 .text()
367 .await
368 .map_err(|e| Error::Other(format!("Failed to read response body: {e}")))?;
369
370 let tasks: Vec<Task> = match serde_json::from_str(&body) {
371 Ok(t) => t,
372 Err(e) => {
373 return Err(Error::InvalidJson(format!(
374 "Failed to parse response body: {e}"
375 )));
376 }
377 };
378
379 if tasks.is_empty() {
380 return Err(Error::NotFound);
381 }
382
383 Ok(tasks)
384 }
385
386 pub async fn upload_document(&self, file_path: &Path, filename: &str) -> Result<TaskId> {
390 let file_bytes = std::fs::read(file_path)
391 .map_err(|e| Error::Other(format!("Failed to read file: {e}")))?;
392
393 let form = multipart::Form::new().part(
394 "document",
395 multipart::Part::bytes(file_bytes).file_name(filename.to_string()),
396 );
397
398 let url = format!("{}/api/documents/post_document/", self.base_url);
399
400 let resp = self
401 .client
402 .post(&url)
403 .multipart(form)
404 .send()
405 .await
406 .map_err(|e| Error::Other(format!("Failed to send request: {e}")))?;
407
408 let status = resp.status();
409 if !resp.status().is_success() {
410 return Err(Error::Response {
411 status_code: status.as_u16(),
412 body: resp.text().await.unwrap_or_default(),
413 });
414 }
415
416 let task_id: String = resp
417 .json()
418 .await
419 .map_err(|e| Error::Other(format!("Failed to parse task ID: {e}")))?;
420 Ok(TaskId(task_id))
421 }
422
423 #[inline]
424 #[must_use]
425 pub fn tags(&self) -> &HashMap<TagId, Tag> {
426 &self.tags
427 }
428
429 #[must_use]
430 pub fn find_tag_by_name(&self, name: &str) -> Option<&Tag> {
431 self.tags.values().find(|tag| tag.name == name)
432 }
433
434 #[inline]
435 #[must_use]
436 pub fn document_types(&self) -> &HashMap<DocumentTypeId, DocumentType> {
437 &self.document_types
438 }
439
440 #[must_use]
441 pub fn find_document_type_by_name(&self, name: &str) -> Option<&DocumentType> {
442 self.document_types.values().find(|dt| dt.name == name)
443 }
444
445 #[inline]
446 #[must_use]
447 pub fn correspondents(&self) -> &HashMap<CorrespondentId, Correspondent> {
448 &self.correspondents
449 }
450
451 #[inline]
452 #[must_use]
453 pub fn custom_fields(&self) -> &HashMap<CustomFieldId, CustomField> {
454 &self.custom_fields
455 }
456
457 #[must_use]
458 pub fn find_custom_field_by_name(&self, name: &str) -> Option<&CustomField> {
459 self.custom_fields.values().find(|field| field.name == name)
460 }
461}